diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -29,15 +29,62 @@ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
29
29
  PipelineImageInput = Union[
30
30
  PIL.Image.Image,
31
31
  np.ndarray,
32
- torch.FloatTensor,
32
+ torch.Tensor,
33
33
  List[PIL.Image.Image],
34
34
  List[np.ndarray],
35
- List[torch.FloatTensor],
35
+ List[torch.Tensor],
36
36
  ]
37
37
 
38
38
  PipelineDepthInput = PipelineImageInput
39
39
 
40
40
 
41
+ def is_valid_image(image) -> bool:
42
+ r"""
43
+ Checks if the input is a valid image.
44
+
45
+ A valid image can be:
46
+ - A `PIL.Image.Image`.
47
+ - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
48
+
49
+ Args:
50
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
51
+ The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
52
+
53
+ Returns:
54
+ `bool`:
55
+ `True` if the input is a valid image, `False` otherwise.
56
+ """
57
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
58
+
59
+
60
+ def is_valid_image_imagelist(images):
61
+ r"""
62
+ Checks if the input is a valid image or list of images.
63
+
64
+ The input can be one of the following formats:
65
+ - A 4D tensor or numpy array (batch of images).
66
+ - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
67
+ `torch.Tensor`.
68
+ - A list of valid images.
69
+
70
+ Args:
71
+ images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
72
+ The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
73
+ images.
74
+
75
+ Returns:
76
+ `bool`:
77
+ `True` if the input is valid, `False` otherwise.
78
+ """
79
+ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
80
+ return True
81
+ elif is_valid_image(images):
82
+ return True
83
+ elif isinstance(images, list):
84
+ return all(is_valid_image(image) for image in images)
85
+ return False
86
+
87
+
41
88
  class VaeImageProcessor(ConfigMixin):
42
89
  """
43
90
  Image processor for VAE.
@@ -67,6 +114,7 @@ class VaeImageProcessor(ConfigMixin):
67
114
  self,
68
115
  do_resize: bool = True,
69
116
  vae_scale_factor: int = 8,
117
+ vae_latent_channels: int = 4,
70
118
  resample: str = "lanczos",
71
119
  do_normalize: bool = True,
72
120
  do_binarize: bool = False,
@@ -80,12 +128,19 @@ class VaeImageProcessor(ConfigMixin):
80
128
  " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
81
129
  " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
82
130
  )
83
- self.config.do_convert_rgb = False
84
131
 
85
132
  @staticmethod
86
133
  def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
87
- """
134
+ r"""
88
135
  Convert a numpy image or a batch of images to a PIL image.
136
+
137
+ Args:
138
+ images (`np.ndarray`):
139
+ The image array to convert to PIL format.
140
+
141
+ Returns:
142
+ `List[PIL.Image.Image]`:
143
+ A list of PIL images.
89
144
  """
90
145
  if images.ndim == 3:
91
146
  images = images[None, ...]
@@ -100,8 +155,16 @@ class VaeImageProcessor(ConfigMixin):
100
155
 
101
156
  @staticmethod
102
157
  def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
103
- """
158
+ r"""
104
159
  Convert a PIL image or a list of PIL images to NumPy arrays.
160
+
161
+ Args:
162
+ images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
163
+ The PIL image or list of images to convert to NumPy format.
164
+
165
+ Returns:
166
+ `np.ndarray`:
167
+ A NumPy array representation of the images.
105
168
  """
106
169
  if not isinstance(images, list):
107
170
  images = [images]
@@ -111,9 +174,17 @@ class VaeImageProcessor(ConfigMixin):
111
174
  return images
112
175
 
113
176
  @staticmethod
114
- def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
115
- """
177
+ def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
178
+ r"""
116
179
  Convert a NumPy image to a PyTorch tensor.
180
+
181
+ Args:
182
+ images (`np.ndarray`):
183
+ The NumPy image array to convert to PyTorch format.
184
+
185
+ Returns:
186
+ `torch.Tensor`:
187
+ A PyTorch tensor representation of the images.
117
188
  """
118
189
  if images.ndim == 3:
119
190
  images = images[..., None]
@@ -122,31 +193,63 @@ class VaeImageProcessor(ConfigMixin):
122
193
  return images
123
194
 
124
195
  @staticmethod
125
- def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
126
- """
196
+ def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
197
+ r"""
127
198
  Convert a PyTorch tensor to a NumPy image.
199
+
200
+ Args:
201
+ images (`torch.Tensor`):
202
+ The PyTorch tensor to convert to NumPy format.
203
+
204
+ Returns:
205
+ `np.ndarray`:
206
+ A NumPy array representation of the images.
128
207
  """
129
208
  images = images.cpu().permute(0, 2, 3, 1).float().numpy()
130
209
  return images
131
210
 
132
211
  @staticmethod
133
212
  def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
134
- """
213
+ r"""
135
214
  Normalize an image array to [-1,1].
215
+
216
+ Args:
217
+ images (`np.ndarray` or `torch.Tensor`):
218
+ The image array to normalize.
219
+
220
+ Returns:
221
+ `np.ndarray` or `torch.Tensor`:
222
+ The normalized image array.
136
223
  """
137
224
  return 2.0 * images - 1.0
138
225
 
139
226
  @staticmethod
140
227
  def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
141
- """
228
+ r"""
142
229
  Denormalize an image array to [0,1].
230
+
231
+ Args:
232
+ images (`np.ndarray` or `torch.Tensor`):
233
+ The image array to denormalize.
234
+
235
+ Returns:
236
+ `np.ndarray` or `torch.Tensor`:
237
+ The denormalized image array.
143
238
  """
144
- return (images / 2 + 0.5).clamp(0, 1)
239
+ return (images * 0.5 + 0.5).clamp(0, 1)
145
240
 
146
241
  @staticmethod
147
242
  def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
148
- """
243
+ r"""
149
244
  Converts a PIL image to RGB format.
245
+
246
+ Args:
247
+ image (`PIL.Image.Image`):
248
+ The PIL image to convert to RGB.
249
+
250
+ Returns:
251
+ `PIL.Image.Image`:
252
+ The RGB-converted PIL image.
150
253
  """
151
254
  image = image.convert("RGB")
152
255
 
@@ -154,8 +257,16 @@ class VaeImageProcessor(ConfigMixin):
154
257
 
155
258
  @staticmethod
156
259
  def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
157
- """
158
- Converts a PIL image to grayscale format.
260
+ r"""
261
+ Converts a given PIL image to grayscale.
262
+
263
+ Args:
264
+ image (`PIL.Image.Image`):
265
+ The input image to convert.
266
+
267
+ Returns:
268
+ `PIL.Image.Image`:
269
+ The image converted to grayscale.
159
270
  """
160
271
  image = image.convert("L")
161
272
 
@@ -163,8 +274,16 @@ class VaeImageProcessor(ConfigMixin):
163
274
 
164
275
  @staticmethod
165
276
  def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
166
- """
277
+ r"""
167
278
  Applies Gaussian blur to an image.
279
+
280
+ Args:
281
+ image (`PIL.Image.Image`):
282
+ The PIL image to convert to grayscale.
283
+
284
+ Returns:
285
+ `PIL.Image.Image`:
286
+ The grayscale-converted PIL image.
168
287
  """
169
288
  image = image.filter(ImageFilter.GaussianBlur(blur_factor))
170
289
 
@@ -172,9 +291,10 @@ class VaeImageProcessor(ConfigMixin):
172
291
 
173
292
  @staticmethod
174
293
  def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
175
- """
176
- Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
177
- for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
294
+ r"""
295
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
296
+ ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
297
+ processing are 512x512, the region will be expanded to 128x128.
178
298
 
179
299
  Args:
180
300
  mask_image (PIL.Image.Image): Mask image.
@@ -183,7 +303,8 @@ class VaeImageProcessor(ConfigMixin):
183
303
  pad (int, optional): Padding to be added to the crop region. Defaults to 0.
184
304
 
185
305
  Returns:
186
- tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
306
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
307
+ matches the original aspect ratio.
187
308
  """
188
309
 
189
310
  mask_image = mask_image.convert("L")
@@ -264,13 +385,21 @@ class VaeImageProcessor(ConfigMixin):
264
385
  width: int,
265
386
  height: int,
266
387
  ) -> PIL.Image.Image:
267
- """
268
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
388
+ r"""
389
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
390
+ the image within the dimensions, filling empty with data from image.
269
391
 
270
392
  Args:
271
- image: The image to resize.
272
- width: The width to resize the image to.
273
- height: The height to resize the image to.
393
+ image (`PIL.Image.Image`):
394
+ The image to resize and fill.
395
+ width (`int`):
396
+ The width to resize the image to.
397
+ height (`int`):
398
+ The height to resize the image to.
399
+
400
+ Returns:
401
+ `PIL.Image.Image`:
402
+ The resized and filled image.
274
403
  """
275
404
 
276
405
  ratio = width / height
@@ -308,13 +437,21 @@ class VaeImageProcessor(ConfigMixin):
308
437
  width: int,
309
438
  height: int,
310
439
  ) -> PIL.Image.Image:
311
- """
312
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
440
+ r"""
441
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
442
+ the image within the dimensions, cropping the excess.
313
443
 
314
444
  Args:
315
- image: The image to resize.
316
- width: The width to resize the image to.
317
- height: The height to resize the image to.
445
+ image (`PIL.Image.Image`):
446
+ The image to resize and crop.
447
+ width (`int`):
448
+ The width to resize the image to.
449
+ height (`int`):
450
+ The height to resize the image to.
451
+
452
+ Returns:
453
+ `PIL.Image.Image`:
454
+ The resized and cropped image.
318
455
  """
319
456
  ratio = width / height
320
457
  src_ratio = image.width / image.height
@@ -346,12 +483,12 @@ class VaeImageProcessor(ConfigMixin):
346
483
  The width to resize to.
347
484
  resize_mode (`str`, *optional*, defaults to `default`):
348
485
  The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
349
- within the specified width and height, and it may not maintaining the original aspect ratio.
350
- If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
351
- within the dimensions, filling empty with data from image.
352
- If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
353
- within the dimensions, cropping the excess.
354
- Note that resize_mode `fill` and `crop` are only supported for PIL image input.
486
+ within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
487
+ will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
488
+ then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
489
+ the image to fit within the specified width and height, maintaining the aspect ratio, and then center
490
+ the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
491
+ supported for PIL image input.
355
492
 
356
493
  Returns:
357
494
  `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
@@ -400,25 +537,49 @@ class VaeImageProcessor(ConfigMixin):
400
537
 
401
538
  return image
402
539
 
540
+ def _denormalize_conditionally(
541
+ self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
542
+ ) -> torch.Tensor:
543
+ r"""
544
+ Denormalize a batch of images based on a condition list.
545
+
546
+ Args:
547
+ images (`torch.Tensor`):
548
+ The input image tensor.
549
+ do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
550
+ A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
551
+ value of `do_normalize` in the `VaeImageProcessor` config.
552
+ """
553
+ if do_denormalize is None:
554
+ return self.denormalize(images) if self.config.do_normalize else images
555
+
556
+ return torch.stack(
557
+ [self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]
558
+ )
559
+
403
560
  def get_default_height_width(
404
561
  self,
405
562
  image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
406
563
  height: Optional[int] = None,
407
564
  width: Optional[int] = None,
408
565
  ) -> Tuple[int, int]:
409
- """
410
- This function return the height and width that are downscaled to the next integer multiple of
411
- `vae_scale_factor`.
566
+ r"""
567
+ Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
412
568
 
413
569
  Args:
414
- image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
415
- The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
416
- shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
417
- have shape `[batch, channel, height, width]`.
418
- height (`int`, *optional*, defaults to `None`):
419
- The height in preprocessed image. If `None`, will use the height of `image` input.
420
- width (`int`, *optional*`, defaults to `None`):
421
- The width in preprocessed. If `None`, will use the width of the `image` input.
570
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
571
+ The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
572
+ should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
573
+ tensor, it should have shape `[batch, channels, height, width]`.
574
+ height (`Optional[int]`, *optional*, defaults to `None`):
575
+ The height of the preprocessed image. If `None`, the height of the `image` input will be used.
576
+ width (`Optional[int]`, *optional*, defaults to `None`):
577
+ The width of the preprocessed image. If `None`, the width of the `image` input will be used.
578
+
579
+ Returns:
580
+ `Tuple[int, int]`:
581
+ A tuple containing the height and width, both resized to the nearest integer multiple of
582
+ `vae_scale_factor`.
422
583
  """
