diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -35,12 +35,12 @@ class Transformer2DModelOutput(BaseOutput):
35
35
  The output of [`Transformer2DModel`].
36
36
 
37
37
  Args:
38
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
38
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
39
39
  The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
40
40
  distributions for the unnoised latent pixels.
41
41
  """
42
42
 
43
- sample: torch.FloatTensor
43
+ sample: torch.Tensor
44
44
 
45
45
 
46
46
  class Transformer2DModel(ModelMixin, ConfigMixin):
@@ -72,6 +72,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
72
72
  """
73
73
 
74
74
  _supports_gradient_checkpointing = True
75
+ _no_split_modules = ["BasicTransformerBlock"]
75
76
 
76
77
  @register_to_config
77
78
  def __init__(
@@ -100,8 +101,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
100
101
  attention_type: str = "default",
101
102
  caption_channels: int = None,
102
103
  interpolation_scale: float = None,
104
+ use_additional_conditions: Optional[bool] = None,
103
105
  ):
104
106
  super().__init__()
107
+
108
+ # Validate inputs.
105
109
  if patch_size is not None:
106
110
  if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
107
111
  raise NotImplementedError(
@@ -112,13 +116,22 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
112
116
  f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
113
117
  )
114
118
 
119
+ # Set some common variables used across the board.
115
120
  self.use_linear_projection = use_linear_projection
121
+ self.interpolation_scale = interpolation_scale
122
+ self.caption_channels = caption_channels
116
123
  self.num_attention_heads = num_attention_heads
117
124
  self.attention_head_dim = attention_head_dim
118
- inner_dim = num_attention_heads * attention_head_dim
119
-
120
- conv_cls = nn.Conv2d
121
- linear_cls = nn.Linear
125
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
126
+ self.in_channels = in_channels
127
+ self.out_channels = in_channels if out_channels is None else out_channels
128
+ self.gradient_checkpointing = False
129
+ if use_additional_conditions is None:
130
+ if norm_type == "ada_norm_single" and sample_size == 128:
131
+ use_additional_conditions = True
132
+ else:
133
+ use_additional_conditions = False
134
+ self.use_additional_conditions = use_additional_conditions
122
135
 
123
136
  # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
124
137
  # Define whether input is continuous or discrete depending on configuration
@@ -129,7 +142,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
129
142
  if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
130
143
  deprecation_message = (
131
144
  f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
132
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
145
+ " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
133
146
  " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
134
147
  " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
135
148
  " would be very nice if you could open a Pull request for the `transformer/config.json` file"
@@ -153,104 +166,165 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
153
166
  f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
154
167
  )
155
168
 
156
- # 2. Define input layers
169
+ # 2. Initialize the right blocks.
170
+ # These functions follow a common structure:
171
+ # a. Initialize the input blocks. b. Initialize the transformer blocks.
172
+ # c. Initialize the output blocks and other projection blocks when necessary.
157
173
  if self.is_input_continuous:
158
- self.in_channels = in_channels
159
-
160
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
161
- if use_linear_projection:
162
- self.proj_in = linear_cls(in_channels, inner_dim)
163
- else:
164
- self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
174
+ self._init_continuous_input(norm_type=norm_type)
165
175
  elif self.is_input_vectorized:
166
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
167
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
176
+ self._init_vectorized_inputs(norm_type=norm_type)
177
+ elif self.is_input_patches:
178
+ self._init_patched_inputs(norm_type=norm_type)
168
179
 
169
- self.height = sample_size
170
- self.width = sample_size
171
- self.num_vector_embeds = num_vector_embeds
172
- self.num_latent_pixels = self.height * self.width
180
+ def _init_continuous_input(self, norm_type):
181
+ self.norm = torch.nn.GroupNorm(
182
+ num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
183
+ )
184
+ if self.use_linear_projection:
185
+ self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
186
+ else:
187
+ self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
173
188
 
174
- self.latent_image_embedding = ImagePositionalEmbeddings(
175
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
176
- )
177
- elif self.is_input_patches:
178
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
189
+ self.transformer_blocks = nn.ModuleList(
190
+ [
191
+ BasicTransformerBlock(
192
+ self.inner_dim,
193
+ self.config.num_attention_heads,
194
+ self.config.attention_head_dim,
195
+ dropout=self.config.dropout,
196
+ cross_attention_dim=self.config.cross_attention_dim,
197
+ activation_fn=self.config.activation_fn,
198
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
199
+ attention_bias=self.config.attention_bias,
200
+ only_cross_attention=self.config.only_cross_attention,
201
+ double_self_attention=self.config.double_self_attention,
202
+ upcast_attention=self.config.upcast_attention,
203
+ norm_type=norm_type,
204
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
205
+ norm_eps=self.config.norm_eps,
206
+ attention_type=self.config.attention_type,
207
+ )
208
+ for _ in range(self.config.num_layers)
209
+ ]
210
+ )
179
211
 
180
- self.height = sample_size
181
- self.width = sample_size
212
+ if self.use_linear_projection:
213
+ self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
214
+ else:
215
+ self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
182
216
 
183
- self.patch_size = patch_size
184
- interpolation_scale = (
185
- interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1)
186
- )
187
- self.pos_embed = PatchEmbed(
188
- height=sample_size,
189
- width=sample_size,
190
- patch_size=patch_size,
191
- in_channels=in_channels,
192
- embed_dim=inner_dim,
193
- interpolation_scale=interpolation_scale,
194
- )
217
+ def _init_vectorized_inputs(self, norm_type):
218
+ assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
219
+ assert (
220
+ self.config.num_vector_embeds is not None
221
+ ), "Transformer2DModel over discrete input must provide num_embed"
222
+
223
+ self.height = self.config.sample_size
224
+ self.width = self.config.sample_size
225
+ self.num_latent_pixels = self.height * self.width
226
+
227
+ self.latent_image_embedding = ImagePositionalEmbeddings(
228
+ num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
229
+ )
195
230
 
196
- # 3. Define transformers blocks
197
231
  self.transformer_blocks = nn.ModuleList(
198
232
  [
199
233
  BasicTransformerBlock(
200
- inner_dim,
201
- num_attention_heads,
202
- attention_head_dim,
203
- dropout=dropout,
204
- cross_attention_dim=cross_attention_dim,
205
- activation_fn=activation_fn,
206
- num_embeds_ada_norm=num_embeds_ada_norm,
207
- attention_bias=attention_bias,
208
- only_cross_attention=only_cross_attention,
209
- double_self_attention=double_self_attention,
210
- upcast_attention=upcast_attention,
234
+ self.inner_dim,
235
+ self.config.num_attention_heads,
236
+ self.config.attention_head_dim,
237
+ dropout=self.config.dropout,
238
+ cross_attention_dim=self.config.cross_attention_dim,
239
+ activation_fn=self.config.activation_fn,
240
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
241
+ attention_bias=self.config.attention_bias,
242
+ only_cross_attention=self.config.only_cross_attention,
243
+ double_self_attention=self.config.double_self_attention,
244
+ upcast_attention=self.config.upcast_attention,
211
245
  norm_type=norm_type,
212
- norm_elementwise_affine=norm_elementwise_affine,
213
- norm_eps=norm_eps,
214
- attention_type=attention_type,
246
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
247
+ norm_eps=self.config.norm_eps,
248
+ attention_type=self.config.attention_type,
215
249
  )
216
- for d in range(num_layers)
250
+ for _ in range(self.config.num_layers)
217
251
  ]
218
252
  )
219
253
 
220
- # 4. Define output layers
221
- self.out_channels = in_channels if out_channels is None else out_channels
222
- if self.is_input_continuous:
223
- # TODO: should use out_channels for continuous projections
224
- if use_linear_projection:
225
- self.proj_out = linear_cls(inner_dim, in_channels)
226
- else:
227
- self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
228
- elif self.is_input_vectorized:
229
- self.norm_out = nn.LayerNorm(inner_dim)
230
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
231
- elif self.is_input_patches and norm_type != "ada_norm_single":
232
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
233
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
234
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
235
- elif self.is_input_patches and norm_type == "ada_norm_single":
236
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
237
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
238
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
239
-
240
- # 5. PixArt-Alpha blocks.
254
+ self.norm_out = nn.LayerNorm(self.inner_dim)
255
+ self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
256
+
257
+ def _init_patched_inputs(self, norm_type):
258
+ assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
259
+
260
+ self.height = self.config.sample_size
261
+ self.width = self.config.sample_size
262
+
263
+ self.patch_size = self.config.patch_size
264
+ interpolation_scale = (
265
+ self.config.interpolation_scale
266
+ if self.config.interpolation_scale is not None
267
+ else max(self.config.sample_size // 64, 1)
268
+ )
269
+ self.pos_embed = PatchEmbed(
270
+ height=self.config.sample_size,
271
+ width=self.config.sample_size,
272
+ patch_size=self.config.patch_size,
273
+ in_channels=self.in_channels,
274
+ embed_dim=self.inner_dim,
275
+ interpolation_scale=interpolation_scale,
276
+ )
277
+
278
+ self.transformer_blocks = nn.ModuleList(
279
+ [
280
+ BasicTransformerBlock(
281
+ self.inner_dim,
282
+ self.config.num_attention_heads,
283
+ self.config.attention_head_dim,
284
+ dropout=self.config.dropout,
285
+ cross_attention_dim=self.config.cross_attention_dim,
286
+ activation_fn=self.config.activation_fn,
287
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
288
+ attention_bias=self.config.attention_bias,
289
+ only_cross_attention=self.config.only_cross_attention,
290
+ double_self_attention=self.config.double_self_attention,
291
+ upcast_attention=self.config.upcast_attention,
292
+ norm_type=norm_type,
293
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
294
+ norm_eps=self.config.norm_eps,
295
+ attention_type=self.config.attention_type,
296
+ )
297
+ for _ in range(self.config.num_layers)
298
+ ]
299
+ )
300
+
301
+ if self.config.norm_type != "ada_norm_single":
302
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
303
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
304
+ self.proj_out_2 = nn.Linear(
305
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
306
+ )
307
+ elif self.config.norm_type == "ada_norm_single":
308
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
309
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
310
+ self.proj_out = nn.Linear(
311
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
312
+ )
313
+
314
+ # PixArt-Alpha blocks.
241
315
  self.adaln_single = None
242
- self.use_additional_conditions = False
243
- if norm_type == "ada_norm_single":
244
- self.use_additional_conditions = self.config.sample_size == 128
316
+ if self.config.norm_type == "ada_norm_single":
245
317
  # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
246
318
  # additional conditions until we find better name
247
- self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
319
+ self.adaln_single = AdaLayerNormSingle(
320
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
321
+ )
248
322
 
249
323
  self.caption_projection = None
250
- if caption_channels is not None:
251
- self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
252
-
253
- self.gradient_checkpointing = False
324
+ if self.caption_channels is not None:
325
+ self.caption_projection = PixArtAlphaTextProjection(
326
+ in_features=self.caption_channels, hidden_size=self.inner_dim
327
+ )
254
328
 
255
329
  def _set_gradient_checkpointing(self, module, value=False):
256
330
  if hasattr(module, "gradient_checkpointing"):
@@ -272,9 +346,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
272
346
  The [`Transformer2DModel`] forward method.
273
347
 
274
348
  Args:
275
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
349
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
276
350
  Input `hidden_states`.
277
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
351
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
278
352
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
279
353
  self-attention.
280
354
  timestep ( `torch.LongTensor`, *optional*):
@@ -308,7 +382,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
308
382
  """
309
383
  if cross_attention_kwargs is not None:
310
384
  if cross_attention_kwargs.get("scale", None) is not None:
311
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
385
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
312
386
  # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
313
387
  # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
314
388
  # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -334,41 +408,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
334
408
 
335
409
  # 1. Input
336
410
  if self.is_input_continuous:
337
- batch, _, height, width = hidden_states.shape
411
+ batch_size, _, height, width = hidden_states.shape
338
412
  residual = hidden_states
339
-
340
- hidden_states = self.norm(hidden_states)
341
- if not self.use_linear_projection:
342
- hidden_states = self.proj_in(hidden_states)
343
- inner_dim = hidden_states.shape[1]
344
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
345
- else:
346
- inner_dim = hidden_states.shape[1]
347
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
348
- hidden_states = self.proj_in(hidden_states)
349
-
413
+ hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
350
414
  elif self.is_input_vectorized:
351
415
  hidden_states = self.latent_image_embedding(hidden_states)
352
416
  elif self.is_input_patches:
353
417
  height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
354
- hidden_states = self.pos_embed(hidden_states)
355
-
356
- if self.adaln_single is not None:
357
- if self.use_additional_conditions and added_cond_kwargs is None:
358
- raise ValueError(
359
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
360
- )
361
- batch_size = hidden_states.shape[0]
362
- timestep, embedded_timestep = self.adaln_single(
363
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
364
- )
418
+ hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
419
+ hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
420
+ )
365
421
 
366
422
  # 2. Blocks
367
- if self.caption_projection is not None:
368
- batch_size = hidden_states.shape[0]
369
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
370
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
371
-
372
423
  for block in self.transformer_blocks:
373
424
  if self.training and self.gradient_checkpointing:
374
425
 
@@ -406,51 +457,116 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
406
457
 
407
458
  # 3. Output
408
459
  if self.is_input_continuous:
409
- if not self.use_linear_projection:
410
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
411
- hidden_states = self.proj_out(hidden_states)
412
- else:
413
- hidden_states = self.proj_out(hidden_states)
414
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
415
-
416
- output = hidden_states + residual
460
+ output = self._get_output_for_continuous_inputs(
461
+ hidden_states=hidden_states,
462
+ residual=residual,
463
+ batch_size=batch_size,
464
+ height=height,
465
+ width=width,
466
+ inner_dim=inner_dim,
467
+ )
417
468
  elif self.is_input_vectorized:
418
- hidden_states = self.norm_out(hidden_states)
419
- logits = self.out(hidden_states)
420
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
421
- logits = logits.permute(0, 2, 1)
469
+ output = self._get_output_for_vectorized_inputs(hidden_states)
470
+ elif self.is_input_patches:
471
+ output = self._get_output_for_patched_inputs(
472
+ hidden_states=hidden_states,
473
+ timestep=timestep,
474
+ class_labels=class_labels,
475
+ embedded_timestep=embedded_timestep,
476
+ height=height,
477
+ width=width,
478
+ )
479
+
480
+ if not return_dict:
481
+ return (output,)
422
482
 
423
- # log(p(x_0))
424
- output = F.log_softmax(logits.double(), dim=1).float()
483
+ return Transformer2DModelOutput(sample=output)
484
+
485
+ def _operate_on_continuous_inputs(self, hidden_states):
486
+ batch, _, height, width = hidden_states.shape
487
+ hidden_states = self.norm(hidden_states)
488
+
489
+ if not self.use_linear_projection:
490
+ hidden_states = self.proj_in(hidden_states)
491
+ inner_dim = hidden_states.shape[1]
492
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
493
+ else:
494
+ inner_dim = hidden_states.shape[1]
495
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
496
+ hidden_states = self.proj_in(hidden_states)
425
497
 
426
- if self.is_input_patches:
427
- if self.config.norm_type != "ada_norm_single":
428
- conditioning = self.transformer_blocks[0].norm1.emb(
429
- timestep, class_labels, hidden_dtype=hidden_states.dtype
498
+ return hidden_states, inner_dim
499
+
500
+ def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
501
+ batch_size = hidden_states.shape[0]
502
+ hidden_states = self.pos_embed(hidden_states)
503
+ embedded_timestep = None
504
+
505
+ if self.adaln_single is not None:
506
+ if self.use_additional_conditions and added_cond_kwargs is None:
507
+ raise ValueError(
508
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
430
509
  )
431
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
432
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
433
- hidden_states = self.proj_out_2(hidden_states)
434
- elif self.config.norm_type == "ada_norm_single":
435
- shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
436
- hidden_states = self.norm_out(hidden_states)
437
- # Modulation
438
- hidden_states = hidden_states * (1 + scale) + shift
439
- hidden_states = self.proj_out(hidden_states)
440
- hidden_states = hidden_states.squeeze(1)
441
-
442
- # unpatchify
443
- if self.adaln_single is None:
444
- height = width = int(hidden_states.shape[1] ** 0.5)
445
- hidden_states = hidden_states.reshape(
446
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
510
+ timestep, embedded_timestep = self.adaln_single(
511
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
512
+ )
513
+
514
+ if self.caption_projection is not None:
515
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
516
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
517
+
518
+ return hidden_states, encoder_hidden_states, timestep, embedded_timestep
519
+
520
+ def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
521
+ if not self.use_linear_projection:
522
+ hidden_states = (
523
+ hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
447
524
  )
448
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
449
- output = hidden_states.reshape(
450
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
525
+ hidden_states = self.proj_out(hidden_states)
526
+ else:
527
+ hidden_states = self.proj_out(hidden_states)
528
+ hidden_states = (
529
+ hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
451
530
  )
452
531
 
453
- if not return_dict:
454
- return (output,)
532
+ output = hidden_states + residual
533
+ return output
455
534
 
456
- return Transformer2DModelOutput(sample=output)
535
+ def _get_output_for_vectorized_inputs(self, hidden_states):
536
+ hidden_states = self.norm_out(hidden_states)
537
+ logits = self.out(hidden_states)
538
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
539
+ logits = logits.permute(0, 2, 1)
540
+ # log(p(x_0))
541
+ output = F.log_softmax(logits.double(), dim=1).float()
542
+ return output
543
+
544
+ def _get_output_for_patched_inputs(
545
+ self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
546
+ ):
547
+ if self.config.norm_type != "ada_norm_single":
548
+ conditioning = self.transformer_blocks[0].norm1.emb(
549
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
550
+ )
551
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
552
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
553
+ hidden_states = self.proj_out_2(hidden_states)
554
+ elif self.config.norm_type == "ada_norm_single":
555
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
556
+ hidden_states = self.norm_out(hidden_states)
557
+ # Modulation
558
+ hidden_states = hidden_states * (1 + scale) + shift
559
+ hidden_states = self.proj_out(hidden_states)
560
+ hidden_states = hidden_states.squeeze(1)
561
+
562
+ # unpatchify
563
+ if self.adaln_single is None:
564
+ height = width = int(hidden_states.shape[1] ** 0.5)
565
+ hidden_states = hidden_states.reshape(
566
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
567
+ )
568
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
569
+ output = hidden_states.reshape(
570
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
571
+ )
572
+ return output
@@ -31,11 +31,11 @@ class TransformerTemporalModelOutput(BaseOutput):
31
31
  The output of [`TransformerTemporalModel`].
32
32
 
33
33
  Args:
34
- sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
34
+ sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
35
35
  The hidden states output conditioned on `encoder_hidden_states` input.
36
36
  """
37
37
 
38
- sample: torch.FloatTensor
38
+ sample: torch.Tensor
39
39
 
40
40
 
41
41
  class TransformerTemporalModel(ModelMixin, ConfigMixin):
@@ -120,7 +120,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
120
120
 
121
121
  def forward(
122
122
  self,
123
- hidden_states: torch.FloatTensor,
123
+ hidden_states: torch.Tensor,
124
124
  encoder_hidden_states: Optional[torch.LongTensor] = None,
125
125
  timestep: Optional[torch.LongTensor] = None,
126
126
  class_labels: torch.LongTensor = None,
@@ -132,7 +132,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
132
132
  The [`TransformerTemporal`] forward method.
133
133
 
134
134
  Args:
135
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
135
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
136
136
  Input hidden_states.
137
137
  encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
138
138
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
@@ -283,7 +283,7 @@ class TransformerSpatioTemporalModel(nn.Module):
283
283
  ):
284
284
  """
285
285
  Args:
286
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
286
+ hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
287
287
  Input hidden_states.
288
288
  num_frames (`int`):
289
289
  The number of frames to be processed per batch. This is used to reshape the hidden states.
@@ -294,8 +294,8 @@ class TransformerSpatioTemporalModel(nn.Module):
294
294
  A tensor indicating whether the input contains only images. 1 indicates that the input contains only
295
295
  images, 0 indicates that the input contains video frames.
296
296
  return_dict (`bool`, *optional*, defaults to `True`):
297
- Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
298
- tuple.
297
+ Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
298
+ plain tuple.
299
299
 
300
300
  Returns:
301
301
  [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
@@ -311,10 +311,10 @@ class TransformerSpatioTemporalModel(nn.Module):
311
311
  time_context_first_timestep = time_context[None, :].reshape(
312
312
  batch_size, num_frames, -1, time_context.shape[-1]
313
313
  )[:, 0]
314
- time_context = time_context_first_timestep[None, :].broadcast_to(
315
- height * width, batch_size, 1, time_context.shape[-1]
314
+ time_context = time_context_first_timestep[:, None].broadcast_to(
315
+ batch_size, height * width, time_context.shape[-2], time_context.shape[-1]
316
316
  )
317
- time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
317
+ time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1])
318
318
 
319
319
  residual = hidden_states
320
320
 
@@ -31,11 +31,11 @@ class UNet1DOutput(BaseOutput):
31
31
  The output of [`UNet1DModel`].
32
32
 
33
33
  Args:
34
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
34
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, sample_size)`):
35
35
  The hidden states output from the last layer of the model.
36
36
  """
37
37
 
38
- sample: torch.FloatTensor
38
+ sample: torch.Tensor
39
39
 
40
40
 
41
41
  class UNet1DModel(ModelMixin, ConfigMixin):
@@ -194,7 +194,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
194
194
 
195
195
  def forward(
196
196
  self,
197
- sample: torch.FloatTensor,
197
+ sample: torch.Tensor,
198
198
  timestep: Union[torch.Tensor, float, int],
199
199
  return_dict: bool = True,
200
200
  ) -> Union[UNet1DOutput, Tuple]:
@@ -202,9 +202,9 @@ class UNet1DModel(ModelMixin, ConfigMixin):
202
202
  The [`UNet1DModel`] forward method.
203
203
 
204
204
  Args:
205
- sample (`torch.FloatTensor`):
205
+ sample (`torch.Tensor`):
206
206
  The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
207
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
207
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
208
208
  return_dict (`bool`, *optional*, defaults to `True`):
209
209
  Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
210
210