diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -29,15 +29,34 @@ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
29
29
  PipelineImageInput = Union[
30
30
  PIL.Image.Image,
31
31
  np.ndarray,
32
- torch.FloatTensor,
32
+ torch.Tensor,
33
33
  List[PIL.Image.Image],
34
34
  List[np.ndarray],
35
- List[torch.FloatTensor],
35
+ List[torch.Tensor],
36
36
  ]
37
37
 
38
38
  PipelineDepthInput = PipelineImageInput
39
39
 
40
40
 
41
+ def is_valid_image(image):
42
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
43
+
44
+
45
+ def is_valid_image_imagelist(images):
46
+ # check if the image input is one of the supported formats for image and image list:
47
+ # it can be either one of below 3
48
+ # (1) a 4d pytorch tensor or numpy array,
49
+ # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
50
+ # (3) a list of valid image
51
+ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
52
+ return True
53
+ elif is_valid_image(images):
54
+ return True
55
+ elif isinstance(images, list):
56
+ return all(is_valid_image(image) for image in images)
57
+ return False
58
+
59
+
41
60
  class VaeImageProcessor(ConfigMixin):
42
61
  """
43
62
  Image processor for VAE.
@@ -80,7 +99,6 @@ class VaeImageProcessor(ConfigMixin):
80
99
  " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
81
100
  " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
82
101
  )
83
- self.config.do_convert_rgb = False
84
102
 
85
103
  @staticmethod
86
104
  def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
@@ -111,7 +129,7 @@ class VaeImageProcessor(ConfigMixin):
111
129
  return images
112
130
 
113
131
  @staticmethod
114
- def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
132
+ def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
115
133
  """
116
134
  Convert a NumPy image to a PyTorch tensor.
117
135
  """
@@ -122,7 +140,7 @@ class VaeImageProcessor(ConfigMixin):
122
140
  return images
123
141
 
124
142
  @staticmethod
125
- def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
143
+ def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
126
144
  """
127
145
  Convert a PyTorch tensor to a NumPy image.
128
146
  """
@@ -173,8 +191,9 @@ class VaeImageProcessor(ConfigMixin):
173
191
  @staticmethod
174
192
  def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
175
193
  """
176
- Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
177
- for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
194
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
195
+ ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
196
+ processing are 512x512, the region will be expanded to 128x128.
178
197
 
179
198
  Args:
180
199
  mask_image (PIL.Image.Image): Mask image.
@@ -183,7 +202,8 @@ class VaeImageProcessor(ConfigMixin):
183
202
  pad (int, optional): Padding to be added to the crop region. Defaults to 0.
184
203
 
185
204
  Returns:
186
- tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
205
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
206
+ matches the original aspect ratio.
187
207
  """
188
208
 
189
209
  mask_image = mask_image.convert("L")
@@ -265,7 +285,8 @@ class VaeImageProcessor(ConfigMixin):
265
285
  height: int,
266
286
  ) -> PIL.Image.Image:
267
287
  """
268
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
288
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
289
+ the image within the dimensions, filling empty with data from image.
269
290
 
270
291
  Args:
271
292
  image: The image to resize.
@@ -309,7 +330,8 @@ class VaeImageProcessor(ConfigMixin):
309
330
  height: int,
310
331
  ) -> PIL.Image.Image:
311
332
  """
312
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
333
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
334
+ the image within the dimensions, cropping the excess.
313
335
 
314
336
  Args:
315
337
  image: The image to resize.
@@ -346,12 +368,12 @@ class VaeImageProcessor(ConfigMixin):
346
368
  The width to resize to.
347
369
  resize_mode (`str`, *optional*, defaults to `default`):
348
370
  The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
349
- within the specified width and height, and it may not maintaining the original aspect ratio.
350
- If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
351
- within the dimensions, filling empty with data from image.
352
- If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
353
- within the dimensions, cropping the excess.
354
- Note that resize_mode `fill` and `crop` are only supported for PIL image input.
371
+ within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
372
+ will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
373
+ then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
374
+ the image to fit within the specified width and height, maintaining the aspect ratio, and then center
375
+ the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
376
+ supported for PIL image input.
355
377
 
356
378
  Returns:
357
379
  `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
@@ -456,19 +478,21 @@ class VaeImageProcessor(ConfigMixin):
456
478
 
457
479
  Args:
458
480
  image (`pipeline_image_input`):
459
- The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
481
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
482
+ supported formats.
460
483
  height (`int`, *optional*, defaults to `None`):
461
- The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
484
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
485
+ height.
462
486
  width (`int`, *optional*`, defaults to `None`):
463
- The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
487
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
464
488
  resize_mode (`str`, *optional*, defaults to `default`):
465
- The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
466
- within the specified width and height, and it may not maintaining the original aspect ratio.
467
- If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
468
- within the dimensions, filling empty with data from image.
469
- If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
470
- within the dimensions, cropping the excess.
471
- Note that resize_mode `fill` and `crop` are only supported for PIL image input.
489
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
490
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
491
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
492
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
493
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
494
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
495
+ supported for PIL image input.
472
496
  crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
473
497
  The crop coordinates for each image in the batch. If `None`, will not crop the image.
474
498
  """
