diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,183 @@
1
+ import os
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
8
+ from ...models.modeling_utils import ModelMixin
9
+ from ...utils import logging
10
+
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class MultiControlNetModel(ModelMixin):
16
+ r"""
17
+ Multiple `ControlNetModel` wrapper class for Multi-ControlNet
18
+
19
+ This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
20
+ compatible with `ControlNetModel`.
21
+
22
+ Args:
23
+ controlnets (`List[ControlNetModel]`):
24
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
25
+ `ControlNetModel` as a list.
26
+ """
27
+
28
+ def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
29
+ super().__init__()
30
+ self.nets = nn.ModuleList(controlnets)
31
+
32
+ def forward(
33
+ self,
34
+ sample: torch.Tensor,
35
+ timestep: Union[torch.Tensor, float, int],
36
+ encoder_hidden_states: torch.Tensor,
37
+ controlnet_cond: List[torch.tensor],
38
+ conditioning_scale: List[float],
39
+ class_labels: Optional[torch.Tensor] = None,
40
+ timestep_cond: Optional[torch.Tensor] = None,
41
+ attention_mask: Optional[torch.Tensor] = None,
42
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
43
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
44
+ guess_mode: bool = False,
45
+ return_dict: bool = True,
46
+ ) -> Union[ControlNetOutput, Tuple]:
47
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
48
+ down_samples, mid_sample = controlnet(
49
+ sample=sample,
50
+ timestep=timestep,
51
+ encoder_hidden_states=encoder_hidden_states,
52
+ controlnet_cond=image,
53
+ conditioning_scale=scale,
54
+ class_labels=class_labels,
55
+ timestep_cond=timestep_cond,
56
+ attention_mask=attention_mask,
57
+ added_cond_kwargs=added_cond_kwargs,
58
+ cross_attention_kwargs=cross_attention_kwargs,
59
+ guess_mode=guess_mode,
60
+ return_dict=return_dict,
61
+ )
62
+
63
+ # merge samples
64
+ if i == 0:
65
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
66
+ else:
67
+ down_block_res_samples = [
68
+ samples_prev + samples_curr
69
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
70
+ ]
71
+ mid_block_res_sample += mid_sample
72
+
73
+ return down_block_res_samples, mid_block_res_sample
74
+
75
+ def save_pretrained(
76
+ self,
77
+ save_directory: Union[str, os.PathLike],
78
+ is_main_process: bool = True,
79
+ save_function: Callable = None,
80
+ safe_serialization: bool = True,
81
+ variant: Optional[str] = None,
82
+ ):
83
+ """
84
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
85
+ `[`~models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained`]` class method.
86
+
87
+ Arguments:
88
+ save_directory (`str` or `os.PathLike`):
89
+ Directory to which to save. Will be created if it doesn't exist.
90
+ is_main_process (`bool`, *optional*, defaults to `True`):
91
+ Whether the process calling this is the main process or not. Useful when in distributed training like
92
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
93
+ the main process to avoid race conditions.
94
+ save_function (`Callable`):
95
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
96
+ need to replace `torch.save` by another method. Can be configured with the environment variable
97
+ `DIFFUSERS_SAVE_MODE`.
98
+ safe_serialization (`bool`, *optional*, defaults to `True`):
99
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
100
+ variant (`str`, *optional*):
101
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
102
+ """
103
+ for idx, controlnet in enumerate(self.nets):
104
+ suffix = "" if idx == 0 else f"_{idx}"
105
+ controlnet.save_pretrained(
106
+ save_directory + suffix,
107
+ is_main_process=is_main_process,
108
+ save_function=save_function,
109
+ safe_serialization=safe_serialization,
110
+ variant=variant,
111
+ )
112
+
113
+ @classmethod
114
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
115
+ r"""
116
+ Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
117
+
118
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
119
+ the model, you should first set it back in training mode with `model.train()`.
120
+
121
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
122
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
123
+ task.
124
+
125
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
126
+ weights are discarded.
127
+
128
+ Parameters:
129
+ pretrained_model_path (`os.PathLike`):
130
+ A path to a *directory* containing model weights saved using
131
+ [`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
132
+ `./my_model_directory/controlnet`.
133
+ torch_dtype (`str` or `torch.dtype`, *optional*):
134
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
135
+ will be automatically derived from the model's weights.
136
+ output_loading_info(`bool`, *optional*, defaults to `False`):
137
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
138
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
139
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
140
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
141
+ same device.
142
+
143
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
144
+ more information about each option see [designing a device
145
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
146
+ max_memory (`Dict`, *optional*):
147
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
148
+ GPU and the available CPU RAM if unset.
149
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
150
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
151
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
152
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
153
+ setting this argument to `True` will raise an error.
154
+ variant (`str`, *optional*):
155
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
156
+ ignored when using `from_flax`.
157
+ use_safetensors (`bool`, *optional*, defaults to `None`):
158
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
159
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
160
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
161
+ """
162
+ idx = 0
163
+ controlnets = []
164
+
165
+ # load controlnet and append to list until no controlnet directory exists anymore
166
+ # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
167
+ # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
168
+ model_path_to_load = pretrained_model_path
169
+ while os.path.isdir(model_path_to_load):
170
+ controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
171
+ controlnets.append(controlnet)
172
+
173
+ idx += 1
174
+ model_path_to_load = pretrained_model_path + f"_{idx}"
175
+
176
+ logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
177
+
178
+ if len(controlnets) == 0:
179
+ raise ValueError(
180
+ f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
181
+ )
182
+
183
+ return cls(controlnets)
@@ -102,7 +102,6 @@ class Downsample2D(nn.Module):
102
102
  self.padding = padding
