diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -329,13 +329,6 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
329
329
  safety_checker=safety_checker,
330
330
  feature_extractor=feature_extractor,
331
331
  )
332
- processor = (
333
- CrossFrameAttnProcessor2_0(batch_size=2)
334
- if hasattr(F, "scaled_dot_product_attention")
335
- else CrossFrameAttnProcessor(batch_size=2)
336
- )
337
- self.unet.set_attn_processor(processor)
338
-
339
332
  if safety_checker is None and requires_safety_checker:
340
333
  logger.warning(
341
334
  f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -399,7 +392,7 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
399
392
  `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
400
393
  callback (`Callable`, *optional*):
401
394
  A function that calls every `callback_steps` steps during inference. The function is called with the
402
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
395
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
403
396
  callback_steps (`int`, *optional*, defaults to 1):
404
397
  The frequency at which the `callback` function is called. If not specified, the callback is called at
405
398
  every step.
@@ -502,7 +495,12 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
502
495
 
503
496
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
504
497
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
505
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
498
+ shape = (
499
+ batch_size,
500
+ num_channels_latents,
501
+ int(height) // self.vae_scale_factor,
502
+ int(width) // self.vae_scale_factor,
503
+ )
506
504
  if isinstance(generator, list) and len(generator) != batch_size:
507
505
  raise ValueError(
508
506
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -531,12 +529,12 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
531
529
  num_videos_per_prompt: Optional[int] = 1,
532
530
  eta: float = 0.0,
533
531
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
534
- latents: Optional[torch.FloatTensor] = None,
532
+ latents: Optional[torch.Tensor] = None,
535
533
  motion_field_strength_x: float = 12,
536
534
  motion_field_strength_y: float = 12,
537
535
  output_type: Optional[str] = "tensor",
538
536
  return_dict: bool = True,
539
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
537
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
540
538
  callback_steps: Optional[int] = 1,
541
539
  t0: int = 44,
542
540
  t1: int = 47,
@@ -571,19 +569,19 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
571
569
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
572
570
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
573
571
  generation deterministic.
574
- latents (`torch.FloatTensor`, *optional*):
572
+ latents (`torch.Tensor`, *optional*):
575
573
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
576
574
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
577
575
  tensor is generated by sampling using the supplied random `generator`.
578
- output_type (`str`, *optional*, defaults to `"numpy"`):
579
- The output format of the generated video. Choose between `"latent"` and `"numpy"`.
576
+ output_type (`str`, *optional*, defaults to `"np"`):
577
+ The output format of the generated video. Choose between `"latent"` and `"np"`.
580
578
  return_dict (`bool`, *optional*, defaults to `True`):
581
579
  Whether or not to return a
582
580
  [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput`] instead of
583
581
  a plain tuple.
584
582
  callback (`Callable`, *optional*):
585
583
  A function that calls every `callback_steps` steps during inference. The function is called with the
586
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
584
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
587
585
  callback_steps (`int`, *optional*, defaults to 1):
588
586
  The frequency at which the `callback` function is called. If not specified, the callback is called at
589
587
  every step.
@@ -616,6 +614,15 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
616
614
 
617
615
  assert num_videos_per_prompt == 1
618
616
 
617
+ # set the processor
618
+ original_attn_proc = self.unet.attn_processors
619
+ processor = (
620
+ CrossFrameAttnProcessor2_0(batch_size=2)
621
+ if hasattr(F, "scaled_dot_product_attention")
622
+ else CrossFrameAttnProcessor(batch_size=2)
623
+ )
624
+ self.unet.set_attn_processor(processor)
625
+
619
626
  if isinstance(prompt, str):
620
627
  prompt = [prompt]
621
628
  if isinstance(negative_prompt, str):
@@ -739,6 +746,8 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
739
746
 
740
747
  # Offload all models
741
748
  self.maybe_free_model_hooks()
749
+ # make sure to set the original attention processors back
750
+ self.unet.set_attn_processor(original_attn_proc)
742
751
 
743
752
  if not return_dict:
744
753
  return (image, has_nsfw_concept)
@@ -786,8 +795,8 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
786
795
  num_images_per_prompt,
787
796
  do_classifier_free_guidance,
788
797
  negative_prompt=None,
