diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -60,14 +60,14 @@ class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin):
60
60
  Output class for Stable Diffusion pipelines.
61
61
 
62
62
  Args:
63
- latents (`torch.FloatTensor`)
63
+ latents (`torch.Tensor`)
64
64
  inverted latents tensor
65
65
  images (`List[PIL.Image.Image]` or `np.ndarray`)
66
66
  List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
67
67
  num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
68
68
  """
69
69
 
70
- latents: torch.FloatTensor
70
+ latents: torch.Tensor
71
71
  images: Union[List[PIL.Image.Image], np.ndarray]
72
72
 
73
73
 
@@ -377,8 +377,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
377
377
  num_images_per_prompt,
378
378
  do_classifier_free_guidance,
379
379
  negative_prompt=None,
380
- prompt_embeds: Optional[torch.FloatTensor] = None,
381
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
380
+ prompt_embeds: Optional[torch.Tensor] = None,
381
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
382
382
  lora_scale: Optional[float] = None,
383
383
  **kwargs,
384
384
  ):
@@ -410,8 +410,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
410
410
  num_images_per_prompt,
411
411
  do_classifier_free_guidance,
412
412
  negative_prompt=None,
413
- prompt_embeds: Optional[torch.FloatTensor] = None,
414
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
413
+ prompt_embeds: Optional[torch.Tensor] = None,
414
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
415
415
  lora_scale: Optional[float] = None,
416
416
  clip_skip: Optional[int] = None,
417
417
  ):
@@ -431,10 +431,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
431
431
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
432
432
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
433
433
  less than `1`).
434
- prompt_embeds (`torch.FloatTensor`, *optional*):
434
+ prompt_embeds (`torch.Tensor`, *optional*):
435
435
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
436
436
  provided, text embeddings will be generated from `prompt` input argument.
437
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
437
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
438
438
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
439
439
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
440
440
  argument.
@@ -661,7 +661,12 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
661
661
 
662
662
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
663
663
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
664
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
664
+ shape = (
665
+ batch_size,
666
+ num_channels_latents,
667
+ int(height) // self.vae_scale_factor,
668
+ int(width) // self.vae_scale_factor,
669
+ )
665
670
  if isinstance(generator, list) and len(generator) != batch_size:
666
671
  raise ValueError(
667
672
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -702,7 +707,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
702
707
  return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
703
708
 
704
709
  @torch.no_grad()
705
- def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.FloatTensor:
710
+ def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.Tensor:
706
711
  num_prompts = len(prompt)
707
712
  embeds = []
708
713
  for i in range(0, num_prompts, batch_size):
@@ -822,13 +827,13 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
822
827
  num_images_per_prompt: Optional[int] = 1,
823
828
  eta: float = 0.0,
824
829
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
825
- latents: Optional[torch.FloatTensor] = None,
826
- prompt_embeds: Optional[torch.FloatTensor] = None,
827
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
830
+ latents: Optional[torch.Tensor] = None,
831
+ prompt_embeds: Optional[torch.Tensor] = None,
832
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
828
833
  cross_attention_guidance_amount: float = 0.1,
829
834
  output_type: Optional[str] = "pil",
830
835
  return_dict: bool = True,
831
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
836
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
832
837
  callback_steps: Optional[int] = 1,
833
838
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
834
839
  clip_skip: Optional[int] = None,
@@ -871,14 +876,14 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
871
876
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
872
877
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
873
878
  to make generation deterministic.
874
- latents (`torch.FloatTensor`, *optional*):
879
+ latents (`torch.Tensor`, *optional*):
875
880
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
876
881
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
877
882
  tensor will ge generated by sampling using the supplied random `generator`.
878
- prompt_embeds (`torch.FloatTensor`, *optional*):
883
+ prompt_embeds (`torch.Tensor`, *optional*):
879
884
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
880
885
  provided, text embeddings will be generated from `prompt` input argument.
881
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
886
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
882
887
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
883
888
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
884
889
  argument.
@@ -892,7 +897,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
892
897
  plain tuple.
893
898
  callback (`Callable`, *optional*):
894
899
  A function that will be called every `callback_steps` steps during inference. The function will be
895
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
900
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
896
901
  callback_steps (`int`, *optional*, defaults to 1):
897
902
  The frequency at which the `callback` function will be called. If not specified, the callback will be
898
903
  called at every step.
@@ -1107,12 +1112,12 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
1107
1112
  num_inference_steps: int = 50,
1108
1113
  guidance_scale: float = 1,
1109
1114
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1110
- latents: Optional[torch.FloatTensor] = None,
1111
- prompt_embeds: Optional[torch.FloatTensor] = None,
1115
+ latents: Optional[torch.Tensor] = None,
1116
+ prompt_embeds: Optional[torch.Tensor] = None,
1112
1117
  cross_attention_guidance_amount: float = 0.1,
1113
1118
  output_type: Optional[str] = "pil",
1114
1119
  return_dict: bool = True,
1115
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1120
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
1116
1121
  callback_steps: Optional[int] = 1,
1117
1122
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1118
1123
  lambda_auto_corr: float = 20.0,
@@ -1127,7 +1132,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
1127
1132
  prompt (`str` or `List[str]`, *optional*):
1128
1133
  The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1129
1134
  instead.
1130
- image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
1135
+ image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
1131
1136
  `Image`, or tensor representing an image batch which will be used for conditioning. Can also accept
1132
1137
  image latents as `image`, if passing latents directly, it will not be encoded again.
1133
1138
  num_inference_steps (`int`, *optional*, defaults to 50):
@@ -1142,11 +1147,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
1142
1147
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1143
1148
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1144
1149
  to make generation deterministic.
1145
- latents (`torch.FloatTensor`, *optional*):
1150
+ latents (`torch.Tensor`, *optional*):
1146
1151
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1147
1152
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1148
1153
  tensor will ge generated by sampling using the supplied random `generator`.
1149
- prompt_embeds (`torch.FloatTensor`, *optional*):
1154
+ prompt_embeds (`torch.Tensor`, *optional*):
1150
1155
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1151
1156
  provided, text embeddings will be generated from `prompt` input argument.
1152
1157
  cross_attention_guidance_amount (`float`, defaults to 0.1):
@@ -1159,7 +1164,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
1159
1164
  plain tuple.
1160
1165
  callback (`Callable`, *optional*):
1161
1166
  A function that will be called every `callback_steps` steps during inference. The function will be
1162
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1167
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
1163
1168
  callback_steps (`int`, *optional*, defaults to 1):
1164
1169
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1165
1170
  called at every step.
@@ -363,6 +363,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
363
363
  """