103
103
  stride = 2
104
104
  self.name = name
105
- conv_cls = nn.Conv2d
106
105
 
107
106
  if norm_type == "ln_norm":
108
107
  self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -114,7 +113,7 @@ class Downsample2D(nn.Module):
114
113
  raise ValueError(f"unknown norm_type: {norm_type}")
115
114
 
116
115
  if use_conv:
117
- conv = conv_cls(
116
+ conv = nn.Conv2d(
118
117
  self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
119
118
  )
120
119
  else:
@@ -130,7 +129,7 @@ class Downsample2D(nn.Module):
130
129
  else:
131
130
  self.conv = conv
132
131
 
133
- def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
132
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
134
133
  if len(args) > 0 or kwargs.get("scale", None) is not None:
135
134
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
136
135
  deprecate("scale", "1.0.0", deprecation_message)
@@ -181,24 +180,24 @@ class FirDownsample2D(nn.Module):
181
180
 
182
181
  def _downsample_2d(
183
182
  self,
184
- hidden_states: torch.FloatTensor,
185
- weight: Optional[torch.FloatTensor] = None,
186
- kernel: Optional[torch.FloatTensor] = None,
183
+ hidden_states: torch.Tensor,
184
+ weight: Optional[torch.Tensor] = None,
185
+ kernel: Optional[torch.Tensor] = None,
187
186
  factor: int = 2,
188
187
  gain: float = 1,
189
- ) -> torch.FloatTensor:
188
+ ) -> torch.Tensor:
190
189
  """Fused `Conv2d()` followed by `downsample_2d()`.
191
190
  Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
192
191
  efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
193
192
  arbitrary order.
194
193
 
195
194
  Args:
196
- hidden_states (`torch.FloatTensor`):
195
+ hidden_states (`torch.Tensor`):
197
196
  Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
198
- weight (`torch.FloatTensor`, *optional*):
197
+ weight (`torch.Tensor`, *optional*):
199
198
  Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
200
199
  performed by `inChannels = x.shape[0] // numGroups`.
201
- kernel (`torch.FloatTensor`, *optional*):
200
+ kernel (`torch.Tensor`, *optional*):
202
201
  FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
203
202
  corresponds to average pooling.
204
203
  factor (`int`, *optional*, default to `2`):
@@ -207,7 +206,7 @@ class FirDownsample2D(nn.Module):
207
206
  Scaling factor for signal magnitude.
208
207
 
209
208
  Returns:
210
- output (`torch.FloatTensor`):
209
+ output (`torch.Tensor`):
211
210
  Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
212
211
  datatype as `x`.
213
212
  """