@@ -492,12 +516,27 @@ class VaeImageProcessor(ConfigMixin):
492
516
  else:
493
517
  image = np.expand_dims(image, axis=-1)
494
518
 
495
- if isinstance(image, supported_formats):
496
- image = [image]
497
- elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
519
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
520
+ warnings.warn(
521
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
522
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
523
+ FutureWarning,
524
+ )
525
+ image = np.concatenate(image, axis=0)
526
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
527
+ warnings.warn(
528
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
529
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
530
+ FutureWarning,
531
+ )
532
+ image = torch.cat(image, axis=0)
533
+
534
+ if not is_valid_image_imagelist(image):
498
535
  raise ValueError(
499
- f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
536
+ f"Input is in incorrect format. Currently, we only support {', '.join(supported_formats)}"
500
537
  )
538
+ if not isinstance(image, list):
539
+ image = [image]
501
540
 
502
541
  if isinstance(image[0], PIL.Image.Image):
503
542
  if crops_coords is not None:
@@ -556,15 +595,15 @@ class VaeImageProcessor(ConfigMixin):
556
595
 
557
596
  def postprocess(
558
597
  self,
559
- image: torch.FloatTensor,
598
+ image: torch.Tensor,
560
599
  output_type: str = "pil",
561
600
  do_denormalize: Optional[List[bool]] = None,
562
- ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
601
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
563
602
  """
564
603
  Postprocess the image output from tensor to `output_type`.
565
604
 
566
605
  Args:
567
- image (`torch.FloatTensor`):
606
+ image (`torch.Tensor`):
568
607
  The image input, should be a pytorch tensor with shape `B x C x H x W`.
569
608
  output_type (`str`, *optional*, defaults to `pil`):
570
609
  The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
@@ -573,7 +612,7 @@ class VaeImageProcessor(ConfigMixin):
573
612
  `VaeImageProcessor` config.
574
613
 
575
614
  Returns:
576
- `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
615
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
577
616
  The postprocessed image.
578
617
  """
579
618
  if not isinstance(image, torch.Tensor):
@@ -733,15 +772,15 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
733
772
 
734
773
  def postprocess(
735
774
  self,
736
- image: torch.FloatTensor,
775
+ image: torch.Tensor,
737
776
  output_type: str = "pil",
738
777
  do_denormalize: Optional[List[bool]] = None,
739
- ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
778
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
740
779
  """
741
780
  Postprocess the image output from tensor to `output_type`.
742
781
 
743
782
  Args:
744
- image (`torch.FloatTensor`):
783
+ image (`torch.Tensor`):
745
784
  The image input, should be a pytorch tensor with shape `B x C x H x W`.
746
785
  output_type (`str`, *optional*, defaults to `pil`):
747
786
  The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
@@ -750,7 +789,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
750
789
  `VaeImageProcessor` config.
751
790
 
752
791
  Returns:
753
- `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
792
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
754
793
  The postprocessed image.
755
794
  """
756
795
  if not isinstance(image, torch.Tensor):
@@ -788,8 +827,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
788
827
 
789
828
  def preprocess(
790
829
  self,
791
- rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
792
- depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
830
+ rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
831
+ depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
793
832
  height: Optional[int] = None,
794
833
  width: Optional[int] = None,
795
834
  target_res: Optional[int] = None,
@@ -928,13 +967,13 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
928
967
  )
929
968
 
930
969
  @staticmethod
931
- def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
970
+ def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
932
971
  """
933
- Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
934
- If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
972
+ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
973
+ aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
935
974
 
936
975
  Args:
937
- mask (`torch.FloatTensor`):
976
+ mask (`torch.Tensor`):
938
977
  The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
939
978
  batch_size (`int`):
940
979
  The batch size.
@@ -944,7 +983,7 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
944
983
  The dimensionality of the value embeddings.
945
984
 
946
985
  Returns:
947
- `torch.FloatTensor`:
986
+ `torch.Tensor`:
948
987
  The downsampled mask tensor.
949
988
 
950
989
  """
@@ -988,3 +1027,77 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
988
1027
  )
989
1028
 
990
1029
  return mask_downsample
1030
+
1031
+
1032
+ class PixArtImageProcessor(VaeImageProcessor):
1033
+ """
1034
+ Image processor for PixArt image resize and crop.
1035
+
1036
+ Args:
1037
+ do_resize (`bool`, *optional*, defaults to `True`):
1038
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
1039
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
1040
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
1041
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
1042
+ resample (`str`, *optional*, defaults to `lanczos`):
1043
+ Resampling filter to use when resizing the image.
1044
+ do_normalize (`bool`, *optional*, defaults to `True`):
1045
+ Whether to normalize the image to [-1,1].
1046
+ do_binarize (`bool`, *optional*, defaults to `False`):
1047
+ Whether to binarize the image to 0/1.
1048
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
1049
+ Whether to convert the images to RGB format.
1050
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
1051
+ Whether to convert the images to grayscale format.
1052
+ """
1053
+
1054
+ @register_to_config
1055
+ def __init__(
1056
+ self,
1057
+ do_resize: bool = True,
1058
+ vae_scale_factor: int = 8,
1059
+ resample: str = "lanczos",
1060
+ do_normalize: bool = True,
1061
+ do_binarize: bool = False,
1062
+ do_convert_grayscale: bool = False,
1063
+ ):
1064
+ super().__init__(
1065
+ do_resize=do_resize,
1066
+ vae_scale_factor=vae_scale_factor,
1067
+ resample=resample,
1068
+ do_normalize=do_normalize,
1069
+ do_binarize=do_binarize,
1070
+ do_convert_grayscale=do_convert_grayscale,
1071
+ )
1072
+
1073
+ @staticmethod
1074
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
1075
+ """Returns binned height and width."""
1076
+ ar = float(height / width)
1077
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
1078
+ default_hw = ratios[closest_ratio]
1079
+ return int(default_hw[0]), int(default_hw[1])
1080
+
1081
+ @staticmethod
1082
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
1083
+ orig_height, orig_width = samples.shape[2], samples.shape[3]
1084
+
1085
+ # Check if resizing is needed
1086
+ if orig_height != new_height or orig_width != new_width:
1087
+ ratio = max(new_height / orig_height, new_width / orig_width)
1088
+ resized_width = int(orig_width * ratio)
1089
+ resized_height = int(orig_height * ratio)
1090
+
1091
+ # Resize
1092
+ samples = F.interpolate(
1093
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
1094
+ )
1095
+
1096
+ # Center Crop
1097
+ start_x = (resized_width - new_width) // 2
1098
+ end_x = start_x + new_width
1099
+ start_y = (resized_height - new_height) // 2
1100
+ end_y = start_y + new_height
1101
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
1102
+
1103
+ return samples
@@ -54,9 +54,7 @@ if is_transformers_available():
54
54
  _import_structure = {}
55
55
 
56
56
  if is_torch_available():
57
- _import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
58
-
59
- _import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
57
+ _import_structure["single_file_model"] = ["FromOriginalModelMixin"]
60
58
  _import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
61
59
  _import_structure["utils"] = ["AttnProcsLayers"]
62
60
  if is_transformers_available():
@@ -70,8 +68,7 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
70
68
 
71
69
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
72
70
  if is_torch_available():
73
- from .autoencoder import FromOriginalVAEMixin
74
- from .controlnet import FromOriginalControlNetMixin
71
+ from .single_file_model import FromOriginalModelMixin
75
72
  from .unet import UNet2DConditionLoadersMixin
76
73
  from .utils import AttnProcsLayers
77
74
 
@@ -50,9 +50,9 @@ class FromOriginalVAEMixin:
50
50
  cache_dir (`Union[str, os.PathLike]`, *optional*):
51
51
  Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
52
52
  is not used.
53
- resume_download (`bool`, *optional*, defaults to `False`):
54
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
55
- incompletely downloaded files are deleted.
53
+ resume_download:
54
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
55
+ of Diffusers.
56
56
  proxies (`Dict[str, str]`, *optional*):
57
57
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
58
58
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -99,7 +99,7 @@ class FromOriginalVAEMixin:
99
99
 
100
100
  original_config_file = kwargs.pop("original_config_file", None)
101
101
  config_file = kwargs.pop("config_file", None)
102
- resume_download = kwargs.pop("resume_download", False)
102
+ resume_download = kwargs.pop("resume_download", None)
103
103
  force_download = kwargs.pop("force_download", False)
104
104
  proxies = kwargs.pop("proxies", None)
105
105
  token = kwargs.pop("token", None)
@@ -50,9 +50,9 @@ class FromOriginalControlNetMixin:
50
50
  cache_dir (`Union[str, os.PathLike]`, *optional*):
51
51
  Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
52
52
  is not used.
53
- resume_download (`bool`, *optional*, defaults to `False`):
54
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
55
- incompletely downloaded files are deleted.
53
+ resume_download:
54
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
55
+ of Diffusers.
56
56
  proxies (`Dict[str, str]`, *optional*):
57
57
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
58
58
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -89,7 +89,7 @@ class FromOriginalControlNetMixin:
89
89
  """
90
90
  original_config_file = kwargs.pop("original_config_file", None)
91
91
  config_file = kwargs.pop("config_file", None)
92
- resume_download = kwargs.pop("resume_download", False)
92
+ resume_download = kwargs.pop("resume_download", None)
93
93
  force_download = kwargs.pop("force_download", False)
94
94
  proxies = kwargs.pop("proxies", None)
95
95
  token = kwargs.pop("token", None)
@@ -16,17 +16,20 @@ from pathlib import Path
16
16
  from typing import Dict, List, Optional, Union
17
17
 
18
18
  import torch
19
+ import torch.nn.functional as F
19
20
  from huggingface_hub.utils import validate_hf_hub_args
20
21
  from safetensors import safe_open
21
22
 
22
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
23
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
23
24
  from ..utils import (
25
+ USE_PEFT_BACKEND,
24
26
  _get_model_file,
25
27
  is_accelerate_available,
26
28
  is_torch_version,
27
29
  is_transformers_available,
28
30
  logging,
29
31
  )
32
+ from .unet_loader_utils import _maybe_expand_lora_scales
30
33
 
31
34
 
32
35
  if is_transformers_available():
@@ -36,6 +39,8 @@ if is_transformers_available():
36
39
  )
37
40
 
38
41
  from ..models.attention_processor import (
42
+ AttnProcessor,
43
+ AttnProcessor2_0,
39
44
  IPAdapterAttnProcessor,
40
45
  IPAdapterAttnProcessor2_0,
41
46
  )
@@ -67,26 +72,27 @@ class IPAdapterMixin:
67
72
  - A [torch state
68
73
  dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
69
74
  subfolder (`str` or `List[str]`):
70
- The subfolder location of a model file within a larger model repository on the Hub or locally.
71
- If a list is passed, it should have the same length as `weight_name`.
75
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
76
+ list is passed, it should have the same length as `weight_name`.
72
77
  weight_name (`str` or `List[str]`):
73
78
  The name of the weight file to load. If a list is passed, it should have the same length as
74
79
  `weight_name`.
75
80
  image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
76
81
  The subfolder location of the image encoder within a larger model repository on the Hub or locally.
77
- Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
78
- you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
79
- If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
80
- for example, `image_encoder_folder="different_subfolder/image_encoder"`.
82
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
83
+ `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
84
+ `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
85
+ `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
86
+ `image_encoder_folder="different_subfolder/image_encoder"`.
81
87
  cache_dir (`Union[str, os.PathLike]`, *optional*):
82
88
  Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
83
89
  is not used.
84
90
  force_download (`bool`, *optional*, defaults to `False`):
85
91
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
86
92
  cached versions if they exist.
87
- resume_download (`bool`, *optional*, defaults to `False`):
88
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
89
- incompletely downloaded files are deleted.
93
+ resume_download:
94
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
95
+ of Diffusers.
90
96
  proxies (`Dict[str, str]`, *optional*):
91
97
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
92
98
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -129,7 +135,7 @@ class IPAdapterMixin:
129
135
  # Load the main state dict first.
130
136
  cache_dir = kwargs.pop("cache_dir", None)
131
137
  force_download = kwargs.pop("force_download", False)
132
- resume_download = kwargs.pop("resume_download", False)
138
+ resume_download = kwargs.pop("resume_download", None)
133
139
  proxies = kwargs.pop("proxies", None)
134
140
  local_files_only = kwargs.pop("local_files_only", None)
135
141
  token = kwargs.pop("token", None)
@@ -182,7 +188,7 @@ class IPAdapterMixin:
182
188
  elif key.startswith("ip_adapter."):
183
189
  state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
184
190
  else:
185
- state_dict = torch.load(model_file, map_location="cpu")
191
+ state_dict = load_state_dict(model_file)
186
192
  else:
187
193
  state_dict = pretrained_model_name_or_path_or_dict
188
194
 
@@ -227,27 +233,69 @@ class IPAdapterMixin:
227
233
  unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
228
234
  unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
229
235
 
236
+ extra_loras = unet._load_ip_adapter_loras(state_dicts)
237
+ if extra_loras != {}:
238
+ if not USE_PEFT_BACKEND:
239
+ logger.warning("PEFT backend is required to load these weights.")
240
+ else:
241
+ # apply the IP Adapter Face ID LoRA weights
242
+ peft_config = getattr(unet, "peft_config", {})
243
+ for k, lora in extra_loras.items():
244
+ if f"faceid_{k}" not in peft_config:
245
+ self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
246
+ self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
247
+
230
248
  def set_ip_adapter_scale(self, scale):
231
249
  """
232
- Sets the conditioning scale between text and image.
250
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
251
+ granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
233
252
 
234
253
  Example:
235
254
 
236
255
  ```py
237
- pipeline.set_ip_adapter_scale(0.5)
256
+ # To use original IP-Adapter
257
+ scale = 1.0
258
+ pipeline.set_ip_adapter_scale(scale)
259
+
260
+ # To use style block only
261
+ scale = {
262
+ "up": {"block_0": [0.0, 1.0, 0.0]},
263
+ }
264
+ pipeline.set_ip_adapter_scale(scale)
265
+
266
+ # To use style+layout blocks
267
+ scale = {
268
+ "down": {"block_2": [0.0, 1.0]},
269
+ "up": {"block_0": [0.0, 1.0, 0.0]},
270
+ }
271
+ pipeline.set_ip_adapter_scale(scale)
272
+
273
+ # To use style and layout from 2 reference images
274
+ scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
275
+ pipeline.set_ip_adapter_scale(scales)
238
276
  ```
239
277
  """
240
278
  unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
241
- for attn_processor in unet.attn_processors.values():
279
+ if not isinstance(scale, list):
280
+ scale = [scale]
281
+ scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
282
+
283
+ for attn_name, attn_processor in unet.attn_processors.items():
242
284
  if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
243
- if not isinstance(scale, list):
244
- scale = [scale] * len(attn_processor.scale)
245
- if len(attn_processor.scale) != len(scale):
285
+ if len(scale_configs) != len(attn_processor.scale):
246
286
  raise ValueError(
247
- f"`scale` should be a list of same length as the number if ip-adapters "
248
- f"Expected {len(attn_processor.scale)} but got {len(scale)}."
287
+ f"Cannot assign {len(scale_configs)} scale_configs to "
288
+ f"{len(attn_processor.scale)} IP-Adapter."
249
289
  )
250
- attn_processor.scale = scale
290
+ elif len(scale_configs) == 1:
291
+ scale_configs = scale_configs * len(attn_processor.scale)
292
+ for i, scale_config in enumerate(scale_configs):
293
+ if isinstance(scale_config, dict):
294
+ for k, s in scale_config.items():
295
+ if attn_name.startswith(k):
296
+ attn_processor.scale[i] = s
297
+ else:
298
+ attn_processor.scale[i] = scale_config
251
299
 
252
300
  def unload_ip_adapter(self):
253
301
  """
@@ -278,4 +326,14 @@ class IPAdapterMixin:
278
326
  self.config.encoder_hid_dim_type = None
279
327
 
280
328
  # restore original Unet attention processors layers
281
- self.unet.set_default_attn_processor()
329
+ attn_procs = {}
330
+ for name, value in self.unet.attn_processors.items():
331
+ attn_processor_class = (
332
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
333
+ )
334
+ attn_procs[name] = (
335
+ attn_processor_class
336
+ if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
337
+ else value.__class__()
338
+ )
339
+ self.unet.set_attn_processor(attn_procs)