364
364
 
365
365
  _supports_gradient_checkpointing = True
366
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"]
366
367
 
367
368
  @register_to_config
368
369
  def __init__(
@@ -531,7 +532,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
531
532
  elif encoder_hid_dim_type == "text_image_proj":
532
533
  # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
533
534
  # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
534
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
535
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
535
536
  self.encoder_hid_proj = TextImageProjection(
536
537
  text_embed_dim=encoder_hid_dim,
537
538
  image_embed_dim=cross_attention_dim,
@@ -591,7 +592,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
591
592
  elif addition_embed_type == "text_image":
592
593
  # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
593
594
  # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
594
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
595
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
595
596
  self.add_embedding = TextImageTimeEmbedding(
596
597
  text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
597
598
  )
@@ -816,7 +817,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
816
817
  positive_len = 768
817
818
  if isinstance(cross_attention_dim, int):
818
819
  positive_len = cross_attention_dim
819
- elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
820
+ elif isinstance(cross_attention_dim, (list, tuple)):
820
821
  positive_len = cross_attention_dim[0]
821
822
 
822
823
  feature_type = "text-only" if attention_type == "gated" else "text-image"
@@ -1000,8 +1001,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1000
1001
 
1001
1002
  def fuse_qkv_projections(self):
1002
1003
  """
1003
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
1004
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
1004
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
1005
+ are fused. For cross-attention modules, key and value projection matrices are fused.
1005
1006
 
1006
1007
  <Tip warning={true}>
1007
1008
 
@@ -1047,7 +1048,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1047
1048
 
1048
1049
  def forward(
1049
1050
  self,
1050
- sample: torch.FloatTensor,
1051
+ sample: torch.Tensor,
1051
1052
  timestep: Union[torch.Tensor, float, int],
1052
1053
  encoder_hidden_states: torch.Tensor,
1053
1054
  class_labels: Optional[torch.Tensor] = None,
@@ -1065,10 +1066,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1065
1066
  The [`UNetFlatConditionModel`] forward method.
1066
1067
 
1067
1068
  Args:
1068
- sample (`torch.FloatTensor`):
1069
+ sample (`torch.Tensor`):
1069
1070
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
1070
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1071
- encoder_hidden_states (`torch.FloatTensor`):
1071
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
1072
+ encoder_hidden_states (`torch.Tensor`):
1072
1073
  The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1073
1074
  class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1074
1075
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
@@ -1112,8 +1113,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1112
1113
 
1113
1114
  Returns:
1114
1115
  [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1115
- If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
1116
- a `tuple` is returned where the first element is the sample tensor.
1116
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1117
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1117
1118
  """
1118
1119
  # By default samples have to be AT least a multiple of the overall upsampling factor.
1119
1120
  # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
@@ -1257,7 +1258,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1257
1258
  if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1258
1259
  encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1259
1260
  elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1260
- # Kadinsky 2.1 - style
1261
+ # Kandinsky 2.1 - style
1261
1262
  if "image_embeds" not in added_cond_kwargs:
1262
1263
  raise ValueError(
1263
1264
  f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
@@ -1589,8 +1590,8 @@ class DownBlockFlat(nn.Module):
1589
1590
  self.gradient_checkpointing = False
1590
1591
 
1591
1592
  def forward(
1592
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
1593
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1593
+ self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None
1594
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1594
1595
  output_states = ()
1595
1596
 
1596
1597
  for resnet in self.resnets:
@@ -1718,14 +1719,14 @@ class CrossAttnDownBlockFlat(nn.Module):
1718
1719
 
1719
1720
  def forward(
1720
1721
  self,
1721
- hidden_states: torch.FloatTensor,
1722
- temb: Optional[torch.FloatTensor] = None,
1723
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1724
- attention_mask: Optional[torch.FloatTensor] = None,
1722
+ hidden_states: torch.Tensor,
1723
+ temb: Optional[torch.Tensor] = None,
1724
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1725
+ attention_mask: Optional[torch.Tensor] = None,
1725
1726
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1726
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1727
- additional_residuals: Optional[torch.FloatTensor] = None,
1728
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1727
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1728
+ additional_residuals: Optional[torch.Tensor] = None,
1729
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1729
1730
  output_states = ()
1730
1731
 
1731
1732
  blocks = list(zip(self.resnets, self.attentions))
@@ -1836,13 +1837,13 @@ class UpBlockFlat(nn.Module):
1836
1837
 
1837
1838
  def forward(
1838
1839
  self,
1839
- hidden_states: torch.FloatTensor,
1840
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1841
- temb: Optional[torch.FloatTensor] = None,
1840
+ hidden_states: torch.Tensor,
1841
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1842
+ temb: Optional[torch.Tensor] = None,
1842
1843
  upsample_size: Optional[int] = None,
1843
1844
  *args,
1844
1845
  **kwargs,
1845
- ) -> torch.FloatTensor:
1846
+ ) -> torch.Tensor:
1846
1847
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1847
1848
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1848
1849
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1993,18 +1994,18 @@ class CrossAttnUpBlockFlat(nn.Module):
1993
1994
 
1994
1995
  def forward(
1995
1996
  self,
1996
- hidden_states: torch.FloatTensor,
1997
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1998
- temb: Optional[torch.FloatTensor] = None,
1999
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1997
+ hidden_states: torch.Tensor,
1998
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1999
+ temb: Optional[torch.Tensor] = None,
2000
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2000
2001
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
2001
2002
  upsample_size: Optional[int] = None,
2002
- attention_mask: Optional[torch.FloatTensor] = None,
2003
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
2004
- ) -> torch.FloatTensor:
2003
+ attention_mask: Optional[torch.Tensor] = None,
2004
+ encoder_attention_mask: Optional[torch.Tensor] = None,
2005
+ ) -> torch.Tensor:
2005
2006
  if cross_attention_kwargs is not None:
2006
2007
  if cross_attention_kwargs.get("scale", None) is not None:
2007
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2008
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
2008
2009
 
2009
2010
  is_freeu_enabled = (
2010
2011
  getattr(self, "s1", None)
@@ -2103,8 +2104,8 @@ class UNetMidBlockFlat(nn.Module):
2103
2104
  output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
2104
2105
 
2105
2106
  Returns:
2106
- `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
2107
- in_channels, height, width)`.
2107
+ `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
2108
+ height, width)`.
2108
2109
 
2109
2110
  """
2110
2111
 
@@ -2222,7 +2223,7 @@ class UNetMidBlockFlat(nn.Module):
2222
2223
  self.attentions = nn.ModuleList(attentions)
2223
2224
  self.resnets = nn.ModuleList(resnets)
2224
2225
 
2225
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
2226
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
2226
2227
  hidden_states = self.resnets[0](hidden_states, temb)
2227
2228
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
2228
2229
  if attn is not None:
@@ -2238,6 +2239,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2238
2239
  self,
2239
2240
  in_channels: int,
2240
2241
  temb_channels: int,
2242
+ out_channels: Optional[int] = None,
2241
2243
  dropout: float = 0.0,
2242
2244
  num_layers: int = 1,
2243
2245
  transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -2245,6 +2247,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2245
2247
  resnet_time_scale_shift: str = "default",
2246
2248
  resnet_act_fn: str = "swish",
2247
2249
  resnet_groups: int = 32,
2250
+ resnet_groups_out: Optional[int] = None,
2248
2251
  resnet_pre_norm: bool = True,
2249
2252
  num_attention_heads: int = 1,
2250
2253
  output_scale_factor: float = 1.0,
@@ -2256,6 +2259,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2256
2259
  ):
2257
2260
  super().__init__()
2258
2261
 
2262
+ out_channels = out_channels or in_channels
2263
+ self.in_channels = in_channels
2264
+ self.out_channels = out_channels
2265
+
2259
2266
  self.has_cross_attention = True
2260
2267
  self.num_attention_heads = num_attention_heads
2261
2268
  resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -2264,14 +2271,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2264
2271
  if isinstance(transformer_layers_per_block, int):
2265
2272
  transformer_layers_per_block = [transformer_layers_per_block] * num_layers
2266
2273
 
2274
+ resnet_groups_out = resnet_groups_out or resnet_groups
2275
+
2267
2276
  # there is always at least one resnet
2268
2277
  resnets = [
2269
2278
  ResnetBlockFlat(
2270
2279
  in_channels=in_channels,
2271
- out_channels=in_channels,
2280
+ out_channels=out_channels,
2272
2281
  temb_channels=temb_channels,
2273
2282
  eps=resnet_eps,
2274
2283
  groups=resnet_groups,
2284
+ groups_out=resnet_groups_out,
2275
2285
  dropout=dropout,
2276
2286
  time_embedding_norm=resnet_time_scale_shift,
2277
2287
  non_linearity=resnet_act_fn,
@@ -2286,11 +2296,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2286
2296
  attentions.append(
2287
2297
  Transformer2DModel(
2288
2298
  num_attention_heads,
2289
- in_channels // num_attention_heads,
2290
- in_channels=in_channels,
2299
+ out_channels // num_attention_heads,
2300
+ in_channels=out_channels,
2291
2301
  num_layers=transformer_layers_per_block[i],
2292
2302
  cross_attention_dim=cross_attention_dim,
2293
- norm_num_groups=resnet_groups,
2303
+ norm_num_groups=resnet_groups_out,
2294
2304
  use_linear_projection=use_linear_projection,
2295
2305
  upcast_attention=upcast_attention,
2296
2306
  attention_type=attention_type,
@@ -2300,8 +2310,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2300
2310
  attentions.append(
2301
2311
  DualTransformer2DModel(
2302
2312
  num_attention_heads,
2303
- in_channels // num_attention_heads,
2304
- in_channels=in_channels,
2313
+ out_channels // num_attention_heads,
2314
+ in_channels=out_channels,
2305
2315
  num_layers=1,
2306
2316
  cross_attention_dim=cross_attention_dim,
2307
2317
  norm_num_groups=resnet_groups,
@@ -2309,11 +2319,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2309
2319
  )
2310
2320
  resnets.append(
2311
2321
  ResnetBlockFlat(
2312
- in_channels=in_channels,
2313
- out_channels=in_channels,
2322
+ in_channels=out_channels,
2323
+ out_channels=out_channels,
2314
2324
  temb_channels=temb_channels,
2315
2325
  eps=resnet_eps,
2316
- groups=resnet_groups,
2326
+ groups=resnet_groups_out,
2317
2327
  dropout=dropout,
2318
2328
  time_embedding_norm=resnet_time_scale_shift,
2319
2329
  non_linearity=resnet_act_fn,
@@ -2329,16 +2339,16 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2329
2339
 
2330
2340
  def forward(
2331
2341
  self,
2332
- hidden_states: torch.FloatTensor,
2333
- temb: Optional[torch.FloatTensor] = None,
2334
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
2335
- attention_mask: Optional[torch.FloatTensor] = None,
2342
+ hidden_states: torch.Tensor,
2343
+ temb: Optional[torch.Tensor] = None,
2344
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2345
+ attention_mask: Optional[torch.Tensor] = None,
2336
2346
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
2337
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
2338
- ) -> torch.FloatTensor:
2347
+ encoder_attention_mask: Optional[torch.Tensor] = None,
2348
+ ) -> torch.Tensor:
2339
2349
  if cross_attention_kwargs is not None:
2340
2350
  if cross_attention_kwargs.get("scale", None) is not None:
2341
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2351
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
2342
2352
 
2343
2353
  hidden_states = self.resnets[0](hidden_states, temb)
2344
2354
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
@@ -2470,16 +2480,16 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
2470
2480
 
2471
2481
  def forward(
2472
2482
  self,
2473
- hidden_states: torch.FloatTensor,
2474
- temb: Optional[torch.FloatTensor] = None,
2475
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
2476
- attention_mask: Optional[torch.FloatTensor] = None,
2483
+ hidden_states: torch.Tensor,
2484
+ temb: Optional[torch.Tensor] = None,
2485
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2486
+ attention_mask: Optional[torch.Tensor] = None,
2477
2487
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
2478
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
2479
- ) -> torch.FloatTensor:
2488
+ encoder_attention_mask: Optional[torch.Tensor] = None,
2489
+ ) -> torch.Tensor:
2480
2490
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
2481
2491
  if cross_attention_kwargs.get("scale", None) is not None:
2482
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2492
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
2483
2493
 
2484
2494
  if attention_mask is None:
2485
2495
  # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -81,7 +81,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
81
81
  @torch.no_grad()
82
82
  def image_variation(
83
83
  self,
84
- image: Union[torch.FloatTensor, PIL.Image.Image],
84
+ image: Union[torch.Tensor, PIL.Image.Image],
85
85
  height: Optional[int] = None,
86
86
  width: Optional[int] = None,
87
87
  num_inference_steps: int = 50,
@@ -90,10 +90,10 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
90
90
  num_images_per_prompt: Optional[int] = 1,
91
91
  eta: float = 0.0,
92
92
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
93
- latents: Optional[torch.FloatTensor] = None,
93
+ latents: Optional[torch.Tensor] = None,
94
94
  output_type: Optional[str] = "pil",
95
95
  return_dict: bool = True,
96
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
96
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
97
97
  callback_steps: int = 1,
98
98
  ):
99
99
  r"""
@@ -123,7 +123,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
123
123
  generator (`torch.Generator`, *optional*):
124
124
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
125
125
  generation deterministic.
126
- latents (`torch.FloatTensor`, *optional*):
126
+ latents (`torch.Tensor`, *optional*):
127
127
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
128
128
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
129
129
  tensor is generated by sampling using the supplied random `generator`.
@@ -134,7 +134,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
134
134
  plain tuple.
135
135
  callback (`Callable`, *optional*):
136
136
  A function that calls every `callback_steps` steps during inference. The function is called with the
137
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
137
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
138
138
  callback_steps (`int`, *optional*, defaults to 1):
139
139
  The frequency at which the `callback` function is called. If not specified, the callback is called at
140
140
  every step.
@@ -202,10 +202,10 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
202
202
  num_images_per_prompt: Optional[int] = 1,
203
203
  eta: float = 0.0,
204
204
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
205
- latents: Optional[torch.FloatTensor] = None,
205
+ latents: Optional[torch.Tensor] = None,
206
206
  output_type: Optional[str] = "pil",
207
207
  return_dict: bool = True,
208
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
208
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
209
209
  callback_steps: int = 1,
210
210
  ):
211
211
  r"""
@@ -235,7 +235,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
235
235
  generator (`torch.Generator`, *optional*):
236
236
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
237
237
  generation deterministic.
238
- latents (`torch.FloatTensor`, *optional*):
238
+ latents (`torch.Tensor`, *optional*):
239
239
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
240
240
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
241
241
  tensor is generated by sampling using the supplied random `generator`.
@@ -246,7 +246,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
246
246
  plain tuple.
247
247
  callback (`Callable`, *optional*):
248
248
  A function that calls every `callback_steps` steps during inference. The function is called with the
249
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
249
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
250
250
  callback_steps (`int`, *optional*, defaults to 1):
251
251
  The frequency at which the `callback` function is called. If not specified, the callback is called at
252
252
  every step.
@@ -311,10 +311,10 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
311
311
  num_images_per_prompt: Optional[int] = 1,
312
312
  eta: float = 0.0,
313
313
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
314
- latents: Optional[torch.FloatTensor] = None,
314
+ latents: Optional[torch.Tensor] = None,
315
315
  output_type: Optional[str] = "pil",
316
316
  return_dict: bool = True,
317
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
317
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
318
318
  callback_steps: int = 1,
319
319
  ):
320
320
  r"""
@@ -344,7 +344,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
344
344
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
345
345
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
346
346
  generation deterministic.
347
- latents (`torch.FloatTensor`, *optional*):
347
+ latents (`torch.Tensor`, *optional*):
348
348
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
349
349
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
350
350
  tensor is generated by sampling using the supplied random `generator`.
@@ -355,7 +355,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
355
355
  plain tuple.
356
356
  callback (`Callable`, *optional*):
357
357
  A function that calls every `callback_steps` steps during inference. The function is called with the
358
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
358
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
359
359
  callback_steps (`int`, *optional*, defaults to 1):
360
360
  The frequency at which the `callback` function is called. If not specified, the callback is called at
361
361
  every step.
@@ -348,7 +348,12 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
348
348
 
349
349
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
350
350
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
351
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
351
+ shape = (
352
+ batch_size,
353
+ num_channels_latents,
354
+ int(height) // self.vae_scale_factor,
355
+ int(width) // self.vae_scale_factor,
356
+ )
352
357
  if isinstance(generator, list) and len(generator) != batch_size:
353
358
  raise ValueError(
354
359
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -390,10 +395,10 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
390
395
  num_images_per_prompt: Optional[int] = 1,
391
396
  eta: float = 0.0,
392
397
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
393
- latents: Optional[torch.FloatTensor] = None,
398
+ latents: Optional[torch.Tensor] = None,
394
399
  output_type: Optional[str] = "pil",
395
400
  return_dict: bool = True,
396
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
401
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
397
402
  callback_steps: int = 1,
398
403
  **kwargs,
399
404
  ):
@@ -424,7 +429,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
424
429
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
425
430
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
426
431
  generation deterministic.
427
- latents (`torch.FloatTensor`, *optional*):
432
+ latents (`torch.Tensor`, *optional*):
428
433
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
429
434
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
430
435
  tensor is generated by sampling using the supplied random `generator`.
@@ -434,7 +439,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
434
439
  Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
435
440
  callback (`Callable`, *optional*):
436
441
  A function that calls every `callback_steps` steps during inference. The function is called with the
437
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
442
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
438
443
  callback_steps (`int`, *optional*, defaults to 1):
439
444
  The frequency at which the `callback` function is called. If not specified, the callback is called at
440
445
  every step.