diffusers 0.27.0__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +50 -53
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.0.dist-info/RECORD +0 -399
  443. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -27,17 +27,58 @@ from ..resnet import (
27
27
  TemporalConvLayer,
28
28
  Upsample2D,
29
29
  )
30
- from ..transformers.dual_transformer_2d import DualTransformer2DModel
31
30
  from ..transformers.transformer_2d import Transformer2DModel
32
31
  from ..transformers.transformer_temporal import (
33
32
  TransformerSpatioTemporalModel,
34
33
  TransformerTemporalModel,
35
34
  )
35
+ from .unet_motion_model import (
36
+ CrossAttnDownBlockMotion,
37
+ CrossAttnUpBlockMotion,
38
+ DownBlockMotion,
39
+ UNetMidBlockCrossAttnMotion,
40
+ UpBlockMotion,
41
+ )
36
42
 
37
43
 
38
44
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
45
 
40
46
 
47
+ class DownBlockMotion(DownBlockMotion):
48
+ def __init__(self, *args, **kwargs):
49
+ deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead."
50
+ deprecate("DownBlockMotion", "1.0.0", deprecation_message)
51
+ super().__init__(*args, **kwargs)
52
+
53
+
54
+ class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion):
55
+ def __init__(self, *args, **kwargs):
56
+ deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead."
57
+ deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message)
58
+ super().__init__(*args, **kwargs)
59
+
60
+
61
+ class UpBlockMotion(UpBlockMotion):
62
+ def __init__(self, *args, **kwargs):
63
+ deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead."
64
+ deprecate("UpBlockMotion", "1.0.0", deprecation_message)
65
+ super().__init__(*args, **kwargs)
66
+
67
+
68
+ class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion):
69
+ def __init__(self, *args, **kwargs):
70
+ deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead."
71
+ deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message)
72
+ super().__init__(*args, **kwargs)
73
+
74
+
75
+ class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion):
76
+ def __init__(self, *args, **kwargs):
77
+ deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead."
78
+ deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message)
79
+ super().__init__(*args, **kwargs)
80
+
81
+
41
82
  def get_down_block(
42
83
  down_block_type: str,
43
84
  num_layers: int,
@@ -58,12 +99,12 @@ def get_down_block(
58
99
  resnet_time_scale_shift: str = "default",
59
100
  temporal_num_attention_heads: int = 8,
60
101
  temporal_max_seq_length: int = 32,
61
- transformer_layers_per_block: int = 1,
102
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
103
+ temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
104
+ dropout: float = 0.0,
62
105
  ) -> Union[
63
106
  "DownBlock3D",
64
107
  "CrossAttnDownBlock3D",
65
- "DownBlockMotion",
66
- "CrossAttnDownBlockMotion",
67
108
  "DownBlockSpatioTemporal",
68
109
  "CrossAttnDownBlockSpatioTemporal",
69
110
  ]:
@@ -79,6 +120,7 @@ def get_down_block(
79
120
  resnet_groups=resnet_groups,
80
121
  downsample_padding=downsample_padding,
81
122
  resnet_time_scale_shift=resnet_time_scale_shift,
123
+ dropout=dropout,
82
124
  )
83
125
  elif down_block_type == "CrossAttnDownBlock3D":
84
126
  if cross_attention_dim is None:
@@ -100,44 +142,7 @@ def get_down_block(
100
142
  only_cross_attention=only_cross_attention,
101
143
  upcast_attention=upcast_attention,
102
144
  resnet_time_scale_shift=resnet_time_scale_shift,
103
- )
104
- if down_block_type == "DownBlockMotion":
105
- return DownBlockMotion(
106
- num_layers=num_layers,
107
- in_channels=in_channels,
108
- out_channels=out_channels,
109
- temb_channels=temb_channels,
110
- add_downsample=add_downsample,
111
- resnet_eps=resnet_eps,
112
- resnet_act_fn=resnet_act_fn,
113
- resnet_groups=resnet_groups,
114
- downsample_padding=downsample_padding,
115
- resnet_time_scale_shift=resnet_time_scale_shift,
116
- temporal_num_attention_heads=temporal_num_attention_heads,
117
- temporal_max_seq_length=temporal_max_seq_length,
118
- )
119
- elif down_block_type == "CrossAttnDownBlockMotion":
120
- if cross_attention_dim is None:
121
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
122
- return CrossAttnDownBlockMotion(
123
- num_layers=num_layers,
124
- in_channels=in_channels,
125
- out_channels=out_channels,
126
- temb_channels=temb_channels,
127
- add_downsample=add_downsample,
128
- resnet_eps=resnet_eps,
129
- resnet_act_fn=resnet_act_fn,
130
- resnet_groups=resnet_groups,
131
- downsample_padding=downsample_padding,
132
- cross_attention_dim=cross_attention_dim,
133
- num_attention_heads=num_attention_heads,
134
- dual_cross_attention=dual_cross_attention,
135
- use_linear_projection=use_linear_projection,
136
- only_cross_attention=only_cross_attention,
137
- upcast_attention=upcast_attention,
138
- resnet_time_scale_shift=resnet_time_scale_shift,
139
- temporal_num_attention_heads=temporal_num_attention_heads,
140
- temporal_max_seq_length=temporal_max_seq_length,
145
+ dropout=dropout,
141
146
  )
142
147
  elif down_block_type == "DownBlockSpatioTemporal":
143
148
  # added for SDV
@@ -188,13 +193,12 @@ def get_up_block(
188
193
  temporal_num_attention_heads: int = 8,
189
194
  temporal_cross_attention_dim: Optional[int] = None,
190
195
  temporal_max_seq_length: int = 32,
191
- transformer_layers_per_block: int = 1,
196
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
197
+ temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
192
198
  dropout: float = 0.0,
193
199
  ) -> Union[
194
200
  "UpBlock3D",
195
201
  "CrossAttnUpBlock3D",
196
- "UpBlockMotion",
197
- "CrossAttnUpBlockMotion",
198
202
  "UpBlockSpatioTemporal",
199
203
  "CrossAttnUpBlockSpatioTemporal",
200
204
  ]:
@@ -211,6 +215,7 @@ def get_up_block(
211
215
  resnet_groups=resnet_groups,
212
216
  resnet_time_scale_shift=resnet_time_scale_shift,
213
217
  resolution_idx=resolution_idx,
218
+ dropout=dropout,
214
219
  )
215
220
  elif up_block_type == "CrossAttnUpBlock3D":
216
221
  if cross_attention_dim is None:
@@ -233,46 +238,7 @@ def get_up_block(
233
238
  upcast_attention=upcast_attention,
234
239
  resnet_time_scale_shift=resnet_time_scale_shift,
235
240
  resolution_idx=resolution_idx,
236
- )
237
- if up_block_type == "UpBlockMotion":
238
- return UpBlockMotion(
239
- num_layers=num_layers,
240
- in_channels=in_channels,
241
- out_channels=out_channels,
242
- prev_output_channel=prev_output_channel,
243
- temb_channels=temb_channels,
244
- add_upsample=add_upsample,
245
- resnet_eps=resnet_eps,
246
- resnet_act_fn=resnet_act_fn,
247
- resnet_groups=resnet_groups,
248
- resnet_time_scale_shift=resnet_time_scale_shift,
249
- resolution_idx=resolution_idx,
250
- temporal_num_attention_heads=temporal_num_attention_heads,
251
- temporal_max_seq_length=temporal_max_seq_length,
252
- )
253
- elif up_block_type == "CrossAttnUpBlockMotion":
254
- if cross_attention_dim is None:
255
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
256
- return CrossAttnUpBlockMotion(
257
- num_layers=num_layers,
258
- in_channels=in_channels,
259
- out_channels=out_channels,
260
- prev_output_channel=prev_output_channel,
261
- temb_channels=temb_channels,
262
- add_upsample=add_upsample,
263
- resnet_eps=resnet_eps,
264
- resnet_act_fn=resnet_act_fn,
265
- resnet_groups=resnet_groups,
266
- cross_attention_dim=cross_attention_dim,
267
- num_attention_heads=num_attention_heads,
268
- dual_cross_attention=dual_cross_attention,
269
- use_linear_projection=use_linear_projection,
270
- only_cross_attention=only_cross_attention,
271
- upcast_attention=upcast_attention,
272
- resnet_time_scale_shift=resnet_time_scale_shift,
273
- resolution_idx=resolution_idx,
274
- temporal_num_attention_heads=temporal_num_attention_heads,
275
- temporal_max_seq_length=temporal_max_seq_length,
241
+ dropout=dropout,
276
242
  )
277
243
  elif up_block_type == "UpBlockSpatioTemporal":
278
244
  # added for SDV
@@ -409,13 +375,13 @@ class UNetMidBlock3DCrossAttn(nn.Module):
409
375
 
410
376
  def forward(
411
377
  self,
412
- hidden_states: torch.FloatTensor,
413
- temb: Optional[torch.FloatTensor] = None,
414
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
415
- attention_mask: Optional[torch.FloatTensor] = None,
378
+ hidden_states: torch.Tensor,
379
+ temb: Optional[torch.Tensor] = None,
380
+ encoder_hidden_states: Optional[torch.Tensor] = None,
381
+ attention_mask: Optional[torch.Tensor] = None,
416
382
  num_frames: int = 1,
417
383
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
418
- ) -> torch.FloatTensor:
384
+ ) -> torch.Tensor:
419
385
  hidden_states = self.resnets[0](hidden_states, temb)
