diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -13,19 +13,21 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import Callable, Dict, List, Optional, Union
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
17
 
18
18
  import torch
19
19
  from transformers import (
20
20
  CLIPTextModelWithProjection,
21
21
  CLIPTokenizer,
22
+ SiglipImageProcessor,
23
+ SiglipVisionModel,
22
24
  T5EncoderModel,
23
25
  T5TokenizerFast,
24
26
  )
25
27
 
26
28
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
27
29
  from ...image_processor import PipelineImageInput, VaeImageProcessor
28
- from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
30
+ from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
29
31
  from ...models.autoencoders import AutoencoderKL
30
32
  from ...models.transformers import SD3Transformer2DModel
31
33
  from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -80,7 +82,7 @@ def calculate_shift(
80
82
  base_seq_len: int = 256,
81
83
  max_seq_len: int = 4096,
82
84
  base_shift: float = 0.5,
83
- max_shift: float = 1.16,
85
+ max_shift: float = 1.15,
84
86
  ):
85
87
  m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
86
88
  b = base_shift - m * base_seq_len
@@ -162,7 +164,7 @@ def retrieve_timesteps(
162
164
  return timesteps, num_inference_steps
163
165
 
164
166
 
165
- class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
167
+ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
166
168
  r"""
167
169
  Args:
168
170
  transformer ([`SD3Transformer2DModel`]):
@@ -194,10 +196,14 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
194
196
  tokenizer_3 (`T5TokenizerFast`):
195
197
  Tokenizer of class
196
198
  [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
199
+ image_encoder (`SiglipVisionModel`, *optional*):
200
+ Pre-trained Vision Model for IP Adapter.
201
+ feature_extractor (`SiglipImageProcessor`, *optional*):
202
+ Image processor for IP Adapter.
197
203
  """
198
204
 
199
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
200
- _optional_components = []
205
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
206
+ _optional_components = ["image_encoder", "feature_extractor"]
201
207
  _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
202
208
 
203
209
  def __init__(
@@ -211,6 +217,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
211
217
  tokenizer_2: CLIPTokenizer,
212
218
  text_encoder_3: T5EncoderModel,
213
219
  tokenizer_3: T5TokenizerFast,
220
+ image_encoder: Optional[SiglipVisionModel] = None,
221
+ feature_extractor: Optional[SiglipImageProcessor] = None,
214
222
  ):
215
223
  super().__init__()
216
224
 
@@ -224,20 +232,29 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
224
232
  tokenizer_3=tokenizer_3,
225
233
  transformer=transformer,
226
234
  scheduler=scheduler,
235
+ image_encoder=image_encoder,
236
+ feature_extractor=feature_extractor,
227
237
  )
228
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
238
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
239
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
229
240
  self.image_processor = VaeImageProcessor(
230
- vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
241
+ vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
231
242
  )
232
243
  self.mask_processor = VaeImageProcessor(
233
244
  vae_scale_factor=self.vae_scale_factor,
234
- vae_latent_channels=self.vae.config.latent_channels,
245
+ vae_latent_channels=latent_channels,
235
246
  do_normalize=False,
236
247
  do_binarize=True,
237
248
  do_convert_grayscale=True,
238
249
  )
239
- self.tokenizer_max_length = self.tokenizer.model_max_length
240
- self.default_sample_size = self.transformer.config.sample_size
250
+ self.tokenizer_max_length = (
251
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
252
+ )
253
+ self.default_sample_size = (
254
+ self.transformer.config.sample_size
255
+ if hasattr(self, "transformer") and self.transformer is not None
256
+ else 128
257
+ )
241
258
  self.patch_size = (
242
259
  self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
243
260
  )
@@ -399,9 +416,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
399
416
  negative_prompt_2 (`str` or `List[str]`, *optional*):
400
417
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
401
418
  `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
402
- negative_prompt_2 (`str` or `List[str]`, *optional*):
419
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
403
420
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
404
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
421
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
405
422
  prompt_embeds (`torch.FloatTensor`, *optional*):
406
423
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
407
424
  provided, text embeddings will be generated from `prompt` input argument.
@@ -811,6 +828,10 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
811
828
  def do_classifier_free_guidance(self):
812
829
  return self._guidance_scale > 1
813
830
 
831
+ @property
832
+ def joint_attention_kwargs(self):
833
+ return self._joint_attention_kwargs
834
+
814
835
  @property
815
836
  def num_timesteps(self):
816
837
  return self._num_timesteps
@@ -819,6 +840,84 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
819
840
  def interrupt(self):
820
841
  return self._interrupt
821
842
 
843
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
844
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
845
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
846
+
847
+ Args:
848
+ image (`PipelineImageInput`):
849
+ Input image to be encoded.
850
+ device: (`torch.device`):
851
+ Torch device.
852
+
853
+ Returns:
854
+ `torch.Tensor`: The encoded image feature representation.
855
+ """
856
+ if not isinstance(image, torch.Tensor):
857
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
858
+
859
+ image = image.to(device=device, dtype=self.dtype)
860
+
861
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
862
+
863
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
864
+ def prepare_ip_adapter_image_embeds(
865
+ self,
866
+ ip_adapter_image: Optional[PipelineImageInput] = None,
867
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
868
+ device: Optional[torch.device] = None,
869
+ num_images_per_prompt: int = 1,
870
+ do_classifier_free_guidance: bool = True,
871
+ ) -> torch.Tensor:
872
+ """Prepares image embeddings for use in the IP-Adapter.
873
+
874
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
875
+
876
+ Args:
877
+ ip_adapter_image (`PipelineImageInput`, *optional*):
878
+ The input image to extract features from for IP-Adapter.
879
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
880
+ Precomputed image embeddings.
881
+ device: (`torch.device`, *optional*):
882
+ Torch device.
883
+ num_images_per_prompt (`int`, defaults to 1):
884
+ Number of images that should be generated per prompt.
885
+ do_classifier_free_guidance (`bool`, defaults to True):
886
+ Whether to use classifier free guidance or not.
887
+ """
888
+ device = device or self._execution_device
889
+
890
+ if ip_adapter_image_embeds is not None:
891
+ if do_classifier_free_guidance:
892
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
893
+ else:
894
+ single_image_embeds = ip_adapter_image_embeds
895
+ elif ip_adapter_image is not None:
896
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
897
+ if do_classifier_free_guidance:
898
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
899
+ else:
900
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
901
+
902
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
903
+
904
+ if do_classifier_free_guidance:
905
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
906
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
907
+
908
+ return image_embeds.to(device=device)
909
+
910
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
911
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
912
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
913
+ logger.warning(
914
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
915
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
916
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
917
+ )
918
+
919
+ super().enable_sequential_cpu_offload(*args, **kwargs)
920
+
822
921
  @torch.no_grad()
823
922
  @replace_example_docstring(EXAMPLE_DOC_STRING)
824
923
  def __call__(
@@ -846,8 +945,11 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
846
945
  negative_prompt_embeds: Optional[torch.Tensor] = None,
847
946
  pooled_prompt_embeds: Optional[torch.Tensor] = None,
848
947
  negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
948
+ ip_adapter_image: Optional[PipelineImageInput] = None,
949
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
849
950
  output_type: Optional[str] = "pil",
850
951
  return_dict: bool = True,
952
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
851
953
  clip_skip: Optional[int] = None,
852
954
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
853
955
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -883,9 +985,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
883
985
  mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
884
986
  `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
885
987
  latents tensor will ge generated by `mask_image`.
886
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
988
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
887
989
  The height in pixels of the generated image. This is set to 1024 by default for the best results.
888
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
990
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
889
991
  The width in pixels of the generated image. This is set to 1024 by default for the best results.
890
992
  padding_mask_crop (`int`, *optional*, defaults to `None`):
891
993
  The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
@@ -946,12 +1048,22 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
946
1048
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
947
1049
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
948
1050
  input argument.
1051
+ ip_adapter_image (`PipelineImageInput`, *optional*):
1052
+ Optional image input to work with IP Adapters.
1053
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
1054
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
1055
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
1056
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
949
1057
  output_type (`str`, *optional*, defaults to `"pil"`):
950
1058
  The output format of the generate image. Choose between
951
1059
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
952
1060
  return_dict (`bool`, *optional*, defaults to `True`):
953
1061
  Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
954
1062
  a plain tuple.
1063
+ joint_attention_kwargs (`dict`, *optional*):
1064
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1065
+ `self.processor` in
1066
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
955
1067
  callback_on_step_end (`Callable`, *optional*):
956
1068
  A function that calls at the end of each denoising steps during the inference. The function is called
957
1069
  with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -999,6 +1111,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
999
1111
 
1000
1112
  self._guidance_scale = guidance_scale
1001
1113
  self._clip_skip = clip_skip
1114
+ self._joint_attention_kwargs = joint_attention_kwargs
1002
1115
  self._interrupt = False
1003
1116
 
1004
1117
  # 2. Define call parameters
@@ -1046,10 +1159,10 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
1046
1159
  )
1047
1160
  mu = calculate_shift(
1048
1161
  image_seq_len,
1049
- self.scheduler.config.base_image_seq_len,
1050
- self.scheduler.config.max_image_seq_len,
1051
- self.scheduler.config.base_shift,
1052
- self.scheduler.config.max_shift,
1162
+ self.scheduler.config.get("base_image_seq_len", 256),
1163
+ self.scheduler.config.get("max_image_seq_len", 4096),
1164
+ self.scheduler.config.get("base_shift", 0.5),
1165
+ self.scheduler.config.get("max_shift", 1.16),
1053
1166
  )
1054
1167
  scheduler_kwargs["mu"] = mu
1055
1168
  elif mu is not None:
@@ -1145,7 +1258,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
1145
1258
  f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
1146
1259
  f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1147
1260
  f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1148
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1261
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
1149
1262
  " `pipeline.transformer` or your `mask_image` or `image` input."
1150
1263
  )
1151
1264
  elif num_channels_transformer != 16:
@@ -1153,7 +1266,22 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
1153
1266
  f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}."
1154
1267
  )
1155
1268
 
1156
- # 7. Denoising loop
1269
+ # 7. Prepare image embeddings
1270
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1271
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1272
+ ip_adapter_image,
1273
+ ip_adapter_image_embeds,
1274
+ device,
1275
+ batch_size * num_images_per_prompt,
1276
+ self.do_classifier_free_guidance,
1277
+ )
1278
+
1279
+ if self.joint_attention_kwargs is None:
1280
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1281
+ else:
1282
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1283
+
1284
+ # 8. Denoising loop
1157
1285
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1158
1286
  self._num_timesteps = len(timesteps)
1159
1287
  with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1174,6 +1302,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
1174
1302
  timestep=timestep,
1175
1303
  encoder_hidden_states=prompt_embeds,
1176
1304
  pooled_projections=pooled_prompt_embeds,
1305
+ joint_attention_kwargs=self.joint_attention_kwargs,
1177
1306
  return_dict=False,
1178
1307
  )[0]
1179
1308
 
@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
30
30
  from ...utils import (
31
31
  USE_PEFT_BACKEND,
32
32
  deprecate,
33
+ is_torch_xla_available,
33
34
  logging,
34
35
  replace_example_docstring,
35
36
  scale_lora_layers,
@@ -41,6 +42,14 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
41
42
  from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
42
43
 
43
44
 
45
+ if is_torch_xla_available():
46
+ import torch_xla.core.xla_model as xm
47
+
48
+ XLA_AVAILABLE = True
49
+ else:
50
+ XLA_AVAILABLE = False
51
+
52
+
44
53
  logger = logging.get_logger(__name__)
45
54
 
46
55
  EXAMPLE_DOC_STRING = """
@@ -194,8 +203,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
194
203
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
195
204
  safety_checker ([`StableDiffusionSafetyChecker`]):
196
205
  Classification module that estimates whether generated images could be considered offensive or harmful.
197
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
198
- about a model's potential harms.
206
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
207
+ more details about a model's potential harms.
199
208
  feature_extractor ([`~transformers.CLIPImageProcessor`]):
200
209
  A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
201
210
  """
@@ -242,7 +251,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
242
251
  safety_checker=safety_checker,
243
252
  feature_extractor=feature_extractor,
244
253
  )
245
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
254
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
246
255
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
247
256
  self.register_to_config(requires_safety_checker=requires_safety_checker)
248
257
 
@@ -1008,6 +1017,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
1008
1017
  step_idx = i // getattr(self.scheduler, "order", 1)
1009
1018
  callback(step_idx, t, latents)
1010
1019
 
1020
+ if XLA_AVAILABLE:
1021
+ xm.mark_step()
1022
+
1011
1023
  # 8. Post-processing
1012
1024
  if not output_type == "latent":
1013
1025
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
@@ -33,6 +33,7 @@ from ...utils import (
33
33
  USE_PEFT_BACKEND,
34
34
  BaseOutput,
35
35
  deprecate,
36
+ is_torch_xla_available,
36
37
  logging,
37
38
  replace_example_docstring,
38
39
  scale_lora_layers,
@@ -44,6 +45,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
44
45
  from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
45
46
 
46
47
 
48
+ if is_torch_xla_available():
49
+ import torch_xla.core.xla_model as xm
50
+
51
+ XLA_AVAILABLE = True
52
+ else:
53
+ XLA_AVAILABLE = False
54
+
47
55
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
56
 
49
57
 
@@ -268,8 +276,8 @@ class StableDiffusionDiffEditPipeline(
268
276
  A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents.
269
277
  safety_checker ([`StableDiffusionSafetyChecker`]):
270
278
  Classification module that estimates whether generated images could be considered offensive or harmful.
271
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
272
- about a model's potential harms.
279
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
280
+ more details about a model's potential harms.
273
281
  feature_extractor ([`~transformers.CLIPImageProcessor`]):
274
282
  A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
275
283
  """
@@ -292,7 +300,7 @@ class StableDiffusionDiffEditPipeline(
292
300
  ):
293
301
  super().__init__()
294
302
 
295
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
303
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
296
304
  deprecation_message = (
297
305
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
298
306
  f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -306,7 +314,7 @@ class StableDiffusionDiffEditPipeline(
306
314
  new_config["steps_offset"] = 1
307
315
  scheduler._internal_dict = FrozenDict(new_config)
308
316
 
309
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
317
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
310
318
  deprecation_message = (
311
319
  f"The configuration file of this scheduler: {scheduler} has not set the configuration"
312
320
  " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -336,17 +344,21 @@ class StableDiffusionDiffEditPipeline(
336
344
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
337
345
  )
338
346
 
339
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
340
- version.parse(unet.config._diffusers_version).base_version
341
- ) < version.parse("0.9.0.dev0")
342
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
347
+ is_unet_version_less_0_9_0 = (
348
+ unet is not None
349
+ and hasattr(unet.config, "_diffusers_version")
350
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
351
+ )
352
+ is_unet_sample_size_less_64 = (
353
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
354
+ )
343
355
  if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
344
356
  deprecation_message = (
345
357
  "The configuration file of the unet has set the default `sample_size` to smaller than"
346
358
  " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
347
359
  " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
348
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
349
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
360
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
361
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
350
362
  " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
351
363
  " in the config might lead to incorrect results in future versions. If you have downloaded this"
352
364
  " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -367,7 +379,7 @@ class StableDiffusionDiffEditPipeline(
367
379
  feature_extractor=feature_extractor,
368
380
  inverse_scheduler=inverse_scheduler,
369
381
  )
370
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
382
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
371
383
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
372
384
  self.register_to_config(requires_safety_checker=requires_safety_checker)
373
385
 
@@ -1508,6 +1520,9 @@ class StableDiffusionDiffEditPipeline(
1508
1520
  step_idx = i // getattr(self.scheduler, "order", 1)
1509
1521
  callback(step_idx, t, latents)
1510
1522
 
1523
+ if XLA_AVAILABLE:
1524
+ xm.mark_step()
1525
+
1511
1526
  if not output_type == "latent":
1512
1527
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1513
1528
  image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
@@ -29,6 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers
29
29
  from ...utils import (
30
30
  USE_PEFT_BACKEND,
31
31
  deprecate,
32
+ is_torch_xla_available,
32
33
  logging,
33
34
  replace_example_docstring,
34
35
  scale_lora_layers,
@@ -40,8 +41,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
40
41
  from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
41
42
 
42
43
 
44
+ if is_torch_xla_available():
45
+ import torch_xla.core.xla_model as xm
46
+
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
43
51
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
52
 
53
+
45
54
  EXAMPLE_DOC_STRING = """
46
55
  Examples:
47
56
  ```py
@@ -120,8 +129,8 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
120
129
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
121
130
  safety_checker ([`StableDiffusionSafetyChecker`]):
122
131
  Classification module that estimates whether generated images could be considered offensive or harmful.
123
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
124
- about a model's potential harms.
132
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
133
+ more details about a model's potential harms.
125
134
  feature_extractor ([`~transformers.CLIPImageProcessor`]):
126
135
  A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
127
136
  """
@@ -168,7 +177,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
168
177
  safety_checker=safety_checker,
169
178
  feature_extractor=feature_extractor,
170
179
  )
171
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
180
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
172
181
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
173
182
  self.register_to_config(requires_safety_checker=requires_safety_checker)
174
183
 
@@ -828,6 +837,9 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
828
837
  step_idx = i // getattr(self.scheduler, "order", 1)
829
838
  callback(step_idx, t, latents)
830
839
 
840
+ if XLA_AVAILABLE:
841
+ xm.mark_step()
842
+
831
843
  if not output_type == "latent":
832
844
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
833
845
  image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
@@ -32,7 +32,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel
32
32
  from ...models.attention import GatedSelfAttentionDense
33
33
  from ...models.lora import adjust_lora_scale_text_encoder
34
34
  from ...schedulers import KarrasDiffusionSchedulers
35
- from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
35
+ from ...utils import (
36
+ USE_PEFT_BACKEND,
37
+ is_torch_xla_available,
38
+ logging,
39
+ replace_example_docstring,
40
+ scale_lora_layers,
41
+ unscale_lora_layers,
42
+ )
36
43
  from ...utils.torch_utils import randn_tensor
37
44
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
38
45
  from ..stable_diffusion import StableDiffusionPipelineOutput
@@ -40,8 +47,16 @@ from ..stable_diffusion.clip_image_project_model import CLIPImageProjection
40
47
  from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
41
48
 
42
49
 
50
+ if is_torch_xla_available():
51
+ import torch_xla.core.xla_model as xm
52
+
53
+ XLA_AVAILABLE = True
54
+ else:
55
+ XLA_AVAILABLE = False
56
+
43
57
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
58
 
59
+
45
60
  EXAMPLE_DOC_STRING = """
46
61
  Examples:
47
62
  ```py
@@ -172,8 +187,8 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
172
187
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
173
188
  safety_checker ([`StableDiffusionSafetyChecker`]):
174
189
  Classification module that estimates whether generated images could be considered offensive or harmful.
175
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
176
- about a model's potential harms.
190
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
191
+ more details about a model's potential harms.
177
192
  feature_extractor ([`~transformers.CLIPImageProcessor`]):
178
193
  A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
179
194
  """
@@ -226,7 +241,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
226
241
  safety_checker=safety_checker,
227
242
  feature_extractor=feature_extractor,
228
243
  )
229
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
244
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
230
245
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
231
246
  self.register_to_config(requires_safety_checker=requires_safety_checker)
232
247
 
@@ -1010,6 +1025,9 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
1010
1025
  step_idx = i // getattr(self.scheduler, "order", 1)
1011
1026
  callback(step_idx, t, latents)
1012
1027
 
1028
+ if XLA_AVAILABLE:
1029
+ xm.mark_step()
1030
+
1013
1031
  if not output_type == "latent":
1014
1032
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1015
1033
  image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
@@ -19,15 +19,31 @@ from typing import Callable, List, Optional, Union
19
19
  import torch
20
20
  from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
21
21
  from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
22
+ from transformers import (
23
+ CLIPImageProcessor,
24
+ CLIPTextModel,
25
+ CLIPTokenizer,
26
+ CLIPTokenizerFast,
27
+ )
22
28
 
23
29
  from ...image_processor import VaeImageProcessor
24
- from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
30
+ from ...loaders import (
31
+ StableDiffusionLoraLoaderMixin,
32
+ TextualInversionLoaderMixin,
33
+ )
34
+ from ...models import AutoencoderKL, UNet2DConditionModel
25
35
  from ...models.lora import adjust_lora_scale_text_encoder
26
- from ...schedulers import LMSDiscreteScheduler
27
- from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
36
+ from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
37
+ from ...utils import (
38
+ USE_PEFT_BACKEND,
39
+ deprecate,
40
+ logging,
41
+ scale_lora_layers,
42
+ unscale_lora_layers,
43
+ )
28
44
  from ...utils.torch_utils import randn_tensor
29
45
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
30
- from ..stable_diffusion import StableDiffusionPipelineOutput
46
+ from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
31
47
 
32
48
 
33
49
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -83,7 +99,8 @@ class StableDiffusionKDiffusionPipeline(
83
99
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
84
100
  safety_checker ([`StableDiffusionSafetyChecker`]):
85
101
  Classification module that estimates whether generated images could be considered offensive or harmful.
86
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
102
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
103
+ details.
87
104
  feature_extractor ([`CLIPImageProcessor`]):
88
105
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
89
106
  """
@@ -94,13 +111,13 @@ class StableDiffusionKDiffusionPipeline(
94
111
 
95
112
  def __init__(
96
113
  self,
97
- vae,
98
- text_encoder,
99
- tokenizer,
100
- unet,
101
- scheduler,
102
- safety_checker,
103
- feature_extractor,
114
+ vae: AutoencoderKL,
115
+ text_encoder: CLIPTextModel,
116
+ tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
117
+ unet: UNet2DConditionModel,
118
+ scheduler: KarrasDiffusionSchedulers,
119
+ safety_checker: StableDiffusionSafetyChecker,
120
+ feature_extractor: CLIPImageProcessor,
104
121
  requires_safety_checker: bool = True,
105
122
  ):
106
123
  super().__init__()
@@ -124,7 +141,7 @@ class StableDiffusionKDiffusionPipeline(
124
141
  feature_extractor=feature_extractor,
125
142
  )
126
143
  self.register_to_config(requires_safety_checker=requires_safety_checker)
127
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
144
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
128
145
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
129
146
 
130
147
  model = ModelWrapper(unet, scheduler.alphas_cumprod)