diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +33 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +8 -0
  21. diffusers/models/activations.py +23 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +475 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +363 -32
  35. diffusers/models/model_loading_utils.py +177 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_outputs.py +14 -0
  39. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  40. diffusers/models/modeling_utils.py +175 -99
  41. diffusers/models/normalization.py +2 -1
  42. diffusers/models/resnet.py +18 -23
  43. diffusers/models/transformer_temporal.py +3 -3
  44. diffusers/models/transformers/__init__.py +3 -0
  45. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  46. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  47. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  48. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  49. diffusers/models/transformers/prior_transformer.py +7 -7
  50. diffusers/models/transformers/t5_film_transformer.py +17 -19
  51. diffusers/models/transformers/transformer_2d.py +292 -184
  52. diffusers/models/transformers/transformer_temporal.py +10 -10
  53. diffusers/models/unets/unet_1d.py +5 -5
  54. diffusers/models/unets/unet_1d_blocks.py +29 -29
  55. diffusers/models/unets/unet_2d.py +6 -6
  56. diffusers/models/unets/unet_2d_blocks.py +137 -128
  57. diffusers/models/unets/unet_2d_condition.py +19 -15
  58. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  59. diffusers/models/unets/unet_3d_blocks.py +79 -77
  60. diffusers/models/unets/unet_3d_condition.py +13 -9
  61. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  62. diffusers/models/unets/unet_kandinsky3.py +1 -1
  63. diffusers/models/unets/unet_motion_model.py +114 -14
  64. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  65. diffusers/models/unets/unet_stable_cascade.py +16 -13
  66. diffusers/models/upsampling.py +17 -20
  67. diffusers/models/vq_model.py +16 -15
  68. diffusers/pipelines/__init__.py +27 -3
  69. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  70. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  71. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  72. diffusers/pipelines/animatediff/__init__.py +2 -0
  73. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  74. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  75. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  76. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  77. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  78. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  79. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  80. diffusers/pipelines/auto_pipeline.py +21 -17
  81. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  82. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  83. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  84. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  85. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  86. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  87. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  88. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  89. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  90. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  91. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  92. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  93. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  94. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  95. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  96. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  97. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  98. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  99. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  100. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  101. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  102. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  103. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  104. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  105. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  106. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  107. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  108. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  109. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  110. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  111. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  112. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  113. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  114. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  115. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  116. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  117. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  118. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  119. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  120. diffusers/pipelines/free_init_utils.py +39 -38
  121. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  122. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  123. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  124. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  125. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  126. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  127. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  128. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  129. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  130. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  131. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  132. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  133. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  134. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  135. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  136. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  137. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  138. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  139. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  140. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  141. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  142. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  143. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  144. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  145. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  146. diffusers/pipelines/marigold/__init__.py +50 -0
  147. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  148. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  149. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  150. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  151. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  152. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  153. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  154. diffusers/pipelines/pipeline_loading_utils.py +269 -23
  155. diffusers/pipelines/pipeline_utils.py +266 -37
  156. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
  158. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  159. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  160. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  161. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  162. diffusers/pipelines/shap_e/renderer.py +1 -1
  163. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  164. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  165. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  166. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  167. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  168. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  169. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  172. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  173. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  174. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  175. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  176. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  177. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  178. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  179. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  180. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  181. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  182. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  183. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  184. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  185. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  186. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  187. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  188. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  189. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  190. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  191. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  192. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  193. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  194. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  195. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  196. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  197. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  198. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  199. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  200. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  201. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  202. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  203. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  204. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  205. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  206. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  207. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  208. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  209. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  210. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  211. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  212. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  213. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  214. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  215. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  216. diffusers/schedulers/__init__.py +2 -2
  217. diffusers/schedulers/deprecated/__init__.py +1 -1
  218. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  219. diffusers/schedulers/scheduling_amused.py +5 -5
  220. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  221. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  222. diffusers/schedulers/scheduling_ddim.py +22 -24
  223. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  224. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  225. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  226. diffusers/schedulers/scheduling_ddpm.py +20 -22
  227. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  228. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  229. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  230. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  231. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  232. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  236. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  237. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  238. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  239. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  240. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  241. diffusers/schedulers/scheduling_ipndm.py +8 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  244. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  245. diffusers/schedulers/scheduling_lcm.py +21 -23
  246. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  247. diffusers/schedulers/scheduling_pndm.py +20 -20
  248. diffusers/schedulers/scheduling_repaint.py +20 -20
  249. diffusers/schedulers/scheduling_sasolver.py +55 -54
  250. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  251. diffusers/schedulers/scheduling_tcd.py +39 -30
  252. diffusers/schedulers/scheduling_unclip.py +15 -15
  253. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  254. diffusers/schedulers/scheduling_utils.py +14 -5
  255. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  256. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  257. diffusers/training_utils.py +56 -1
  258. diffusers/utils/__init__.py +7 -0
  259. diffusers/utils/doc_utils.py +1 -0
  260. diffusers/utils/dummy_pt_objects.py +75 -0
  261. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  262. diffusers/utils/dynamic_modules_utils.py +24 -11
  263. diffusers/utils/hub_utils.py +3 -2
  264. diffusers/utils/import_utils.py +91 -0
  265. diffusers/utils/loading_utils.py +2 -2
  266. diffusers/utils/logging.py +1 -1
  267. diffusers/utils/peft_utils.py +32 -5
  268. diffusers/utils/state_dict_utils.py +11 -2
  269. diffusers/utils/testing_utils.py +71 -6
  270. diffusers/utils/torch_utils.py +1 -0
  271. diffusers/video_processor.py +113 -0
  272. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
  273. diffusers-0.28.1.dist-info/RECORD +419 -0
  274. diffusers-0.27.2.dist-info/RECORD +0 -399
  275. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
  276. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
  277. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
  278. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -11,39 +11,30 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from dataclasses import dataclass
15
14
  from typing import Any, Dict, Optional
16
15
 
17
16
  import torch
18
17
  import torch.nn.functional as F
19
18
  from torch import nn
20
19
 
21
- from ...configuration_utils import ConfigMixin, register_to_config
22
- from ...utils import BaseOutput, deprecate, is_torch_version, logging
20
+ from ...configuration_utils import LegacyConfigMixin, register_to_config
21
+ from ...utils import deprecate, is_torch_version, logging
23
22
  from ..attention import BasicTransformerBlock
24
23
  from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
25
- from ..modeling_utils import ModelMixin
24
+ from ..modeling_outputs import Transformer2DModelOutput
25
+ from ..modeling_utils import LegacyModelMixin
26
26
  from ..normalization import AdaLayerNormSingle
27
27
 
28
28
 
29
29
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
30
 
31
31
 
32
- @dataclass
33
- class Transformer2DModelOutput(BaseOutput):
34
- """
35
- The output of [`Transformer2DModel`].
36
-
37
- Args:
38
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
39
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
40
- distributions for the unnoised latent pixels.
41
- """
42
-
43
- sample: torch.FloatTensor
32
+ class Transformer2DModelOutput(Transformer2DModelOutput):
33
+ deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
34
+ deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
44
35
 
45
36
 
46
- class Transformer2DModel(ModelMixin, ConfigMixin):
37
+ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
47
38
  """
48
39
  A 2D Transformer model for image-like data.
49
40
 
@@ -72,6 +63,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
72
63
  """
73
64
 
74
65
  _supports_gradient_checkpointing = True
66
+ _no_split_modules = ["BasicTransformerBlock"]
75
67
 
76
68
  @register_to_config
77
69
  def __init__(
@@ -100,8 +92,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
100
92
  attention_type: str = "default",
101
93
  caption_channels: int = None,
102
94
  interpolation_scale: float = None,
95
+ use_additional_conditions: Optional[bool] = None,
103
96
  ):
104
97
  super().__init__()
98
+
99
+ # Validate inputs.
105
100
  if patch_size is not None:
106
101
  if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
107
102
  raise NotImplementedError(
@@ -112,31 +107,12 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
112
107
  f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
113
108
  )
114
109
 
115
- self.use_linear_projection = use_linear_projection
116
- self.num_attention_heads = num_attention_heads
117
- self.attention_head_dim = attention_head_dim
118
- inner_dim = num_attention_heads * attention_head_dim
119
-
120
- conv_cls = nn.Conv2d
121
- linear_cls = nn.Linear
122
-
123
110
  # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
124
111
  # Define whether input is continuous or discrete depending on configuration
125
112
  self.is_input_continuous = (in_channels is not None) and (patch_size is None)
126
113
  self.is_input_vectorized = num_vector_embeds is not None
127
114
  self.is_input_patches = in_channels is not None and patch_size is not None
128
115
 
