diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -66,7 +66,7 @@ class DownResnetBlock1D(nn.Module):
66
66
  if add_downsample:
67
67
  self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
68
68
 
69
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
69
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
70
70
  output_states = ()
71
71
 
72
72
  hidden_states = self.resnets[0](hidden_states, temb)
@@ -128,10 +128,10 @@ class UpResnetBlock1D(nn.Module):
128
128
 
129
129
  def forward(
130
130
  self,
131
- hidden_states: torch.FloatTensor,
132
- res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
133
- temb: Optional[torch.FloatTensor] = None,
134
- ) -> torch.FloatTensor:
131
+ hidden_states: torch.Tensor,
132
+ res_hidden_states_tuple: Optional[Tuple[torch.Tensor, ...]] = None,
133
+ temb: Optional[torch.Tensor] = None,
134
+ ) -> torch.Tensor:
135
135
  if res_hidden_states_tuple is not None:
136
136
  res_hidden_states = res_hidden_states_tuple[-1]
137
137
  hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
@@ -161,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module):
161
161
  self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
162
162
  self.down2 = Downsample1D(out_channels // 4, use_conv=True)
163
163
 
164
- def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
164
+ def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
165
165
  x = self.res1(x, temb)
166
166
  x = self.down1(x)
167
167
  x = self.res2(x, temb)
@@ -209,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module):
209
209
  if self.upsample and self.downsample:
210
210
  raise ValueError("Block cannot downsample and upsample")
211
211
 
212
- def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
212
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
213
213
  hidden_states = self.resnets[0](hidden_states, temb)
214
214
  for resnet in self.resnets[1:]:
215
215
  hidden_states = resnet(hidden_states, temb)
@@ -230,7 +230,7 @@ class OutConv1DBlock(nn.Module):
230
230
  self.final_conv1d_act = get_activation(act_fn)
231
231
  self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
232
232
 
233
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
233
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
234
234
  hidden_states = self.final_conv1d_1(hidden_states)
235
235
  hidden_states = rearrange_dims(hidden_states)
236
236
  hidden_states = self.final_conv1d_gn(hidden_states)
@@ -251,7 +251,7 @@ class OutValueFunctionBlock(nn.Module):
251
251
  ]
252
252
  )
253
253
 
254
- def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
254
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
255
255
  hidden_states = hidden_states.view(hidden_states.shape[0], -1)
256
256
  hidden_states = torch.cat((hidden_states, temb), dim=-1)
257
257
  for layer in self.final_block:
@@ -288,7 +288,7 @@ class Downsample1d(nn.Module):
288
288
  self.pad = kernel_1d.shape[0] // 2 - 1
289
289
  self.register_buffer("kernel", kernel_1d)
290
290
 
291
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
291
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
292
292
  hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
293
293
  weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
294
294
  indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
@@ -305,7 +305,7 @@ class Upsample1d(nn.Module):
305
305
  self.pad = kernel_1d.shape[0] // 2 - 1
306
306
  self.register_buffer("kernel", kernel_1d)
307
307
 
308
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
308
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
309
309
  hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
310
310
  weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
311
311
  indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
@@ -335,7 +335,7 @@ class SelfAttention1d(nn.Module):
335
335
  new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
336
336
  return new_projection
337
337
 
338
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
338
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
339
339
  residual = hidden_states
340
340
  batch, channel_dim, seq = hidden_states.shape
341
341
 
@@ -390,7 +390,7 @@ class ResConvBlock(nn.Module):
390
390
  self.group_norm_2 = nn.GroupNorm(1, out_channels)
391
391
  self.gelu_2 = nn.GELU()
392
392
 
393
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
393
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
394
394
  residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
395
395
 
396
396
  hidden_states = self.conv_1(hidden_states)
@@ -435,7 +435,7 @@ class UNetMidBlock1D(nn.Module):
435
435
  self.attentions = nn.ModuleList(attentions)
436
436
  self.resnets = nn.ModuleList(resnets)
437
437
 
438
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
438
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
439
439
  hidden_states = self.down(hidden_states)
440
440
  for attn, resnet in zip(self.attentions, self.resnets):
441
441
  hidden_states = resnet(hidden_states)
@@ -466,7 +466,7 @@ class AttnDownBlock1D(nn.Module):
466
466
  self.attentions = nn.ModuleList(attentions)
467
467
  self.resnets = nn.ModuleList(resnets)
468
468
 
469
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
469
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
470
470
  hidden_states = self.down(hidden_states)
471
471
 
472
472
  for resnet, attn in zip(self.resnets, self.attentions):
@@ -490,7 +490,7 @@ class DownBlock1D(nn.Module):
490
490
 
491
491
  self.resnets = nn.ModuleList(resnets)
492
492
 
493
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
493
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
494
494
  hidden_states = self.down(hidden_states)
495
495
 
496
496
  for resnet in self.resnets:
@@ -512,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module):
512
512
 
513
513
  self.resnets = nn.ModuleList(resnets)
514
514
 
515
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
515
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
516
516
  hidden_states = torch.cat([hidden_states, temb], dim=1)
517
517
  for resnet in self.resnets:
518
518
  hidden_states = resnet(hidden_states)
@@ -542,10 +542,10 @@ class AttnUpBlock1D(nn.Module):
542
542
 
543
543
  def forward(
544
544
  self,
545
- hidden_states: torch.FloatTensor,
546
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
547
- temb: Optional[torch.FloatTensor] = None,
548
- ) -> torch.FloatTensor:
545
+ hidden_states: torch.Tensor,
546
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
547
+ temb: Optional[torch.Tensor] = None,
548
+ ) -> torch.Tensor:
549
549
  res_hidden_states = res_hidden_states_tuple[-1]
550
550
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
551
551
 
@@ -574,10 +574,10 @@ class UpBlock1D(nn.Module):
574
574
 
575
575
  def forward(
576
576
  self,
577
- hidden_states: torch.FloatTensor,
578
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
579
- temb: Optional[torch.FloatTensor] = None,
580
- ) -> torch.FloatTensor:
577
+ hidden_states: torch.Tensor,
578
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
579
+ temb: Optional[torch.Tensor] = None,
580
+ ) -> torch.Tensor:
581
581
  res_hidden_states = res_hidden_states_tuple[-1]
582
582
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
583
583
 
@@ -604,10 +604,10 @@ class UpBlock1DNoSkip(nn.Module):
604
604
 
605
605
  def forward(
606
606
  self,
607
- hidden_states: torch.FloatTensor,
608
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
609
- temb: Optional[torch.FloatTensor] = None,
610
- ) -> torch.FloatTensor:
607
+ hidden_states: torch.Tensor,
608
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
609
+ temb: Optional[torch.Tensor] = None,
610
+ ) -> torch.Tensor:
611
611
  res_hidden_states = res_hidden_states_tuple[-1]
612
612
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
613
613
 
@@ -30,11 +30,11 @@ class UNet2DOutput(BaseOutput):
30
30
  The output of [`UNet2DModel`].
31
31
 
32
32
  Args:
33
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
33
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
34
34
  The hidden states output from the last layer of the model.
35
35
  """
36
36
 
37
- sample: torch.FloatTensor
37
+ sample: torch.Tensor
38
38
 
39
39
 
40
40
  class UNet2DModel(ModelMixin, ConfigMixin):
@@ -242,7 +242,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
242
242
 
243
243
  def forward(
244
244
  self,
245
- sample: torch.FloatTensor,
245
+ sample: torch.Tensor,
246
246
  timestep: Union[torch.Tensor, float, int],
247
247
  class_labels: Optional[torch.Tensor] = None,
248
248
  return_dict: bool = True,
@@ -251,10 +251,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
251
251
  The [`UNet2DModel`] forward method.
252
252
 
253
253
  Args:
254
- sample (`torch.FloatTensor`):
254
+ sample (`torch.Tensor`):
255
255
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
256
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
257
- class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
256
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
257
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
258
258
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
259
259
  return_dict (`bool`, *optional*, defaults to `True`):
260
260
  Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.