@@ -245,7 +244,7 @@ class FirDownsample2D(nn.Module):
245
244
 
246
245
  return output
247
246
 
248
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
247
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
249
248
  if self.use_conv:
250
249
  downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
251
250
  hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -286,12 +285,80 @@ class KDownsample2D(nn.Module):
286
285
  return F.conv2d(inputs, weight, stride=2)
287
286
 
288
287
 
288
+ class CogVideoXDownsample3D(nn.Module):
289
+ # Todo: Wait for paper relase.
290
+ r"""
291
+ A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
292
+
293
+ Args:
294
+ in_channels (`int`):
295
+ Number of channels in the input image.
296
+ out_channels (`int`):
297
+ Number of channels produced by the convolution.
298
+ kernel_size (`int`, defaults to `3`):
299
+ Size of the convolving kernel.
300
+ stride (`int`, defaults to `2`):
301
+ Stride of the convolution.
302
+ padding (`int`, defaults to `0`):
303
+ Padding added to all four sides of the input.
304
+ compress_time (`bool`, defaults to `False`):
305
+ Whether or not to compress the time dimension.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ in_channels: int,
311
+ out_channels: int,
312
+ kernel_size: int = 3,
313
+ stride: int = 2,
314
+ padding: int = 0,
315
+ compress_time: bool = False,
316
+ ):
317
+ super().__init__()
318
+
319
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
320
+ self.compress_time = compress_time
321
+
322
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
323
+ if self.compress_time:
324
+ batch_size, channels, frames, height, width = x.shape
325
+
326
+ # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
327
+ x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
328
+
329
+ if x.shape[-1] % 2 == 1:
330
+ x_first, x_rest = x[..., 0], x[..., 1:]
331
+ if x_rest.shape[-1] > 0:
332
+ # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
333
+ x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
334
+
335
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
336
+ # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
337
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
338
+ else:
339
+ # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
340
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
341
+ # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
342
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
343
+
344
+ # Pad the tensor
345
+ pad = (0, 1, 0, 1)
346
+ x = F.pad(x, pad, mode="constant", value=0)
347
+ batch_size, channels, frames, height, width = x.shape
348
+ # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
349
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
350
+ x = self.conv(x)
351
+ # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
352
+ x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
353
+ return x
354
+
355
+
289
356
  def downsample_2d(
290
- hidden_states: torch.FloatTensor,
291
- kernel: Optional[torch.FloatTensor] = None,
357
+ hidden_states: torch.Tensor,
358
+ kernel: Optional[torch.Tensor] = None,
292
359
  factor: int = 2,
293
360
  gain: float = 1,
294
- ) -> torch.FloatTensor:
361
+ ) -> torch.Tensor:
295
362
  r"""Downsample2D a batch of 2D images with the given filter.
296
363
  Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
297
364
  given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
@@ -299,9 +366,9 @@ def downsample_2d(
299
366
  shape is a multiple of the downsampling factor.
300
367
 
301
368
  Args:
302
- hidden_states (`torch.FloatTensor`)
369
+ hidden_states (`torch.Tensor`)
303
370
  Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
304
- kernel (`torch.FloatTensor`, *optional*):
371
+ kernel (`torch.Tensor`, *optional*):
305
372
  FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
306
373
  corresponds to average pooling.
307
374
  factor (`int`, *optional*, default to `2`):
@@ -310,7 +377,7 @@ def downsample_2d(
310
377
  Scaling factor for signal magnitude.
311
378
 
312
379
  Returns:
313
- output (`torch.FloatTensor`):
380
+ output (`torch.Tensor`):
314
381
  Tensor of the shape `[N, C, H // factor, W // factor]`
315
382
  """
316
383