diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,10 @@ from typing import Any, Dict, Optional, Tuple, Union
15
15
 
16
16
  import torch
17
17
  import torch.nn as nn
18
+ import torch.nn.functional as F
18
19
  import torch.utils.checkpoint
19
20
 
20
- from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
21
22
  from ...loaders import UNet2DConditionLoadersMixin
22
23
  from ...utils import logging
23
24
  from ..attention_processor import (
@@ -27,6 +28,9 @@ from ..attention_processor import (
27
28
  AttentionProcessor,
28
29
  AttnAddedKVProcessor,
29
30
  AttnProcessor,
31
+ AttnProcessor2_0,
32
+ IPAdapterAttnProcessor,
33
+ IPAdapterAttnProcessor2_0,
30
34
  )
31
35
  from ..embeddings import TimestepEmbedding, Timesteps
32
36
  from ..modeling_utils import ModelMixin
@@ -211,6 +215,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
211
215
  norm_num_groups: int = 32,
212
216
  norm_eps: float = 1e-5,
213
217
  cross_attention_dim: int = 1280,
218
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
219
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
214
220
  use_linear_projection: bool = False,
215
221
  num_attention_heads: Union[int, Tuple[int, ...]] = 8,
216
222
  motion_max_seq_length: int = 32,
@@ -218,6 +224,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
218
224
  use_motion_mid_block: int = True,
219
225
  encoder_hid_dim: Optional[int] = None,
220
226
  encoder_hid_dim_type: Optional[str] = None,
227
+ addition_embed_type: Optional[str] = None,
228
+ addition_time_embed_dim: Optional[int] = None,
229
+ projection_class_embeddings_input_dim: Optional[int] = None,
221
230
  time_cond_proj_dim: Optional[int] = None,
222
231
  ):
223
232
  super().__init__()
@@ -240,6 +249,21 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
240
249
  f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
241
250
  )
242
251
 
252
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
253
+ raise ValueError(
254
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
255
+ )
256
+
257
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
258
+ raise ValueError(
259
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
260
+ )
261
+
262
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
263
+ for layer_number_per_block in transformer_layers_per_block:
264
+ if isinstance(layer_number_per_block, list):
265
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
266
+
243
267
  # input
244
268
  conv_in_kernel = 3
245
269
  conv_out_kernel = 3
@@ -260,6 +284,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
260
284
  if encoder_hid_dim_type is None:
261
285
  self.encoder_hid_proj = None
262
286
 
287
+ if addition_embed_type == "text_time":
288
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0)
289
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
290
+
263
291
  # class embedding
264
292
  self.down_blocks = nn.ModuleList([])
265
293
  self.up_blocks = nn.ModuleList([])
@@ -267,6 +295,15 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
267
295
  if isinstance(num_attention_heads, int):
268
296
  num_attention_heads = (num_attention_heads,) * len(down_block_types)
269
297
 
298
+ if isinstance(cross_attention_dim, int):
299
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
300
+
301
+ if isinstance(layers_per_block, int):
302
+ layers_per_block = [layers_per_block] * len(down_block_types)
303
+
304
+ if isinstance(transformer_layers_per_block, int):
305
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
306
+
270
307
  # down
271
308
  output_channel = block_out_channels[0]
272
309
  for i, down_block_type in enumerate(down_block_types):
@@ -276,7 +313,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
276
313
 
277
314
  down_block = get_down_block(
278
315
  down_block_type,
279
- num_layers=layers_per_block,
316
+ num_layers=layers_per_block[i],
280
317
  in_channels=input_channel,
281
318
  out_channels=output_channel,
282
319
  temb_channels=time_embed_dim,
@@ -284,13 +321,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
284
321
  resnet_eps=norm_eps,
285
322
  resnet_act_fn=act_fn,
286
323
  resnet_groups=norm_num_groups,
287
- cross_attention_dim=cross_attention_dim,
324
+ cross_attention_dim=cross_attention_dim[i],
288
325
  num_attention_heads=num_attention_heads[i],
289
326
  downsample_padding=downsample_padding,
290
327
  use_linear_projection=use_linear_projection,
291
328
  dual_cross_attention=False,
292
329
  temporal_num_attention_heads=motion_num_attention_heads,
293
330
  temporal_max_seq_length=motion_max_seq_length,
331
+ transformer_layers_per_block=transformer_layers_per_block[i],
294
332
  )
