diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,889 @@
1
+ # Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import re
18
+ import urllib.parse as ul
19
+ import warnings
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
24
+
25
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from ...image_processor import PixArtImageProcessor
27
+ from ...loaders import SanaLoraLoaderMixin
28
+ from ...models import AutoencoderDC, SanaTransformer2DModel
29
+ from ...schedulers import DPMSolverMultistepScheduler
30
+ from ...utils import (
31
+ BACKENDS_MAPPING,
32
+ USE_PEFT_BACKEND,
33
+ is_bs4_available,
34
+ is_ftfy_available,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from ...utils.torch_utils import randn_tensor
42
+ from ..pipeline_utils import DiffusionPipeline
43
+ from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
44
+ from .pipeline_output import SanaPipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+ if is_bs4_available():
57
+ from bs4 import BeautifulSoup
58
+
59
+ if is_ftfy_available():
60
+ import ftfy
61
+
62
+
63
+ EXAMPLE_DOC_STRING = """
64
+ Examples:
65
+ ```py
66
+ >>> import torch
67
+ >>> from diffusers import SanaSprintPipeline
68
+
69
+ >>> pipe = SanaSprintPipeline.from_pretrained(
70
+ ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
71
+ ... )
72
+ >>> pipe.to("cuda")
73
+
74
+ >>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0]
75
+ >>> image[0].save("output.png")
76
+ ```
77
+ """
78
+
79
+
80
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81
+ def retrieve_timesteps(
82
+ scheduler,
83
+ num_inference_steps: Optional[int] = None,
84
+ device: Optional[Union[str, torch.device]] = None,
85
+ timesteps: Optional[List[int]] = None,
86
+ sigmas: Optional[List[float]] = None,
87
+ **kwargs,
88
+ ):
89
+ r"""
90
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
+
93
+ Args:
94
+ scheduler (`SchedulerMixin`):
95
+ The scheduler to get timesteps from.
96
+ num_inference_steps (`int`):
97
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98
+ must be `None`.
99
+ device (`str` or `torch.device`, *optional*):
100
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101
+ timesteps (`List[int]`, *optional*):
102
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103
+ `num_inference_steps` and `sigmas` must be `None`.
104
+ sigmas (`List[float]`, *optional*):
105
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106
+ `num_inference_steps` and `timesteps` must be `None`.
107
+
108
+ Returns:
109
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110
+ second element is the number of inference steps.
111
+ """
112
+ if timesteps is not None and sigmas is not None:
113
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114
+ if timesteps is not None:
115
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ elif sigmas is not None:
125
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accept_sigmas:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ else:
135
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ return timesteps, num_inference_steps
138
+
139
+
140
+ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
141
+ r"""
142
+ Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641).
143
+ """
144
+
145
+ # fmt: off
146
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
147
+ # fmt: on
148
+
149
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
150
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
151
+
152
+ def __init__(
153
+ self,
154
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
155
+ text_encoder: Gemma2PreTrainedModel,
156
+ vae: AutoencoderDC,
157
+ transformer: SanaTransformer2DModel,
158
+ scheduler: DPMSolverMultistepScheduler,
159
+ ):
160
+ super().__init__()
161
+
162
+ self.register_modules(
163
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
164
+ )
165
+
166
+ self.vae_scale_factor = (
167
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
168
+ if hasattr(self, "vae") and self.vae is not None
169
+ else 32
170
+ )
171
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
172
+
173
+ def enable_vae_slicing(self):
174
+ r"""
175
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
176
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
177
+ """
178
+ self.vae.enable_slicing()
179
+
180
+ def disable_vae_slicing(self):
181
+ r"""
182
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
183
+ computing decoding in one step.
184
+ """
185
+ self.vae.disable_slicing()
186
+
187
+ def enable_vae_tiling(self):
188
+ r"""
189
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
190
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
191
+ processing larger images.
192
+ """
193
+ self.vae.enable_tiling()
194
+
195
+ def disable_vae_tiling(self):
196
+ r"""
197
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
198
+ computing decoding in one step.
199
+ """
200
+ self.vae.disable_tiling()
201
+
202
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
203
+ def _get_gemma_prompt_embeds(
204
+ self,
205
+ prompt: Union[str, List[str]],
206
+ device: torch.device,
207
+ dtype: torch.dtype,
208
+ clean_caption: bool = False,
209
+ max_sequence_length: int = 300,
210
+ complex_human_instruction: Optional[List[str]] = None,
211
+ ):
212
+ r"""
213
+ Encodes the prompt into text encoder hidden states.
214
+
215
+ Args:
216
+ prompt (`str` or `List[str]`, *optional*):
217
+ prompt to be encoded
218
+ device: (`torch.device`, *optional*):
219
+ torch device to place the resulting embeddings on
220
+ clean_caption (`bool`, defaults to `False`):
221
+ If `True`, the function will preprocess and clean the provided caption before encoding.
222
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
223
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
224
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
225
+ the prompt.
226
+ """
227
+ prompt = [prompt] if isinstance(prompt, str) else prompt
228
+
229
+ if getattr(self, "tokenizer", None) is not None:
230
+ self.tokenizer.padding_side = "right"
231
+
232
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
233
+
234
+ # prepare complex human instruction
235
+ if not complex_human_instruction:
236
+ max_length_all = max_sequence_length
237
+ else:
238
+ chi_prompt = "\n".join(complex_human_instruction)
239
+ prompt = [chi_prompt + p for p in prompt]
240
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
241
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
242
+
243
+ text_inputs = self.tokenizer(
244
+ prompt,
245
+ padding="max_length",
246
+ max_length=max_length_all,
247
+ truncation=True,
248
+ add_special_tokens=True,
249
+ return_tensors="pt",
250
+ )
251
+ text_input_ids = text_inputs.input_ids
252
+
253
+ prompt_attention_mask = text_inputs.attention_mask
254
+ prompt_attention_mask = prompt_attention_mask.to(device)
255
+
256
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
257
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
258
+
259
+ return prompt_embeds, prompt_attention_mask
260
+
261
+ def encode_prompt(
262
+ self,
263
+ prompt: Union[str, List[str]],
264
+ num_images_per_prompt: int = 1,
265
+ device: Optional[torch.device] = None,
266
+ prompt_embeds: Optional[torch.Tensor] = None,
267
+ prompt_attention_mask: Optional[torch.Tensor] = None,
268
+ clean_caption: bool = False,
269
+ max_sequence_length: int = 300,
270
+ complex_human_instruction: Optional[List[str]] = None,
271
+ lora_scale: Optional[float] = None,
272
+ ):
273
+ r"""
274
+ Encodes the prompt into text encoder hidden states.
275
+
276
+ Args:
277
+ prompt (`str` or `List[str]`, *optional*):
278
+ prompt to be encoded
279
+
280
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
281
+ number of images that should be generated per prompt
282
+ device: (`torch.device`, *optional*):
283
+ torch device to place the resulting embeddings on
284
+ prompt_embeds (`torch.Tensor`, *optional*):
285
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
286
+ provided, text embeddings will be generated from `prompt` input argument.
287
+ clean_caption (`bool`, defaults to `False`):
288
+ If `True`, the function will preprocess and clean the provided caption before encoding.
289
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
290
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
291
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
292
+ the prompt.
293
+ """
294
+
295
+ if device is None:
296
+ device = self._execution_device
297
+
298
+ if self.transformer is not None:
299
+ dtype = self.transformer.dtype
300
+ elif self.text_encoder is not None:
301
+ dtype = self.text_encoder.dtype
302
+ else:
303
+ dtype = None
304
+
305
+ # set lora scale so that monkey patched LoRA
306
+ # function of text encoder can correctly access it
307
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
308
+ self._lora_scale = lora_scale
309
+
310
+ # dynamically adjust the LoRA scale
311
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
312
+ scale_lora_layers(self.text_encoder, lora_scale)
313
+
314
+ if getattr(self, "tokenizer", None) is not None:
315
+ self.tokenizer.padding_side = "right"
316
+
317
+ # See Section 3.1. of the paper.
318
+ max_length = max_sequence_length
319
+ select_index = [0] + list(range(-max_length + 1, 0))
320
+
321
+ if prompt_embeds is None:
322
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
323
+ prompt=prompt,
324
+ device=device,
325
+ dtype=dtype,
326
+ clean_caption=clean_caption,
327
+ max_sequence_length=max_sequence_length,
328
+ complex_human_instruction=complex_human_instruction,
329
+ )
330
+
331
+ prompt_embeds = prompt_embeds[:, select_index]
332
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
333
+
334
+ bs_embed, seq_len, _ = prompt_embeds.shape
335
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
336
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
337
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
338
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
339
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
340
+
341
+ if self.text_encoder is not None:
342
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
343
+ # Retrieve the original scale by scaling back the LoRA layers
344
+ unscale_lora_layers(self.text_encoder, lora_scale)
345
+
346
+ return prompt_embeds, prompt_attention_mask
347
+
348
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
349
+ def prepare_extra_step_kwargs(self, generator, eta):
350
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
351
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
352
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
353
+ # and should be between [0, 1]
354
+
355
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
356
+ extra_step_kwargs = {}
357
+ if accepts_eta:
358
+ extra_step_kwargs["eta"] = eta
359
+
360
+ # check if the scheduler accepts generator
361
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
362
+ if accepts_generator:
363
+ extra_step_kwargs["generator"] = generator
364
+ return extra_step_kwargs
365
+
366
+ def check_inputs(
367
+ self,
368
+ prompt,
369
+ height,
370
+ width,
371
+ num_inference_steps,
372
+ timesteps,
373
+ max_timesteps,
374
+ intermediate_timesteps,
375
+ callback_on_step_end_tensor_inputs=None,
376
+ prompt_embeds=None,
377
+ prompt_attention_mask=None,
378
+ ):
379
+ if height % 32 != 0 or width % 32 != 0:
380
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
381
+
382
+ if callback_on_step_end_tensor_inputs is not None and not all(
383
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
384
+ ):
385
+ raise ValueError(
386
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
387
+ )
388
+
389
+ if prompt is not None and prompt_embeds is not None:
390
+ raise ValueError(
391
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
392
+ " only forward one of the two."
393
+ )
394
+ elif prompt is None and prompt_embeds is None:
395
+ raise ValueError(
396
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
397
+ )
398
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
399
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
400
+
401
+ if prompt_embeds is not None and prompt_attention_mask is None:
402
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
403
+
404
+ if timesteps is not None and len(timesteps) != num_inference_steps + 1:
405
+ raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
406
+
407
+ if timesteps is not None and max_timesteps is not None:
408
+ raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
409
+
410
+ if timesteps is None and max_timesteps is None:
411
+ raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
412
+
413
+ if intermediate_timesteps is not None and num_inference_steps != 2:
414
+ raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
415
+
416
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
417
+ def _text_preprocessing(self, text, clean_caption=False):
418
+ if clean_caption and not is_bs4_available():
419
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
420
+ logger.warning("Setting `clean_caption` to False...")
421
+ clean_caption = False
422
+
423
+ if clean_caption and not is_ftfy_available():
424
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
425
+ logger.warning("Setting `clean_caption` to False...")
426
+ clean_caption = False
427
+
428
+ if not isinstance(text, (tuple, list)):
429
+ text = [text]
430
+
431
+ def process(text: str):
432
+ if clean_caption:
433
+ text = self._clean_caption(text)
434
+ text = self._clean_caption(text)
435
+ else:
436
+ text = text.lower().strip()
437
+ return text
438
+
439
+ return [process(t) for t in text]
440
+
441
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
442
+ def _clean_caption(self, caption):
443
+ caption = str(caption)
444
+ caption = ul.unquote_plus(caption)
445
+ caption = caption.strip().lower()
446
+ caption = re.sub("<person>", "person", caption)
447
+ # urls:
448
+ caption = re.sub(
449
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
450
+ "",
451
+ caption,
452
+ ) # regex for urls
453
+ caption = re.sub(
454
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
455
+ "",
456
+ caption,
457
+ ) # regex for urls
458
+ # html:
459
+ caption = BeautifulSoup(caption, features="html.parser").text
460
+
461
+ # @<nickname>
462
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
463
+
464
+ # 31C0—31EF CJK Strokes
465
+ # 31F0—31FF Katakana Phonetic Extensions
466
+ # 3200—32FF Enclosed CJK Letters and Months
467
+ # 3300—33FF CJK Compatibility
468
+ # 3400—4DBF CJK Unified Ideographs Extension A
469
+ # 4DC0—4DFF Yijing Hexagram Symbols
470
+ # 4E00—9FFF CJK Unified Ideographs
471
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
472
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
473
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
474
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
475
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
476
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
477
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
478
+ #######################################################
479
+
480
+ # все виды тире / all types of dash --> "-"
481
+ caption = re.sub(
482
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
483
+ "-",
484
+ caption,
485
+ )
486
+
487
+ # кавычки к одному стандарту
488
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
489
+ caption = re.sub(r"[‘’]", "'", caption)
490
+
491
+ # &quot;
492
+ caption = re.sub(r"&quot;?", "", caption)
493
+ # &amp
494
+ caption = re.sub(r"&amp", "", caption)
495
+
496
+ # ip adresses:
497
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
498
+
499
+ # article ids:
500
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
501
+
502
+ # \n
503
+ caption = re.sub(r"\\n", " ", caption)
504
+
505
+ # "#123"
506
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
507
+ # "#12345.."
508
+ caption = re.sub(r"#\d{5,}\b", "", caption)
509
+ # "123456.."
510
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
511
+ # filenames:
512
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
513
+
514
+ #
515
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
516
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
517
+
518
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
519
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
520
+
521
+ # this-is-my-cute-cat / this_is_my_cute_cat
522
+ regex2 = re.compile(r"(?:\-|\_)")
523
+ if len(re.findall(regex2, caption)) > 3:
524
+ caption = re.sub(regex2, " ", caption)
525
+
526
+ caption = ftfy.fix_text(caption)
527
+ caption = html.unescape(html.unescape(caption))
528
+
529
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
530
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
531
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
532
+
533
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
534
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
535
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
536
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
537
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
538
+
539
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
540
+
541
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
542
+
543
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
544
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
545
+ caption = re.sub(r"\s+", " ", caption)
546
+
547
+ caption.strip()
548
+
549
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
550
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
551
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
552
+ caption = re.sub(r"^\.\S+$", "", caption)
553
+
554
+ return caption.strip()
555
+
556
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
557
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
558
+ if latents is not None:
559
+ return latents.to(device=device, dtype=dtype)
560
+
561
+ shape = (
562
+ batch_size,
563
+ num_channels_latents,
564
+ int(height) // self.vae_scale_factor,
565
+ int(width) // self.vae_scale_factor,
566
+ )
567
+ if isinstance(generator, list) and len(generator) != batch_size:
568
+ raise ValueError(
569
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
570
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
571
+ )
572
+
573
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
574
+ return latents
575
+
576
+ @property
577
+ def guidance_scale(self):
578
+ return self._guidance_scale
579
+
580
+ @property
581
+ def attention_kwargs(self):
582
+ return self._attention_kwargs
583
+
584
+ @property
585
+ def num_timesteps(self):
586
+ return self._num_timesteps
587
+
588
+ @property
589
+ def interrupt(self):
590
+ return self._interrupt
591
+
592
+ @torch.no_grad()
593
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
594
+ def __call__(
595
+ self,
596
+ prompt: Union[str, List[str]] = None,
597
+ num_inference_steps: int = 2,
598
+ timesteps: List[int] = None,
599
+ max_timesteps: float = 1.57080,
600
+ intermediate_timesteps: float = 1.3,
601
+ guidance_scale: float = 4.5,
602
+ num_images_per_prompt: Optional[int] = 1,
603
+ height: int = 1024,
604
+ width: int = 1024,
605
+ eta: float = 0.0,
606
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
607
+ latents: Optional[torch.Tensor] = None,
608
+ prompt_embeds: Optional[torch.Tensor] = None,
609
+ prompt_attention_mask: Optional[torch.Tensor] = None,
610
+ output_type: Optional[str] = "pil",
611
+ return_dict: bool = True,
612
+ clean_caption: bool = False,
613
+ use_resolution_binning: bool = True,
614
+ attention_kwargs: Optional[Dict[str, Any]] = None,
615
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
616
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
617
+ max_sequence_length: int = 300,
618
+ complex_human_instruction: List[str] = [
619
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
620
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
621
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
622
+ "Here are examples of how to transform or refine prompts:",
623
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
624
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
625
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
626
+ "User Prompt: ",
627
+ ],
628
+ ) -> Union[SanaPipelineOutput, Tuple]:
629
+ """
630
+ Function invoked when calling the pipeline for generation.
631
+
632
+ Args:
633
+ prompt (`str` or `List[str]`, *optional*):
634
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
635
+ instead.
636
+ num_inference_steps (`int`, *optional*, defaults to 20):
637
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
638
+ expense of slower inference.
639
+ max_timesteps (`float`, *optional*, defaults to 1.57080):
640
+ The maximum timestep value used in the SCM scheduler.
641
+ intermediate_timesteps (`float`, *optional*, defaults to 1.3):
642
+ The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
643
+ timesteps (`List[int]`, *optional*):
644
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
645
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
646
+ passed will be used. Must be in descending order.
647
+ guidance_scale (`float`, *optional*, defaults to 4.5):
648
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
649
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
650
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
651
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
652
+ usually at the expense of lower image quality.
653
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
654
+ The number of images to generate per prompt.
655
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
656
+ The height in pixels of the generated image.
657
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
658
+ The width in pixels of the generated image.
659
+ eta (`float`, *optional*, defaults to 0.0):
660
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
661
+ [`schedulers.DDIMScheduler`], will be ignored for others.
662
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
663
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
664
+ to make generation deterministic.
665
+ latents (`torch.Tensor`, *optional*):
666
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
667
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
668
+ tensor will ge generated by sampling using the supplied random `generator`.
669
+ prompt_embeds (`torch.Tensor`, *optional*):
670
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
671
+ provided, text embeddings will be generated from `prompt` input argument.
672
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
673
+ output_type (`str`, *optional*, defaults to `"pil"`):
674
+ The output format of the generate image. Choose between
675
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
676
+ return_dict (`bool`, *optional*, defaults to `True`):
677
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
678
+ attention_kwargs:
679
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
680
+ `self.processor` in
681
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
682
+ clean_caption (`bool`, *optional*, defaults to `True`):
683
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
684
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
685
+ prompt.
686
+ use_resolution_binning (`bool` defaults to `True`):
687
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
688
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
689
+ the requested resolution. Useful for generating non-square images.
690
+ callback_on_step_end (`Callable`, *optional*):
691
+ A function that calls at the end of each denoising steps during the inference. The function is called
692
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
693
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
694
+ `callback_on_step_end_tensor_inputs`.
695
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
696
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
697
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
698
+ `._callback_tensor_inputs` attribute of your pipeline class.
699
+ max_sequence_length (`int` defaults to `300`):
700
+ Maximum sequence length to use with the `prompt`.
701
+ complex_human_instruction (`List[str]`, *optional*):
702
+ Instructions for complex human attention:
703
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
704
+
705
+ Examples:
706
+
707
+ Returns:
708
+ [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
709
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
710
+ otherwise a `tuple` is returned where the first element is a list with the generated images
711
+ """
712
+
713
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
714
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
715
+
716
+ # 1. Check inputs. Raise error if not correct
717
+ if use_resolution_binning:
718
+ if self.transformer.config.sample_size == 32:
719
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
720
+ else:
721
+ raise ValueError("Invalid sample size")
722
+ orig_height, orig_width = height, width
723
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
724
+
725
+ self.check_inputs(
726
+ prompt=prompt,
727
+ height=height,
728
+ width=width,
729
+ num_inference_steps=num_inference_steps,
730
+ timesteps=timesteps,
731
+ max_timesteps=max_timesteps,
732
+ intermediate_timesteps=intermediate_timesteps,
733
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
734
+ prompt_embeds=prompt_embeds,
735
+ prompt_attention_mask=prompt_attention_mask,
736
+ )
737
+
738
+ self._guidance_scale = guidance_scale
739
+ self._attention_kwargs = attention_kwargs
740
+ self._interrupt = False
741
+
742
+ # 2. Default height and width to transformer
743
+ if prompt is not None and isinstance(prompt, str):
744
+ batch_size = 1
745
+ elif prompt is not None and isinstance(prompt, list):
746
+ batch_size = len(prompt)
747
+ else:
748
+ batch_size = prompt_embeds.shape[0]
749
+
750
+ device = self._execution_device
751
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
752
+
753
+ # 3. Encode input prompt
754
+ (
755
+ prompt_embeds,
756
+ prompt_attention_mask,
757
+ ) = self.encode_prompt(
758
+ prompt,
759
+ num_images_per_prompt=num_images_per_prompt,
760
+ device=device,
761
+ prompt_embeds=prompt_embeds,
762
+ prompt_attention_mask=prompt_attention_mask,
763
+ clean_caption=clean_caption,
764
+ max_sequence_length=max_sequence_length,
765
+ complex_human_instruction=complex_human_instruction,
766
+ lora_scale=lora_scale,
767
+ )
768
+
769
+ # 4. Prepare timesteps
770
+ timesteps, num_inference_steps = retrieve_timesteps(
771
+ self.scheduler,
772
+ num_inference_steps,
773
+ device,
774
+ timesteps,
775
+ sigmas=None,
776
+ max_timesteps=max_timesteps,
777
+ intermediate_timesteps=intermediate_timesteps,
778
+ )
779
+ if hasattr(self.scheduler, "set_begin_index"):
780
+ self.scheduler.set_begin_index(0)
781
+
782
+ # 5. Prepare latents.
783
+ latent_channels = self.transformer.config.in_channels
784
+ latents = self.prepare_latents(
785
+ batch_size * num_images_per_prompt,
786
+ latent_channels,
787
+ height,
788
+ width,
789
+ torch.float32,
790
+ device,
791
+ generator,
792
+ latents,
793
+ )
794
+
795
+ latents = latents * self.scheduler.config.sigma_data
796
+
797
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
798
+ guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
799
+ guidance = guidance * self.transformer.config.guidance_embeds_scale
800
+
801
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
802
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
803
+
804
+ # 7. Denoising loop
805
+ timesteps = timesteps[:-1]
806
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
807
+ self._num_timesteps = len(timesteps)
808
+
809
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
810
+ for i, t in enumerate(timesteps):
811
+ if self.interrupt:
812
+ continue
813
+
814
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
815
+ timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
816
+ latents_model_input = latents / self.scheduler.config.sigma_data
817
+
818
+ scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
819
+
820
+ scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1)
821
+ latent_model_input = latents_model_input * torch.sqrt(
822
+ scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
823
+ )
824
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
825
+
826
+ # predict noise model_output
827
+ noise_pred = self.transformer(
828
+ latent_model_input,
829
+ encoder_hidden_states=prompt_embeds,
830
+ encoder_attention_mask=prompt_attention_mask,
831
+ guidance=guidance,
832
+ timestep=scm_timestep,
833
+ return_dict=False,
834
+ attention_kwargs=self.attention_kwargs,
835
+ )[0]
836
+
837
+ noise_pred = (
838
+ (1 - 2 * scm_timestep_expanded) * latent_model_input
839
+ + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred
840
+ ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2)
841
+ noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
842
+
843
+ # compute previous image: x_t -> x_t-1
844
+ latents, denoised = self.scheduler.step(
845
+ noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
846
+ )
847
+
848
+ if callback_on_step_end is not None:
849
+ callback_kwargs = {}
850
+ for k in callback_on_step_end_tensor_inputs:
851
+ callback_kwargs[k] = locals()[k]
852
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
853
+
854
+ latents = callback_outputs.pop("latents", latents)
855
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
856
+
857
+ # call the callback, if provided
858
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
859
+ progress_bar.update()
860
+
861
+ if XLA_AVAILABLE:
862
+ xm.mark_step()
863
+
864
+ latents = denoised / self.scheduler.config.sigma_data
865
+ if output_type == "latent":
866
+ image = latents
867
+ else:
868
+ latents = latents.to(self.vae.dtype)
869
+ try:
870
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
871
+ except torch.cuda.OutOfMemoryError as e:
872
+ warnings.warn(
873
+ f"{e}. \n"
874
+ f"Try to use VAE tiling for large images. For example: \n"
875
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
876
+ )
877
+ if use_resolution_binning:
878
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
879
+
880
+ if not output_type == "latent":
881
+ image = self.image_processor.postprocess(image, output_type=output_type)
882
+
883
+ # Offload all models
884
+ self.maybe_free_model_hooks()
885
+
886
+ if not return_dict:
887
+ return (image,)
888
+
889
+ return SanaPipelineOutput(images=image)