diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,425 @@
1
+ # coding=utf-8
2
+ # Copyright 2025 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import io
17
+ import json
18
+ from typing import List, Literal, Optional, Union, cast
19
+
20
+ import requests
21
+
22
+ from .deprecation_utils import deprecate
23
+ from .import_utils import is_safetensors_available, is_torch_available
24
+
25
+
26
+ if is_torch_available():
27
+ import torch
28
+
29
+ from ..image_processor import VaeImageProcessor
30
+ from ..video_processor import VideoProcessor
31
+
32
+ if is_safetensors_available():
33
+ import safetensors.torch
34
+
35
+ DTYPE_MAP = {
36
+ "float16": torch.float16,
37
+ "float32": torch.float32,
38
+ "bfloat16": torch.bfloat16,
39
+ "uint8": torch.uint8,
40
+ }
41
+
42
+
43
+ from PIL import Image
44
+
45
+
46
+ def detect_image_type(data: bytes) -> str:
47
+ if data.startswith(b"\xff\xd8"):
48
+ return "jpeg"
49
+ elif data.startswith(b"\x89PNG\r\n\x1a\n"):
50
+ return "png"
51
+ elif data.startswith(b"GIF87a") or data.startswith(b"GIF89a"):
52
+ return "gif"
53
+ elif data.startswith(b"BM"):
54
+ return "bmp"
55
+ return "unknown"
56
+
57
+
58
+ def check_inputs_decode(
59
+ endpoint: str,
60
+ tensor: "torch.Tensor",
61
+ processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
62
+ do_scaling: bool = True,
63
+ scaling_factor: Optional[float] = None,
64
+ shift_factor: Optional[float] = None,
65
+ output_type: Literal["mp4", "pil", "pt"] = "pil",
66
+ return_type: Literal["mp4", "pil", "pt"] = "pil",
67
+ image_format: Literal["png", "jpg"] = "jpg",
68
+ partial_postprocess: bool = False,
69
+ input_tensor_type: Literal["binary"] = "binary",
70
+ output_tensor_type: Literal["binary"] = "binary",
71
+ height: Optional[int] = None,
72
+ width: Optional[int] = None,
73
+ ):
74
+ if tensor.ndim == 3 and height is None and width is None:
75
+ raise ValueError("`height` and `width` required for packed latents.")
76
+ if (
77
+ output_type == "pt"
78
+ and return_type == "pil"
79
+ and not partial_postprocess
80
+ and not isinstance(processor, (VaeImageProcessor, VideoProcessor))
81
+ ):
82
+ raise ValueError("`processor` is required.")
83
+ if do_scaling and scaling_factor is None:
84
+ deprecate(
85
+ "do_scaling",
86
+ "1.0.0",
87
+ "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
88
+ standard_warn=False,
89
+ )
90
+
91
+
92
+ def postprocess_decode(
93
+ response: requests.Response,
94
+ processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
95
+ output_type: Literal["mp4", "pil", "pt"] = "pil",
96
+ return_type: Literal["mp4", "pil", "pt"] = "pil",
97
+ partial_postprocess: bool = False,
98
+ ):
99
+ if output_type == "pt" or (output_type == "pil" and processor is not None):
100
+ output_tensor = response.content
101
+ parameters = response.headers
102
+ shape = json.loads(parameters["shape"])
103
+ dtype = parameters["dtype"]
104
+ torch_dtype = DTYPE_MAP[dtype]
105
+ output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
106
+ if output_type == "pt":
107
+ if partial_postprocess:
108
+ if return_type == "pil":
109
+ output = [Image.fromarray(image.numpy()) for image in output_tensor]
110
+ if len(output) == 1:
111
+ output = output[0]
112
+ elif return_type == "pt":
113
+ output = output_tensor
114
+ else:
115
+ if processor is None or return_type == "pt":
116
+ output = output_tensor
117
+ else:
118
+ if isinstance(processor, VideoProcessor):
119
+ output = cast(
120
+ List[Image.Image],
121
+ processor.postprocess_video(output_tensor, output_type="pil")[0],
122
+ )
123
+ else:
124
+ output = cast(
125
+ Image.Image,
126
+ processor.postprocess(output_tensor, output_type="pil")[0],
127
+ )
128
+ elif output_type == "pil" and return_type == "pil" and processor is None:
129
+ output = Image.open(io.BytesIO(response.content)).convert("RGB")
130
+ detected_format = detect_image_type(response.content)
131
+ output.format = detected_format
132
+ elif output_type == "pil" and processor is not None:
133
+ if return_type == "pil":
134
+ output = [
135
+ Image.fromarray(image)
136
+ for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
137
+ ]
138
+ elif return_type == "pt":
139
+ output = output_tensor
140
+ elif output_type == "mp4" and return_type == "mp4":
141
+ output = response.content
142
+ return output
143
+
144
+
145
+ def prepare_decode(
146
+ tensor: "torch.Tensor",
147
+ processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
148
+ do_scaling: bool = True,
149
+ scaling_factor: Optional[float] = None,
150
+ shift_factor: Optional[float] = None,
151
+ output_type: Literal["mp4", "pil", "pt"] = "pil",
152
+ image_format: Literal["png", "jpg"] = "jpg",
153
+ partial_postprocess: bool = False,
154
+ height: Optional[int] = None,
155
+ width: Optional[int] = None,
156
+ ):
157
+ headers = {}
158
+ parameters = {
159
+ "image_format": image_format,
160
+ "output_type": output_type,
161
+ "partial_postprocess": partial_postprocess,
162
+ "shape": list(tensor.shape),
163
+ "dtype": str(tensor.dtype).split(".")[-1],
164
+ }
165
+ if do_scaling and scaling_factor is not None:
166
+ parameters["scaling_factor"] = scaling_factor
167
+ if do_scaling and shift_factor is not None:
168
+ parameters["shift_factor"] = shift_factor
169
+ if do_scaling and scaling_factor is None:
170
+ parameters["do_scaling"] = do_scaling
171
+ elif do_scaling and scaling_factor is None and shift_factor is None:
172
+ parameters["do_scaling"] = do_scaling
173
+ if height is not None and width is not None:
174
+ parameters["height"] = height
175
+ parameters["width"] = width
176
+ headers["Content-Type"] = "tensor/binary"
177
+ headers["Accept"] = "tensor/binary"
178
+ if output_type == "pil" and image_format == "jpg" and processor is None:
179
+ headers["Accept"] = "image/jpeg"
180
+ elif output_type == "pil" and image_format == "png" and processor is None:
181
+ headers["Accept"] = "image/png"
182
+ elif output_type == "mp4":
183
+ headers["Accept"] = "text/plain"
184
+ tensor_data = safetensors.torch._tobytes(tensor, "tensor")
185
+ return {"data": tensor_data, "params": parameters, "headers": headers}
186
+
187
+
188
+ def remote_decode(
189
+ endpoint: str,
190
+ tensor: "torch.Tensor",
191
+ processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
192
+ do_scaling: bool = True,
193
+ scaling_factor: Optional[float] = None,
194
+ shift_factor: Optional[float] = None,
195
+ output_type: Literal["mp4", "pil", "pt"] = "pil",
196
+ return_type: Literal["mp4", "pil", "pt"] = "pil",
197
+ image_format: Literal["png", "jpg"] = "jpg",
198
+ partial_postprocess: bool = False,
199
+ input_tensor_type: Literal["binary"] = "binary",
200
+ output_tensor_type: Literal["binary"] = "binary",
201
+ height: Optional[int] = None,
202
+ width: Optional[int] = None,
203
+ ) -> Union[Image.Image, List[Image.Image], bytes, "torch.Tensor"]:
204
+ """
205
+ Hugging Face Hybrid Inference that allow running VAE decode remotely.
206
+
207
+ Args:
208
+ endpoint (`str`):
209
+ Endpoint for Remote Decode.
210
+ tensor (`torch.Tensor`):
211
+ Tensor to be decoded.
212
+ processor (`VaeImageProcessor` or `VideoProcessor`, *optional*):
213
+ Used with `return_type="pt"`, and `return_type="pil"` for Video models.
214
+ do_scaling (`bool`, default `True`, *optional*):
215
+ **DEPRECATED**. **pass `scaling_factor`/`shift_factor` instead.** **still set
216
+ do_scaling=None/do_scaling=False for no scaling until option is removed** When `True` scaling e.g. `latents
217
+ / self.vae.config.scaling_factor` is applied remotely. If `False`, input must be passed with scaling
218
+ applied.
219
+ scaling_factor (`float`, *optional*):
220
+ Scaling is applied when passed e.g. [`latents /
221
+ self.vae.config.scaling_factor`](https://github.com/huggingface/diffusers/blob/7007febae5cff000d4df9059d9cf35133e8b2ca9/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L1083C37-L1083C77).
222
+ - SD v1: 0.18215
223
+ - SD XL: 0.13025
224
+ - Flux: 0.3611
225
+ If `None`, input must be passed with scaling applied.
226
+ shift_factor (`float`, *optional*):
227
+ Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`.
228
+ - Flux: 0.1159
229
+ If `None`, input must be passed with scaling applied.
230
+ output_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
231
+ **Endpoint** output type. Subject to change. Report feedback on preferred type.
232
+
233
+ `"mp4": Supported by video models. Endpoint returns `bytes` of video. `"pil"`: Supported by image and video
234
+ models.
235
+ Image models: Endpoint returns `bytes` of an image in `image_format`. Video models: Endpoint returns
236
+ `torch.Tensor` with partial `postprocessing` applied.
237
+ Requires `processor` as a flag (any `None` value will work).
238
+ `"pt"`: Support by image and video models. Endpoint returns `torch.Tensor`.
239
+ With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor.
240
+
241
+ Recommendations:
242
+ `"pt"` with `partial_postprocess=True` is the smallest transfer for full quality. `"pt"` with
243
+ `partial_postprocess=False` is the most compatible with third party code. `"pil"` with
244
+ `image_format="jpg"` is the smallest transfer overall.
245
+
246
+ return_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
247
+ **Function** return type.
248
+
249
+ `"mp4": Function returns `bytes` of video. `"pil"`: Function returns `PIL.Image.Image`.
250
+ With `output_type="pil" no further processing is applied. With `output_type="pt" a `PIL.Image.Image` is
251
+ created.
252
+ `partial_postprocess=False` `processor` is required. `partial_postprocess=True` `processor` is
253
+ **not** required.
254
+ `"pt"`: Function returns `torch.Tensor`.
255
+ `processor` is **not** required. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
256
+ denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.
257
+
258
+ image_format (`"png"` or `"jpg"`, default `jpg`):
259
+ Used with `output_type="pil"`. Endpoint returns `jpg` or `png`.
260
+
261
+ partial_postprocess (`bool`, default `False`):
262
+ Used with `output_type="pt"`. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
263
+ denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.
264
+
265
+ input_tensor_type (`"binary"`, default `"binary"`):
266
+ Tensor transfer type.
267
+
268
+ output_tensor_type (`"binary"`, default `"binary"`):
269
+ Tensor transfer type.
270
+
271
+ height (`int`, **optional**):
272
+ Required for `"packed"` latents.
273
+
274
+ width (`int`, **optional**):
275
+ Required for `"packed"` latents.
276
+
277
+ Returns:
278
+ output (`Image.Image` or `List[Image.Image]` or `bytes` or `torch.Tensor`).
279
+ """
280
+ if input_tensor_type == "base64":
281
+ deprecate(
282
+ "input_tensor_type='base64'",
283
+ "1.0.0",
284
+ "input_tensor_type='base64' is deprecated. Using `binary`.",
285
+ standard_warn=False,
286
+ )
287
+ input_tensor_type = "binary"
288
+ if output_tensor_type == "base64":
289
+ deprecate(
290
+ "output_tensor_type='base64'",
291
+ "1.0.0",
292
+ "output_tensor_type='base64' is deprecated. Using `binary`.",
293
+ standard_warn=False,
294
+ )
295
+ output_tensor_type = "binary"
296
+ check_inputs_decode(
297
+ endpoint,
298
+ tensor,
299
+ processor,
300
+ do_scaling,
301
+ scaling_factor,
302
+ shift_factor,
303
+ output_type,
304
+ return_type,
305
+ image_format,
306
+ partial_postprocess,
307
+ input_tensor_type,
308
+ output_tensor_type,
309
+ height,
310
+ width,
311
+ )
312
+ kwargs = prepare_decode(
313
+ tensor=tensor,
314
+ processor=processor,
315
+ do_scaling=do_scaling,
316
+ scaling_factor=scaling_factor,
317
+ shift_factor=shift_factor,
318
+ output_type=output_type,
319
+ image_format=image_format,
320
+ partial_postprocess=partial_postprocess,
321
+ height=height,
322
+ width=width,
323
+ )
324
+ response = requests.post(endpoint, **kwargs)
325
+ if not response.ok:
326
+ raise RuntimeError(response.json())
327
+ output = postprocess_decode(
328
+ response=response,
329
+ processor=processor,
330
+ output_type=output_type,
331
+ return_type=return_type,
332
+ partial_postprocess=partial_postprocess,
333
+ )
334
+ return output
335
+
336
+
337
+ def check_inputs_encode(
338
+ endpoint: str,
339
+ image: Union["torch.Tensor", Image.Image],
340
+ scaling_factor: Optional[float] = None,
341
+ shift_factor: Optional[float] = None,
342
+ ):
343
+ pass
344
+
345
+
346
+ def postprocess_encode(
347
+ response: requests.Response,
348
+ ):
349
+ output_tensor = response.content
350
+ parameters = response.headers
351
+ shape = json.loads(parameters["shape"])
352
+ dtype = parameters["dtype"]
353
+ torch_dtype = DTYPE_MAP[dtype]
354
+ output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
355
+ return output_tensor
356
+
357
+
358
+ def prepare_encode(
359
+ image: Union["torch.Tensor", Image.Image],
360
+ scaling_factor: Optional[float] = None,
361
+ shift_factor: Optional[float] = None,
362
+ ):
363
+ headers = {}
364
+ parameters = {}
365
+ if scaling_factor is not None:
366
+ parameters["scaling_factor"] = scaling_factor
367
+ if shift_factor is not None:
368
+ parameters["shift_factor"] = shift_factor
369
+ if isinstance(image, torch.Tensor):
370
+ data = safetensors.torch._tobytes(image.contiguous(), "tensor")
371
+ parameters["shape"] = list(image.shape)
372
+ parameters["dtype"] = str(image.dtype).split(".")[-1]
373
+ else:
374
+ buffer = io.BytesIO()
375
+ image.save(buffer, format="PNG")
376
+ data = buffer.getvalue()
377
+ return {"data": data, "params": parameters, "headers": headers}
378
+
379
+
380
+ def remote_encode(
381
+ endpoint: str,
382
+ image: Union["torch.Tensor", Image.Image],
383
+ scaling_factor: Optional[float] = None,
384
+ shift_factor: Optional[float] = None,
385
+ ) -> "torch.Tensor":
386
+ """
387
+ Hugging Face Hybrid Inference that allow running VAE encode remotely.
388
+
389
+ Args:
390
+ endpoint (`str`):
391
+ Endpoint for Remote Decode.
392
+ image (`torch.Tensor` or `PIL.Image.Image`):
393
+ Image to be encoded.
394
+ scaling_factor (`float`, *optional*):
395
+ Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`].
396
+ - SD v1: 0.18215
397
+ - SD XL: 0.13025
398
+ - Flux: 0.3611
399
+ If `None`, input must be passed with scaling applied.
400
+ shift_factor (`float`, *optional*):
401
+ Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`.
402
+ - Flux: 0.1159
403
+ If `None`, input must be passed with scaling applied.
404
+
405
+ Returns:
406
+ output (`torch.Tensor`).
407
+ """
408
+ check_inputs_encode(
409
+ endpoint,
410
+ image,
411
+ scaling_factor,
412
+ shift_factor,
413
+ )
414
+ kwargs = prepare_encode(
415
+ image=image,
416
+ scaling_factor=scaling_factor,
417
+ shift_factor=shift_factor,
418
+ )
419
+ response = requests.post(endpoint, **kwargs)
420
+ if not response.ok:
421
+ raise RuntimeError(response.json())
422
+ output = postprocess_encode(
423
+ response=response,
424
+ )
425
+ return output
@@ -0,0 +1,52 @@
1
+ import ast
2
+ import importlib
3
+ import inspect
4
+ import textwrap
5
+
6
+
7
+ class ReturnNameVisitor(ast.NodeVisitor):
8
+ """Thanks to ChatGPT for pairing."""
9
+
10
+ def __init__(self):
11
+ self.return_names = []
12
+
13
+ def visit_Return(self, node):
14
+ # Check if the return value is a tuple.
15
+ if isinstance(node.value, ast.Tuple):
16
+ for elt in node.value.elts:
17
+ if isinstance(elt, ast.Name):
18
+ self.return_names.append(elt.id)
19
+ else:
20
+ try:
21
+ self.return_names.append(ast.unparse(elt))
22
+ except Exception:
23
+ self.return_names.append(str(elt))
24
+ else:
25
+ if isinstance(node.value, ast.Name):
26
+ self.return_names.append(node.value.id)
27
+ else:
28
+ try:
29
+ self.return_names.append(ast.unparse(node.value))
30
+ except Exception:
31
+ self.return_names.append(str(node.value))
32
+ self.generic_visit(node)
33
+
34
+ def _determine_parent_module(self, cls):
35
+ from diffusers import DiffusionPipeline
36
+ from diffusers.models.modeling_utils import ModelMixin
37
+
38
+ if issubclass(cls, DiffusionPipeline):
39
+ return "pipelines"
40
+ elif issubclass(cls, ModelMixin):
41
+ return "models"
42
+ else:
43
+ raise NotImplementedError
44
+
45
+ def get_ast_tree(self, cls, attribute_name="encode_prompt"):
46
+ parent_module_name = self._determine_parent_module(cls)
47
+ main_module = importlib.import_module(f"diffusers.{parent_module_name}")
48
+ current_cls_module = getattr(main_module, cls.__name__)
49
+ source_code = inspect.getsource(getattr(current_cls_module, attribute_name))
50
+ source_code = textwrap.dedent(source_code)
51
+ tree = ast.parse(source_code)
52
+ return tree
@@ -17,9 +17,14 @@ State dict utilities: utility methods for converting state dicts easily
17
17
 
18
18
  import enum
19
19
 
20
+ from .import_utils import is_torch_available
20
21
  from .logging import get_logger
21
22
 
22
23
 
24
+ if is_torch_available():
25
+ import torch
26
+
27
+
23
28
  logger = get_logger(__name__)
24
29
 
25
30
 
@@ -329,7 +334,16 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
329
334
  kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
330
335
  kohya_ss_state_dict[kohya_key] = weight
331
336
  if "lora_down" in kohya_key:
332
- alpha_key = f'{kohya_key.split(".")[0]}.alpha'
337
+ alpha_key = f"{kohya_key.split('.')[0]}.alpha"
333
338
  kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
334
339
 
335
340
  return kohya_ss_state_dict
341
+
342
+
343
+ def state_dict_all_zero(state_dict, filter_str=None):
344
+ if filter_str is not None:
345
+ if isinstance(filter_str, str):
346
+ filter_str = [filter_str]
347
+ state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
348
+
349
+ return all(torch.all(param == 0).item() for param in state_dict.values())