diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +17 -12
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. diffusers-0.33.0.dist-info/RECORD +0 -608
  475. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  476. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
  477. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -22,9 +22,11 @@ except OptionalDependencyNotAvailable:
22
22
 
23
23
  _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
24
  else:
25
+ _import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
25
26
  _import_structure["pipeline_ltx"] = ["LTXPipeline"]
26
27
  _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
27
28
  _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
29
+ _import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
28
30
 
29
31
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
30
32
  try:
@@ -34,9 +36,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
34
36
  except OptionalDependencyNotAvailable:
35
37
  from ...utils.dummy_torch_and_transformers_objects import *
36
38
  else:
39
+ from .modeling_latent_upsampler import LTXLatentUpsamplerModel
37
40
  from .pipeline_ltx import LTXPipeline
38
41
  from .pipeline_ltx_condition import LTXConditionPipeline
39
42
  from .pipeline_ltx_image2video import LTXImageToVideoPipeline
43
+ from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
40
44
 
41
45
  else:
42
46
  import sys
@@ -0,0 +1,188 @@
1
+ # Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...models.modeling_utils import ModelMixin
21
+
22
+
23
+ class ResBlock(torch.nn.Module):
24
+ def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
25
+ super().__init__()
26
+ if mid_channels is None:
27
+ mid_channels = channels
28
+
29
+ Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
30
+
31
+ self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
32
+ self.norm1 = torch.nn.GroupNorm(32, mid_channels)
33
+ self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
34
+ self.norm2 = torch.nn.GroupNorm(32, channels)
35
+ self.activation = torch.nn.SiLU()
36
+
37
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
38
+ residual = hidden_states
39
+ hidden_states = self.conv1(hidden_states)
40
+ hidden_states = self.norm1(hidden_states)
41
+ hidden_states = self.activation(hidden_states)
42
+ hidden_states = self.conv2(hidden_states)
43
+ hidden_states = self.norm2(hidden_states)
44
+ hidden_states = self.activation(hidden_states + residual)
45
+ return hidden_states
46
+
47
+
48
+ class PixelShuffleND(torch.nn.Module):
49
+ def __init__(self, dims, upscale_factors=(2, 2, 2)):
50
+ super().__init__()
51
+
52
+ self.dims = dims
53
+ self.upscale_factors = upscale_factors
54
+
55
+ if dims not in [1, 2, 3]:
56
+ raise ValueError("dims must be 1, 2, or 3")
57
+
58
+ def forward(self, x):
59
+ if self.dims == 3:
60
+ # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
61
+ return (
62
+ x.unflatten(1, (-1, *self.upscale_factors[:3]))
63
+ .permute(0, 1, 5, 2, 6, 3, 7, 4)
64
+ .flatten(6, 7)
65
+ .flatten(4, 5)
66
+ .flatten(2, 3)
67
+ )
68
+ elif self.dims == 2:
69
+ # spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
70
+ return (
71
+ x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
72
+ )
73
+ elif self.dims == 1:
74
+ # temporal: b (c p1) f h w -> b c (f p1) h w
75
+ return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
76
+
77
+
78
+ class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin):
79
+ """
80
+ Model to spatially upsample VAE latents.
81
+
82
+ Args:
83
+ in_channels (`int`, defaults to `128`):
84
+ Number of channels in the input latent
85
+ mid_channels (`int`, defaults to `512`):
86
+ Number of channels in the middle layers
87
+ num_blocks_per_stage (`int`, defaults to `4`):
88
+ Number of ResBlocks to use in each stage (pre/post upsampling)
89
+ dims (`int`, defaults to `3`):
90
+ Number of dimensions for convolutions (2 or 3)
91
+ spatial_upsample (`bool`, defaults to `True`):
92
+ Whether to spatially upsample the latent
93
+ temporal_upsample (`bool`, defaults to `False`):
94
+ Whether to temporally upsample the latent
95
+ """
96
+
97
+ @register_to_config
98
+ def __init__(
99
+ self,
100
+ in_channels: int = 128,
101
+ mid_channels: int = 512,
102
+ num_blocks_per_stage: int = 4,
103
+ dims: int = 3,
104
+ spatial_upsample: bool = True,
105
+ temporal_upsample: bool = False,
106
+ ):
107
+ super().__init__()
108
+
109
+ self.in_channels = in_channels
110
+ self.mid_channels = mid_channels
111
+ self.num_blocks_per_stage = num_blocks_per_stage
112
+ self.dims = dims
113
+ self.spatial_upsample = spatial_upsample
114
+ self.temporal_upsample = temporal_upsample
115
+
116
+ ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
117
+
118
+ self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
119
+ self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
120
+ self.initial_activation = torch.nn.SiLU()
121
+
122
+ self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
123
+
124
+ if spatial_upsample and temporal_upsample:
125
+ self.upsampler = torch.nn.Sequential(
126
+ torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
127
+ PixelShuffleND(3),
128
+ )
129
+ elif spatial_upsample:
130
+ self.upsampler = torch.nn.Sequential(
131
+ torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
132
+ PixelShuffleND(2),
133
+ )
134
+ elif temporal_upsample:
135
+ self.upsampler = torch.nn.Sequential(
136
+ torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
137
+ PixelShuffleND(1),
138
+ )
139
+ else:
140
+ raise ValueError("Either spatial_upsample or temporal_upsample must be True")
141
+
142
+ self.post_upsample_res_blocks = torch.nn.ModuleList(
143
+ [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
144
+ )
145
+
146
+ self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
147
+
148
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
149
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
150
+
151
+ if self.dims == 2:
152
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
153
+ hidden_states = self.initial_conv(hidden_states)
154
+ hidden_states = self.initial_norm(hidden_states)
155
+ hidden_states = self.initial_activation(hidden_states)
156
+
157
+ for block in self.res_blocks:
158
+ hidden_states = block(hidden_states)
159
+
160
+ hidden_states = self.upsampler(hidden_states)
161
+
162
+ for block in self.post_upsample_res_blocks:
163
+ hidden_states = block(hidden_states)
164
+
165
+ hidden_states = self.final_conv(hidden_states)
166
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
167
+ else:
168
+ hidden_states = self.initial_conv(hidden_states)
169
+ hidden_states = self.initial_norm(hidden_states)
170
+ hidden_states = self.initial_activation(hidden_states)
171
+
172
+ for block in self.res_blocks:
173
+ hidden_states = block(hidden_states)
174
+
175
+ if self.temporal_upsample:
176
+ hidden_states = self.upsampler(hidden_states)
177
+ hidden_states = hidden_states[:, :, 1:, :, :]
178
+ else:
179
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
180
+ hidden_states = self.upsampler(hidden_states)
181
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
182
+
183
+ for block in self.post_upsample_res_blocks:
184
+ hidden_states = block(hidden_states)
185
+
186
+ hidden_states = self.final_conv(hidden_states)
187
+
188
+ return hidden_states
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -140,6 +140,33 @@ def retrieve_timesteps(
140
140
  return timesteps, num_inference_steps
141
141
 
142
142
 
143
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
144
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
145
+ r"""
146
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
147
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
148
+ Flawed](https://huggingface.co/papers/2305.08891).
149
+
150
+ Args:
151
+ noise_cfg (`torch.Tensor`):
152
+ The predicted noise tensor for the guided diffusion process.
153
+ noise_pred_text (`torch.Tensor`):
154
+ The predicted noise tensor for the text-guided diffusion process.
155
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
156
+ A rescale factor applied to the noise predictions.
157
+
158
+ Returns:
159
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
160
+ """
161
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
162
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
163
+ # rescale the results from guidance (fixes overexposure)
164
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
165
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
166
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
167
+ return noise_cfg
168
+
169
+
143
170
  class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
144
171
  r"""
145
172
  Pipeline for text-to-video generation.
@@ -481,6 +508,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
481
508
  def guidance_scale(self):
482
509
  return self._guidance_scale
483
510
 
511
+ @property
512
+ def guidance_rescale(self):
513
+ return self._guidance_rescale
514
+
484
515
  @property
485
516
  def do_classifier_free_guidance(self):
486
517
  return self._guidance_scale > 1.0
@@ -514,6 +545,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
514
545
  num_inference_steps: int = 50,
515
546
  timesteps: List[int] = None,
516
547
  guidance_scale: float = 3,
548
+ guidance_rescale: float = 0.0,
517
549
  num_videos_per_prompt: Optional[int] = 1,
518
550
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
519
551
  latents: Optional[torch.Tensor] = None,
@@ -551,11 +583,16 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
551
583
  in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
552
584
  passed will be used. Must be in descending order.
553
585
  guidance_scale (`float`, defaults to `3 `):
554
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
555
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
556
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
557
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
558
- usually at the expense of lower image quality.
586
+ Guidance scale as defined in [Classifier-Free Diffusion
587
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
588
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
589
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
590
+ the text `prompt`, usually at the expense of lower image quality.
591
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
592
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
593
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
594
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
595
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
559
596
  num_videos_per_prompt (`int`, *optional*, defaults to 1):
560
597
  The number of videos to generate per prompt.
561
598
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -624,6 +661,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
624
661
  )
625
662
 
626
663
  self._guidance_scale = guidance_scale
664
+ self._guidance_rescale = guidance_rescale
627
665
  self._attention_kwargs = attention_kwargs
628
666
  self._interrupt = False
629
667
  self._current_timestep = None
@@ -737,6 +775,12 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
737
775
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
738
776
  noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
739
777
 
778
+ if self.guidance_rescale > 0:
779
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
780
+ noise_pred = rescale_noise_cfg(
781
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
782
+ )
783
+
740
784
  # compute the previous noisy sample x_t -> x_t-1
741
785
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
742
786
 
@@ -789,6 +833,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
789
833
  ]
790
834
  latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
791
835
 
836
+ latents = latents.to(self.vae.dtype)
792
837
  video = self.vae.decode(latents, timestep, return_dict=False)[0]
793
838
  video = self.video_processor.postprocess_video(video, output_type=output_type)
794
839
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -222,6 +222,33 @@ def retrieve_latents(
222
222
  raise AttributeError("Could not access latents of provided encoder_output")
223
223
 
224
224
 
225
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
226
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
227
+ r"""
228
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
229
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
230
+ Flawed](https://huggingface.co/papers/2305.08891).
231
+
232
+ Args:
233
+ noise_cfg (`torch.Tensor`):
234
+ The predicted noise tensor for the guided diffusion process.
235
+ noise_pred_text (`torch.Tensor`):
236
+ The predicted noise tensor for the text-guided diffusion process.
237
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
238
+ A rescale factor applied to the noise predictions.
239
+
240
+ Returns:
241
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
242
+ """
243
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
244
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
245
+ # rescale the results from guidance (fixes overexposure)
246
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
247
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
248
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
249
+ return noise_cfg
250
+
251
+
225
252
  class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
226
253
  r"""
227
254
  Pipeline for text/image/video-to-video generation.
@@ -430,6 +457,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
430
457
  video,
431
458
  frame_index,
432
459
  strength,
460
+ denoise_strength,
433
461
  height,
434
462
  width,
435
463
  callback_on_step_end_tensor_inputs=None,
@@ -497,6 +525,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
497
525
  elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
498
526
  raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
499
527
 
528
+ if denoise_strength < 0 or denoise_strength > 1:
529
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}")
530
+
500
531
  @staticmethod
501
532
  def _prepare_video_ids(
502
533
  batch_size: int,
@@ -649,6 +680,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
649
680
  width: int = 704,
650
681
  num_frames: int = 161,
651
682
  num_prefix_latent_frames: int = 2,
683
+ sigma: Optional[torch.Tensor] = None,
684
+ latents: Optional[torch.Tensor] = None,
652
685
  generator: Optional[torch.Generator] = None,
653
686
  device: Optional[torch.device] = None,
654
687
  dtype: Optional[torch.dtype] = None,
@@ -658,7 +691,18 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
658
691
  latent_width = width // self.vae_spatial_compression_ratio
659
692
 
660
693
  shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
661
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
694
+
695
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
696
+ if latents is not None and sigma is not None:
697
+ if latents.shape != shape:
698
+ raise ValueError(
699
+ f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input."
700
+ )
701
+ latents = latents.to(device=device, dtype=dtype)
702
+ sigma = sigma.to(device=device, dtype=dtype)
703
+ latents = sigma * noise + (1 - sigma) * latents
704
+ else:
705
+ latents = noise
662
706
 
663
707
  if len(conditions) > 0:
664
708
  condition_latent_frames_mask = torch.zeros(
@@ -766,10 +810,21 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
766
810
 
767
811
  return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
768
812
 
813
+ def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength):
814
+ num_steps = min(int(num_inference_steps * strength), num_inference_steps)
815
+ start_index = max(num_inference_steps - num_steps, 0)
816
+ sigmas = sigmas[start_index:]
817
+ timesteps = timesteps[start_index:]
818
+ return sigmas, timesteps, num_inference_steps - start_index
819
+
769
820
  @property
770
821
  def guidance_scale(self):
771
822
  return self._guidance_scale
772
823
 
824
+ @property
825
+ def guidance_rescale(self):
826
+ return self._guidance_rescale
827
+
773
828
  @property
774
829
  def do_classifier_free_guidance(self):
775
830
  return self._guidance_scale > 1.0
@@ -799,6 +854,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
799
854
  video: List[PipelineImageInput] = None,
800
855
  frame_index: Union[int, List[int]] = 0,
801
856
  strength: Union[float, List[float]] = 1.0,
857
+ denoise_strength: float = 1.0,
802
858
  prompt: Union[str, List[str]] = None,
803
859
  negative_prompt: Optional[Union[str, List[str]]] = None,
804
860
  height: int = 512,
@@ -808,6 +864,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
808
864
  num_inference_steps: int = 50,
809
865
  timesteps: List[int] = None,
810
866
  guidance_scale: float = 3,
867
+ guidance_rescale: float = 0.0,
811
868
  image_cond_noise_scale: float = 0.15,
812
869
  num_videos_per_prompt: Optional[int] = 1,
813
870
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -842,6 +899,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
842
899
  generation. If not provided, one has to pass `conditions`.
843
900
  strength (`float` or `List[float]`, *optional*):
844
901
  The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
902
+ denoise_strength (`float`, defaults to `1.0`):
903
+ The strength of the noise added to the latents for editing. Higher strength leads to more noise added
904
+ to the latents, therefore leading to more differences between original video and generated video. This
905
+ is useful for video-to-video editing.
845
906
  prompt (`str` or `List[str]`, *optional*):
846
907
  The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
847
908
  instead.
@@ -859,11 +920,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
859
920
  in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
860
921
  passed will be used. Must be in descending order.
861
922
  guidance_scale (`float`, defaults to `3 `):
862
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
863
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
864
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
865
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
866
- usually at the expense of lower image quality.
923
+ Guidance scale as defined in [Classifier-Free Diffusion
924
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
925
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
926
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
927
+ the text `prompt`, usually at the expense of lower image quality.
928
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
929
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
930
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
931
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
932
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
867
933
  num_videos_per_prompt (`int`, *optional*, defaults to 1):
868
934
  The number of videos to generate per prompt.
869
935
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -918,8 +984,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
918
984
 
919
985
  if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
920
986
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
921
- if latents is not None:
922
- raise ValueError("Passing latents is not yet supported.")
923
987
 
924
988
  # 1. Check inputs. Raise error if not correct
925
989
  self.check_inputs(
@@ -929,6 +993,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
929
993
  video=video,
930
994
  frame_index=frame_index,
931
995
  strength=strength,
996
+ denoise_strength=denoise_strength,
932
997
  height=height,
933
998
  width=width,
934
999
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
@@ -939,6 +1004,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
939
1004
  )
940
1005
 
941
1006
  self._guidance_scale = guidance_scale
1007
+ self._guidance_rescale = guidance_rescale
942
1008
  self._attention_kwargs = attention_kwargs
943
1009
  self._interrupt = False
944
1010
  self._current_timestep = None
@@ -977,8 +1043,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
977
1043
  strength = [strength] * num_conditions
978
1044
 
979
1045
  device = self._execution_device
1046
+ vae_dtype = self.vae.dtype
980
1047
 
981
- # 3. Prepare text embeddings
1048
+ # 3. Prepare text embeddings & conditioning image/video
982
1049
  (
983
1050
  prompt_embeds,
984
1051
  prompt_attention_mask,
@@ -1000,8 +1067,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
1000
1067
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1001
1068
  prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
1002
1069
 
1003
- vae_dtype = self.vae.dtype
1004
-
1005
1070
  conditioning_tensors = []
1006
1071
  is_conditioning_image_or_video = image is not None or video is not None
1007
1072
  if is_conditioning_image_or_video:
@@ -1032,7 +1097,27 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
1032
1097
  )
1033
1098
  conditioning_tensors.append(condition_tensor)
1034
1099
 
1035
- # 4. Prepare latent variables
1100
+ # 4. Prepare timesteps
1101
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
1102
+ latent_height = height // self.vae_spatial_compression_ratio
1103
+ latent_width = width // self.vae_spatial_compression_ratio
1104
+ if timesteps is None:
1105
+ sigmas = linear_quadratic_schedule(num_inference_steps)
1106
+ timesteps = sigmas * 1000
1107
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1108
+ sigmas = self.scheduler.sigmas
1109
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1110
+
1111
+ latent_sigma = None
1112
+ if denoise_strength < 1:
1113
+ sigmas, timesteps, num_inference_steps = self.get_timesteps(
1114
+ sigmas, timesteps, num_inference_steps, denoise_strength
1115
+ )
1116
+ latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt)
1117
+
1118
+ self._num_timesteps = len(timesteps)
1119
+
1120
+ # 5. Prepare latent variables
1036
1121
  num_channels_latents = self.transformer.config.in_channels
1037
1122
  latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
1038
1123
  conditioning_tensors,
@@ -1043,6 +1128,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
1043
1128
  height=height,
1044
1129
  width=width,
1045
1130
  num_frames=num_frames,
1131
+ sigma=latent_sigma,
1132
+ latents=latents,
1046
1133
  generator=generator,
1047
1134
  device=device,
1048
1135
  dtype=torch.float32,
@@ -1056,21 +1143,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
1056
1143
  if self.do_classifier_free_guidance:
1057
1144
  video_coords = torch.cat([video_coords, video_coords], dim=0)
1058
1145
 
1059
- # 5. Prepare timesteps
1060
- latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
1061
- latent_height = height // self.vae_spatial_compression_ratio
1062
- latent_width = width // self.vae_spatial_compression_ratio
1063
- sigmas = linear_quadratic_schedule(num_inference_steps)
1064
- timesteps = sigmas * 1000
1065
- timesteps, num_inference_steps = retrieve_timesteps(
1066
- self.scheduler,
1067
- num_inference_steps,
1068
- device,
1069
- timesteps=timesteps,
1070
- )
1071
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1072
- self._num_timesteps = len(timesteps)
1073
-
1074
1146
  # 6. Denoising loop
1075
1147
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1076
1148
  for i, t in enumerate(timesteps):
@@ -1120,6 +1192,12 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
1120
1192
  noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1121
1193
  timestep, _ = timestep.chunk(2)
1122
1194
 
1195
+ if self.guidance_rescale > 0:
1196
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1197
+ noise_pred = rescale_noise_cfg(
1198
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
1199
+ )
1200
+
1123
1201
  denoised_latents = self.scheduler.step(
1124
1202
  -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
1125
1203
  )[0]
@@ -1168,7 +1246,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
1168
1246
  if not self.vae.config.timestep_conditioning:
1169
1247
  timestep = None
1170
1248
  else:
1171
- noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
1249
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
1172
1250
  if not isinstance(decode_timestep, list):
1173
1251
  decode_timestep = [decode_timestep] * batch_size
1174
1252
  if decode_noise_scale is None: