diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -561,7 +561,7 @@ class AutoencoderTinyBlock(nn.Module):
561
561
  ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
562
562
 
563
563
  Returns:
564
- `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
564
+ `torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
565
565
  `out_channels`.
566
566
  """
567
567
 
@@ -582,7 +582,7 @@ class AutoencoderTinyBlock(nn.Module):
582
582
  )
583
583
  self.fuse = nn.ReLU()
584
584
 
585
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
585
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
586
586
  return self.fuse(self.conv(x) + self.skip(x))
587
587
 
588
588
 
@@ -612,8 +612,8 @@ class UNetMidBlock2D(nn.Module):
612
612
  output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
613
613
 
614
614
  Returns:
615
- `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
616
- in_channels, height, width)`.
615
+ `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
616
+ height, width)`.
617
617
 
618
618
  """
619
619
 
@@ -731,7 +731,7 @@ class UNetMidBlock2D(nn.Module):
731
731
  self.attentions = nn.ModuleList(attentions)
732
732
  self.resnets = nn.ModuleList(resnets)
733
733
 
734
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
734
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
735
735
  hidden_states = self.resnets[0](hidden_states, temb)
736
736
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
737
737
  if attn is not None:
@@ -746,6 +746,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
746
746
  self,
747
747
  in_channels: int,
748
748
  temb_channels: int,
749
+ out_channels: Optional[int] = None,
749
750
  dropout: float = 0.0,
750
751
  num_layers: int = 1,
751
752
  transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -753,6 +754,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
753
754
  resnet_time_scale_shift: str = "default",
754
755
  resnet_act_fn: str = "swish",
755
756
  resnet_groups: int = 32,
757
+ resnet_groups_out: Optional[int] = None,
756
758
  resnet_pre_norm: bool = True,
757
759
  num_attention_heads: int = 1,
758
760
  output_scale_factor: float = 1.0,
@@ -764,6 +766,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
764
766
  ):
765
767
  super().__init__()
766
768
 
769
+ out_channels = out_channels or in_channels
770
+ self.in_channels = in_channels
771
+ self.out_channels = out_channels
772
+
767
773
  self.has_cross_attention = True
768
774
  self.num_attention_heads = num_attention_heads
769
775
  resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -772,14 +778,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
772
778
  if isinstance(transformer_layers_per_block, int):
773
779
  transformer_layers_per_block = [transformer_layers_per_block] * num_layers
774
780
 
781
+ resnet_groups_out = resnet_groups_out or resnet_groups
782
+
775
783
  # there is always at least one resnet
776
784
  resnets = [
777
785
  ResnetBlock2D(
778
786
  in_channels=in_channels,
779
- out_channels=in_channels,
787
+ out_channels=out_channels,
780
788
  temb_channels=temb_channels,
781
789
  eps=resnet_eps,
782
790
  groups=resnet_groups,
791
+ groups_out=resnet_groups_out,
783
792
  dropout=dropout,
784
793
  time_embedding_norm=resnet_time_scale_shift,
785
794
  non_linearity=resnet_act_fn,
@@ -794,11 +803,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
794
803
  attentions.append(
795
804
  Transformer2DModel(
796
805
  num_attention_heads,
797
- in_channels // num_attention_heads,
798
- in_channels=in_channels,
806
+ out_channels // num_attention_heads,
807
+ in_channels=out_channels,
799
808
  num_layers=transformer_layers_per_block[i],
800
809
  cross_attention_dim=cross_attention_dim,
801
- norm_num_groups=resnet_groups,
810
+ norm_num_groups=resnet_groups_out,
802
811
  use_linear_projection=use_linear_projection,
803
812
  upcast_attention=upcast_attention,
804
813
  attention_type=attention_type,
@@ -808,8 +817,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
808
817
  attentions.append(
809
818
  DualTransformer2DModel(
810
819
  num_attention_heads,
811
- in_channels // num_attention_heads,
812
- in_channels=in_channels,
820
+ out_channels // num_attention_heads,
821
+ in_channels=out_channels,
813
822
  num_layers=1,
814
823
  cross_attention_dim=cross_attention_dim,
815
824
  norm_num_groups=resnet_groups,
@@ -817,11 +826,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
817
826
  )
818
827
  resnets.append(
819
828
  ResnetBlock2D(
820
- in_channels=in_channels,
821
- out_channels=in_channels,
829
+ in_channels=out_channels,
830
+ out_channels=out_channels,
822
831
  temb_channels=temb_channels,
823
832
  eps=resnet_eps,
824
- groups=resnet_groups,
833
+ groups=resnet_groups_out,
825
834
  dropout=dropout,
826
835
  time_embedding_norm=resnet_time_scale_shift,
827
836
  non_linearity=resnet_act_fn,
@@ -837,16 +846,16 @@ class UNetMidBlock2DCrossAttn(nn.Module):
837
846
 
838
847
  def forward(
839
848
  self,
840
- hidden_states: torch.FloatTensor,
841
- temb: Optional[torch.FloatTensor] = None,
842
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
843
- attention_mask: Optional[torch.FloatTensor] = None,
849
+ hidden_states: torch.Tensor,
850
+ temb: Optional[torch.Tensor] = None,
851
+ encoder_hidden_states: Optional[torch.Tensor] = None,
852
+ attention_mask: Optional[torch.Tensor] = None,
844
853
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
845
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
846
- ) -> torch.FloatTensor:
854
+ encoder_attention_mask: Optional[torch.Tensor] = None,
855
+ ) -> torch.Tensor:
847
856
  if cross_attention_kwargs is not None:
848
857
  if cross_attention_kwargs.get("scale", None) is not None:
849
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
858
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
850
859
 
851
860
  hidden_states = self.resnets[0](hidden_states, temb)
852
861
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
@@ -977,16 +986,16 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
977
986
 
978
987
  def forward(
979
988
  self,
980
- hidden_states: torch.FloatTensor,
981
- temb: Optional[torch.FloatTensor] = None,
982
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
983
- attention_mask: Optional[torch.FloatTensor] = None,
989
+ hidden_states: torch.Tensor,
990
+ temb: Optional[torch.Tensor] = None,
991
+ encoder_hidden_states: Optional[torch.Tensor] = None,
992
+ attention_mask: Optional[torch.Tensor] = None,
984
993
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
985
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
986
- ) -> torch.FloatTensor:
994
+ encoder_attention_mask: Optional[torch.Tensor] = None,
995
+ ) -> torch.Tensor:
987
996
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
988
997
  if cross_attention_kwargs.get("scale", None) is not None:
989
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
998
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
990
999
 
991
1000
  if attention_mask is None:
992
1001
  # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -1109,14 +1118,14 @@ class AttnDownBlock2D(nn.Module):
1109
1118
 
1110
1119
  def forward(
1111
1120
  self,
1112
- hidden_states: torch.FloatTensor,
1113
- temb: Optional[torch.FloatTensor] = None,
1121
+ hidden_states: torch.Tensor,
1122
+ temb: Optional[torch.Tensor] = None,
1114
1123
  upsample_size: Optional[int] = None,
1115
1124
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1116
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1125
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1117
1126
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1118
1127
  if cross_attention_kwargs.get("scale", None) is not None:
1119
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1128
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1120
1129
 
1121
1130
  output_states = ()
1122
1131
 
@@ -1231,17 +1240,17 @@ class CrossAttnDownBlock2D(nn.Module):
1231
1240
 
1232
1241
  def forward(
1233
1242
  self,
1234
- hidden_states: torch.FloatTensor,
1235
- temb: Optional[torch.FloatTensor] = None,
1236
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1237
- attention_mask: Optional[torch.FloatTensor] = None,
1243
+ hidden_states: torch.Tensor,
1244
+ temb: Optional[torch.Tensor] = None,
1245
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1246
+ attention_mask: Optional[torch.Tensor] = None,
1238
1247
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1239
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1240
- additional_residuals: Optional[torch.FloatTensor] = None,
1241
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1248
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1249
+ additional_residuals: Optional[torch.Tensor] = None,
1250
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1242
1251
  if cross_attention_kwargs is not None:
1243
1252
  if cross_attention_kwargs.get("scale", None) is not None:
1244
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1253
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1245
1254
 
1246
1255
  output_states = ()
1247
1256
 
@@ -1353,8 +1362,8 @@ class DownBlock2D(nn.Module):
1353
1362
  self.gradient_checkpointing = False
1354
1363
 
1355
1364
  def forward(
1356
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
1357
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1365
+ self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
1366
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1358
1367
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1359
1368
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1360
1369
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1456,7 +1465,7 @@ class DownEncoderBlock2D(nn.Module):
1456
1465
  else:
1457
1466
  self.downsamplers = None
1458
1467
 
1459
- def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1468
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1460
1469
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1461
1470
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1462
1471
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1558,7 +1567,7 @@ class AttnDownEncoderBlock2D(nn.Module):
1558
1567
  else:
1559
1568
  self.downsamplers = None
1560
1569
 
1561
- def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1570
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1562
1571
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1563
1572
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1564
1573
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1657,12 +1666,12 @@ class AttnSkipDownBlock2D(nn.Module):
1657
1666
 
1658
1667
  def forward(
1659
1668
  self,
1660
- hidden_states: torch.FloatTensor,
1661
- temb: Optional[torch.FloatTensor] = None,
1662
- skip_sample: Optional[torch.FloatTensor] = None,
1669
+ hidden_states: torch.Tensor,
1670
+ temb: Optional[torch.Tensor] = None,
1671
+ skip_sample: Optional[torch.Tensor] = None,
1663
1672
  *args,
1664
1673
  **kwargs,
1665
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
1674
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
1666
1675
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1667
1676
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1668
1677
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1748,12 +1757,12 @@ class SkipDownBlock2D(nn.Module):
1748
1757
 
1749
1758
  def forward(
1750
1759
  self,
1751
- hidden_states: torch.FloatTensor,
1752
- temb: Optional[torch.FloatTensor] = None,
1753
- skip_sample: Optional[torch.FloatTensor] = None,
1760
+ hidden_states: torch.Tensor,
1761
+ temb: Optional[torch.Tensor] = None,
1762
+ skip_sample: Optional[torch.Tensor] = None,
1754
1763
  *args,
1755
1764
  **kwargs,
1756
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
1765
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
1757
1766
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1758
1767
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1759
1768
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1841,8 +1850,8 @@ class ResnetDownsampleBlock2D(nn.Module):
1841
1850
  self.gradient_checkpointing = False
1842
1851
 
1843
1852
  def forward(
1844
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
1845
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1853
+ self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
1854
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1846
1855
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1847
1856
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1848
1857
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1977,16 +1986,16 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
1977
1986
 
1978
1987
  def forward(
1979
1988
  self,
1980
- hidden_states: torch.FloatTensor,
1981
- temb: Optional[torch.FloatTensor] = None,
1982
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1983
- attention_mask: Optional[torch.FloatTensor] = None,
1989
+ hidden_states: torch.Tensor,
1990
+ temb: Optional[torch.Tensor] = None,
1991
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1992
+ attention_mask: Optional[torch.Tensor] = None,
1984
1993
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1985
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1986
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1994
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1995
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1987
1996
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1988
1997
  if cross_attention_kwargs.get("scale", None) is not None:
1989
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1998
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1990
1999
 
1991
2000
  output_states = ()
1992
2001
 
@@ -2088,8 +2097,8 @@ class KDownBlock2D(nn.Module):
2088
2097
  self.gradient_checkpointing = False
2089
2098
 
2090
2099
  def forward(
2091
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
2092
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2100
+ self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
2101
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
2093
2102
  if len(args) > 0 or kwargs.get("scale", None) is not None:
2094
2103
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2095
2104
  deprecate("scale", "1.0.0", deprecation_message)
@@ -2192,16 +2201,16 @@ class KCrossAttnDownBlock2D(nn.Module):
2192
2201
 
2193
2202
  def forward(
2194
2203
  self,
2195
- hidden_states: torch.FloatTensor,
2196
- temb: Optional[torch.FloatTensor] = None,
2197
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
2198
- attention_mask: Optional[torch.FloatTensor] = None,
2204
+ hidden_states: torch.Tensor,
2205
+ temb: Optional[torch.Tensor] = None,
2206
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2207
+ attention_mask: Optional[torch.Tensor] = None,
2199
2208
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
2200
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
2201
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2209
+ encoder_attention_mask: Optional[torch.Tensor] = None,
2210
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
2202
2211
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
2203
2212
  if cross_attention_kwargs.get("scale", None) is not None:
2204
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2213
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
2205
2214
 
2206
2215
  output_states = ()
2207
2216
 
@@ -2349,13 +2358,13 @@ class AttnUpBlock2D(nn.Module):
2349
2358
 
2350
2359
  def forward(
2351
2360
  self,
2352
- hidden_states: torch.FloatTensor,
2353
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2354
- temb: Optional[torch.FloatTensor] = None,
2361
+ hidden_states: torch.Tensor,
2362
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
2363
+ temb: Optional[torch.Tensor] = None,
2355
2364
  upsample_size: Optional[int] = None,
2356
2365
  *args,
2357
2366
  **kwargs,
2358
- ) -> torch.FloatTensor:
2367
+ ) -> torch.Tensor:
2359
2368
  if len(args) > 0 or kwargs.get("scale", None) is not None:
2360
2369
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2361
2370
  deprecate("scale", "1.0.0", deprecation_message)
@@ -2472,18 +2481,18 @@ class CrossAttnUpBlock2D(nn.Module):
2472
2481
 
2473
2482
  def forward(
2474
2483
  self,
2475
- hidden_states: torch.FloatTensor,
2476
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2477
- temb: Optional[torch.FloatTensor] = None,
2478
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
2484
+ hidden_states: torch.Tensor,
2485
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
2486
+ temb: Optional[torch.Tensor] = None,
2487
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2479
2488
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
2480
2489
  upsample_size: Optional[int] = None,
2481
- attention_mask: Optional[torch.FloatTensor] = None,
2482
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
2483
- ) -> torch.FloatTensor:
2490
+ attention_mask: Optional[torch.Tensor] = None,
2491
+ encoder_attention_mask: Optional[torch.Tensor] = None,
2492
+ ) -> torch.Tensor:
2484
2493
  if cross_attention_kwargs is not None:
2485
2494
  if cross_attention_kwargs.get("scale", None) is not None:
2486
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2495
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
2487
2496
 
2488
2497
  is_freeu_enabled = (
2489
2498
  getattr(self, "s1", None)
@@ -2607,13 +2616,13 @@ class UpBlock2D(nn.Module):
2607
2616
 
2608
2617
  def forward(
2609
2618
  self,
2610
- hidden_states: torch.FloatTensor,
2611
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2612
- temb: Optional[torch.FloatTensor] = None,
2619
+ hidden_states: torch.Tensor,
2620
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
2621
+ temb: Optional[torch.Tensor] = None,
2613
2622
  upsample_size: Optional[int] = None,
2614
2623
  *args,
2615
2624
  **kwargs,
2616
- ) -> torch.FloatTensor:
2625
+ ) -> torch.Tensor:
2617
2626
  if len(args) > 0 or kwargs.get("scale", None) is not None:
2618
2627
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2619
2628
  deprecate("scale", "1.0.0", deprecation_message)
@@ -2732,7 +2741,7 @@ class UpDecoderBlock2D(nn.Module):
2732
2741
 
2733
2742
  self.resolution_idx = resolution_idx
2734
2743
 
2735
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
2744
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
2736
2745
  for resnet in self.resnets:
2737
2746
  hidden_states = resnet(hidden_states, temb=temb)
2738
2747
 
@@ -2830,7 +2839,7 @@ class AttnUpDecoderBlock2D(nn.Module):
2830
2839
 
2831
2840
  self.resolution_idx = resolution_idx
2832
2841
 
2833
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
2842
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
2834
2843
  for resnet, attn in zip(self.resnets, self.attentions):
2835
2844
  hidden_states = resnet(hidden_states, temb=temb)
2836
2845
  hidden_states = attn(hidden_states, temb=temb)
@@ -2938,13 +2947,13 @@ class AttnSkipUpBlock2D(nn.Module):
2938
2947
 
2939
2948
  def forward(
2940
2949
  self,
2941
- hidden_states: torch.FloatTensor,
2942
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2943
- temb: Optional[torch.FloatTensor] = None,
2950
+ hidden_states: torch.Tensor,
2951
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
2952
+ temb: Optional[torch.Tensor] = None,
2944
2953
  skip_sample=None,
2945
2954
  *args,
2946
2955
  **kwargs,
2947
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
2956
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
2948
2957
  if len(args) > 0 or kwargs.get("scale", None) is not None:
2949
2958
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2950
2959
  deprecate("scale", "1.0.0", deprecation_message)
@@ -3050,13 +3059,13 @@ class SkipUpBlock2D(nn.Module):
3050
3059
 
3051
3060
  def forward(
3052
3061
  self,
3053
- hidden_states: torch.FloatTensor,
3054
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3055
- temb: Optional[torch.FloatTensor] = None,
3062
+ hidden_states: torch.Tensor,
3063
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3064
+ temb: Optional[torch.Tensor] = None,
3056
3065
  skip_sample=None,
3057
3066
  *args,
3058
3067
  **kwargs,
3059
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
3068
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
3060
3069
  if len(args) > 0 or kwargs.get("scale", None) is not None:
3061
3070
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3062
3071
  deprecate("scale", "1.0.0", deprecation_message)
@@ -3157,13 +3166,13 @@ class ResnetUpsampleBlock2D(nn.Module):
3157
3166
 
3158
3167
  def forward(
3159
3168
  self,
3160
- hidden_states: torch.FloatTensor,
3161
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3162
- temb: Optional[torch.FloatTensor] = None,
3169
+ hidden_states: torch.Tensor,
3170
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3171
+ temb: Optional[torch.Tensor] = None,
3163
3172
  upsample_size: Optional[int] = None,
3164
3173
  *args,
3165
3174
  **kwargs,
3166
- ) -> torch.FloatTensor:
3175
+ ) -> torch.Tensor:
3167
3176
  if len(args) > 0 or kwargs.get("scale", None) is not None:
3168
3177
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3169
3178
  deprecate("scale", "1.0.0", deprecation_message)
@@ -3301,18 +3310,18 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
3301
3310
 
3302
3311
  def forward(
3303
3312
  self,
3304
- hidden_states: torch.FloatTensor,
3305
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3306
- temb: Optional[torch.FloatTensor] = None,
3307
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
3313
+ hidden_states: torch.Tensor,
3314
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3315
+ temb: Optional[torch.Tensor] = None,
3316
+ encoder_hidden_states: Optional[torch.Tensor] = None,
3308
3317
  upsample_size: Optional[int] = None,
3309
- attention_mask: Optional[torch.FloatTensor] = None,
3318
+ attention_mask: Optional[torch.Tensor] = None,
3310
3319
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
3311
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
3312
- ) -> torch.FloatTensor:
3320
+ encoder_attention_mask: Optional[torch.Tensor] = None,
3321
+ ) -> torch.Tensor:
3313
3322
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
3314
3323
  if cross_attention_kwargs.get("scale", None) is not None:
3315
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
3324
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
3316
3325
 
3317
3326
  if attention_mask is None:
3318
3327
  # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -3419,13 +3428,13 @@ class KUpBlock2D(nn.Module):
3419
3428
 
3420
3429
  def forward(
3421
3430
  self,
3422
- hidden_states: torch.FloatTensor,
3423
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3424
- temb: Optional[torch.FloatTensor] = None,
3431
+ hidden_states: torch.Tensor,
3432
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3433
+ temb: Optional[torch.Tensor] = None,
3425
3434
  upsample_size: Optional[int] = None,
3426
3435
  *args,
3427
3436
  **kwargs,
3428
- ) -> torch.FloatTensor:
3437
+ ) -> torch.Tensor:
3429
3438
  if len(args) > 0 or kwargs.get("scale", None) is not None:
3430
3439
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3431
3440
  deprecate("scale", "1.0.0", deprecation_message)
@@ -3549,15 +3558,15 @@ class KCrossAttnUpBlock2D(nn.Module):
3549
3558
 
3550
3559
  def forward(
3551
3560
  self,
3552
- hidden_states: torch.FloatTensor,
3553
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3554
- temb: Optional[torch.FloatTensor] = None,
3555
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
3561
+ hidden_states: torch.Tensor,
3562
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3563
+ temb: Optional[torch.Tensor] = None,
3564
+ encoder_hidden_states: Optional[torch.Tensor] = None,
3556
3565
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
3557
3566
  upsample_size: Optional[int] = None,
3558
- attention_mask: Optional[torch.FloatTensor] = None,
3559
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
3560
- ) -> torch.FloatTensor:
3567
+ attention_mask: Optional[torch.Tensor] = None,
3568
+ encoder_attention_mask: Optional[torch.Tensor] = None,
3569
+ ) -> torch.Tensor:
3561
3570
  res_hidden_states_tuple = res_hidden_states_tuple[-1]
3562
3571
  if res_hidden_states_tuple is not None:
3563
3572
  hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
@@ -3675,26 +3684,26 @@ class KAttentionBlock(nn.Module):
3675
3684
  cross_attention_norm=cross_attention_norm,
3676
3685
  )
3677
3686
 
3678
- def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
3687
+ def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
3679
3688
  return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
3680
3689
 
3681
- def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
3690
+ def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
3682
3691
  return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
3683
3692
 
3684
3693
  def forward(
3685
3694
  self,
3686
- hidden_states: torch.FloatTensor,
3687
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
3695
+ hidden_states: torch.Tensor,
3696
+ encoder_hidden_states: Optional[torch.Tensor] = None,
3688
3697
  # TODO: mark emb as non-optional (self.norm2 requires it).
3689
3698
  # requires assessing impact of change to positional param interface.
3690
- emb: Optional[torch.FloatTensor] = None,
3691
- attention_mask: Optional[torch.FloatTensor] = None,
3699
+ emb: Optional[torch.Tensor] = None,
3700
+ attention_mask: Optional[torch.Tensor] = None,
3692
3701
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
3693
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
3694
- ) -> torch.FloatTensor:
3702
+ encoder_attention_mask: Optional[torch.Tensor] = None,
3703
+ ) -> torch.Tensor:
3695
3704
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
3696
3705
  if cross_attention_kwargs.get("scale", None) is not None:
3697
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
3706
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
3698
3707
 
3699
3708
  # 1. Self-Attention
3700
3709
  if self.add_self_attention: