diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -191,7 +191,12 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
191
191
 
192
192
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
193
193
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
194
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
194
+ shape = (
195
+ batch_size,
196
+ num_channels_latents,
197
+ int(height) // self.vae_scale_factor,
198
+ int(width) // self.vae_scale_factor,
199
+ )
195
200
  if isinstance(generator, list) and len(generator) != batch_size:
196
201
  raise ValueError(
197
202
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -219,10 +224,10 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
219
224
  num_images_per_prompt: int = 1,
220
225
  eta: float = 0.0,
221
226
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
222
- latents: Optional[torch.FloatTensor] = None,
227
+ latents: Optional[torch.Tensor] = None,
223
228
  output_type: Optional[str] = "pil",
224
229
  return_dict: bool = True,
225
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
230
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
226
231
  callback_steps: int = 1,
227
232
  editing_prompt: Optional[Union[str, List[str]]] = None,
228
233
  editing_prompt_embeddings: Optional[torch.Tensor] = None,
@@ -263,7 +268,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
263
268
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
264
269
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
265
270
  generation deterministic.
266
- latents (`torch.FloatTensor`, *optional*):
271
+ latents (`torch.Tensor`, *optional*):
267
272
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
268
273
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
269
274
  tensor is generated by sampling using the supplied random `generator`.
@@ -274,7 +279,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
274
279
  plain tuple.
275
280
  callback (`Callable`, *optional*):
276
281
  A function that calls every `callback_steps` steps during inference. The function is called with the
277
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
282
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
278
283
  callback_steps (`int`, *optional*, defaults to 1):
279
284
  The frequency at which the `callback` function is called. If not specified, the callback is called at
280
285
  every step.
@@ -371,6 +376,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
371
376
 
372
377
  # 2. Define call parameters
373
378
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
379
+ device = self._execution_device
374
380
 
375
381
  if editing_prompt:
376
382
  enable_edit_guidance = True
@@ -400,7 +406,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
400
406
  f" {self.tokenizer.model_max_length} tokens: {removed_text}"
401
407
  )
402
408
  text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
403
- text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
409
+ text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
404
410
 
405
411
  # duplicate text embeddings for each generation per prompt, using mps friendly method
406
412
  bs_embed, seq_len, _ = text_embeddings.shape
@@ -428,9 +434,9 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
428
434
  f" {self.tokenizer.model_max_length} tokens: {removed_text}"
429
435
  )
430
436
  edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length]
431
- edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0]
437
+ edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0]
432
438
  else:
433
- edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1)
439
+ edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1)
434
440
 
435
441
  # duplicate text embeddings for each generation per prompt, using mps friendly method
436
442
  bs_embed_edit, seq_len_edit, _ = edit_concepts.shape
@@ -471,7 +477,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
471
477
  truncation=True,
472
478
  return_tensors="pt",
473
479
  )
474
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
480
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
475
481
 
476
482
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
477
483
  seq_len = uncond_embeddings.shape[1]
@@ -488,7 +494,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
488
494
  # get the initial random noise unless the user supplied it
489
495
 
490
496
  # 4. Prepare timesteps
491
- self.scheduler.set_timesteps(num_inference_steps, device=self.device)
497
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
492
498
  timesteps = self.scheduler.timesteps
493
499
 
494
500
  # 5. Prepare latent variables
@@ -499,7 +505,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
499
505
  height,
500
506
  width,
501
507
  text_embeddings.dtype,
502
- self.device,
508
+ device,
503
509
  generator,
504
510
  latents,
505
511
  )
@@ -557,12 +563,12 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
557
563
  if enable_edit_guidance:
558
564
  concept_weights = torch.zeros(
559
565
  (len(noise_pred_edit_concepts), noise_guidance.shape[0]),
560
- device=self.device,
566
+ device=device,
561
567
  dtype=noise_guidance.dtype,
562
568
  )
563
569
  noise_guidance_edit = torch.zeros(
564
570
  (len(noise_pred_edit_concepts), *noise_guidance.shape),
565
- device=self.device,
571
+ device=device,
566
572
  dtype=noise_guidance.dtype,
567
573
  )
