diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ from ...utils import (
29
29
  deprecate,
30
30
  is_bs4_available,
31
31
  is_ftfy_available,
32
+ is_torch_xla_available,
32
33
  logging,
33
34
  replace_example_docstring,
34
35
  )
@@ -36,8 +37,16 @@ from ...utils.torch_utils import randn_tensor
36
37
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
37
38
 
38
39
 
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
39
47
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
48
 
49
+
41
50
  if is_bs4_available():
42
51
  from bs4 import BeautifulSoup
43
52
 
@@ -285,7 +294,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
285
294
  tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
286
295
  )
287
296
 
288
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
297
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
289
298
  self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
290
299
 
291
300
  # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
@@ -898,10 +907,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
898
907
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
899
908
  # This would be a good case for the `match` statement (Python 3.10+)
900
909
  is_mps = latent_model_input.device.type == "mps"
910
+ is_npu = latent_model_input.device.type == "npu"
901
911
  if isinstance(current_timestep, float):
902
- dtype = torch.float32 if is_mps else torch.float64
912
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
903
913
  else:
904
- dtype = torch.int32 if is_mps else torch.int64
914
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
905
915
  current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
906
916
  elif len(current_timestep.shape) == 0:
907
917
  current_timestep = current_timestep[None].to(latent_model_input.device)
@@ -931,8 +941,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
931
941
 
932
942
  # compute previous image: x_t -> x_t-1
933
943
  if num_inference_steps == 1:
934
- # For DMD one step sampling: https://arxiv.org/abs/2311.18828
935
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
944
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1]
936
945
  else:
937
946
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
938
947
 
@@ -943,6 +952,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
943
952
  step_idx = i // getattr(self.scheduler, "order", 1)
944
953
  callback(step_idx, t, latents)
945
954
 
955
+ if XLA_AVAILABLE:
956
+ xm.mark_step()
957
+
946
958
  if not output_type == "latent":
947
959
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
948
960
  if use_resolution_binning:
@@ -29,6 +29,7 @@ from ...utils import (
29
29
  deprecate,
30
30
  is_bs4_available,
31
31
  is_ftfy_available,
32
+ is_torch_xla_available,
32
33
  logging,
33
34
  replace_example_docstring,
34
35
  )
@@ -41,8 +42,16 @@ from .pipeline_pixart_alpha import (
41
42
  )
42
43
 
43
44
 
45
+ if is_torch_xla_available():
46
+ import torch_xla.core.xla_model as xm
47
+
48
+ XLA_AVAILABLE = True
49
+ else:
50
+ XLA_AVAILABLE = False
51
+
44
52
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
53
 
54
+
46
55
  if is_bs4_available():
47
56
  from bs4 import BeautifulSoup
48
57
 
@@ -211,7 +220,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
211
220
  tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
212
221
  )
213
222
 
214
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
223
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
215
224
  self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
216
225
 
217
226
  # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300
@@ -813,10 +822,11 @@ class PixArtSigmaPipeline(DiffusionPipeline):
813
822
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
814
823
  # This would be a good case for the `match` statement (Python 3.10+)
815
824
  is_mps = latent_model_input.device.type == "mps"
825
+ is_npu = latent_model_input.device.type == "npu"
816
826
  if isinstance(current_timestep, float):
817
- dtype = torch.float32 if is_mps else torch.float64
827
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
818
828
  else:
819
- dtype = torch.int32 if is_mps else torch.int64
829
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
820
830
  current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
821
831
  elif len(current_timestep.shape) == 0:
822
832
  current_timestep = current_timestep[None].to(latent_model_input.device)
@@ -854,8 +864,11 @@ class PixArtSigmaPipeline(DiffusionPipeline):
854
864
  step_idx = i // getattr(self.scheduler, "order", 1)
855
865
  callback(step_idx, t, latents)
856
866
 
867
+ if XLA_AVAILABLE:
868
+ xm.mark_step()
869
+
857
870
  if not output_type == "latent":
858
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
871
+ image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0]
859
872
  if use_resolution_binning:
860
873
  image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
861
874
  else:
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
23
23
  _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
24
  else:
25
25
  _import_structure["pipeline_sana"] = ["SanaPipeline"]
26
+ _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
26
27
 
27
28
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
29
  try:
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
33
34
  from ...utils.dummy_torch_and_transformers_objects import *
34
35
  else:
35
36
  from .pipeline_sana import SanaPipeline
37
+ from .pipeline_sana_sprint import SanaSprintPipeline
36
38
  else:
37
39
  import sys
38
40
 
@@ -16,10 +16,11 @@ import html
16
16
  import inspect
17
17
  import re
18
18
  import urllib.parse as ul
19
+ import warnings
19
20
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
21
 
21
22
  import torch
22
- from transformers import AutoModelForCausalLM, AutoTokenizer
23
+ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
23
24
 
24
25
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
26
  from ...image_processor import PixArtImageProcessor
@@ -31,6 +32,7 @@ from ...utils import (
31
32
  USE_PEFT_BACKEND,
32
33
  is_bs4_available,
33
34
  is_ftfy_available,
35
+ is_torch_xla_available,
34
36
  logging,
35
37
  replace_example_docstring,
36
38
  scale_lora_layers,
@@ -46,6 +48,13 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
46
48
  from .pipeline_output import SanaPipelineOutput
47
49
 
48
50
 
51
+ if is_torch_xla_available():
52
+ import torch_xla.core.xla_model as xm
53
+
54
+ XLA_AVAILABLE = True
55
+ else:
56
+ XLA_AVAILABLE = False
57
+
49
58
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
59
 
51
60
  if is_bs4_available():
@@ -55,6 +64,49 @@ if is_ftfy_available():
55
64
  import ftfy
56
65
 
57
66
 
67
+ ASPECT_RATIO_4096_BIN = {
68
+ "0.25": [2048.0, 8192.0],
69
+ "0.26": [2048.0, 7936.0],
70
+ "0.27": [2048.0, 7680.0],
71
+ "0.28": [2048.0, 7424.0],
72
+ "0.32": [2304.0, 7168.0],
73
+ "0.33": [2304.0, 6912.0],
74
+ "0.35": [2304.0, 6656.0],
75
+ "0.4": [2560.0, 6400.0],
76
+ "0.42": [2560.0, 6144.0],
77
+ "0.48": [2816.0, 5888.0],
78
+ "0.5": [2816.0, 5632.0],
79
+ "0.52": [2816.0, 5376.0],
80
+ "0.57": [3072.0, 5376.0],
81
+ "0.6": [3072.0, 5120.0],
82
+ "0.68": [3328.0, 4864.0],
83
+ "0.72": [3328.0, 4608.0],
84
+ "0.78": [3584.0, 4608.0],
85
+ "0.82": [3584.0, 4352.0],
86
+ "0.88": [3840.0, 4352.0],
87
+ "0.94": [3840.0, 4096.0],
88
+ "1.0": [4096.0, 4096.0],
89
+ "1.07": [4096.0, 3840.0],
90
+ "1.13": [4352.0, 3840.0],
91
+ "1.21": [4352.0, 3584.0],
92
+ "1.29": [4608.0, 3584.0],
93
+ "1.38": [4608.0, 3328.0],
94
+ "1.46": [4864.0, 3328.0],
95
+ "1.67": [5120.0, 3072.0],
96
+ "1.75": [5376.0, 3072.0],
97
+ "2.0": [5632.0, 2816.0],
98
+ "2.09": [5888.0, 2816.0],
99
+ "2.4": [6144.0, 2560.0],
100
+ "2.5": [6400.0, 2560.0],
101
+ "2.89": [6656.0, 2304.0],
102
+ "3.0": [6912.0, 2304.0],
103
+ "3.11": [7168.0, 2304.0],
104
+ "3.62": [7424.0, 2048.0],
105
+ "3.75": [7680.0, 2048.0],
106
+ "3.88": [7936.0, 2048.0],
107
+ "4.0": [8192.0, 2048.0],
108
+ }
109
+
58
110
  EXAMPLE_DOC_STRING = """
59
111
  Examples:
60
112
  ```py
@@ -148,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
148
200
 
149
201
  def __init__(
150
202
  self,
151
- tokenizer: AutoTokenizer,
152
- text_encoder: AutoModelForCausalLM,
203
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
204
+ text_encoder: Gemma2PreTrainedModel,
153
205
  vae: AutoencoderDC,
154
206
  transformer: SanaTransformer2DModel,
155
207
  scheduler: DPMSolverMultistepScheduler,
@@ -167,6 +219,93 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
167
219
  )
168
220
  self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
169
221
 
222
+ def enable_vae_slicing(self):
223
+ r"""
224
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
225
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
226
+ """
227
+ self.vae.enable_slicing()
228
+
229
+ def disable_vae_slicing(self):
230
+ r"""
231
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
232
+ computing decoding in one step.
233
+ """
234
+ self.vae.disable_slicing()
235
+
236
+ def enable_vae_tiling(self):
237
+ r"""
238
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
239
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
240
+ processing larger images.
241
+ """
242
+ self.vae.enable_tiling()
243
+
244
+ def disable_vae_tiling(self):
245
+ r"""
246
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
247
+ computing decoding in one step.
248
+ """
249
+ self.vae.disable_tiling()
250
+
251
+ def _get_gemma_prompt_embeds(
252
+ self,
253
+ prompt: Union[str, List[str]],
254
+ device: torch.device,
255
+ dtype: torch.dtype,
256
+ clean_caption: bool = False,
257
+ max_sequence_length: int = 300,
258
+ complex_human_instruction: Optional[List[str]] = None,
259
+ ):
260
+ r"""
261
+ Encodes the prompt into text encoder hidden states.
262
+
263
+ Args:
264
+ prompt (`str` or `List[str]`, *optional*):
265
+ prompt to be encoded
266
+ device: (`torch.device`, *optional*):
267
+ torch device to place the resulting embeddings on
268
+ clean_caption (`bool`, defaults to `False`):
269
+ If `True`, the function will preprocess and clean the provided caption before encoding.
270
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
271
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
272
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
273
+ the prompt.
274
+ """
275
+ prompt = [prompt] if isinstance(prompt, str) else prompt
276
+
277
+ if getattr(self, "tokenizer", None) is not None:
278
+ self.tokenizer.padding_side = "right"
279
+
280
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
281
+
282
+ # prepare complex human instruction
283
+ if not complex_human_instruction:
284
+ max_length_all = max_sequence_length
285
+ else:
286
+ chi_prompt = "\n".join(complex_human_instruction)
287
+ prompt = [chi_prompt + p for p in prompt]
288
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
289
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
290
+
291
+ text_inputs = self.tokenizer(
292
+ prompt,
293
+ padding="max_length",
294
+ max_length=max_length_all,
295
+ truncation=True,
296
+ add_special_tokens=True,
297
+ return_tensors="pt",
298
+ )
299
+ text_input_ids = text_inputs.input_ids
300
+
301
+ prompt_attention_mask = text_inputs.attention_mask
302
+ prompt_attention_mask = prompt_attention_mask.to(device)
303
+
304
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
305
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
306
+
307
+ return prompt_embeds, prompt_attention_mask
308
+
170
309
  def encode_prompt(
171
310
  self,
172
311
  prompt: Union[str, List[str]],
@@ -215,6 +354,13 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
215
354
  if device is None:
216
355
  device = self._execution_device
217
356
 
357
+ if self.transformer is not None:
358
+ dtype = self.transformer.dtype
359
+ elif self.text_encoder is not None:
360
+ dtype = self.text_encoder.dtype
361
+ else:
362
+ dtype = None
363
+
218
364
  # set lora scale so that monkey patched LoRA
219
365
  # function of text encoder can correctly access it
220
366
  if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
@@ -231,50 +377,26 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
231
377
  else:
232
378
  batch_size = prompt_embeds.shape[0]
233
379
 
234
- self.tokenizer.padding_side = "right"
380
+ if getattr(self, "tokenizer", None) is not None:
381
+ self.tokenizer.padding_side = "right"
235
382
 
236
383
  # See Section 3.1. of the paper.
237
384
  max_length = max_sequence_length
238
385
  select_index = [0] + list(range(-max_length + 1, 0))
239
386
 
240
387
  if prompt_embeds is None:
241
- prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
242
-
243
- # prepare complex human instruction
244
- if not complex_human_instruction:
245
- max_length_all = max_length
246
- else:
247
- chi_prompt = "\n".join(complex_human_instruction)
248
- prompt = [chi_prompt + p for p in prompt]
249
- num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
250
- max_length_all = num_chi_prompt_tokens + max_length - 2
251
-
252
- text_inputs = self.tokenizer(
253
- prompt,
254
- padding="max_length",
255
- max_length=max_length_all,
256
- truncation=True,
257
- add_special_tokens=True,
258
- return_tensors="pt",
388
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
389
+ prompt=prompt,
390
+ device=device,
391
+ dtype=dtype,
392
+ clean_caption=clean_caption,
393
+ max_sequence_length=max_sequence_length,
394
+ complex_human_instruction=complex_human_instruction,
259
395
  )
260
- text_input_ids = text_inputs.input_ids
261
396
 
262
- prompt_attention_mask = text_inputs.attention_mask
263
- prompt_attention_mask = prompt_attention_mask.to(device)
264
-
265
- prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
266
- prompt_embeds = prompt_embeds[0][:, select_index]
397
+ prompt_embeds = prompt_embeds[:, select_index]
267
398
  prompt_attention_mask = prompt_attention_mask[:, select_index]
268
399
 
269
- if self.transformer is not None:
270
- dtype = self.transformer.dtype
271
- elif self.text_encoder is not None:
272
- dtype = self.text_encoder.dtype
273
- else:
274
- dtype = None
275
-
276
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
277
-
278
400
  bs_embed, seq_len, _ = prompt_embeds.shape
279
401
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
280
402
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -284,25 +406,15 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
284
406
 
285
407
  # get unconditional embeddings for classifier free guidance
286
408
  if do_classifier_free_guidance and negative_prompt_embeds is None:
287
- uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
288
- uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
289
- max_length = prompt_embeds.shape[1]
290
- uncond_input = self.tokenizer(
291
- uncond_tokens,
292
- padding="max_length",
293
- max_length=max_length,
294
- truncation=True,
295
- return_attention_mask=True,
296
- add_special_tokens=True,
297
- return_tensors="pt",
298
- )
299
- negative_prompt_attention_mask = uncond_input.attention_mask
300
- negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
301
-
302
- negative_prompt_embeds = self.text_encoder(
303
- uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
409
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
410
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
411
+ prompt=negative_prompt,
412
+ device=device,
413
+ dtype=dtype,
414
+ clean_caption=clean_caption,
415
+ max_sequence_length=max_sequence_length,
416
+ complex_human_instruction=False,
304
417
  )
305
- negative_prompt_embeds = negative_prompt_embeds[0]
306
418
 
307
419
  if do_classifier_free_guidance:
308
420
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -611,7 +723,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
611
723
  negative_prompt_attention_mask: Optional[torch.Tensor] = None,
612
724
  output_type: Optional[str] = "pil",
613
725
  return_dict: bool = True,
614
- clean_caption: bool = True,
726
+ clean_caption: bool = False,
615
727
  use_resolution_binning: bool = True,
616
728
  attention_kwargs: Optional[Dict[str, Any]] = None,
617
729
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
@@ -726,7 +838,9 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
726
838
 
727
839
  # 1. Check inputs. Raise error if not correct
728
840
  if use_resolution_binning:
729
- if self.transformer.config.sample_size == 64:
841
+ if self.transformer.config.sample_size == 128:
842
+ aspect_ratio_bin = ASPECT_RATIO_4096_BIN
843
+ elif self.transformer.config.sample_size == 64:
730
844
  aspect_ratio_bin = ASPECT_RATIO_2048_BIN
731
845
  elif self.transformer.config.sample_size == 32:
732
846
  aspect_ratio_bin = ASPECT_RATIO_1024_BIN
@@ -824,6 +938,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
824
938
 
825
939
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
826
940
  timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
941
+ timestep = timestep * self.transformer.config.timestep_scale
827
942
 
828
943
  # predict noise model_output
829
944
  noise_pred = self.transformer(
@@ -864,11 +979,21 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
864
979
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
865
980
  progress_bar.update()
866
981
 
982
+ if XLA_AVAILABLE:
983
+ xm.mark_step()
984
+
867
985
  if output_type == "latent":
868
986
  image = latents
869
987
  else:
870
988
  latents = latents.to(self.vae.dtype)
871
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
989
+ try:
990
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
991
+ except torch.cuda.OutOfMemoryError as e:
992
+ warnings.warn(
993
+ f"{e}. \n"
994
+ f"Try to use VAE tiling for large images. For example: \n"
995
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
996
+ )
872
997
  if use_resolution_binning:
873
998
  image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
874
999