295
333
  self.down_blocks.append(down_block)
296
334
 
@@ -302,13 +340,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
302
340
  resnet_eps=norm_eps,
303
341
  resnet_act_fn=act_fn,
304
342
  output_scale_factor=mid_block_scale_factor,
305
- cross_attention_dim=cross_attention_dim,
343
+ cross_attention_dim=cross_attention_dim[-1],
306
344
  num_attention_heads=num_attention_heads[-1],
307
345
  resnet_groups=norm_num_groups,
308
346
  dual_cross_attention=False,
309
347
  use_linear_projection=use_linear_projection,
310
348
  temporal_num_attention_heads=motion_num_attention_heads,
311
349
  temporal_max_seq_length=motion_max_seq_length,
350
+ transformer_layers_per_block=transformer_layers_per_block[-1],
312
351
  )
313
352
 
314
353
  else:
@@ -318,11 +357,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
318
357
  resnet_eps=norm_eps,
319
358
  resnet_act_fn=act_fn,
320
359
  output_scale_factor=mid_block_scale_factor,
321
- cross_attention_dim=cross_attention_dim,
360
+ cross_attention_dim=cross_attention_dim[-1],
322
361
  num_attention_heads=num_attention_heads[-1],
323
362
  resnet_groups=norm_num_groups,
324
363
  dual_cross_attention=False,
325
364
  use_linear_projection=use_linear_projection,
365
+ transformer_layers_per_block=transformer_layers_per_block[-1],
326
366
  )
327
367
 
328
368
  # count how many layers upsample the images
@@ -331,6 +371,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
331
371
  # up
332
372
  reversed_block_out_channels = list(reversed(block_out_channels))
333
373
  reversed_num_attention_heads = list(reversed(num_attention_heads))
374
+ reversed_layers_per_block = list(reversed(layers_per_block))
375
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
376
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
334
377
 
335
378
  output_channel = reversed_block_out_channels[0]
336
379
  for i, up_block_type in enumerate(up_block_types):
@@ -349,7 +392,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
349
392
 
350
393
  up_block = get_up_block(
351
394
  up_block_type,
352
- num_layers=layers_per_block + 1,
395
+ num_layers=reversed_layers_per_block[i] + 1,
353
396
  in_channels=input_channel,
354
397
  out_channels=output_channel,
355
398
  prev_output_channel=prev_output_channel,
@@ -358,13 +401,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
358
401
  resnet_eps=norm_eps,
359
402
  resnet_act_fn=act_fn,
360
403
  resnet_groups=norm_num_groups,
361
- cross_attention_dim=cross_attention_dim,
404
+ cross_attention_dim=reversed_cross_attention_dim[i],
362
405
  num_attention_heads=reversed_num_attention_heads[i],
363
406
  dual_cross_attention=False,
364
407
  resolution_idx=i,
365
408
  use_linear_projection=use_linear_projection,
366
409
  temporal_num_attention_heads=motion_num_attention_heads,
367
410
  temporal_max_seq_length=motion_max_seq_length,
411
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
368
412
  )
369
413
  self.up_blocks.append(up_block)
370
414
  prev_output_channel = output_channel
@@ -393,8 +437,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
393
437
  ):
394
438
  has_motion_adapter = motion_adapter is not None
395
439
 
440
+ if has_motion_adapter:
441
+ motion_adapter.to(device=unet.device)
442
+
396
443
  # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
397
- config = unet.config
444
+ config = dict(unet.config)
398
445
  config["_class_name"] = cls.__name__
399
446
 
400
447
  down_blocks = []
@@ -427,6 +474,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
427
474
  if not config.get("num_attention_heads"):
428
475
  config["num_attention_heads"] = config["attention_head_dim"]
429
476
 
477
+ config = FrozenDict(config)
430
478
  model = cls.from_config(config)
431
479
 
432
480
  if not load_weights:
@@ -446,6 +494,36 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
446
494
  model.time_proj.load_state_dict(unet.time_proj.state_dict())
447
495
  model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
448
496
 
497
+ if any(
498
+ isinstance(proc, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
499
+ for proc in unet.attn_processors.values()
500
+ ):
501
+ attn_procs = {}
502
+ for name, processor in unet.attn_processors.items():
503
+ if name.endswith("attn1.processor"):
504
+ attn_processor_class = (
505
+ AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
506
+ )
507
+ attn_procs[name] = attn_processor_class()
508
+ else:
509
+ attn_processor_class = (
510
+ IPAdapterAttnProcessor2_0
511
+ if hasattr(F, "scaled_dot_product_attention")
512
+ else IPAdapterAttnProcessor
513
+ )
514
+ attn_procs[name] = attn_processor_class(
515
+ hidden_size=processor.hidden_size,
516
+ cross_attention_dim=processor.cross_attention_dim,
517
+ scale=processor.scale,
518
+ num_tokens=processor.num_tokens,
519
+ )
520
+ for name, processor in model.attn_processors.items():
521
+ if name not in attn_procs:
522
+ attn_procs[name] = processor.__class__()
523
+ model.set_attn_processor(attn_procs)
524
+ model.config.encoder_hid_dim_type = "ip_image_proj"
525
+ model.encoder_hid_proj = unet.encoder_hid_proj
526
+
449
527
  for i, down_block in enumerate(unet.down_blocks):
450
528
  model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
451
529
  if hasattr(model.down_blocks[i], "attentions"):
@@ -705,8 +783,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
705
783
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
706
784
  def fuse_qkv_projections(self):
707
785
  """
708
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
709
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
786
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
787
+ are fused. For cross-attention modules, key and value projection matrices are fused.
710
788
 
711
789
  <Tip warning={true}>
712
790
 
@@ -742,7 +820,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
742
820
 
743
821
  def forward(
744
822
  self,
745
- sample: torch.FloatTensor,
823
+ sample: torch.Tensor,
746
824
  timestep: Union[torch.Tensor, float, int],
747
825
  encoder_hidden_states: torch.Tensor,
748
826
  timestep_cond: Optional[torch.Tensor] = None,
@@ -757,10 +835,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
757
835
  The [`UNetMotionModel`] forward method.
758
836
 
759
837
  Args:
760
- sample (`torch.FloatTensor`):
838
+ sample (`torch.Tensor`):
761
839
  The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
762
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
763
- encoder_hidden_states (`torch.FloatTensor`):
840
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
841
+ encoder_hidden_states (`torch.Tensor`):
764
842
  The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
765
843
  timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
766
844
  Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
@@ -831,6 +909,28 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
831
909
  t_emb = t_emb.to(dtype=self.dtype)
832
910
 
833
911
  emb = self.time_embedding(t_emb, timestep_cond)
912
+ aug_emb = None
913
+
914
+ if self.config.addition_embed_type == "text_time":
915
+ if "text_embeds" not in added_cond_kwargs:
916
+ raise ValueError(
917
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
918
+ )
919
+
920
+ text_embeds = added_cond_kwargs.get("text_embeds")
921
+ if "time_ids" not in added_cond_kwargs:
922
+ raise ValueError(
923
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
924
+ )
925
+ time_ids = added_cond_kwargs.get("time_ids")
926
+ time_embeds = self.add_time_proj(time_ids.flatten())
927
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
928
+
929
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
930
+ add_embeds = add_embeds.to(emb.dtype)
931
+ aug_emb = self.add_embedding(add_embeds)
932
+
933
+ emb = emb if aug_emb is None else emb + aug_emb
834
934
  emb = emb.repeat_interleave(repeats=num_frames, dim=0)
835
935
  encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
836
936
 
@@ -22,17 +22,17 @@ class UNetSpatioTemporalConditionOutput(BaseOutput):
22
22
  The output of [`UNetSpatioTemporalConditionModel`].
23
23
 
24
24
  Args:
25
- sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
25
+ sample (`torch.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
26
26
  The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
27
27
  """
28
28
 
29
- sample: torch.FloatTensor = None
29
+ sample: torch.Tensor = None
30
30
 
31
31
 
32
32
  class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
33
33
  r"""
34
- A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
35
- shaped output.
34
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
35
+ returns a sample shaped output.
36
36
 
37
37
  This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
38
38
  for all models (such as downloading or saving).
@@ -57,7 +57,8 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
57
57
  The dimension of the cross attention features.
58
58
  transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
59
59
  The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
60
- [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
60
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
61
+ [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
61
62
  [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
62
63
  num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
63
64
  The number of attention heads.
@@ -355,7 +356,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
355
356
 
356
357
  def forward(
357
358
  self,
358
- sample: torch.FloatTensor,
359
+ sample: torch.Tensor,
359
360
  timestep: Union[torch.Tensor, float, int],
360
361
  encoder_hidden_states: torch.Tensor,
361
362
  added_time_ids: torch.Tensor,
@@ -365,21 +366,21 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
365
366
  The [`UNetSpatioTemporalConditionModel`] forward method.
366
367
 
367
368
  Args:
368
- sample (`torch.FloatTensor`):
369
+ sample (`torch.Tensor`):
369
370
  The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
370
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
371
- encoder_hidden_states (`torch.FloatTensor`):
371
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
372
+ encoder_hidden_states (`torch.Tensor`):
372
373
  The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
373
- added_time_ids: (`torch.FloatTensor`):
374
+ added_time_ids: (`torch.Tensor`):
374
375
  The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
375
376
  embeddings and added to the time embeddings.
376
377
  return_dict (`bool`, *optional*, defaults to `True`):
377
- Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
378
- tuple.
378
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
379
+ of a plain tuple.
379
380
  Returns:
380
381
  [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
381
- If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
382
- a `tuple` is returned where the first element is the sample tensor.
382
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
383
+ returned, otherwise a `tuple` is returned where the first element is the sample tensor.
383
384
  """
384
385
  # 1. time
385
386
  timesteps = timestep
@@ -21,7 +21,7 @@ import torch
21
21
  import torch.nn as nn
22
22
 
23
23
  from ...configuration_utils import ConfigMixin, register_to_config
24
- from ...loaders.unet import FromOriginalUNetMixin
24
+ from ...loaders import FromOriginalModelMixin
25
25
  from ...utils import BaseOutput
26
26
  from ..attention_processor import Attention
27
27
  from ..modeling_utils import ModelMixin
@@ -41,11 +41,11 @@ class SDCascadeLayerNorm(nn.LayerNorm):
41
41
  class SDCascadeTimestepBlock(nn.Module):
42
42
  def __init__(self, c, c_timestep, conds=[]):
43
43
  super().__init__()
44
- linear_cls = nn.Linear
45
- self.mapper = linear_cls(c_timestep, c * 2)
44
+
45
+ self.mapper = nn.Linear(c_timestep, c * 2)
46
46
  self.conds = conds
47
47
  for cname in conds:
48
- setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))
48
+ setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2))
49
49
 
50
50
  def forward(self, x, t):
51
51
  t = t.chunk(len(self.conds) + 1, dim=1)
@@ -94,12 +94,11 @@ class GlobalResponseNorm(nn.Module):
94
94
  class SDCascadeAttnBlock(nn.Module):
95
95
  def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
96
96
  super().__init__()
97
- linear_cls = nn.Linear
98
97
 
99
98
  self.self_attn = self_attn
100
99
  self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
101
100
  self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
102
- self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
101
+ self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
103
102
 
104
103
  def forward(self, x, kv):
105
104
  kv = self.kv_mapper(kv)
@@ -132,10 +131,10 @@ class UpDownBlock2d(nn.Module):
132
131
 
133
132
  @dataclass
134
133
  class StableCascadeUNetOutput(BaseOutput):
135
- sample: torch.FloatTensor = None
134
+ sample: torch.Tensor = None
136
135
 
137
136
 
138
- class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
137
+ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
139
138
  _supports_gradient_checkpointing = True
140
139
 
141
140
  @register_to_config
@@ -187,7 +186,8 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
187
186
  block_out_channels (Tuple[int], defaults to (2048, 2048)):
188
187
  Tuple of output channels for each block.
189
188
  num_attention_heads (Tuple[int], defaults to (32, 32)):
190
- Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have attention.
189
+ Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have
190
+ attention.
191
191
  down_num_layers_per_block (Tuple[int], defaults to [8, 24]):
192
192
  Number of layers in each down block.
193
193
  up_num_layers_per_block (Tuple[int], defaults to [24, 8]):
@@ -198,10 +198,9 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
198
198
  Number of 1x1 Convolutional layers to repeat in each up block.
199
199
  block_types_per_layer (Tuple[Tuple[str]], optional,
200
200
  defaults to (
201
- ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
202
- ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock")
203
- ):
204
- Block types used in each layer of the up/down blocks.
201
+ ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ("SDCascadeResBlock",
202
+ "SDCascadeTimestepBlock", "SDCascadeAttnBlock")
203
+ ): Block types used in each layer of the up/down blocks.
205
204
  clip_text_in_channels (`int`, *optional*, defaults to `None`):