568
574
  # noise_guidance_edit = torch.zeros_like(noise_guidance)
@@ -639,33 +645,30 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
639
645
 
640
646
  # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
641
647
 
642
- warmup_inds = torch.tensor(warmup_inds).to(self.device)
648
+ warmup_inds = torch.tensor(warmup_inds).to(device)
643
649
  if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
644
650
  concept_weights = concept_weights.to("cpu") # Offload to cpu
645
651
  noise_guidance_edit = noise_guidance_edit.to("cpu")
646
652
 
647
- concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds)
653
+ concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds)
648
654
  concept_weights_tmp = torch.where(
649
655
  concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
650
656
  )
651
657
  concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
652
658
  # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
653
659
 
654
- noise_guidance_edit_tmp = torch.index_select(
655
- noise_guidance_edit.to(self.device), 0, warmup_inds
656
- )
660
+ noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
657
661
  noise_guidance_edit_tmp = torch.einsum(
658
662
  "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
659
663
  )
660
- noise_guidance_edit_tmp = noise_guidance_edit_tmp
661
664
  noise_guidance = noise_guidance + noise_guidance_edit_tmp
662
665
 
663
666
  self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu()
664
667
 
665
668
  del noise_guidance_edit_tmp
666
669
  del concept_weights_tmp
667
- concept_weights = concept_weights.to(self.device)
668
- noise_guidance_edit = noise_guidance_edit.to(self.device)
670
+ concept_weights = concept_weights.to(device)
671
+ noise_guidance_edit = noise_guidance_edit.to(device)
669
672
 
670
673
  concept_weights = torch.where(
671
674
  concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
@@ -674,6 +677,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
674
677
  concept_weights = torch.nan_to_num(concept_weights)
675
678
 
676
679
  noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
680
+ noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device)
677
681
 
678
682
  noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
679
683
 
@@ -684,7 +688,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
684
688
  self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
685
689
 
686
690
  if sem_guidance is not None:
687
- edit_guidance = sem_guidance[i].to(self.device)
691
+ edit_guidance = sem_guidance[i].to(device)
688
692
  noise_guidance = noise_guidance + edit_guidance
689
693
 
690
694
  noise_pred = noise_pred_uncond + noise_guidance
@@ -700,7 +704,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
700
704
  # 8. Post-processing
701
705
  if not output_type == "latent":
702
706
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
703
- image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
707
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
704
708
  else:
705
709
  image = latents
706
710
  has_nsfw_concept = None
@@ -69,7 +69,7 @@ class ShapEPipelineOutput(BaseOutput):
69
69
  Output class for [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`].
70
70
 
71
71
  Args:
72
- images (`torch.FloatTensor`)
72
+ images (`torch.Tensor`)
73
73
  A list of images for 3D rendering.
74
74
  """
75
75
 
@@ -187,7 +187,7 @@ class ShapEPipeline(DiffusionPipeline):
187
187
  num_images_per_prompt: int = 1,
188
188
  num_inference_steps: int = 25,
189
189
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
190
- latents: Optional[torch.FloatTensor] = None,
190
+ latents: Optional[torch.Tensor] = None,
191
191
  guidance_scale: float = 4.0,
192
192
  frame_size: int = 64,
193
193
  output_type: Optional[str] = "pil", # pil, np, latent, mesh
@@ -207,7 +207,7 @@ class ShapEPipeline(DiffusionPipeline):
207
207
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
208
208
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
209
209
  generation deterministic.
210
- latents (`torch.FloatTensor`, *optional*):
210
+ latents (`torch.Tensor`, *optional*):
211
211
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
212
212
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
213
213
  tensor is generated by sampling using the supplied random `generator`.
@@ -70,7 +70,7 @@ class ShapEPipelineOutput(BaseOutput):
70
70
  Output class for [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`].
71
71
 
72
72
  Args:
73
- images (`torch.FloatTensor`)
73
+ images (`torch.Tensor`)
74
74
  A list of images for 3D rendering.