129
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
130
- deprecation_message = (
131
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
132
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
133
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
134
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
135
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
136
- )
137
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
138
- norm_type = "ada_norm"
139
-
140
116
  if self.is_input_continuous and self.is_input_vectorized:
141
117
  raise ValueError(
142
118
  f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
@@ -153,104 +129,194 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
153
129
  f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
154
130
  )
155
131
 
156
- # 2. Define input layers
157
- if self.is_input_continuous:
158
- self.in_channels = in_channels
132
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
133
+ deprecation_message = (
134
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
135
+ " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
136
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
137
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
138
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
139
+ )
140
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
141
+ norm_type = "ada_norm"
159
142
 
160
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
161
- if use_linear_projection:
162
- self.proj_in = linear_cls(in_channels, inner_dim)
143
+ # Set some common variables used across the board.
144
+ self.use_linear_projection = use_linear_projection
145
+ self.interpolation_scale = interpolation_scale
146
+ self.caption_channels = caption_channels
147
+ self.num_attention_heads = num_attention_heads
148
+ self.attention_head_dim = attention_head_dim
149
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
150
+ self.in_channels = in_channels
151
+ self.out_channels = in_channels if out_channels is None else out_channels
152
+ self.gradient_checkpointing = False
153
+
154
+ if use_additional_conditions is None:
155
+ if norm_type == "ada_norm_single" and sample_size == 128:
156
+ use_additional_conditions = True
163
157
  else:
164
- self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
158
+ use_additional_conditions = False
159
+ self.use_additional_conditions = use_additional_conditions
160
+
161
+ # 2. Initialize the right blocks.
162
+ # These functions follow a common structure:
163
+ # a. Initialize the input blocks. b. Initialize the transformer blocks.
164
+ # c. Initialize the output blocks and other projection blocks when necessary.
165
+ if self.is_input_continuous:
166
+ self._init_continuous_input(norm_type=norm_type)
165
167
  elif self.is_input_vectorized:
166
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
167
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
168
+ self._init_vectorized_inputs(norm_type=norm_type)
169
+ elif self.is_input_patches:
170
+ self._init_patched_inputs(norm_type=norm_type)
168
171
 
169
- self.height = sample_size
170
- self.width = sample_size
171
- self.num_vector_embeds = num_vector_embeds
172
- self.num_latent_pixels = self.height * self.width
172
+ def _init_continuous_input(self, norm_type):
173
+ self.norm = torch.nn.GroupNorm(
174
+ num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
175
+ )
176
+ if self.use_linear_projection:
177
+ self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
178
+ else:
179
+ self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
173
180
 
174
- self.latent_image_embedding = ImagePositionalEmbeddings(
175
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
176
- )
177
- elif self.is_input_patches:
178
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
181
+ self.transformer_blocks = nn.ModuleList(
182
+ [
183
+ BasicTransformerBlock(
184
+ self.inner_dim,
185
+ self.config.num_attention_heads,
186
+ self.config.attention_head_dim,
187
+ dropout=self.config.dropout,
188
+ cross_attention_dim=self.config.cross_attention_dim,
189
+ activation_fn=self.config.activation_fn,
190
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
191
+ attention_bias=self.config.attention_bias,
192
+ only_cross_attention=self.config.only_cross_attention,
193
+ double_self_attention=self.config.double_self_attention,
194
+ upcast_attention=self.config.upcast_attention,
195
+ norm_type=norm_type,
196
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
197
+ norm_eps=self.config.norm_eps,
198
+ attention_type=self.config.attention_type,
199
+ )
200
+ for _ in range(self.config.num_layers)
201
+ ]
202
+ )
179
203
 
180
- self.height = sample_size
181
- self.width = sample_size
204
+ if self.use_linear_projection:
205
+ self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
206
+ else:
207
+ self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
182
208
 
183
- self.patch_size = patch_size
184
- interpolation_scale = (
185
- interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1)
186
- )
187
- self.pos_embed = PatchEmbed(
188
- height=sample_size,
189
- width=sample_size,
190
- patch_size=patch_size,
191
- in_channels=in_channels,
192
- embed_dim=inner_dim,
193
- interpolation_scale=interpolation_scale,
194
- )
209
+ def _init_vectorized_inputs(self, norm_type):
210
+ assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
211
+ assert (
212
+ self.config.num_vector_embeds is not None
213
+ ), "Transformer2DModel over discrete input must provide num_embed"
214
+
215
+ self.height = self.config.sample_size
216
+ self.width = self.config.sample_size
217
+ self.num_latent_pixels = self.height * self.width
218
+
219
+ self.latent_image_embedding = ImagePositionalEmbeddings(
220
+ num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
221
+ )
195
222
 
196
- # 3. Define transformers blocks
197
223
  self.transformer_blocks = nn.ModuleList(
198
224
  [
199
225
  BasicTransformerBlock(
200
- inner_dim,
201
- num_attention_heads,
202
- attention_head_dim,
203
- dropout=dropout,
204
- cross_attention_dim=cross_attention_dim,
205
- activation_fn=activation_fn,
206
- num_embeds_ada_norm=num_embeds_ada_norm,
207
- attention_bias=attention_bias,
208
- only_cross_attention=only_cross_attention,
209
- double_self_attention=double_self_attention,
210
- upcast_attention=upcast_attention,
226
+ self.inner_dim,
227
+ self.config.num_attention_heads,
228
+ self.config.attention_head_dim,
229
+ dropout=self.config.dropout,
230
+ cross_attention_dim=self.config.cross_attention_dim,
231
+ activation_fn=self.config.activation_fn,
232
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
233
+ attention_bias=self.config.attention_bias,
234
+ only_cross_attention=self.config.only_cross_attention,
235
+ double_self_attention=self.config.double_self_attention,
236
+ upcast_attention=self.config.upcast_attention,
211
237
  norm_type=norm_type,
212
- norm_elementwise_affine=norm_elementwise_affine,
213
- norm_eps=norm_eps,
214
- attention_type=attention_type,
238
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
239
+ norm_eps=self.config.norm_eps,
240
+ attention_type=self.config.attention_type,
215
241
  )
216
- for d in range(num_layers)
242
+ for _ in range(self.config.num_layers)
217
243
  ]
218
244
  )
219
245
 
220
- # 4. Define output layers
221
- self.out_channels = in_channels if out_channels is None else out_channels
222
- if self.is_input_continuous:
223
- # TODO: should use out_channels for continuous projections
224
- if use_linear_projection:
225
- self.proj_out = linear_cls(inner_dim, in_channels)
226
- else:
227
- self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
228
- elif self.is_input_vectorized:
229
- self.norm_out = nn.LayerNorm(inner_dim)
230
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
231
- elif self.is_input_patches and norm_type != "ada_norm_single":
232
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
233
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
234
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
235
- elif self.is_input_patches and norm_type == "ada_norm_single":
236
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
237
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
238
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
239
-
240
- # 5. PixArt-Alpha blocks.
246
+ self.norm_out = nn.LayerNorm(self.inner_dim)
247
+ self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
248
+
249
+ def _init_patched_inputs(self, norm_type):
250
+ assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
251
+
252
+ self.height = self.config.sample_size
253
+ self.width = self.config.sample_size
254
+
255
+ self.patch_size = self.config.patch_size
256
+ interpolation_scale = (
257
+ self.config.interpolation_scale
258
+ if self.config.interpolation_scale is not None
259
+ else max(self.config.sample_size // 64, 1)
260
+ )
261
+ self.pos_embed = PatchEmbed(
262
+ height=self.config.sample_size,
263
+ width=self.config.sample_size,
264
+ patch_size=self.config.patch_size,
265
+ in_channels=self.in_channels,
266
+ embed_dim=self.inner_dim,
267
+ interpolation_scale=interpolation_scale,
268
+ )
269
+
270
+ self.transformer_blocks = nn.ModuleList(
271
+ [
272
+ BasicTransformerBlock(
273
+ self.inner_dim,
274
+ self.config.num_attention_heads,
275
+ self.config.attention_head_dim,
276
+ dropout=self.config.dropout,
277
+ cross_attention_dim=self.config.cross_attention_dim,
278
+ activation_fn=self.config.activation_fn,
279
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
280
+ attention_bias=self.config.attention_bias,
281
+ only_cross_attention=self.config.only_cross_attention,
282
+ double_self_attention=self.config.double_self_attention,
283
+ upcast_attention=self.config.upcast_attention,
284
+ norm_type=norm_type,
285
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
286
+ norm_eps=self.config.norm_eps,
287
+ attention_type=self.config.attention_type,
288
+ )
289
+ for _ in range(self.config.num_layers)
290
+ ]
291
+ )
292
+
293
+ if self.config.norm_type != "ada_norm_single":
294
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
295
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
296
+ self.proj_out_2 = nn.Linear(
297
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
298
+ )
299
+ elif self.config.norm_type == "ada_norm_single":
300
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
301
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
302
+ self.proj_out = nn.Linear(
303
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
304
+ )
305
+
306
+ # PixArt-Alpha blocks.
241
307
  self.adaln_single = None
242
- self.use_additional_conditions = False
243
- if norm_type == "ada_norm_single":
244
- self.use_additional_conditions = self.config.sample_size == 128
308
+ if self.config.norm_type == "ada_norm_single":
245
309
  # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
246
310
  # additional conditions until we find better name
247
- self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
311
+ self.adaln_single = AdaLayerNormSingle(
312
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
313
+ )
248
314
 
249
315
  self.caption_projection = None
250
- if caption_channels is not None:
251
- self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
252
-
253
- self.gradient_checkpointing = False
316
+ if self.caption_channels is not None:
317
+ self.caption_projection = PixArtAlphaTextProjection(
318
+ in_features=self.caption_channels, hidden_size=self.inner_dim
319
+ )
254
320
 
255
321
  def _set_gradient_checkpointing(self, module, value=False):
256
322
  if hasattr(module, "gradient_checkpointing"):
@@ -272,9 +338,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
272
338
  The [`Transformer2DModel`] forward method.
273
339
 
274
340
  Args:
275
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
341
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
276
342
  Input `hidden_states`.
277
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
343
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
278
344
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
279
345
  self-attention.
280
346
  timestep ( `torch.LongTensor`, *optional*):
@@ -308,7 +374,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
308
374
  """
309
375
  if cross_attention_kwargs is not None:
310
376
  if cross_attention_kwargs.get("scale", None) is not None:
311
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
377
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
312
378
  # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
313
379
  # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
314
380
  # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -334,41 +400,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
334
400
 
335
401
  # 1. Input
336
402
  if self.is_input_continuous:
337
- batch, _, height, width = hidden_states.shape
403
+ batch_size, _, height, width = hidden_states.shape
338
404
  residual = hidden_states
339
-
340
- hidden_states = self.norm(hidden_states)
341
- if not self.use_linear_projection:
342
- hidden_states = self.proj_in(hidden_states)
343
- inner_dim = hidden_states.shape[1]
344
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
345
- else:
346
- inner_dim = hidden_states.shape[1]
347
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
348
- hidden_states = self.proj_in(hidden_states)
349
-
405
+ hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
350
406
  elif self.is_input_vectorized:
351
407
  hidden_states = self.latent_image_embedding(hidden_states)
352
408
  elif self.is_input_patches:
353
409
  height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
354
- hidden_states = self.pos_embed(hidden_states)
355
-
356
- if self.adaln_single is not None:
357
- if self.use_additional_conditions and added_cond_kwargs is None:
358
- raise ValueError(
359
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
360
- )
361
- batch_size = hidden_states.shape[0]
362
- timestep, embedded_timestep = self.adaln_single(
363
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
364
- )
410
+ hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
411
+ hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
412
+ )
365
413
 
366
414
  # 2. Blocks
367
- if self.caption_projection is not None:
368
- batch_size = hidden_states.shape[0]
369
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
370
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
371
-
372
415
  for block in self.transformer_blocks:
373
416
  if self.training and self.gradient_checkpointing:
374
417
 
@@ -406,51 +449,116 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
406
449
 
407
450
  # 3. Output
408
451
  if self.is_input_continuous:
409
- if not self.use_linear_projection:
410
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
411
- hidden_states = self.proj_out(hidden_states)
412
- else:
413
- hidden_states = self.proj_out(hidden_states)
414
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
415
-
416
- output = hidden_states + residual
452
+ output = self._get_output_for_continuous_inputs(
453
+ hidden_states=hidden_states,
454
+ residual=residual,
455
+ batch_size=batch_size,
456
+ height=height,
457
+ width=width,
458
+ inner_dim=inner_dim,
459
+ )
417
460
  elif self.is_input_vectorized:
418
- hidden_states = self.norm_out(hidden_states)
419
- logits = self.out(hidden_states)
420
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
421
- logits = logits.permute(0, 2, 1)
461
+ output = self._get_output_for_vectorized_inputs(hidden_states)
462
+ elif self.is_input_patches:
463
+ output = self._get_output_for_patched_inputs(
464
+ hidden_states=hidden_states,
465
+ timestep=timestep,
466
+ class_labels=class_labels,
467
+ embedded_timestep=embedded_timestep,
468
+ height=height,
469
+ width=width,
470
+ )
422
471
 
423
- # log(p(x_0))
424
- output = F.log_softmax(logits.double(), dim=1).float()
472
+ if not return_dict:
473
+ return (output,)
474
+
475
+ return Transformer2DModelOutput(sample=output)
425
476
 
426
- if self.is_input_patches:
427
- if self.config.norm_type != "ada_norm_single":
428
- conditioning = self.transformer_blocks[0].norm1.emb(
429
- timestep, class_labels, hidden_dtype=hidden_states.dtype
477
+ def _operate_on_continuous_inputs(self, hidden_states):
478
+ batch, _, height, width = hidden_states.shape
479
+ hidden_states = self.norm(hidden_states)
480
+
481
+ if not self.use_linear_projection:
482
+ hidden_states = self.proj_in(hidden_states)
483
+ inner_dim = hidden_states.shape[1]
484
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
485
+ else:
486
+ inner_dim = hidden_states.shape[1]
487
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
488
+ hidden_states = self.proj_in(hidden_states)
489
+
490
+ return hidden_states, inner_dim
491
+
492
+ def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
493
+ batch_size = hidden_states.shape[0]
494
+ hidden_states = self.pos_embed(hidden_states)
495
+ embedded_timestep = None
496
+
497
+ if self.adaln_single is not None:
498
+ if self.use_additional_conditions and added_cond_kwargs is None:
499
+ raise ValueError(
500
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
430
501
  )
431
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
432
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
433
- hidden_states = self.proj_out_2(hidden_states)
434
- elif self.config.norm_type == "ada_norm_single":
435
- shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
436
- hidden_states = self.norm_out(hidden_states)
437
- # Modulation
438
- hidden_states = hidden_states * (1 + scale) + shift
439
- hidden_states = self.proj_out(hidden_states)
440
- hidden_states = hidden_states.squeeze(1)
441
-
442
- # unpatchify
443
- if self.adaln_single is None:
444
- height = width = int(hidden_states.shape[1] ** 0.5)
445
- hidden_states = hidden_states.reshape(
446
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
502
+ timestep, embedded_timestep = self.adaln_single(
503
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
447
504
  )
448
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
449
- output = hidden_states.reshape(
450
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
505
+
506
+ if self.caption_projection is not None:
507
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
508
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
509
+
510
+ return hidden_states, encoder_hidden_states, timestep, embedded_timestep
511
+
512
+ def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
513
+ if not self.use_linear_projection:
514
+ hidden_states = (
515
+ hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
516
+ )
517
+ hidden_states = self.proj_out(hidden_states)
518
+ else:
519
+ hidden_states = self.proj_out(hidden_states)
520
+ hidden_states = (
521
+ hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
451
522
  )
452
523
 
453
- if not return_dict:
454
- return (output,)
524
+ output = hidden_states + residual
525
+ return output
455
526
 
456
- return Transformer2DModelOutput(sample=output)
527
+ def _get_output_for_vectorized_inputs(self, hidden_states):
528
+ hidden_states = self.norm_out(hidden_states)
529
+ logits = self.out(hidden_states)
530
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
531
+ logits = logits.permute(0, 2, 1)
532
+ # log(p(x_0))
533
+ output = F.log_softmax(logits.double(), dim=1).float()
534
+ return output
535
+
536
+ def _get_output_for_patched_inputs(
537
+ self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
538
+ ):
539
+ if self.config.norm_type != "ada_norm_single":
540
+ conditioning = self.transformer_blocks[0].norm1.emb(
541
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
542
+ )
543
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
544
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
545
+ hidden_states = self.proj_out_2(hidden_states)
546
+ elif self.config.norm_type == "ada_norm_single":
547
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
548
+ hidden_states = self.norm_out(hidden_states)
549
+ # Modulation
550
+ hidden_states = hidden_states * (1 + scale) + shift
551
+ hidden_states = self.proj_out(hidden_states)
552
+ hidden_states = hidden_states.squeeze(1)
553
+
554
+ # unpatchify
555
+ if self.adaln_single is None:
556
+ height = width = int(hidden_states.shape[1] ** 0.5)
557
+ hidden_states = hidden_states.reshape(
558
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
559
+ )
560
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
561
+ output = hidden_states.reshape(
562
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
563
+ )
564
+ return output