420
386
  hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
421
387
  for attn, temp_attn, resnet, temp_conv in zip(
@@ -542,13 +508,13 @@ class CrossAttnDownBlock3D(nn.Module):
542
508
 
543
509
  def forward(
544
510
  self,
545
- hidden_states: torch.FloatTensor,
546
- temb: Optional[torch.FloatTensor] = None,
547
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
548
- attention_mask: Optional[torch.FloatTensor] = None,
511
+ hidden_states: torch.Tensor,
512
+ temb: Optional[torch.Tensor] = None,
513
+ encoder_hidden_states: Optional[torch.Tensor] = None,
514
+ attention_mask: Optional[torch.Tensor] = None,
549
515
  num_frames: int = 1,
550
516
  cross_attention_kwargs: Dict[str, Any] = None,
551
- ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
517
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
552
518
  # TODO(Patrick, William) - attention mask is not used
553
519
  output_states = ()
554
520
 
@@ -649,10 +615,10 @@ class DownBlock3D(nn.Module):
649
615
 
650
616
  def forward(
651
617
  self,
652
- hidden_states: torch.FloatTensor,
653
- temb: Optional[torch.FloatTensor] = None,
618
+ hidden_states: torch.Tensor,
619
+ temb: Optional[torch.Tensor] = None,
654
620
  num_frames: int = 1,
655
- ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
621
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
656
622
  output_states = ()
657
623
 
658
624
  for resnet, temp_conv in zip(self.resnets, self.temp_convs):
@@ -767,15 +733,15 @@ class CrossAttnUpBlock3D(nn.Module):
767
733
 
768
734
  def forward(
769
735
  self,
770
- hidden_states: torch.FloatTensor,
771
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
772
- temb: Optional[torch.FloatTensor] = None,
773
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
736
+ hidden_states: torch.Tensor,
737
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
738
+ temb: Optional[torch.Tensor] = None,
739
+ encoder_hidden_states: Optional[torch.Tensor] = None,
774
740
  upsample_size: Optional[int] = None,
775
- attention_mask: Optional[torch.FloatTensor] = None,
741
+ attention_mask: Optional[torch.Tensor] = None,
776
742
  num_frames: int = 1,
777
743
  cross_attention_kwargs: Dict[str, Any] = None,
778
- ) -> torch.FloatTensor:
744
+ ) -> torch.Tensor:
779
745
  is_freeu_enabled = (
780
746
  getattr(self, "s1", None)
781
747
  and getattr(self, "s2", None)
@@ -889,12 +855,12 @@ class UpBlock3D(nn.Module):
889
855
 
890
856
  def forward(
891
857
  self,
892
- hidden_states: torch.FloatTensor,
893
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
894
- temb: Optional[torch.FloatTensor] = None,
858
+ hidden_states: torch.Tensor,
859
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
860
+ temb: Optional[torch.Tensor] = None,
895
861
  upsample_size: Optional[int] = None,
896
862
  num_frames: int = 1,
897
- ) -> torch.FloatTensor:
863
+ ) -> torch.Tensor:
898
864
  is_freeu_enabled = (
899
865
  getattr(self, "s1", None)
900
866
  and getattr(self, "s2", None)
@@ -930,970 +896,137 @@ class UpBlock3D(nn.Module):
930
896
  return hidden_states
931
897
 
932
898
 
933
- class DownBlockMotion(nn.Module):
899
+ class MidBlockTemporalDecoder(nn.Module):
934
900
  def __init__(
935
901
  self,
936
902
  in_channels: int,
937
903
  out_channels: int,
938
- temb_channels: int,
939
- dropout: float = 0.0,
904
+ attention_head_dim: int = 512,
940
905
  num_layers: int = 1,
941
- resnet_eps: float = 1e-6,
942
- resnet_time_scale_shift: str = "default",
943
- resnet_act_fn: str = "swish",
944
- resnet_groups: int = 32,
945
- resnet_pre_norm: bool = True,
946
- output_scale_factor: float = 1.0,
947
- add_downsample: bool = True,
948
- downsample_padding: int = 1,
949
- temporal_num_attention_heads: int = 1,
950
- temporal_cross_attention_dim: Optional[int] = None,
951
- temporal_max_seq_length: int = 32,
906
+ upcast_attention: bool = False,
952
907
  ):
953
908
  super().__init__()
954
- resnets = []
955
- motion_modules = []
956
909
 
910
+ resnets = []
911
+ attentions = []
957
912
  for i in range(num_layers):
958
- in_channels = in_channels if i == 0 else out_channels
913
+ input_channels = in_channels if i == 0 else out_channels
959
914
  resnets.append(
960
- ResnetBlock2D(
961
- in_channels=in_channels,
915
+ SpatioTemporalResBlock(
916
+ in_channels=input_channels,
962
917
  out_channels=out_channels,
963
- temb_channels=temb_channels,
964
- eps=resnet_eps,
965
- groups=resnet_groups,
966
- dropout=dropout,
967
- time_embedding_norm=resnet_time_scale_shift,
968
- non_linearity=resnet_act_fn,
969
- output_scale_factor=output_scale_factor,
970
- pre_norm=resnet_pre_norm,
971
- )
972
- )
973
- motion_modules.append(
974
- TransformerTemporalModel(
975
- num_attention_heads=temporal_num_attention_heads,
976
- in_channels=out_channels,
977
- norm_num_groups=resnet_groups,
978
- cross_attention_dim=temporal_cross_attention_dim,
979
- attention_bias=False,
980
- activation_fn="geglu",
981
- positional_embeddings="sinusoidal",
982
- num_positional_embeddings=temporal_max_seq_length,
983
- attention_head_dim=out_channels // temporal_num_attention_heads,
918
+ temb_channels=None,
919
+ eps=1e-6,
920
+ temporal_eps=1e-5,
921
+ merge_factor=0.0,
922
+ merge_strategy="learned",
923
+ switch_spatial_to_temporal_mix=True,
984
924
  )
985
925
  )
986
926
 
987
- self.resnets = nn.ModuleList(resnets)
988
- self.motion_modules = nn.ModuleList(motion_modules)
989
-
990
- if add_downsample:
991
- self.downsamplers = nn.ModuleList(
992
- [
993
- Downsample2D(
994
- out_channels,
995
- use_conv=True,
996
- out_channels=out_channels,
997
- padding=downsample_padding,
998
- name="op",
999
- )
1000
- ]
927
+ attentions.append(
928
+ Attention(
929
+ query_dim=in_channels,
930
+ heads=in_channels // attention_head_dim,
931
+ dim_head=attention_head_dim,
932
+ eps=1e-6,
933
+ upcast_attention=upcast_attention,
934
+ norm_num_groups=32,
935
+ bias=True,
936
+ residual_connection=True,
1001
937
  )
1002
- else:
1003
- self.downsamplers = None
938
+ )
1004
939
 
1005
- self.gradient_checkpointing = False
940
+ self.attentions = nn.ModuleList(attentions)
941
+ self.resnets = nn.ModuleList(resnets)
1006
942
 
1007
943
  def forward(
1008
944
  self,
1009
- hidden_states: torch.FloatTensor,
1010
- temb: Optional[torch.FloatTensor] = None,
1011
- num_frames: int = 1,
1012
- *args,
1013
- **kwargs,
1014
- ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1015
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1016
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1017
- deprecate("scale", "1.0.0", deprecation_message)
1018
-
1019
- output_states = ()
1020
-
1021
- blocks = zip(self.resnets, self.motion_modules)
1022
- for resnet, motion_module in blocks:
1023
- if self.training and self.gradient_checkpointing:
945
+ hidden_states: torch.Tensor,
946
+ image_only_indicator: torch.Tensor,
947
+ ):
948
+ hidden_states = self.resnets[0](
949
+ hidden_states,
950
+ image_only_indicator=image_only_indicator,
951
+ )
952
+ for resnet, attn in zip(self.resnets[1:], self.attentions):
953
+ hidden_states = attn(hidden_states)
954
+ hidden_states = resnet(
955
+ hidden_states,
956
+ image_only_indicator=image_only_indicator,
957
+ )
1024
958
 
1025
- def create_custom_forward(module):
1026
- def custom_forward(*inputs):
1027
- return module(*inputs)
959
+ return hidden_states
1028
960
 
1029
- return custom_forward
1030
961
 
1031
- if is_torch_version(">=", "1.11.0"):
1032
- hidden_states = torch.utils.checkpoint.checkpoint(
1033
- create_custom_forward(resnet),
1034
- hidden_states,
1035
- temb,
1036
- use_reentrant=False,
1037
- )
1038
- else:
1039
- hidden_states = torch.utils.checkpoint.checkpoint(
1040
- create_custom_forward(resnet), hidden_states, temb
1041
- )
962
+ class UpBlockTemporalDecoder(nn.Module):
963
+ def __init__(
964
+ self,
965
+ in_channels: int,
966
+ out_channels: int,
967
+ num_layers: int = 1,
968
+ add_upsample: bool = True,
969
+ ):
970
+ super().__init__()
971
+ resnets = []
972
+ for i in range(num_layers):
973
+ input_channels = in_channels if i == 0 else out_channels
1042
974
 
1043
- else:
1044
- hidden_states = resnet(hidden_states, temb)
1045
- hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
975
+ resnets.append(
976
+ SpatioTemporalResBlock(
977
+ in_channels=input_channels,
978
+ out_channels=out_channels,
979
+ temb_channels=None,
980
+ eps=1e-6,
981
+ temporal_eps=1e-5,
982
+ merge_factor=0.0,
983
+ merge_strategy="learned",
984
+ switch_spatial_to_temporal_mix=True,
985
+ )
986
+ )
987
+ self.resnets = nn.ModuleList(resnets)
1046
988
 
1047
- output_states = output_states + (hidden_states,)
989
+ if add_upsample:
990
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
991
+ else:
992
+ self.upsamplers = None
1048
993
 
1049
- if self.downsamplers is not None:
1050
- for downsampler in self.downsamplers:
1051
- hidden_states = downsampler(hidden_states)
994
+ def forward(
995
+ self,
996
+ hidden_states: torch.Tensor,
997
+ image_only_indicator: torch.Tensor,
998
+ ) -> torch.Tensor:
999
+ for resnet in self.resnets:
1000
+ hidden_states = resnet(
1001
+ hidden_states,
1002
+ image_only_indicator=image_only_indicator,
1003
+ )
1052
1004
 
1053
- output_states = output_states + (hidden_states,)
1005
+ if self.upsamplers is not None:
1006
+ for upsampler in self.upsamplers:
1007
+ hidden_states = upsampler(hidden_states)
1054
1008
 
1055
- return hidden_states, output_states
1009
+ return hidden_states
1056
1010
 
1057
1011
 
1058
- class CrossAttnDownBlockMotion(nn.Module):
1012
+ class UNetMidBlockSpatioTemporal(nn.Module):
1059
1013
  def __init__(
1060
1014
  self,
1061
1015
  in_channels: int,
1062
- out_channels: int,
1063
1016
  temb_channels: int,
1064
- dropout: float = 0.0,
1065
1017
  num_layers: int = 1,
1066
- transformer_layers_per_block: int = 1,
1067
- resnet_eps: float = 1e-6,
1068
- resnet_time_scale_shift: str = "default",
1069
- resnet_act_fn: str = "swish",
1070
- resnet_groups: int = 32,
1071
- resnet_pre_norm: bool = True,
1018
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
1072
1019
  num_attention_heads: int = 1,
1073
1020
  cross_attention_dim: int = 1280,
1074
- output_scale_factor: float = 1.0,
1075
- downsample_padding: int = 1,
1076
- add_downsample: bool = True,
1077
- dual_cross_attention: bool = False,
1078
- use_linear_projection: bool = False,
1079
- only_cross_attention: bool = False,
1080
- upcast_attention: bool = False,
1081
- attention_type: str = "default",
1082
- temporal_cross_attention_dim: Optional[int] = None,
1083
- temporal_num_attention_heads: int = 8,
1084
- temporal_max_seq_length: int = 32,
1085
1021
  ):
1086
1022
  super().__init__()
1087
- resnets = []
1088
- attentions = []
1089
- motion_modules = []
1090
1023
 
1091
1024
  self.has_cross_attention = True
1092
1025
  self.num_attention_heads = num_attention_heads
1093
1026
 
1094
- for i in range(num_layers):
1095
- in_channels = in_channels if i == 0 else out_channels
1096
- resnets.append(
1097
- ResnetBlock2D(
1098
- in_channels=in_channels,
1099
- out_channels=out_channels,
1100
- temb_channels=temb_channels,
1101
- eps=resnet_eps,
1102
- groups=resnet_groups,
1103
- dropout=dropout,
1104
- time_embedding_norm=resnet_time_scale_shift,
1105
- non_linearity=resnet_act_fn,
1106
- output_scale_factor=output_scale_factor,
1107
- pre_norm=resnet_pre_norm,
1108
- )
1109
- )
1110
-
1111
- if not dual_cross_attention:
1112
- attentions.append(
1113
- Transformer2DModel(
1114
- num_attention_heads,
1115
- out_channels // num_attention_heads,
1116
- in_channels=out_channels,
1117
- num_layers=transformer_layers_per_block,
1118
- cross_attention_dim=cross_attention_dim,
1119
- norm_num_groups=resnet_groups,
1120
- use_linear_projection=use_linear_projection,
1121
- only_cross_attention=only_cross_attention,
1122
- upcast_attention=upcast_attention,
1123
- attention_type=attention_type,
1124
- )
1125
- )
1126
- else:
1127
- attentions.append(
1128
- DualTransformer2DModel(
1129
- num_attention_heads,
1130
- out_channels // num_attention_heads,
1131
- in_channels=out_channels,
1132
- num_layers=1,
1133
- cross_attention_dim=cross_attention_dim,
1134
- norm_num_groups=resnet_groups,
1135
- )
1136
- )
1137
-
1138
- motion_modules.append(
1139
- TransformerTemporalModel(
1140
- num_attention_heads=temporal_num_attention_heads,
1141
- in_channels=out_channels,
1142
- norm_num_groups=resnet_groups,
1143
- cross_attention_dim=temporal_cross_attention_dim,
1144
- attention_bias=False,
1145
- activation_fn="geglu",
1146
- positional_embeddings="sinusoidal",
1147
- num_positional_embeddings=temporal_max_seq_length,
1148
- attention_head_dim=out_channels // temporal_num_attention_heads,
1149
- )
1150
- )
1151
-
1152
- self.attentions = nn.ModuleList(attentions)
1153
- self.resnets = nn.ModuleList(resnets)
1154
- self.motion_modules = nn.ModuleList(motion_modules)
1155
-
1156
- if add_downsample:
1157
- self.downsamplers = nn.ModuleList(
1158
- [
1159
- Downsample2D(
1160
- out_channels,
1161
- use_conv=True,
1162
- out_channels=out_channels,
1163
- padding=downsample_padding,
1164
- name="op",
1165
- )
1166
- ]
1167
- )
1168
- else:
1169
- self.downsamplers = None
1170
-
1171
- self.gradient_checkpointing = False
1172
-
1173
- def forward(
1174
- self,
1175
- hidden_states: torch.FloatTensor,
1176
- temb: Optional[torch.FloatTensor] = None,
1177
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1178
- attention_mask: Optional[torch.FloatTensor] = None,
1179
- num_frames: int = 1,
1180
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1181
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1182
- additional_residuals: Optional[torch.FloatTensor] = None,
1183
- ):
1184
- if cross_attention_kwargs is not None:
1185
- if cross_attention_kwargs.get("scale", None) is not None:
1186
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1187
-
1188
- output_states = ()
1189
-
1190
- blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
1191
- for i, (resnet, attn, motion_module) in enumerate(blocks):
1192
- if self.training and self.gradient_checkpointing:
1193
-
1194
- def create_custom_forward(module, return_dict=None):
1195
- def custom_forward(*inputs):
1196
- if return_dict is not None:
1197
- return module(*inputs, return_dict=return_dict)
1198
- else:
1199
- return module(*inputs)
1200
-
1201
- return custom_forward
1202
-
1203
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1204
- hidden_states = torch.utils.checkpoint.checkpoint(
1205
- create_custom_forward(resnet),
1206
- hidden_states,
1207
- temb,
1208
- **ckpt_kwargs,
1209
- )
1210
- hidden_states = attn(
1211
- hidden_states,
1212
- encoder_hidden_states=encoder_hidden_states,
1213
- cross_attention_kwargs=cross_attention_kwargs,
1214
- attention_mask=attention_mask,
1215
- encoder_attention_mask=encoder_attention_mask,
1216
- return_dict=False,
1217
- )[0]
1218
- else:
1219
- hidden_states = resnet(hidden_states, temb)
1220
- hidden_states = attn(
1221
- hidden_states,
1222
- encoder_hidden_states=encoder_hidden_states,
1223
- cross_attention_kwargs=cross_attention_kwargs,
1224
- attention_mask=attention_mask,
1225
- encoder_attention_mask=encoder_attention_mask,
1226
- return_dict=False,
1227
- )[0]
1228
- hidden_states = motion_module(
1229
- hidden_states,
1230
- num_frames=num_frames,
1231
- )[0]
1232
-
1233
- # apply additional residuals to the output of the last pair of resnet and attention blocks
1234
- if i == len(blocks) - 1 and additional_residuals is not None:
1235
- hidden_states = hidden_states + additional_residuals
1236
-
1237
- output_states = output_states + (hidden_states,)
1238
-
1239
- if self.downsamplers is not None:
1240
- for downsampler in self.downsamplers:
1241
- hidden_states = downsampler(hidden_states)
1242
-
1243
- output_states = output_states + (hidden_states,)
1244
-
1245
- return hidden_states, output_states
1246
-
1247
-
1248
- class CrossAttnUpBlockMotion(nn.Module):
1249
- def __init__(
1250
- self,
1251
- in_channels: int,
1252
- out_channels: int,
1253
- prev_output_channel: int,
1254
- temb_channels: int,
1255
- resolution_idx: Optional[int] = None,
1256
- dropout: float = 0.0,
1257
- num_layers: int = 1,
1258
- transformer_layers_per_block: int = 1,
1259
- resnet_eps: float = 1e-6,
1260
- resnet_time_scale_shift: str = "default",
1261
- resnet_act_fn: str = "swish",
1262
- resnet_groups: int = 32,
1263
- resnet_pre_norm: bool = True,
1264
- num_attention_heads: int = 1,
1265
- cross_attention_dim: int = 1280,
1266
- output_scale_factor: float = 1.0,
1267
- add_upsample: bool = True,
1268
- dual_cross_attention: bool = False,
1269
- use_linear_projection: bool = False,
1270
- only_cross_attention: bool = False,
1271
- upcast_attention: bool = False,
1272
- attention_type: str = "default",
1273
- temporal_cross_attention_dim: Optional[int] = None,
1274
- temporal_num_attention_heads: int = 8,
1275
- temporal_max_seq_length: int = 32,
1276
- ):
1277
- super().__init__()
1278
- resnets = []
1279
- attentions = []
1280
- motion_modules = []
1281
-
1282
- self.has_cross_attention = True
1283
- self.num_attention_heads = num_attention_heads
1284
-
1285
- for i in range(num_layers):
1286
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1287
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
1288
-
1289
- resnets.append(
1290
- ResnetBlock2D(
1291
- in_channels=resnet_in_channels + res_skip_channels,
1292
- out_channels=out_channels,
1293
- temb_channels=temb_channels,
1294
- eps=resnet_eps,
1295
- groups=resnet_groups,
1296
- dropout=dropout,
1297
- time_embedding_norm=resnet_time_scale_shift,
1298
- non_linearity=resnet_act_fn,
1299
- output_scale_factor=output_scale_factor,
1300
- pre_norm=resnet_pre_norm,
1301
- )
1302
- )
1303
-
1304
- if not dual_cross_attention:
1305
- attentions.append(
1306
- Transformer2DModel(
1307
- num_attention_heads,
1308
- out_channels // num_attention_heads,
1309
- in_channels=out_channels,
1310
- num_layers=transformer_layers_per_block,
1311
- cross_attention_dim=cross_attention_dim,
1312
- norm_num_groups=resnet_groups,
1313
- use_linear_projection=use_linear_projection,
1314
- only_cross_attention=only_cross_attention,
1315
- upcast_attention=upcast_attention,
1316
- attention_type=attention_type,
1317
- )
1318
- )
1319
- else:
1320
- attentions.append(
1321
- DualTransformer2DModel(
1322
- num_attention_heads,
1323
- out_channels // num_attention_heads,
1324
- in_channels=out_channels,
1325
- num_layers=1,
1326
- cross_attention_dim=cross_attention_dim,
1327
- norm_num_groups=resnet_groups,
1328
- )
1329
- )
1330
- motion_modules.append(
1331
- TransformerTemporalModel(
1332
- num_attention_heads=temporal_num_attention_heads,
1333
- in_channels=out_channels,
1334
- norm_num_groups=resnet_groups,
1335
- cross_attention_dim=temporal_cross_attention_dim,
1336
- attention_bias=False,
1337
- activation_fn="geglu",
1338
- positional_embeddings="sinusoidal",
1339
- num_positional_embeddings=temporal_max_seq_length,
1340
- attention_head_dim=out_channels // temporal_num_attention_heads,
1341
- )
1342
- )
1343
-
1344
- self.attentions = nn.ModuleList(attentions)
1345
- self.resnets = nn.ModuleList(resnets)
1346
- self.motion_modules = nn.ModuleList(motion_modules)
1347
-
1348
- if add_upsample:
1349
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1350
- else:
1351
- self.upsamplers = None
1352
-
1353
- self.gradient_checkpointing = False
1354
- self.resolution_idx = resolution_idx
1355
-
1356
- def forward(
1357
- self,
1358
- hidden_states: torch.FloatTensor,
1359
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1360
- temb: Optional[torch.FloatTensor] = None,
1361
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1362
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1363
- upsample_size: Optional[int] = None,
1364
- attention_mask: Optional[torch.FloatTensor] = None,
1365
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1366
- num_frames: int = 1,
1367
- ) -> torch.FloatTensor:
1368
- if cross_attention_kwargs is not None:
1369
- if cross_attention_kwargs.get("scale", None) is not None:
1370
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1371
-
1372
- is_freeu_enabled = (
1373
- getattr(self, "s1", None)
1374
- and getattr(self, "s2", None)
1375
- and getattr(self, "b1", None)
1376
- and getattr(self, "b2", None)
1377
- )
1378
-
1379
- blocks = zip(self.resnets, self.attentions, self.motion_modules)
1380
- for resnet, attn, motion_module in blocks:
1381
- # pop res hidden states
1382
- res_hidden_states = res_hidden_states_tuple[-1]
1383
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1384
-
1385
- # FreeU: Only operate on the first two stages
1386
- if is_freeu_enabled:
1387
- hidden_states, res_hidden_states = apply_freeu(
1388
- self.resolution_idx,
1389
- hidden_states,
1390
- res_hidden_states,
1391
- s1=self.s1,
1392
- s2=self.s2,
1393
- b1=self.b1,
1394
- b2=self.b2,
1395
- )
1396
-
1397
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1398
-
1399
- if self.training and self.gradient_checkpointing:
1400
-
1401
- def create_custom_forward(module, return_dict=None):
1402
- def custom_forward(*inputs):
1403
- if return_dict is not None:
1404
- return module(*inputs, return_dict=return_dict)
1405
- else:
1406
- return module(*inputs)
1407
-
1408
- return custom_forward
1409
-
1410
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1411
- hidden_states = torch.utils.checkpoint.checkpoint(
1412
- create_custom_forward(resnet),
1413
- hidden_states,
1414
- temb,
1415
- **ckpt_kwargs,
1416
- )
1417
- hidden_states = attn(
1418
- hidden_states,
1419
- encoder_hidden_states=encoder_hidden_states,
1420
- cross_attention_kwargs=cross_attention_kwargs,
1421
- attention_mask=attention_mask,
1422
- encoder_attention_mask=encoder_attention_mask,
1423
- return_dict=False,
1424
- )[0]
1425
- else:
1426
- hidden_states = resnet(hidden_states, temb)
1427
- hidden_states = attn(
1428
- hidden_states,
1429
- encoder_hidden_states=encoder_hidden_states,
1430
- cross_attention_kwargs=cross_attention_kwargs,
1431
- attention_mask=attention_mask,
1432
- encoder_attention_mask=encoder_attention_mask,
1433
- return_dict=False,
1434
- )[0]
1435
- hidden_states = motion_module(
1436
- hidden_states,
1437
- num_frames=num_frames,
1438
- )[0]
1439
-
1440
- if self.upsamplers is not None:
1441
- for upsampler in self.upsamplers:
1442
- hidden_states = upsampler(hidden_states, upsample_size)
1443
-
1444
- return hidden_states
1445
-
1446
-
1447
- class UpBlockMotion(nn.Module):
1448
- def __init__(
1449
- self,
1450
- in_channels: int,
1451
- prev_output_channel: int,
1452
- out_channels: int,
1453
- temb_channels: int,
1454
- resolution_idx: Optional[int] = None,
1455
- dropout: float = 0.0,
1456
- num_layers: int = 1,
1457
- resnet_eps: float = 1e-6,
1458
- resnet_time_scale_shift: str = "default",
1459
- resnet_act_fn: str = "swish",
1460
- resnet_groups: int = 32,
1461
- resnet_pre_norm: bool = True,
1462
- output_scale_factor: float = 1.0,
1463
- add_upsample: bool = True,
1464
- temporal_norm_num_groups: int = 32,
1465
- temporal_cross_attention_dim: Optional[int] = None,
1466
- temporal_num_attention_heads: int = 8,
1467
- temporal_max_seq_length: int = 32,
1468
- ):
1469
- super().__init__()
1470
- resnets = []
1471
- motion_modules = []
1472
-
1473
- for i in range(num_layers):
1474
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1475
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
1476
-
1477
- resnets.append(
1478
- ResnetBlock2D(
1479
- in_channels=resnet_in_channels + res_skip_channels,
1480
- out_channels=out_channels,
1481
- temb_channels=temb_channels,
1482
- eps=resnet_eps,
1483
- groups=resnet_groups,
1484
- dropout=dropout,
1485
- time_embedding_norm=resnet_time_scale_shift,
1486
- non_linearity=resnet_act_fn,
1487
- output_scale_factor=output_scale_factor,
1488
- pre_norm=resnet_pre_norm,
1489
- )
1490
- )
1491
-
1492
- motion_modules.append(
1493
- TransformerTemporalModel(
1494
- num_attention_heads=temporal_num_attention_heads,
1495
- in_channels=out_channels,
1496
- norm_num_groups=temporal_norm_num_groups,
1497
- cross_attention_dim=temporal_cross_attention_dim,
1498
- attention_bias=False,
1499
- activation_fn="geglu",
1500
- positional_embeddings="sinusoidal",
1501
- num_positional_embeddings=temporal_max_seq_length,
1502
- attention_head_dim=out_channels // temporal_num_attention_heads,
1503
- )
1504
- )
1505
-
1506
- self.resnets = nn.ModuleList(resnets)
1507
- self.motion_modules = nn.ModuleList(motion_modules)
1508
-
1509
- if add_upsample:
1510
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1511
- else:
1512
- self.upsamplers = None
1513
-
1514
- self.gradient_checkpointing = False
1515
- self.resolution_idx = resolution_idx
1516
-
1517
- def forward(
1518
- self,
1519
- hidden_states: torch.FloatTensor,
1520
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1521
- temb: Optional[torch.FloatTensor] = None,
1522
- upsample_size=None,
1523
- num_frames: int = 1,
1524
- *args,
1525
- **kwargs,
1526
- ) -> torch.FloatTensor:
1527
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1528
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1529
- deprecate("scale", "1.0.0", deprecation_message)
1530
-
1531
- is_freeu_enabled = (
1532
- getattr(self, "s1", None)
1533
- and getattr(self, "s2", None)
1534
- and getattr(self, "b1", None)
1535
- and getattr(self, "b2", None)
1536
- )
1537
-
1538
- blocks = zip(self.resnets, self.motion_modules)
1539
-
1540
- for resnet, motion_module in blocks:
1541
- # pop res hidden states
1542
- res_hidden_states = res_hidden_states_tuple[-1]
1543
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1544
-
1545
- # FreeU: Only operate on the first two stages
1546
- if is_freeu_enabled:
1547
- hidden_states, res_hidden_states = apply_freeu(
1548
- self.resolution_idx,
1549
- hidden_states,
1550
- res_hidden_states,
1551
- s1=self.s1,
1552
- s2=self.s2,
1553
- b1=self.b1,
1554
- b2=self.b2,
1555
- )
1556
-
1557
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1558
-
1559
- if self.training and self.gradient_checkpointing:
1560
-
1561
- def create_custom_forward(module):
1562
- def custom_forward(*inputs):
1563
- return module(*inputs)
1564
-
1565
- return custom_forward
1566
-
1567
- if is_torch_version(">=", "1.11.0"):
1568
- hidden_states = torch.utils.checkpoint.checkpoint(
1569
- create_custom_forward(resnet),
1570
- hidden_states,
1571
- temb,
1572
- use_reentrant=False,
1573
- )
1574
- else:
1575
- hidden_states = torch.utils.checkpoint.checkpoint(
1576
- create_custom_forward(resnet), hidden_states, temb
1577
- )
1578
-
1579
- else:
1580
- hidden_states = resnet(hidden_states, temb)
1581
- hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1582
-
1583
- if self.upsamplers is not None:
1584
- for upsampler in self.upsamplers:
1585
- hidden_states = upsampler(hidden_states, upsample_size)
1586
-
1587
- return hidden_states
1588
-
1589
-
1590
- class UNetMidBlockCrossAttnMotion(nn.Module):
1591
- def __init__(
1592
- self,
1593
- in_channels: int,
1594
- temb_channels: int,
1595
- dropout: float = 0.0,
1596
- num_layers: int = 1,
1597
- transformer_layers_per_block: int = 1,
1598
- resnet_eps: float = 1e-6,
1599
- resnet_time_scale_shift: str = "default",
1600
- resnet_act_fn: str = "swish",
1601
- resnet_groups: int = 32,
1602
- resnet_pre_norm: bool = True,
1603
- num_attention_heads: int = 1,
1604
- output_scale_factor: float = 1.0,
1605
- cross_attention_dim: int = 1280,
1606
- dual_cross_attention: float = False,
1607
- use_linear_projection: float = False,
1608
- upcast_attention: float = False,
1609
- attention_type: str = "default",
1610
- temporal_num_attention_heads: int = 1,
1611
- temporal_cross_attention_dim: Optional[int] = None,
1612
- temporal_max_seq_length: int = 32,
1613
- ):
1614
- super().__init__()
1615
-
1616
- self.has_cross_attention = True
1617
- self.num_attention_heads = num_attention_heads
1618
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
1619
-
1620
- # there is always at least one resnet
1621
- resnets = [
1622
- ResnetBlock2D(
1623
- in_channels=in_channels,
1624
- out_channels=in_channels,
1625
- temb_channels=temb_channels,
1626
- eps=resnet_eps,
1627
- groups=resnet_groups,
1628
- dropout=dropout,
1629
- time_embedding_norm=resnet_time_scale_shift,
1630
- non_linearity=resnet_act_fn,
1631
- output_scale_factor=output_scale_factor,
1632
- pre_norm=resnet_pre_norm,
1633
- )
1634
- ]
1635
- attentions = []
1636
- motion_modules = []
1637
-
1638
- for _ in range(num_layers):
1639
- if not dual_cross_attention:
1640
- attentions.append(
1641
- Transformer2DModel(
1642
- num_attention_heads,
1643
- in_channels // num_attention_heads,
1644
- in_channels=in_channels,
1645
- num_layers=transformer_layers_per_block,
1646
- cross_attention_dim=cross_attention_dim,
1647
- norm_num_groups=resnet_groups,
1648
- use_linear_projection=use_linear_projection,
1649
- upcast_attention=upcast_attention,
1650
- attention_type=attention_type,
1651
- )
1652
- )
1653
- else:
1654
- attentions.append(
1655
- DualTransformer2DModel(
1656
- num_attention_heads,
1657
- in_channels // num_attention_heads,
1658
- in_channels=in_channels,
1659
- num_layers=1,
1660
- cross_attention_dim=cross_attention_dim,
1661
- norm_num_groups=resnet_groups,
1662
- )
1663
- )
1664
- resnets.append(
1665
- ResnetBlock2D(
1666
- in_channels=in_channels,
1667
- out_channels=in_channels,
1668
- temb_channels=temb_channels,
1669
- eps=resnet_eps,
1670
- groups=resnet_groups,
1671
- dropout=dropout,
1672
- time_embedding_norm=resnet_time_scale_shift,
1673
- non_linearity=resnet_act_fn,
1674
- output_scale_factor=output_scale_factor,
1675
- pre_norm=resnet_pre_norm,
1676
- )
1677
- )
1678
- motion_modules.append(
1679
- TransformerTemporalModel(
1680
- num_attention_heads=temporal_num_attention_heads,
1681
- attention_head_dim=in_channels // temporal_num_attention_heads,
1682
- in_channels=in_channels,
1683
- norm_num_groups=resnet_groups,
1684
- cross_attention_dim=temporal_cross_attention_dim,
1685
- attention_bias=False,
1686
- positional_embeddings="sinusoidal",
1687
- num_positional_embeddings=temporal_max_seq_length,
1688
- activation_fn="geglu",
1689
- )
1690
- )
1691
-
1692
- self.attentions = nn.ModuleList(attentions)
1693
- self.resnets = nn.ModuleList(resnets)
1694
- self.motion_modules = nn.ModuleList(motion_modules)
1695
-
1696
- self.gradient_checkpointing = False
1697
-
1698
- def forward(
1699
- self,
1700
- hidden_states: torch.FloatTensor,
1701
- temb: Optional[torch.FloatTensor] = None,
1702
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1703
- attention_mask: Optional[torch.FloatTensor] = None,
1704
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1705
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1706
- num_frames: int = 1,
1707
- ) -> torch.FloatTensor:
1708
- if cross_attention_kwargs is not None:
1709
- if cross_attention_kwargs.get("scale", None) is not None:
1710
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1711
-
1712
- hidden_states = self.resnets[0](hidden_states, temb)
1713
-
1714
- blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
1715
- for attn, resnet, motion_module in blocks:
1716
- if self.training and self.gradient_checkpointing:
1717
-
1718
- def create_custom_forward(module, return_dict=None):
1719
- def custom_forward(*inputs):
1720
- if return_dict is not None:
1721
- return module(*inputs, return_dict=return_dict)
1722
- else:
1723
- return module(*inputs)
1724
-
1725
- return custom_forward
1726
-
1727
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1728
- hidden_states = attn(
1729
- hidden_states,
1730
- encoder_hidden_states=encoder_hidden_states,
1731
- cross_attention_kwargs=cross_attention_kwargs,
1732
- attention_mask=attention_mask,
1733
- encoder_attention_mask=encoder_attention_mask,
1734
- return_dict=False,
1735
- )[0]
1736
- hidden_states = torch.utils.checkpoint.checkpoint(
1737
- create_custom_forward(motion_module),
1738
- hidden_states,
1739
- temb,
1740
- **ckpt_kwargs,
1741
- )
1742
- hidden_states = torch.utils.checkpoint.checkpoint(
1743
- create_custom_forward(resnet),
1744
- hidden_states,
1745
- temb,
1746
- **ckpt_kwargs,
1747
- )
1748
- else:
1749
- hidden_states = attn(
1750
- hidden_states,
1751
- encoder_hidden_states=encoder_hidden_states,
1752
- cross_attention_kwargs=cross_attention_kwargs,
1753
- attention_mask=attention_mask,
1754
- encoder_attention_mask=encoder_attention_mask,
1755
- return_dict=False,
1756
- )[0]
1757
- hidden_states = motion_module(
1758
- hidden_states,
1759
- num_frames=num_frames,
1760
- )[0]
1761
- hidden_states = resnet(hidden_states, temb)
1762
-
1763
- return hidden_states
1764
-
1765
-
1766
- class MidBlockTemporalDecoder(nn.Module):
1767
- def __init__(
1768
- self,
1769
- in_channels: int,
1770
- out_channels: int,
1771
- attention_head_dim: int = 512,
1772
- num_layers: int = 1,
1773
- upcast_attention: bool = False,
1774
- ):
1775
- super().__init__()
1776
-
1777
- resnets = []
1778
- attentions = []
1779
- for i in range(num_layers):
1780
- input_channels = in_channels if i == 0 else out_channels
1781
- resnets.append(
1782
- SpatioTemporalResBlock(
1783
- in_channels=input_channels,
1784
- out_channels=out_channels,
1785
- temb_channels=None,
1786
- eps=1e-6,
1787
- temporal_eps=1e-5,
1788
- merge_factor=0.0,
1789
- merge_strategy="learned",
1790
- switch_spatial_to_temporal_mix=True,
1791
- )
1792
- )
1793
-
1794
- attentions.append(
1795
- Attention(
1796
- query_dim=in_channels,
1797
- heads=in_channels // attention_head_dim,
1798
- dim_head=attention_head_dim,
1799
- eps=1e-6,
1800
- upcast_attention=upcast_attention,
1801
- norm_num_groups=32,
1802
- bias=True,
1803
- residual_connection=True,
1804
- )
1805
- )
1806
-
1807
- self.attentions = nn.ModuleList(attentions)
1808
- self.resnets = nn.ModuleList(resnets)
1809
-
1810
- def forward(
1811
- self,
1812
- hidden_states: torch.FloatTensor,
1813
- image_only_indicator: torch.FloatTensor,
1814
- ):
1815
- hidden_states = self.resnets[0](
1816
- hidden_states,
1817
- image_only_indicator=image_only_indicator,
1818
- )
1819
- for resnet, attn in zip(self.resnets[1:], self.attentions):
1820
- hidden_states = attn(hidden_states)
1821
- hidden_states = resnet(
1822
- hidden_states,
1823
- image_only_indicator=image_only_indicator,
1824
- )
1825
-
1826
- return hidden_states
1827
-
1828
-
1829
- class UpBlockTemporalDecoder(nn.Module):
1830
- def __init__(
1831
- self,
1832
- in_channels: int,
1833
- out_channels: int,
1834
- num_layers: int = 1,
1835
- add_upsample: bool = True,
1836
- ):
1837
- super().__init__()
1838
- resnets = []
1839
- for i in range(num_layers):
1840
- input_channels = in_channels if i == 0 else out_channels
1841
-
1842
- resnets.append(
1843
- SpatioTemporalResBlock(
1844
- in_channels=input_channels,
1845
- out_channels=out_channels,
1846
- temb_channels=None,
1847
- eps=1e-6,
1848
- temporal_eps=1e-5,
1849
- merge_factor=0.0,
1850
- merge_strategy="learned",
1851
- switch_spatial_to_temporal_mix=True,
1852
- )
1853
- )
1854
- self.resnets = nn.ModuleList(resnets)
1855
-
1856
- if add_upsample:
1857
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1858
- else:
1859
- self.upsamplers = None
1860
-
1861
- def forward(
1862
- self,
1863
- hidden_states: torch.FloatTensor,
1864
- image_only_indicator: torch.FloatTensor,
1865
- ) -> torch.FloatTensor:
1866
- for resnet in self.resnets:
1867
- hidden_states = resnet(
1868
- hidden_states,
1869
- image_only_indicator=image_only_indicator,
1870
- )
1871
-
1872
- if self.upsamplers is not None:
1873
- for upsampler in self.upsamplers:
1874
- hidden_states = upsampler(hidden_states)
1875
-
1876
- return hidden_states
1877
-
1878
-
1879
- class UNetMidBlockSpatioTemporal(nn.Module):
1880
- def __init__(
1881
- self,
1882
- in_channels: int,
1883
- temb_channels: int,
1884
- num_layers: int = 1,
1885
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
1886
- num_attention_heads: int = 1,
1887
- cross_attention_dim: int = 1280,
1888
- ):
1889
- super().__init__()
1890
-
1891
- self.has_cross_attention = True
1892
- self.num_attention_heads = num_attention_heads
1893
-
1894
- # support for variable transformer layers per block
1895
- if isinstance(transformer_layers_per_block, int):
1896
- transformer_layers_per_block = [transformer_layers_per_block] * num_layers
1027
+ # support for variable transformer layers per block
1028
+ if isinstance(transformer_layers_per_block, int):
1029
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
1897
1030
 
1898
1031
  # there is always at least one resnet
1899
1032
  resnets = [
@@ -1933,11 +1066,11 @@ class UNetMidBlockSpatioTemporal(nn.Module):
1933
1066
 
1934
1067
  def forward(
1935
1068
  self,
1936
- hidden_states: torch.FloatTensor,
1937
- temb: Optional[torch.FloatTensor] = None,
1938
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1069
+ hidden_states: torch.Tensor,
1070
+ temb: Optional[torch.Tensor] = None,
1071
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1939
1072
  image_only_indicator: Optional[torch.Tensor] = None,
1940
- ) -> torch.FloatTensor:
1073
+ ) -> torch.Tensor:
1941
1074
  hidden_states = self.resnets[0](
1942
1075
  hidden_states,
1943
1076
  temb,
@@ -1945,7 +1078,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
1945
1078
  )
1946
1079
 
1947
1080
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
1948
- if self.training and self.gradient_checkpointing: # TODO
1081
+ if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
1949
1082
 
1950
1083
  def create_custom_forward(module, return_dict=None):
1951
1084
  def custom_forward(*inputs):
@@ -2029,13 +1162,13 @@ class DownBlockSpatioTemporal(nn.Module):
2029
1162
 
2030
1163
  def forward(
2031
1164
  self,
2032
- hidden_states: torch.FloatTensor,
2033
- temb: Optional[torch.FloatTensor] = None,
1165
+ hidden_states: torch.Tensor,
1166
+ temb: Optional[torch.Tensor] = None,
2034
1167
  image_only_indicator: Optional[torch.Tensor] = None,
2035
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1168
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
2036
1169
  output_states = ()
2037
1170
  for resnet in self.resnets:
2038
- if self.training and self.gradient_checkpointing:
1171
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2039
1172
 
2040
1173
  def create_custom_forward(module):
2041
1174
  def custom_forward(*inputs):
@@ -2139,16 +1272,16 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
2139
1272
 
2140
1273
  def forward(
2141
1274
  self,
2142
- hidden_states: torch.FloatTensor,
2143
- temb: Optional[torch.FloatTensor] = None,
2144
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1275
+ hidden_states: torch.Tensor,
1276
+ temb: Optional[torch.Tensor] = None,
1277
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2145
1278
  image_only_indicator: Optional[torch.Tensor] = None,
2146
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1279
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
2147
1280
  output_states = ()
2148
1281
 
2149
1282
  blocks = list(zip(self.resnets, self.attentions))
2150
1283
  for resnet, attn in blocks:
2151
- if self.training and self.gradient_checkpointing: # TODO
1284
+ if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
2152
1285
 
2153
1286
  def create_custom_forward(module, return_dict=None):
2154
1287
  def custom_forward(*inputs):
@@ -2238,11 +1371,12 @@ class UpBlockSpatioTemporal(nn.Module):
2238
1371
 
2239
1372
  def forward(
2240
1373
  self,
2241
- hidden_states: torch.FloatTensor,
2242
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2243
- temb: Optional[torch.FloatTensor] = None,
1374
+ hidden_states: torch.Tensor,
1375
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1376
+ temb: Optional[torch.Tensor] = None,
2244
1377
  image_only_indicator: Optional[torch.Tensor] = None,
2245
- ) -> torch.FloatTensor:
1378
+ upsample_size: Optional[int] = None,
1379
+ ) -> torch.Tensor:
2246
1380
  for resnet in self.resnets:
2247
1381
  # pop res hidden states
2248
1382
  res_hidden_states = res_hidden_states_tuple[-1]
@@ -2250,7 +1384,7 @@ class UpBlockSpatioTemporal(nn.Module):
2250
1384
 
2251
1385
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2252
1386
 
2253
- if self.training and self.gradient_checkpointing:
1387
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2254
1388
 
2255
1389
  def create_custom_forward(module):
2256
1390
  def custom_forward(*inputs):
@@ -2282,7 +1416,7 @@ class UpBlockSpatioTemporal(nn.Module):
2282
1416
 
2283
1417
  if self.upsamplers is not None:
2284
1418
  for upsampler in self.upsamplers:
2285
- hidden_states = upsampler(hidden_states)
1419
+ hidden_states = upsampler(hidden_states, upsample_size)
2286
1420
 
2287
1421
  return hidden_states
2288
1422
 
@@ -2347,12 +1481,13 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
2347
1481
 
2348
1482
  def forward(
2349
1483
  self,
2350
- hidden_states: torch.FloatTensor,
2351
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2352
- temb: Optional[torch.FloatTensor] = None,
2353
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1484
+ hidden_states: torch.Tensor,
1485
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1486
+ temb: Optional[torch.Tensor] = None,
1487
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2354
1488
  image_only_indicator: Optional[torch.Tensor] = None,
2355
- ) -> torch.FloatTensor:
1489
+ upsample_size: Optional[int] = None,
1490
+ ) -> torch.Tensor:
2356
1491
  for resnet, attn in zip(self.resnets, self.attentions):
2357
1492
  # pop res hidden states
2358
1493
  res_hidden_states = res_hidden_states_tuple[-1]
@@ -2360,7 +1495,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
2360
1495
 
2361
1496
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2362
1497
 
2363
- if self.training and self.gradient_checkpointing: # TODO
1498
+ if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
2364
1499
 
2365
1500
  def create_custom_forward(module, return_dict=None):
2366
1501
  def custom_forward(*inputs):
@@ -2400,6 +1535,6 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
2400
1535
 
2401
1536
  if self.upsamplers is not None:
2402
1537
  for upsampler in self.upsamplers:
2403
- hidden_states = upsampler(hidden_states)
1538
+ hidden_states = upsampler(hidden_states, upsample_size)
2404
1539
 
2405
1540
  return hidden_states