diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +33 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +8 -0
  21. diffusers/models/activations.py +23 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +475 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +363 -32
  35. diffusers/models/model_loading_utils.py +177 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_outputs.py +14 -0
  39. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  40. diffusers/models/modeling_utils.py +175 -99
  41. diffusers/models/normalization.py +2 -1
  42. diffusers/models/resnet.py +18 -23
  43. diffusers/models/transformer_temporal.py +3 -3
  44. diffusers/models/transformers/__init__.py +3 -0
  45. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  46. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  47. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  48. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  49. diffusers/models/transformers/prior_transformer.py +7 -7
  50. diffusers/models/transformers/t5_film_transformer.py +17 -19
  51. diffusers/models/transformers/transformer_2d.py +292 -184
  52. diffusers/models/transformers/transformer_temporal.py +10 -10
  53. diffusers/models/unets/unet_1d.py +5 -5
  54. diffusers/models/unets/unet_1d_blocks.py +29 -29
  55. diffusers/models/unets/unet_2d.py +6 -6
  56. diffusers/models/unets/unet_2d_blocks.py +137 -128
  57. diffusers/models/unets/unet_2d_condition.py +19 -15
  58. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  59. diffusers/models/unets/unet_3d_blocks.py +79 -77
  60. diffusers/models/unets/unet_3d_condition.py +13 -9
  61. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  62. diffusers/models/unets/unet_kandinsky3.py +1 -1
  63. diffusers/models/unets/unet_motion_model.py +114 -14
  64. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  65. diffusers/models/unets/unet_stable_cascade.py +16 -13
  66. diffusers/models/upsampling.py +17 -20
  67. diffusers/models/vq_model.py +16 -15
  68. diffusers/pipelines/__init__.py +27 -3
  69. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  70. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  71. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  72. diffusers/pipelines/animatediff/__init__.py +2 -0
  73. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  74. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  75. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  76. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  77. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  78. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  79. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  80. diffusers/pipelines/auto_pipeline.py +21 -17
  81. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  82. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  83. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  84. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  85. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  86. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  87. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  88. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  89. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  90. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  91. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  92. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  93. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  94. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  95. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  96. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  97. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  98. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  99. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  100. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  101. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  102. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  103. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  104. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  105. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  106. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  107. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  108. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  109. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  110. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  111. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  112. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  113. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  114. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  115. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  116. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  117. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  118. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  119. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  120. diffusers/pipelines/free_init_utils.py +39 -38
  121. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  122. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  123. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  124. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  125. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  126. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  127. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  128. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  129. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  130. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  131. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  132. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  133. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  134. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  135. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  136. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  137. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  138. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  139. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  140. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  141. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  142. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  143. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  144. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  145. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  146. diffusers/pipelines/marigold/__init__.py +50 -0
  147. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  148. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  149. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  150. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  151. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  152. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  153. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  154. diffusers/pipelines/pipeline_loading_utils.py +269 -23
  155. diffusers/pipelines/pipeline_utils.py +266 -37
  156. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
  158. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  159. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  160. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  161. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  162. diffusers/pipelines/shap_e/renderer.py +1 -1
  163. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  164. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  165. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  166. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  167. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  168. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  169. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  172. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  173. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  174. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  175. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  176. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  177. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  178. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  179. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  180. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  181. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  182. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  183. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  184. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  185. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  186. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  187. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  188. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  189. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  190. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  191. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  192. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  193. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  194. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  195. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  196. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  197. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  198. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  199. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  200. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  201. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  202. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  203. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  204. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  205. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  206. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  207. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  208. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  209. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  210. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  211. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  212. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  213. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  214. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  215. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  216. diffusers/schedulers/__init__.py +2 -2
  217. diffusers/schedulers/deprecated/__init__.py +1 -1
  218. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  219. diffusers/schedulers/scheduling_amused.py +5 -5
  220. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  221. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  222. diffusers/schedulers/scheduling_ddim.py +22 -24
  223. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  224. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  225. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  226. diffusers/schedulers/scheduling_ddpm.py +20 -22
  227. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  228. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  229. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  230. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  231. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  232. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  236. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  237. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  238. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  239. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  240. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  241. diffusers/schedulers/scheduling_ipndm.py +8 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  244. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  245. diffusers/schedulers/scheduling_lcm.py +21 -23
  246. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  247. diffusers/schedulers/scheduling_pndm.py +20 -20
  248. diffusers/schedulers/scheduling_repaint.py +20 -20
  249. diffusers/schedulers/scheduling_sasolver.py +55 -54
  250. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  251. diffusers/schedulers/scheduling_tcd.py +39 -30
  252. diffusers/schedulers/scheduling_unclip.py +15 -15
  253. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  254. diffusers/schedulers/scheduling_utils.py +14 -5
  255. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  256. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  257. diffusers/training_utils.py +56 -1
  258. diffusers/utils/__init__.py +7 -0
  259. diffusers/utils/doc_utils.py +1 -0
  260. diffusers/utils/dummy_pt_objects.py +75 -0
  261. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  262. diffusers/utils/dynamic_modules_utils.py +24 -11
  263. diffusers/utils/hub_utils.py +3 -2
  264. diffusers/utils/import_utils.py +91 -0
  265. diffusers/utils/loading_utils.py +2 -2
  266. diffusers/utils/logging.py +1 -1
  267. diffusers/utils/peft_utils.py +32 -5
  268. diffusers/utils/state_dict_utils.py +11 -2
  269. diffusers/utils/testing_utils.py +71 -6
  270. diffusers/utils/torch_utils.py +1 -0
  271. diffusers/video_processor.py +113 -0
  272. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
  273. diffusers-0.28.1.dist-info/RECORD +419 -0
  274. diffusers-0.27.2.dist-info/RECORD +0 -399
  275. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
  276. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
  277. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
  278. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -31,11 +31,11 @@ class TransformerTemporalModelOutput(BaseOutput):
31
31
  The output of [`TransformerTemporalModel`].
32
32
 
33
33
  Args:
34
- sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
34
+ sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
35
35
  The hidden states output conditioned on `encoder_hidden_states` input.
36
36
  """
37
37
 
38
- sample: torch.FloatTensor
38
+ sample: torch.Tensor
39
39
 
40
40
 
41
41
  class TransformerTemporalModel(ModelMixin, ConfigMixin):
@@ -120,7 +120,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
120
120
 
121
121
  def forward(
122
122
  self,
123
- hidden_states: torch.FloatTensor,
123
+ hidden_states: torch.Tensor,
124
124
  encoder_hidden_states: Optional[torch.LongTensor] = None,
125
125
  timestep: Optional[torch.LongTensor] = None,
126
126
  class_labels: torch.LongTensor = None,
@@ -132,7 +132,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
132
132
  The [`TransformerTemporal`] forward method.
133
133
 
134
134
  Args:
135
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
135
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
136
136
  Input hidden_states.
137
137
  encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
138
138
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
@@ -283,7 +283,7 @@ class TransformerSpatioTemporalModel(nn.Module):
283
283
  ):
284
284
  """
285
285
  Args:
286
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
286
+ hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
287
287
  Input hidden_states.
288
288
  num_frames (`int`):
289
289
  The number of frames to be processed per batch. This is used to reshape the hidden states.
@@ -294,8 +294,8 @@ class TransformerSpatioTemporalModel(nn.Module):
294
294
  A tensor indicating whether the input contains only images. 1 indicates that the input contains only
295
295
  images, 0 indicates that the input contains video frames.
296
296
  return_dict (`bool`, *optional*, defaults to `True`):
297
- Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
298
- tuple.
297
+ Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
298
+ plain tuple.
299
299
 
300
300
  Returns:
301
301
  [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
@@ -311,10 +311,10 @@ class TransformerSpatioTemporalModel(nn.Module):
311
311
  time_context_first_timestep = time_context[None, :].reshape(
312
312
  batch_size, num_frames, -1, time_context.shape[-1]
313
313
  )[:, 0]
314
- time_context = time_context_first_timestep[None, :].broadcast_to(
315
- height * width, batch_size, 1, time_context.shape[-1]
314
+ time_context = time_context_first_timestep[:, None].broadcast_to(
315
+ batch_size, height * width, time_context.shape[-2], time_context.shape[-1]
316
316
  )
317
- time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
317
+ time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1])
318
318
 
319
319
  residual = hidden_states
320
320
 
@@ -31,11 +31,11 @@ class UNet1DOutput(BaseOutput):
31
31
  The output of [`UNet1DModel`].
32
32
 
33
33
  Args:
34
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
34
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, sample_size)`):
35
35
  The hidden states output from the last layer of the model.
36
36
  """
37
37
 
38
- sample: torch.FloatTensor
38
+ sample: torch.Tensor
39
39
 
40
40
 
41
41
  class UNet1DModel(ModelMixin, ConfigMixin):
@@ -194,7 +194,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
194
194
 
195
195
  def forward(
196
196
  self,
197
- sample: torch.FloatTensor,
197
+ sample: torch.Tensor,
198
198
  timestep: Union[torch.Tensor, float, int],
199
199
  return_dict: bool = True,
200
200
  ) -> Union[UNet1DOutput, Tuple]:
@@ -202,9 +202,9 @@ class UNet1DModel(ModelMixin, ConfigMixin):
202
202
  The [`UNet1DModel`] forward method.
203
203
 
204
204
  Args:
205
- sample (`torch.FloatTensor`):
205
+ sample (`torch.Tensor`):
206
206
  The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
207
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
207
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
208
208
  return_dict (`bool`, *optional*, defaults to `True`):
209
209
  Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
210
210
 
@@ -66,7 +66,7 @@ class DownResnetBlock1D(nn.Module):
66
66
  if add_downsample:
67
67
  self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
68
68
 
69
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
69
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
70
70
  output_states = ()
71
71
 
72
72
  hidden_states = self.resnets[0](hidden_states, temb)
@@ -128,10 +128,10 @@ class UpResnetBlock1D(nn.Module):
128
128
 
129
129
  def forward(
130
130
  self,
131
- hidden_states: torch.FloatTensor,
132
- res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
133
- temb: Optional[torch.FloatTensor] = None,
134
- ) -> torch.FloatTensor:
131
+ hidden_states: torch.Tensor,
132
+ res_hidden_states_tuple: Optional[Tuple[torch.Tensor, ...]] = None,
133
+ temb: Optional[torch.Tensor] = None,
134
+ ) -> torch.Tensor:
135
135
  if res_hidden_states_tuple is not None:
136
136
  res_hidden_states = res_hidden_states_tuple[-1]
137
137
  hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
@@ -161,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module):
161
161
  self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
162
162
  self.down2 = Downsample1D(out_channels // 4, use_conv=True)
163
163
 
164
- def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
164
+ def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
165
165
  x = self.res1(x, temb)
166
166
  x = self.down1(x)
167
167
  x = self.res2(x, temb)
@@ -209,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module):
209
209
  if self.upsample and self.downsample:
210
210
  raise ValueError("Block cannot downsample and upsample")
211
211
 
212
- def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
212
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
213
213
  hidden_states = self.resnets[0](hidden_states, temb)
214
214
  for resnet in self.resnets[1:]:
215
215
  hidden_states = resnet(hidden_states, temb)
@@ -230,7 +230,7 @@ class OutConv1DBlock(nn.Module):
230
230
  self.final_conv1d_act = get_activation(act_fn)
231
231
  self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
232
232
 
233
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
233
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
234
234
  hidden_states = self.final_conv1d_1(hidden_states)
235
235
  hidden_states = rearrange_dims(hidden_states)
236
236
  hidden_states = self.final_conv1d_gn(hidden_states)
@@ -251,7 +251,7 @@ class OutValueFunctionBlock(nn.Module):
251
251
  ]
252
252
  )
253
253
 
254
- def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
254
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
255
255
  hidden_states = hidden_states.view(hidden_states.shape[0], -1)
256
256
  hidden_states = torch.cat((hidden_states, temb), dim=-1)
257
257
  for layer in self.final_block:
@@ -288,7 +288,7 @@ class Downsample1d(nn.Module):
288
288
  self.pad = kernel_1d.shape[0] // 2 - 1
289
289
  self.register_buffer("kernel", kernel_1d)
290
290
 
291
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
291
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
292
292
  hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
293
293
  weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
294
294
  indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
@@ -305,7 +305,7 @@ class Upsample1d(nn.Module):
305
305
  self.pad = kernel_1d.shape[0] // 2 - 1
306
306
  self.register_buffer("kernel", kernel_1d)
307
307
 
308
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
308
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
309
309
  hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
310
310
  weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
311
311
  indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
@@ -335,7 +335,7 @@ class SelfAttention1d(nn.Module):
335
335
  new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
336
336
  return new_projection
337
337
 
338
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
338
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
339
339
  residual = hidden_states
340
340
  batch, channel_dim, seq = hidden_states.shape
341
341
 
@@ -390,7 +390,7 @@ class ResConvBlock(nn.Module):
390
390
  self.group_norm_2 = nn.GroupNorm(1, out_channels)
391
391
  self.gelu_2 = nn.GELU()
392
392
 
393
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
393
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
394
394
  residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
395
395
 
396
396
  hidden_states = self.conv_1(hidden_states)
@@ -435,7 +435,7 @@ class UNetMidBlock1D(nn.Module):
435
435
  self.attentions = nn.ModuleList(attentions)
436
436
  self.resnets = nn.ModuleList(resnets)
437
437
 
438
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
438
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
439
439
  hidden_states = self.down(hidden_states)
440
440
  for attn, resnet in zip(self.attentions, self.resnets):
441
441
  hidden_states = resnet(hidden_states)
@@ -466,7 +466,7 @@ class AttnDownBlock1D(nn.Module):
466
466
  self.attentions = nn.ModuleList(attentions)
467
467
  self.resnets = nn.ModuleList(resnets)
468
468
 
469
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
469
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
470
470
  hidden_states = self.down(hidden_states)
471
471
 
472
472
  for resnet, attn in zip(self.resnets, self.attentions):
@@ -490,7 +490,7 @@ class DownBlock1D(nn.Module):
490
490
 
491
491
  self.resnets = nn.ModuleList(resnets)
492
492
 
493
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
493
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
494
494
  hidden_states = self.down(hidden_states)
495
495
 
496
496
  for resnet in self.resnets:
@@ -512,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module):
512
512
 
513
513
  self.resnets = nn.ModuleList(resnets)
514
514
 
515
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
515
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
516
516
  hidden_states = torch.cat([hidden_states, temb], dim=1)
517
517
  for resnet in self.resnets:
518
518
  hidden_states = resnet(hidden_states)
@@ -542,10 +542,10 @@ class AttnUpBlock1D(nn.Module):
542
542
 
543
543
  def forward(
544
544
  self,
545
- hidden_states: torch.FloatTensor,
546
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
547
- temb: Optional[torch.FloatTensor] = None,
548
- ) -> torch.FloatTensor:
545
+ hidden_states: torch.Tensor,
546
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
547
+ temb: Optional[torch.Tensor] = None,
548
+ ) -> torch.Tensor:
549
549
  res_hidden_states = res_hidden_states_tuple[-1]
550
550
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
551
551
 
@@ -574,10 +574,10 @@ class UpBlock1D(nn.Module):
574
574
 
575
575
  def forward(
576
576
  self,
577
- hidden_states: torch.FloatTensor,
578
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
579
- temb: Optional[torch.FloatTensor] = None,
580
- ) -> torch.FloatTensor:
577
+ hidden_states: torch.Tensor,
578
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
579
+ temb: Optional[torch.Tensor] = None,
580
+ ) -> torch.Tensor:
581
581
  res_hidden_states = res_hidden_states_tuple[-1]
582
582
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
583
583
 
@@ -604,10 +604,10 @@ class UpBlock1DNoSkip(nn.Module):
604
604
 
605
605
  def forward(
606
606
  self,
607
- hidden_states: torch.FloatTensor,
608
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
609
- temb: Optional[torch.FloatTensor] = None,
610
- ) -> torch.FloatTensor:
607
+ hidden_states: torch.Tensor,
608
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
609
+ temb: Optional[torch.Tensor] = None,
610
+ ) -> torch.Tensor:
611
611
  res_hidden_states = res_hidden_states_tuple[-1]
612
612
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
613
613
 
@@ -30,11 +30,11 @@ class UNet2DOutput(BaseOutput):
30
30
  The output of [`UNet2DModel`].
31
31
 
32
32
  Args:
33
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
33
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
34
34
  The hidden states output from the last layer of the model.
35
35
  """
36
36
 
37
- sample: torch.FloatTensor
37
+ sample: torch.Tensor
38
38
 
39
39
 
40
40
  class UNet2DModel(ModelMixin, ConfigMixin):
@@ -242,7 +242,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
242
242
 
243
243
  def forward(
244
244
  self,
245
- sample: torch.FloatTensor,
245
+ sample: torch.Tensor,
246
246
  timestep: Union[torch.Tensor, float, int],
247
247
  class_labels: Optional[torch.Tensor] = None,
248
248
  return_dict: bool = True,
@@ -251,10 +251,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
251
251
  The [`UNet2DModel`] forward method.
252
252
 
253
253
  Args:
254
- sample (`torch.FloatTensor`):
254
+ sample (`torch.Tensor`):
255
255
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
256
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
257
- class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
256
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
257
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
258
258
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
259
259
  return_dict (`bool`, *optional*, defaults to `True`):
260
260
  Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.