206
205
  Number of input channels for CLIP based text conditioning.
207
206
  clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280):
@@ -521,9 +520,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
521
520
  if isinstance(block, SDCascadeResBlock):
522
521
  skip = level_outputs[i] if k == 0 and i > 0 else None
523
522
  if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
523
+ orig_type = x.dtype
524
524
  x = torch.nn.functional.interpolate(
525
525
  x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
526
526
  )
527
+ x = x.to(orig_type)
527
528
  x = torch.utils.checkpoint.checkpoint(
528
529
  create_custom_forward(block), x, skip, use_reentrant=False
529
530
  )
@@ -547,9 +548,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
547
548
  if isinstance(block, SDCascadeResBlock):
548
549
  skip = level_outputs[i] if k == 0 and i > 0 else None
549
550
  if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
551
+ orig_type = x.dtype
550
552
  x = torch.nn.functional.interpolate(
551
553
  x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
552
554
  )
555
+ x = x.to(orig_type)
553
556
  x = block(x, skip)
554
557
  elif isinstance(block, SDCascadeAttnBlock):
555
558
  x = block(x, clip)
@@ -110,7 +110,6 @@ class Upsample2D(nn.Module):
110
110
  self.use_conv_transpose = use_conv_transpose
111
111
  self.name = name
112
112
  self.interpolate = interpolate
113
- conv_cls = nn.Conv2d
114
113
 
115
114
  if norm_type == "ln_norm":
116
115
  self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -131,7 +130,7 @@ class Upsample2D(nn.Module):
131
130
  elif use_conv:
132
131
  if kernel_size is None:
133
132
  kernel_size = 3
134
- conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
133
+ conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
135
134
 
136
135
  # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
137
136
  if name == "conv":
@@ -139,9 +138,7 @@ class Upsample2D(nn.Module):
139
138
  else:
140
139
  self.Conv2d_0 = conv
141
140
 
142
- def forward(
143
- self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, *args, **kwargs
144
- ) -> torch.FloatTensor:
141
+ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor:
145
142
  if len(args) > 0 or kwargs.get("scale", None) is not None:
146
143
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
147
144
  deprecate("scale", "1.0.0", deprecation_message)
@@ -218,12 +215,12 @@ class FirUpsample2D(nn.Module):
218
215
 
219
216
  def _upsample_2d(
220
217
  self,
221
- hidden_states: torch.FloatTensor,
222
- weight: Optional[torch.FloatTensor] = None,
223
- kernel: Optional[torch.FloatTensor] = None,
218
+ hidden_states: torch.Tensor,
219
+ weight: Optional[torch.Tensor] = None,
220
+ kernel: Optional[torch.Tensor] = None,
224
221
  factor: int = 2,
225
222
  gain: float = 1,
226
- ) -> torch.FloatTensor:
223
+ ) -> torch.Tensor:
227
224
  """Fused `upsample_2d()` followed by `Conv2d()`.
228
225
 
229
226
  Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
@@ -231,19 +228,19 @@ class FirUpsample2D(nn.Module):
231
228
  arbitrary order.
232
229
 
233
230
  Args:
234
- hidden_states (`torch.FloatTensor`):
231
+ hidden_states (`torch.Tensor`):
235
232
  Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
236
- weight (`torch.FloatTensor`, *optional*):
233
+ weight (`torch.Tensor`, *optional*):
237
234
  Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
238
235
  performed by `inChannels = x.shape[0] // numGroups`.
239
- kernel (`torch.FloatTensor`, *optional*):
236
+ kernel (`torch.Tensor`, *optional*):
240
237
  FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
241
238
  corresponds to nearest-neighbor upsampling.
242
239
  factor (`int`, *optional*): Integer upsampling factor (default: 2).
243
240
  gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
244
241
 
245
242
  Returns:
246
- output (`torch.FloatTensor`):
243
+ output (`torch.Tensor`):
247
244
  Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
248
245
  datatype as `hidden_states`.
249
246
  """
@@ -311,7 +308,7 @@ class FirUpsample2D(nn.Module):
311
308
 
312
309
  return output
313
310
 
314
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
311
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
315
312
  if self.use_conv:
316
313
  height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
317
314
  height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -402,11 +399,11 @@ def upfirdn2d_native(
402
399
 
403
400
 
404
401
  def upsample_2d(
405
- hidden_states: torch.FloatTensor,
406
- kernel: Optional[torch.FloatTensor] = None,
402
+ hidden_states: torch.Tensor,
403
+ kernel: Optional[torch.Tensor] = None,
407
404
  factor: int = 2,
408
405
  gain: float = 1,
409
- ) -> torch.FloatTensor:
406
+ ) -> torch.Tensor:
410
407
  r"""Upsample2D a batch of 2D images with the given filter.
411
408
  Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
412
409
  filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
@@ -414,9 +411,9 @@ def upsample_2d(
414
411
  a: multiple of the upsampling factor.
415
412
 
416
413
  Args:
417
- hidden_states (`torch.FloatTensor`):
414
+ hidden_states (`torch.Tensor`):
418
415
  Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
419
- kernel (`torch.FloatTensor`, *optional*):
416
+ kernel (`torch.Tensor`, *optional*):
420
417
  FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
421
418
  corresponds to nearest-neighbor upsampling.
422
419
  factor (`int`, *optional*, default to `2`):
@@ -425,7 +422,7 @@ def upsample_2d(
425
422
  Scaling factor for signal magnitude (default: 1.0).
426
423
 
427
424
  Returns:
428
- output (`torch.FloatTensor`):
425
+ output (`torch.Tensor`):
429
426
  Tensor of the shape `[N, C, H * factor, W * factor]`
430
427
  """
431
428
  assert isinstance(factor, int) and factor >= 1
@@ -30,11 +30,11 @@ class VQEncoderOutput(BaseOutput):
30
30
  Output of VQModel encoding method.
31
31
 
32
32
  Args:
33
- latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
33
+ latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
34
34
  The encoded output sample from the last layer of the model.
35
35
  """
36
36
 
37
- latents: torch.FloatTensor
37
+ latents: torch.Tensor
38
38
 
39
39
 
40
40
  class VQModel(ModelMixin, ConfigMixin):
@@ -127,7 +127,7 @@ class VQModel(ModelMixin, ConfigMixin):
127
127
  )
128
128
 
129
129
  @apply_forward_hook
130
- def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
130
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
131
131
  h = self.encoder(x)
132
132
  h = self.quant_conv(h)
133
133
 
@@ -138,31 +138,33 @@ class VQModel(ModelMixin, ConfigMixin):
138
138
 
139
139
  @apply_forward_hook
140
140
  def decode(
141
- self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
142
- ) -> Union[DecoderOutput, torch.FloatTensor]:
141
+ self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
142
+ ) -> Union[DecoderOutput, torch.Tensor]:
143
143
  # also go through quantization layer
144
144
  if not force_not_quantize:
145
- quant, _, _ = self.quantize(h)
145
+ quant, commit_loss, _ = self.quantize(h)
146
146
  elif self.config.lookup_from_codebook:
147
147
  quant = self.quantize.get_codebook_entry(h, shape)
148
+ commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
148
149
  else:
149
150
  quant = h
151
+ commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
150
152
  quant2 = self.post_quant_conv(quant)
151
153
  dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
152
154
 
153
155
  if not return_dict:
154
- return (dec,)
156
+ return dec, commit_loss
155
157
 
156
- return DecoderOutput(sample=dec)
158
+ return DecoderOutput(sample=dec, commit_loss=commit_loss)
157
159
 
158
160
  def forward(
159
- self, sample: torch.FloatTensor, return_dict: bool = True
160
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor, ...]]:
161
+ self, sample: torch.Tensor, return_dict: bool = True
162
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
161
163
  r"""
162
164
  The [`VQModel`] forward method.
163
165
 
164
166
  Args:
165
- sample (`torch.FloatTensor`): Input sample.
167
+ sample (`torch.Tensor`): Input sample.
166
168
  return_dict (`bool`, *optional*, defaults to `True`):
167
169
  Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
168
170
 
@@ -173,9 +175,8 @@ class VQModel(ModelMixin, ConfigMixin):
173
175
  """
174
176
 
175
177
  h = self.encode(sample).latents
176
- dec = self.decode(h).sample
178
+ dec = self.decode(h)
177
179
 
178
180
  if not return_dict:
179
- return (dec,)
180
-
181
- return DecoderOutput(sample=dec)
181
+ return dec.sample, dec.commit_loss
182
+ return dec