75
75
  """
76
76
 
@@ -86,7 +86,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
86
86
 
87
87
  Args:
88
88
  prior ([`PriorTransformer`]):
89
- The canonincal unCLIP prior to approximate the image embedding from the text embedding.
89
+ The canonical unCLIP prior to approximate the image embedding from the text embedding.
90
90
  image_encoder ([`~transformers.CLIPVisionModel`]):
91
91
  Frozen image-encoder.
92
92
  image_processor ([`~transformers.CLIPImageProcessor`]):
@@ -169,7 +169,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
169
169
  num_images_per_prompt: int = 1,
170
170
  num_inference_steps: int = 25,
171
171
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
172
- latents: Optional[torch.FloatTensor] = None,
172
+ latents: Optional[torch.Tensor] = None,
173
173
  guidance_scale: float = 4.0,
174
174
  frame_size: int = 64,
175
175
  output_type: Optional[str] = "pil", # pil, np, latent, mesh
@@ -179,7 +179,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
179
179
  The call function to the pipeline for generation.
180
180
 
181
181
  Args:
182
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
182
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
183
183
  `Image` or tensor representing an image batch to be used as the starting point. Can also accept image
184
184
  latents as image, but if passing latents directly it is not encoded again.
185
185
  num_images_per_prompt (`int`, *optional*, defaults to 1):
@@ -190,7 +190,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
190
190
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
191
191
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
192
192
  generation deterministic.
193
- latents (`torch.FloatTensor`, *optional*):
193
+ latents (`torch.Tensor`, *optional*):
194
194
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
195
195
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
196
196
  tensor is generated by sampling using the supplied random `generator`.
@@ -239,15 +239,15 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
239
239
 
240
240
  num_embeddings = self.prior.config.num_embeddings
241
241
  embedding_dim = self.prior.config.embedding_dim
242
-
243
- latents = self.prepare_latents(
244
- (batch_size, num_embeddings * embedding_dim),
245
- image_embeds.dtype,
246
- device,
247
- generator,
248
- latents,
249
- self.scheduler,
250
- )
242
+ if latents is None:
243
+ latents = self.prepare_latents(
244
+ (batch_size, num_embeddings * embedding_dim),
245
+ image_embeds.dtype,
246
+ device,
247
+ generator,
248
+ latents,
249
+ self.scheduler,
250
+ )
251
251
 
252
252
  # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
253
253
  latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
@@ -844,7 +844,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
844
844
  transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1],
845
845
  math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
846
846
 
847
- args:
847
+ Args:
848
848
  rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples:
849
849
  number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including
850
850
 
@@ -0,0 +1,50 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ is_transformers_version,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _import_structure = {}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["modeling_stable_audio"] = ["StableAudioProjectionModel"]
26
+ _import_structure["pipeline_stable_audio"] = ["StableAudioPipeline"]
27
+
28
+
29
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
30
+ try:
31
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
32
+ raise OptionalDependencyNotAvailable()
33
+ except OptionalDependencyNotAvailable:
34
+ from ...utils.dummy_torch_and_transformers_objects import *
35
+
36
+ else:
37
+ from .modeling_stable_audio import StableAudioProjectionModel
38
+ from .pipeline_stable_audio import StableAudioPipeline
39
+
40
+ else:
41
+ import sys
42
+
43
+ sys.modules[__name__] = _LazyModule(
44
+ __name__,
45
+ globals()["__file__"],
46
+ _import_structure,
47
+ module_spec=__spec__,
48
+ )
49
+ for name, value in _dummy_objects.items():
50
+ setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,158 @@
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from math import pi
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.utils.checkpoint
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...models.modeling_utils import ModelMixin
25
+ from ...utils import BaseOutput, logging
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ class StableAudioPositionalEmbedding(nn.Module):
32
+ """Used for continuous time"""
33
+
34
+ def __init__(self, dim: int):
35
+ super().__init__()
36
+ assert (dim % 2) == 0
37
+ half_dim = dim // 2
38
+ self.weights = nn.Parameter(torch.randn(half_dim))
39
+
40
+ def forward(self, times: torch.Tensor) -> torch.Tensor:
41
+ times = times[..., None]
42
+ freqs = times * self.weights[None] * 2 * pi
43
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
44
+ fouriered = torch.cat((times, fouriered), dim=-1)
45
+ return fouriered
46
+
47
+
48
+ @dataclass
49
+ class StableAudioProjectionModelOutput(BaseOutput):
50
+ """
51
+ Args:
52
+ Class for StableAudio projection layer's outputs.
53
+ text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
54
+ Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder.
55
+ seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*):
56
+ Sequence of hidden-states obtained by linearly projecting the audio start hidden states.
57
+ seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*):
58
+ Sequence of hidden-states obtained by linearly projecting the audio end hidden states.
59
+ """
60
+
61
+ text_hidden_states: Optional[torch.Tensor] = None
62
+ seconds_start_hidden_states: Optional[torch.Tensor] = None
63
+ seconds_end_hidden_states: Optional[torch.Tensor] = None
64
+
65
+
66
+ class StableAudioNumberConditioner(nn.Module):
67
+ """
68
+ A simple linear projection model to map numbers to a latent space.
69
+
70
+ Args:
71
+ number_embedding_dim (`int`):
72
+ Dimensionality of the number embeddings.
73
+ min_value (`int`):
74
+ The minimum value of the seconds number conditioning modules.
75
+ max_value (`int`):
76
+ The maximum value of the seconds number conditioning modules
77
+ internal_dim (`int`):
78
+ Dimensionality of the intermediate number hidden states.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ number_embedding_dim,
84
+ min_value,
85
+ max_value,
86
+ internal_dim: Optional[int] = 256,
87
+ ):
88
+ super().__init__()
89
+ self.time_positional_embedding = nn.Sequential(
90
+ StableAudioPositionalEmbedding(internal_dim),
91
+ nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
92
+ )
93
+
94
+ self.number_embedding_dim = number_embedding_dim
95
+ self.min_value = min_value
96
+ self.max_value = max_value
97
+
98
+ def forward(
99
+ self,
100
+ floats: torch.Tensor,
101
+ ):
102
+ floats = floats.clamp(self.min_value, self.max_value)
103
+
104
+ normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value)
105
+
106
+ # Cast floats to same type as embedder
107
+ embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
108
+ normalized_floats = normalized_floats.to(embedder_dtype)
109
+
110
+ embedding = self.time_positional_embedding(normalized_floats)
111
+ float_embeds = embedding.view(-1, 1, self.number_embedding_dim)
112
+
113
+ return float_embeds
114
+
115
+
116
+ class StableAudioProjectionModel(ModelMixin, ConfigMixin):
117
+ """
118
+ A simple linear projection model to map the conditioning values to a shared latent space.
119
+
120
+ Args:
121
+ text_encoder_dim (`int`):
122
+ Dimensionality of the text embeddings from the text encoder (T5).
123
+ conditioning_dim (`int`):
124
+ Dimensionality of the output conditioning tensors.
125
+ min_value (`int`):
126
+ The minimum value of the seconds number conditioning modules.
127
+ max_value (`int`):
128
+ The maximum value of the seconds number conditioning modules
129
+ """
130
+
131
+ @register_to_config
132
+ def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value):
133
+ super().__init__()
134
+ self.text_projection = (
135
+ nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim)
136
+ )
137
+ self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value)
138
+ self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value)
139
+
140
+ def forward(
141
+ self,
142
+ text_hidden_states: Optional[torch.Tensor] = None,
143
+ start_seconds: Optional[torch.Tensor] = None,
144
+ end_seconds: Optional[torch.Tensor] = None,
145
+ ):
146
+ text_hidden_states = (
147
+ text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states)
148
+ )
149
+ seconds_start_hidden_states = (
150
+ start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds)
151
+ )
152
+ seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds)
153
+
154
+ return StableAudioProjectionModelOutput(
155
+ text_hidden_states=text_hidden_states,
156
+ seconds_start_hidden_states=seconds_start_hidden_states,
157
+ seconds_end_hidden_states=seconds_end_hidden_states,
158
+ )