423
584
 
424
585
  if height is None:
@@ -455,22 +616,28 @@ class VaeImageProcessor(ConfigMixin):
455
616
  Preprocess the image input.
456
617
 
457
618
  Args:
458
- image (`pipeline_image_input`):
459
- The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
460
- height (`int`, *optional*, defaults to `None`):
461
- The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
462
- width (`int`, *optional*`, defaults to `None`):
463
- The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
619
+ image (`PipelineImageInput`):
620
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
621
+ supported formats.
622
+ height (`int`, *optional*):
623
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
624
+ height.
625
+ width (`int`, *optional*):
626
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
464
627
  resize_mode (`str`, *optional*, defaults to `default`):
465
- The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
466
- within the specified width and height, and it may not maintaining the original aspect ratio.
467
- If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
468
- within the dimensions, filling empty with data from image.
469
- If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
470
- within the dimensions, cropping the excess.
471
- Note that resize_mode `fill` and `crop` are only supported for PIL image input.
628
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
629
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
630
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
631
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
632
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
633
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
634
+ supported for PIL image input.
472
635
  crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
473
636
  The crop coordinates for each image in the batch. If `None`, will not crop the image.
637
+
638
+ Returns:
639
+ `torch.Tensor`:
640
+ The preprocessed image.
474
641
  """
475
642
  supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
476
643
 
@@ -492,12 +659,27 @@ class VaeImageProcessor(ConfigMixin):
492
659
  else:
493
660
  image = np.expand_dims(image, axis=-1)
494
661
 
495
- if isinstance(image, supported_formats):
496
- image = [image]
497
- elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
662
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
663
+ warnings.warn(
664
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
665
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
666
+ FutureWarning,
667
+ )
668
+ image = np.concatenate(image, axis=0)
669
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
670
+ warnings.warn(
671
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
672
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
673
+ FutureWarning,
674
+ )
675
+ image = torch.cat(image, axis=0)
676
+
677
+ if not is_valid_image_imagelist(image):
498
678
  raise ValueError(
499
- f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
679
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
500
680
  )
681
+ if not isinstance(image, list):
682
+ image = [image]
501
683
 
502
684
  if isinstance(image[0], PIL.Image.Image):
503
685
  if crops_coords is not None:
@@ -529,7 +711,7 @@ class VaeImageProcessor(ConfigMixin):
529
711
 
530
712
  channel = image.shape[1]
531
713
  # don't need any preprocess if the image is latents
532
- if channel == 4:
714
+ if channel == self.config.vae_latent_channels:
533
715
  return image
534
716
 
535
717
  height, width = self.get_default_height_width(image, height, width)
@@ -545,7 +727,6 @@ class VaeImageProcessor(ConfigMixin):
545
727
  FutureWarning,
546
728
  )
547
729
  do_normalize = False
548
-
549
730
  if do_normalize:
550
731
  image = self.normalize(image)
551
732
 
@@ -556,15 +737,15 @@ class VaeImageProcessor(ConfigMixin):
556
737
 
557
738
  def postprocess(
558
739
  self,
559
- image: torch.FloatTensor,
740
+ image: torch.Tensor,
560
741
  output_type: str = "pil",
561
742
  do_denormalize: Optional[List[bool]] = None,
562
- ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
743
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
563
744
  """
564
745
  Postprocess the image output from tensor to `output_type`.
565
746
 
566
747
  Args:
567
- image (`torch.FloatTensor`):
748
+ image (`torch.Tensor`):
568
749
  The image input, should be a pytorch tensor with shape `B x C x H x W`.
569
750
  output_type (`str`, *optional*, defaults to `pil`):
570
751
  The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
@@ -573,7 +754,7 @@ class VaeImageProcessor(ConfigMixin):
573
754
  `VaeImageProcessor` config.
574
755
 
575
756
  Returns:
576
- `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
757
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
577
758
  The postprocessed image.
578
759
  """
579
760
  if not isinstance(image, torch.Tensor):
@@ -591,12 +772,7 @@ class VaeImageProcessor(ConfigMixin):
591
772
  if output_type == "latent":
592
773
  return image
593
774
 
594
- if do_denormalize is None:
595
- do_denormalize = [self.config.do_normalize] * image.shape[0]
596
-
597
- image = torch.stack(
598
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
599
- )
775
+ image = self._denormalize_conditionally(image, do_denormalize)
600
776
 
601
777
  if output_type == "pt":
602
778
  return image
@@ -616,17 +792,29 @@ class VaeImageProcessor(ConfigMixin):
616
792
  image: PIL.Image.Image,
617
793
  crop_coords: Optional[Tuple[int, int, int, int]] = None,
618
794
  ) -> PIL.Image.Image:
619
- """
620
- overlay the inpaint output to the original image
621
- """
795
+ r"""
796
+ Applies an overlay of the mask and the inpainted image on the original image.
622
797
 
623
- width, height = image.width, image.height
798
+ Args:
799
+ mask (`PIL.Image.Image`):
800
+ The mask image that highlights regions to overlay.
801
+ init_image (`PIL.Image.Image`):
802
+ The original image to which the overlay is applied.
803
+ image (`PIL.Image.Image`):
804
+ The image to overlay onto the original.
805
+ crop_coords (`Tuple[int, int, int, int]`, *optional*):
806
+ Coordinates to crop the image. If provided, the image will be cropped accordingly.
807
+
808
+ Returns:
809
+ `PIL.Image.Image`:
810
+ The final image with the overlay applied.
811
+ """
624
812
 
625
- init_image = self.resize(init_image, width=width, height=height)
626
- mask = self.resize(mask, width=width, height=height)
813
+ width, height = init_image.width, init_image.height
627
814
 
628
815
  init_image_masked = PIL.Image.new("RGBa", (width, height))
629
816
  init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
817
+
630
818
  init_image_masked = init_image_masked.convert("RGBA")
631
819
 
632
820
  if crop_coords is not None:
@@ -674,8 +862,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
674
862
 
675
863
  @staticmethod
676
864
  def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
677
- """
678
- Convert a NumPy image or a batch of images to a PIL image.
865
+ r"""
866
+ Convert a NumPy image or a batch of images to a list of PIL images.
867
+
868
+ Args:
869
+ images (`np.ndarray`):
870
+ The input NumPy array of images, which can be a single image or a batch.
871
+
872
+ Returns:
873
+ `List[PIL.Image.Image]`:
874
+ A list of PIL images converted from the input NumPy array.
679
875
  """
680
876
  if images.ndim == 3:
681
877
  images = images[None, ...]
@@ -690,8 +886,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
690
886
 
691
887
  @staticmethod
692
888
  def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
693
- """
889
+ r"""
694
890
  Convert a PIL image or a list of PIL images to NumPy arrays.
891
+
892
+ Args:
893
+ images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
894
+ The input image or list of images to be converted.
895
+
896
+ Returns:
897
+ `np.ndarray`:
898
+ A NumPy array of the converted images.
695
899
  """
696
900
  if not isinstance(images, list):
697
901
  images = [images]
@@ -702,18 +906,30 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
702
906
 
703
907
  @staticmethod
704
908
  def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
705
- """
706
- Args:
707
- image: RGB-like depth image
909
+ r"""
910
+ Convert an RGB-like depth image to a depth map.
708
911
 
709
- Returns: depth map
912
+ Args:
913
+ image (`Union[np.ndarray, torch.Tensor]`):
914
+ The RGB-like depth image to convert.
710
915
 
916
+ Returns:
917
+ `Union[np.ndarray, torch.Tensor]`:
918
+ The corresponding depth map.
711
919
  """
712
920
  return image[:, :, 1] * 2**8 + image[:, :, 2]
713
921
 
714
922
  def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
715
- """
716
- Convert a NumPy depth image or a batch of images to a PIL image.
923
+ r"""
924
+ Convert a NumPy depth image or a batch of images to a list of PIL images.
925
+
926
+ Args:
927
+ images (`np.ndarray`):
928
+ The input NumPy array of depth images, which can be a single image or a batch.
929
+
930
+ Returns:
931
+ `List[PIL.Image.Image]`:
932
+ A list of PIL images converted from the input NumPy depth images.
717
933
  """
718
934
  if images.ndim == 3:
719
935
  images = images[None, ...]
@@ -733,15 +949,15 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
733
949
 
734
950
  def postprocess(
735
951
  self,
736
- image: torch.FloatTensor,
952
+ image: torch.Tensor,
737
953
  output_type: str = "pil",
738
954
  do_denormalize: Optional[List[bool]] = None,
739
- ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
955
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
740
956
  """
741
957
  Postprocess the image output from tensor to `output_type`.
742
958
 
743
959
  Args:
744
- image (`torch.FloatTensor`):
960
+ image (`torch.Tensor`):
745
961
  The image input, should be a pytorch tensor with shape `B x C x H x W`.
746
962
  output_type (`str`, *optional*, defaults to `pil`):
747
963
  The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
@@ -750,7 +966,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
750
966
  `VaeImageProcessor` config.
751
967
 
752
968
  Returns:
753
- `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
969
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
754
970
  The postprocessed image.
755
971
  """
756
972
  if not isinstance(image, torch.Tensor):
@@ -765,12 +981,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
765
981
  deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
766
982
  output_type = "np"
767
983
 
768
- if do_denormalize is None:
769
- do_denormalize = [self.config.do_normalize] * image.shape[0]
770
-
771
- image = torch.stack(
772
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
773
- )
984
+ image = self._denormalize_conditionally(image, do_denormalize)
774
985
 
775
986
  image = self.pt_to_numpy(image)
776
987
 
@@ -788,14 +999,30 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
788
999
 
789
1000
  def preprocess(
790
1001
  self,
791
- rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
792
- depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
1002
+ rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
1003
+ depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
793
1004
  height: Optional[int] = None,
794
1005
  width: Optional[int] = None,
795
1006
  target_res: Optional[int] = None,
796
1007
  ) -> torch.Tensor:
797
- """
798
- Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
1008
+ r"""
1009
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors.
1010
+
1011
+ Args:
1012
+ rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
1013
+ The RGB input image, which can be a single image or a batch.
1014
+ depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
1015
+ The depth input image, which can be a single image or a batch.
1016
+ height (`Optional[int]`, *optional*, defaults to `None`):
1017
+ The desired height of the processed image. If `None`, defaults to the height of the input image.
1018
+ width (`Optional[int]`, *optional*, defaults to `None`):
1019
+ The desired width of the processed image. If `None`, defaults to the width of the input image.
1020
+ target_res (`Optional[int]`, *optional*, defaults to `None`):
1021
+ Target resolution for resizing the images. If specified, overrides height and width.
1022
+
1023
+ Returns:
1024
+ `Tuple[torch.Tensor, torch.Tensor]`:
1025
+ A tuple containing the processed RGB and depth images as PyTorch tensors.
799
1026
  """
800
1027
  supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
801
1028
 
@@ -928,13 +1155,13 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
928
1155
  )
929
1156
 
930
1157
  @staticmethod
931
- def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
1158
+ def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
932
1159
  """
933
- Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
934
- If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
1160
+ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
1161
+ aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
935
1162
 
936
1163
  Args:
937
- mask (`torch.FloatTensor`):
1164
+ mask (`torch.Tensor`):
938
1165
  The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
939
1166
  batch_size (`int`):
940
1167
  The batch size.
@@ -944,7 +1171,7 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
944
1171
  The dimensionality of the value embeddings.
945
1172
 
946
1173
  Returns:
947
- `torch.FloatTensor`:
1174
+ `torch.Tensor`:
948
1175
  The downsampled mask tensor.
949
1176
 
950
1177
  """
@@ -988,3 +1215,100 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
988
1215
  )
989
1216
 
990
1217
  return mask_downsample
1218
+
1219
+
1220
+ class PixArtImageProcessor(VaeImageProcessor):
1221
+ """
1222
+ Image processor for PixArt image resize and crop.
1223
+
1224
+ Args:
1225
+ do_resize (`bool`, *optional*, defaults to `True`):
1226
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
1227
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
1228
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
1229
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
1230
+ resample (`str`, *optional*, defaults to `lanczos`):
1231
+ Resampling filter to use when resizing the image.
1232
+ do_normalize (`bool`, *optional*, defaults to `True`):
1233
+ Whether to normalize the image to [-1,1].
1234
+ do_binarize (`bool`, *optional*, defaults to `False`):
1235
+ Whether to binarize the image to 0/1.
1236
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
1237
+ Whether to convert the images to RGB format.
1238
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
1239
+ Whether to convert the images to grayscale format.
1240
+ """
1241
+
1242
+ @register_to_config
1243
+ def __init__(
1244
+ self,
1245
+ do_resize: bool = True,
1246
+ vae_scale_factor: int = 8,
1247
+ resample: str = "lanczos",
1248
+ do_normalize: bool = True,
1249
+ do_binarize: bool = False,
1250
+ do_convert_grayscale: bool = False,
1251
+ ):
1252
+ super().__init__(
1253
+ do_resize=do_resize,
1254
+ vae_scale_factor=vae_scale_factor,
1255
+ resample=resample,
1256
+ do_normalize=do_normalize,
1257
+ do_binarize=do_binarize,
1258
+ do_convert_grayscale=do_convert_grayscale,
1259
+ )
1260
+
1261
+ @staticmethod
1262
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
1263
+ r"""
1264
+ Returns the binned height and width based on the aspect ratio.
1265
+
1266
+ Args:
1267
+ height (`int`): The height of the image.
1268
+ width (`int`): The width of the image.
1269
+ ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
1270
+
1271
+ Returns:
1272
+ `Tuple[int, int]`: The closest binned height and width.
1273
+ """
1274
+ ar = float(height / width)
1275
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
1276
+ default_hw = ratios[closest_ratio]
1277
+ return int(default_hw[0]), int(default_hw[1])
1278
+
1279
+ @staticmethod
1280
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
1281
+ r"""
1282
+ Resizes and crops a tensor of images to the specified dimensions.
1283
+
1284
+ Args:
1285
+ samples (`torch.Tensor`):
1286
+ A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height,
1287
+ and W is the width.
1288
+ new_width (`int`): The desired width of the output images.
1289
+ new_height (`int`): The desired height of the output images.
1290
+
1291
+ Returns:
1292
+ `torch.Tensor`: A tensor containing the resized and cropped images.
1293
+ """
1294
+ orig_height, orig_width = samples.shape[2], samples.shape[3]
1295
+
1296
+ # Check if resizing is needed
1297
+ if orig_height != new_height or orig_width != new_width:
1298
+ ratio = max(new_height / orig_height, new_width / orig_width)
1299
+ resized_width = int(orig_width * ratio)
1300
+ resized_height = int(orig_height * ratio)
1301
+
1302
+ # Resize
1303
+ samples = F.interpolate(
1304
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
1305
+ )
1306
+
1307
+ # Center Crop
1308
+ start_x = (resized_width - new_width) // 2
1309
+ end_x = start_x + new_width
1310
+ start_y = (resized_height - new_height) // 2
1311
+ end_y = start_y + new_height
1312
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
1313
+
1314
+ return samples