789
- prompt_embeds: Optional[torch.FloatTensor] = None,
790
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
798
+ prompt_embeds: Optional[torch.Tensor] = None,
799
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
791
800
  lora_scale: Optional[float] = None,
792
801
  clip_skip: Optional[int] = None,
793
802
  ):
@@ -807,10 +816,10 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
807
816
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
808
817
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
809
818
  less than `1`).
810
- prompt_embeds (`torch.FloatTensor`, *optional*):
819
+ prompt_embeds (`torch.Tensor`, *optional*):
811
820
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
812
821
  provided, text embeddings will be generated from `prompt` input argument.
813
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
822
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
814
823
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
815
824
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
816
825
  argument.
@@ -411,14 +411,6 @@ class TextToVideoZeroSDXLPipeline(
411
411
  else:
412
412
  self.watermark = None
413
413
 
414
- processor = (
415
- CrossFrameAttnProcessor2_0(batch_size=2)
416
- if hasattr(F, "scaled_dot_product_attention")
417
- else CrossFrameAttnProcessor(batch_size=2)
418
- )
419
-
420
- self.unet.set_attn_processor(processor)
421
-
422
414
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
423
415
  def prepare_extra_step_kwargs(self, generator, eta):
424
416
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -479,7 +471,12 @@ class TextToVideoZeroSDXLPipeline(
479
471
 
480
472
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
481
473
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
482
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
474
+ shape = (
475
+ batch_size,
476
+ num_channels_latents,
477
+ int(height) // self.vae_scale_factor,
478
+ int(width) // self.vae_scale_factor,
479
+ )
483
480
  if isinstance(generator, list) and len(generator) != batch_size:
484
481
  raise ValueError(
485
482
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -584,10 +581,10 @@ class TextToVideoZeroSDXLPipeline(
584
581
  do_classifier_free_guidance: bool = True,
585
582
  negative_prompt: Optional[str] = None,
586
583
  negative_prompt_2: Optional[str] = None,
587
- prompt_embeds: Optional[torch.FloatTensor] = None,
588
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
589
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
590
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
584
+ prompt_embeds: Optional[torch.Tensor] = None,
585
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
586
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
587
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
591
588
  lora_scale: Optional[float] = None,
592
589
  clip_skip: Optional[int] = None,
593
590
  ):
@@ -613,17 +610,17 @@ class TextToVideoZeroSDXLPipeline(
613
610
  negative_prompt_2 (`str` or `List[str]`, *optional*):
614
611
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
615
612
  `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
616
- prompt_embeds (`torch.FloatTensor`, *optional*):
613
+ prompt_embeds (`torch.Tensor`, *optional*):
617
614
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
618
615
  provided, text embeddings will be generated from `prompt` input argument.
619
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
616
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
620
617
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
621
618
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
622
619
  argument.
623
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
620
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
624
621
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
625
622
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
626
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
623
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
627
624
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
628
625
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
629
626
  input argument.
@@ -864,7 +861,7 @@ class TextToVideoZeroSDXLPipeline(
864
861
  `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
865
862
  callback (`Callable`, *optional*):
866
863
  A function that calls every `callback_steps` steps during inference. The function is called with the
867
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
864
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
868
865
  callback_steps (`int`, *optional*, defaults to 1):
869
866
  The frequency at which the `callback` function is called. If not specified, the callback is called at
870
867
  every step.
@@ -936,16 +933,16 @@ class TextToVideoZeroSDXLPipeline(
936
933
  eta: float = 0.0,
937
934
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
938
935
  frame_ids: Optional[List[int]] = None,
939
- prompt_embeds: Optional[torch.FloatTensor] = None,
940
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
941
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
942
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
943
- latents: Optional[torch.FloatTensor] = None,
936
+ prompt_embeds: Optional[torch.Tensor] = None,
937
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
938
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
939
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
940
+ latents: Optional[torch.Tensor] = None,
944
941
  motion_field_strength_x: float = 12,
945
942
  motion_field_strength_y: float = 12,
946
943
  output_type: Optional[str] = "tensor",
947
944
  return_dict: bool = True,
948
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
945
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
949
946
  callback_steps: int = 1,
950
947
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
951
948
  guidance_rescale: float = 0.0,
@@ -1005,21 +1002,21 @@ class TextToVideoZeroSDXLPipeline(
1005
1002
  frame_ids (`List[int]`, *optional*):
1006
1003
  Indexes of the frames that are being generated. This is used when generating longer videos
1007
1004
  chunk-by-chunk.
1008
- prompt_embeds (`torch.FloatTensor`, *optional*):
1005
+ prompt_embeds (`torch.Tensor`, *optional*):
1009
1006
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1010
1007
  provided, text embeddings will be generated from `prompt` input argument.
1011
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1008
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1012
1009
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1013
1010
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1014
1011
  argument.
1015
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1012
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1016
1013
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1017
1014
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
1018
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1015
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1019
1016
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1020
1017
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1021
1018
  input argument.
1022
- latents (`torch.FloatTensor`, *optional*):
1019
+ latents (`torch.Tensor`, *optional*):
1023
1020
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1024
1021
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1025
1022
  tensor will ge generated by sampling using the supplied random `generator`.
@@ -1037,7 +1034,7 @@ class TextToVideoZeroSDXLPipeline(
1037
1034
  of a plain tuple.
1038
1035
  callback (`Callable`, *optional*):
1039
1036
  A function that will be called every `callback_steps` steps during inference. The function will be
1040
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1037
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
1041
1038
  callback_steps (`int`, *optional*, defaults to 1):
1042
1039
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1043
1040
  called at every step.
@@ -1084,6 +1081,15 @@ class TextToVideoZeroSDXLPipeline(
1084
1081
 
1085
1082
  assert num_videos_per_prompt == 1
1086
1083
 
1084
+ # set the processor
1085
+ original_attn_proc = self.unet.attn_processors
1086
+ processor = (
1087
+ CrossFrameAttnProcessor2_0(batch_size=2)
1088
+ if hasattr(F, "scaled_dot_product_attention")
1089
+ else CrossFrameAttnProcessor(batch_size=2)
1090
+ )
1091
+ self.unet.set_attn_processor(processor)
1092
+
1087
1093
  if isinstance(prompt, str):
1088
1094
  prompt = [prompt]
1089
1095
  if isinstance(negative_prompt, str):
@@ -1305,9 +1311,9 @@ class TextToVideoZeroSDXLPipeline(
1305
1311
 
1306
1312
  image = self.image_processor.postprocess(image, output_type=output_type)
1307
1313
 
1308
- # Offload last model to CPU manually for max memory savings
1309
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1310
- self.final_offload_hook.offload()
1314
+ self.maybe_free_model_hooks()
1315
+ # make sure to set the original attention processors back
1316
+ self.unet.set_attn_processor(original_attn_proc)
1311
1317
 
1312
1318
  if not return_dict:
1313
1319
  return (image,)
@@ -217,9 +217,9 @@ class UnCLIPPipeline(DiffusionPipeline):
217
217
  decoder_num_inference_steps: int = 25,
218
218
  super_res_num_inference_steps: int = 7,
219
219
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
220
- prior_latents: Optional[torch.FloatTensor] = None,
221
- decoder_latents: Optional[torch.FloatTensor] = None,
222
- super_res_latents: Optional[torch.FloatTensor] = None,
220
+ prior_latents: Optional[torch.Tensor] = None,
221
+ decoder_latents: Optional[torch.Tensor] = None,
222
+ super_res_latents: Optional[torch.Tensor] = None,
223
223
  text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
224
224
  text_attention_mask: Optional[torch.Tensor] = None,
225
225
  prior_guidance_scale: float = 4.0,
@@ -248,11 +248,11 @@ class UnCLIPPipeline(DiffusionPipeline):
248
248
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
249
249
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
250
250
  generation deterministic.
251
- prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
251
+ prior_latents (`torch.Tensor` of shape (batch size, embeddings dimension), *optional*):
252
252
  Pre-generated noisy latents to be used as inputs for the prior.
253
- decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
253
+ decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*):
254
254
  Pre-generated noisy latents to be used as inputs for the decoder.
255
- super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
255
+ super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*):
256
256
  Pre-generated noisy latents to be used as inputs for the decoder.
257
257
  prior_guidance_scale (`float`, *optional*, defaults to 4.0):
258
258
  A higher guidance scale value encourages the model to generate images closely linked to the text
@@ -199,13 +199,13 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
199
199
  @torch.no_grad()
200
200
  def __call__(
201
201
  self,
202
- image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
202
+ image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor]] = None,
203
203
  num_images_per_prompt: int = 1,
204
204
  decoder_num_inference_steps: int = 25,
205
205
  super_res_num_inference_steps: int = 7,
206
206
  generator: Optional[torch.Generator] = None,
207
- decoder_latents: Optional[torch.FloatTensor] = None,
208
- super_res_latents: Optional[torch.FloatTensor] = None,
207
+ decoder_latents: Optional[torch.Tensor] = None,
208
+ super_res_latents: Optional[torch.Tensor] = None,
209
209
  image_embeddings: Optional[torch.Tensor] = None,
210
210
  decoder_guidance_scale: float = 8.0,
211
211
  output_type: Optional[str] = "pil",
@@ -215,7 +215,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
215
215
  The call function to the pipeline for generation.
216
216
 
217
217
  Args:
218
- image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
218
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
219
219
  `Image` or tensor representing an image batch to be used as the starting point. If you provide a
220
220
  tensor, it needs to be compatible with the [`CLIPImageProcessor`]
221
221
  [configuration](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
@@ -231,9 +231,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
231
231
  generator (`torch.Generator`, *optional*):
232
232
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
233
233
  generation deterministic.
234
- decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
234
+ decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*):
235
235
  Pre-generated noisy latents to be used as inputs for the decoder.
236
- super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
236
+ super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*):
237
237
  Pre-generated noisy latents to be used as inputs for the decoder.
238
238
  decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
239
239
  A higher guidance scale value encourages the model to generate images closely linked to the text
@@ -220,7 +220,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
220
220
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
221
221
  Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds`
222
222
  must be supplied.
223
- input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
223
+ input_embeds (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
224
224
  An embedded representation to directly pass to the transformer as a prefix for beam search. One of
225
225
  `input_ids` and `input_embeds` must be supplied.
226
226
  device:
@@ -739,8 +739,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
739
739
  """
740
740
  Args:
741
741
  hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
742
- When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
743
- hidden_states
742
+ When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states
744
743
  encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
745
744
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
746
745
  self-attention.
@@ -752,7 +751,8 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
752
751
  cross_attention_kwargs (*optional*):
753
752
  Keyword arguments to supply to the cross attention layers, if used.
754
753
  return_dict (`bool`, *optional*, defaults to `True`):
755
- Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
754
+ Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
755
+ tuple.
756
756
  hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
757
757
  Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
758
758
  ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
@@ -1037,9 +1037,9 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
1037
1037
 
1038
1038
  def forward(
1039
1039
  self,
1040
- latent_image_embeds: torch.FloatTensor,
1041
- image_embeds: torch.FloatTensor,
1042
- prompt_embeds: torch.FloatTensor,
1040
+ latent_image_embeds: torch.Tensor,
1041
+ image_embeds: torch.Tensor,
1042
+ prompt_embeds: torch.Tensor,
1043
1043
  timestep_img: Union[torch.Tensor, float, int],
1044
1044
  timestep_text: Union[torch.Tensor, float, int],
1045
1045
  data_type: Optional[Union[torch.Tensor, float, int]] = 1,
@@ -1048,11 +1048,11 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
1048
1048
  ):
1049
1049
  """
1050
1050
  Args:
1051
- latent_image_embeds (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`):
1051
+ latent_image_embeds (`torch.Tensor` of shape `(batch size, latent channels, height, width)`):
1052
1052
  Latent image representation from the VAE encoder.
1053
- image_embeds (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`):
1053
+ image_embeds (`torch.Tensor` of shape `(batch size, 1, clip_img_dim)`):
1054
1054
  CLIP-embedded image representation (unsqueezed in the first dimension).
1055
- prompt_embeds (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`):
1055
+ prompt_embeds (`torch.Tensor` of shape `(batch size, seq_len, text_dim)`):
1056
1056
  CLIP-embedded text representation.
1057
1057
  timestep_img (`torch.long` or `float` or `int`):
1058
1058
  Current denoising step for the image.
@@ -304,7 +304,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
304
304
  if isinstance(image, PIL.Image.Image):
305
305
  batch_size = 1
306
306
  else:
307
- # Image must be available and type either PIL.Image.Image or torch.FloatTensor.
307
+ # Image must be available and type either PIL.Image.Image or torch.Tensor.
308
308
  # Not currently supporting something like image_embeds.
309
309
  batch_size = image.shape[0]
310
310
  multiplier = num_prompts_per_image
@@ -353,8 +353,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
353
353
  num_images_per_prompt,
354
354
  do_classifier_free_guidance,
355
355
  negative_prompt=None,
356
- prompt_embeds: Optional[torch.FloatTensor] = None,
357
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
356
+ prompt_embeds: Optional[torch.Tensor] = None,
357
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
358
358
  lora_scale: Optional[float] = None,
359
359
  **kwargs,
360
360
  ):
@@ -386,8 +386,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
386
386
  num_images_per_prompt,
387
387
  do_classifier_free_guidance,
388
388
  negative_prompt=None,
389
- prompt_embeds: Optional[torch.FloatTensor] = None,
390
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
389
+ prompt_embeds: Optional[torch.Tensor] = None,
390
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
391
391
  lora_scale: Optional[float] = None,
392
392
  clip_skip: Optional[int] = None,
393
393
  ):
@@ -407,10 +407,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
407
407
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
408
408
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
409
409
  less than `1`).
410
- prompt_embeds (`torch.FloatTensor`, *optional*):
410
+ prompt_embeds (`torch.Tensor`, *optional*):
411
411
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
412
412
  provided, text embeddings will be generated from `prompt` input argument.
413
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
413
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
414
414
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
415
415
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
416
416
  argument.
@@ -1080,7 +1080,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
1080
1080
  def __call__(
1081
1081
  self,
1082
1082
  prompt: Optional[Union[str, List[str]]] = None,
1083
- image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
1083
+ image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,
1084
1084
  height: Optional[int] = None,
1085
1085
  width: Optional[int] = None,
1086
1086
  data_type: Optional[int] = 1,
@@ -1091,15 +1091,15 @@ class UniDiffuserPipeline(DiffusionPipeline):
1091
1091
  num_prompts_per_image: Optional[int] = 1,
1092
1092
  eta: float = 0.0,
1093
1093
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1094
- latents: Optional[torch.FloatTensor] = None,
1095
- prompt_latents: Optional[torch.FloatTensor] = None,
1096
- vae_latents: Optional[torch.FloatTensor] = None,
1097
- clip_latents: Optional[torch.FloatTensor] = None,
1098
- prompt_embeds: Optional[torch.FloatTensor] = None,
1099
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1094
+ latents: Optional[torch.Tensor] = None,
1095
+ prompt_latents: Optional[torch.Tensor] = None,
1096
+ vae_latents: Optional[torch.Tensor] = None,
1097
+ clip_latents: Optional[torch.Tensor] = None,
1098
+ prompt_embeds: Optional[torch.Tensor] = None,
1099
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1100
1100
  output_type: Optional[str] = "pil",
1101
1101
  return_dict: bool = True,
1102
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1102
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
1103
1103
  callback_steps: int = 1,
1104
1104
  ):
1105
1105
  r"""
@@ -1109,7 +1109,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
1109
1109
  prompt (`str` or `List[str]`, *optional*):
1110
1110
  The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1111
1111
  Required for text-conditioned image generation (`text2img`) mode.
1112
- image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
1112
+ image (`torch.Tensor` or `PIL.Image.Image`, *optional*):
1113
1113
  `Image` or tensor representing an image batch. Required for image-conditioned text generation
1114
1114
  (`img2text`) mode.
1115
1115
  height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
@@ -1144,29 +1144,29 @@ class UniDiffuserPipeline(DiffusionPipeline):
1144
1144
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1145
1145
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1146
1146
  generation deterministic.
1147
- latents (`torch.FloatTensor`, *optional*):
1147
+ latents (`torch.Tensor`, *optional*):
1148
1148
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for joint
1149
1149
  image-text generation. Can be used to tweak the same generation with different prompts. If not
1150
1150
  provided, a latents tensor is generated by sampling using the supplied random `generator`. This assumes
1151
1151
  a full set of VAE, CLIP, and text latents, if supplied, overrides the value of `prompt_latents`,
1152
1152
  `vae_latents`, and `clip_latents`.
1153
- prompt_latents (`torch.FloatTensor`, *optional*):
1153
+ prompt_latents (`torch.Tensor`, *optional*):
1154
1154
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for text
1155
1155
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1156
1156
  tensor is generated by sampling using the supplied random `generator`.
1157
- vae_latents (`torch.FloatTensor`, *optional*):
1157
+ vae_latents (`torch.Tensor`, *optional*):
1158
1158
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1159
1159
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1160
1160
  tensor is generated by sampling using the supplied random `generator`.
1161
- clip_latents (`torch.FloatTensor`, *optional*):
1161
+ clip_latents (`torch.Tensor`, *optional*):
1162
1162
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1163
1163
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1164
1164
  tensor is generated by sampling using the supplied random `generator`.
1165
- prompt_embeds (`torch.FloatTensor`, *optional*):
1165
+ prompt_embeds (`torch.Tensor`, *optional*):
1166
1166
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1167
1167
  provided, text embeddings are generated from the `prompt` input argument. Used in text-conditioned
1168
1168
  image generation (`text2img`) mode.
1169
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1169
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1170
1170
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1171
1171
  not provided, `negative_prompt_embeds` are be generated from the `negative_prompt` input argument. Used
1172
1172
  in text-conditioned image generation (`text2img`) mode.
@@ -1176,7 +1176,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
1176
1176
  Whether or not to return a [`~pipelines.ImageTextPipelineOutput`] instead of a plain tuple.
1177
1177
  callback (`Callable`, *optional*):
1178
1178
  A function that calls every `callback_steps` steps during inference. The function is called with the
1179
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1179
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
1180
1180
  callback_steps (`int`, *optional*, defaults to 1):
1181
1181
  The frequency at which the `callback` function is called. If not specified, the callback is called at
1182
1182
  every step.
@@ -130,7 +130,7 @@ class PaellaVQModel(ModelMixin, ConfigMixin):
130
130
  )
131
131
 
132
132
  @apply_forward_hook
133
- def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
133
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
134
134
  h = self.in_block(x)
135
135
  h = self.down_blocks(h)
136
136
 
@@ -141,8 +141,8 @@ class PaellaVQModel(ModelMixin, ConfigMixin):
141
141
 
142
142
  @apply_forward_hook
143
143
  def decode(
144
- self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True
145
- ) -> Union[DecoderOutput, torch.FloatTensor]:
144
+ self, h: torch.Tensor, force_not_quantize: bool = True, return_dict: bool = True
145
+ ) -> Union[DecoderOutput, torch.Tensor]:
146
146
  if not force_not_quantize:
147
147
  quant, _, _ = self.vquantizer(h)
148
148
  else:
@@ -155,10 +155,10 @@ class PaellaVQModel(ModelMixin, ConfigMixin):
155
155
 
156
156
  return DecoderOutput(sample=dec)
157
157
 
158
- def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
158
+ def forward(self, sample: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
159
159
  r"""
160
160
  Args:
161
- sample (`torch.FloatTensor`): Input sample.
161
+ sample (`torch.Tensor`): Input sample.
162
162
  return_dict (`bool`, *optional*, defaults to `True`):
163
163
  Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
164
164
  """
@@ -17,8 +17,8 @@ class WuerstchenLayerNorm(nn.LayerNorm):
17
17
  class TimestepBlock(nn.Module):
18
18
  def __init__(self, c, c_timestep):
19
19
  super().__init__()
20
- linear_cls = nn.Linear
21
- self.mapper = linear_cls(c_timestep, c * 2)
20
+
21
+ self.mapper = nn.Linear(c_timestep, c * 2)
22
22
 
23
23
  def forward(self, x, t):
24
24
  a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
@@ -29,13 +29,10 @@ class ResBlock(nn.Module):
29
29
  def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
30
30
  super().__init__()
31
31
 
32
- conv_cls = nn.Conv2d
33
- linear_cls = nn.Linear
34
-
35
- self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
32
+ self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
36
33
  self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
37
34
  self.channelwise = nn.Sequential(
38
- linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
35
+ nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
39
36
  )
40
37
 
41
38
  def forward(self, x, x_skip=None):
@@ -64,12 +61,10 @@ class AttnBlock(nn.Module):
64
61
  def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
65
62
  super().__init__()
66
63
 
67
- linear_cls = nn.Linear
68
-
69
64
  self.self_attn = self_attn
70
65
  self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
71
66
  self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
72
- self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
67
+ self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
73
68
 
74
69
  def forward(self, x, kv):
75
70
  kv = self.kv_mapper(kv)