diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +33 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +8 -0
  21. diffusers/models/activations.py +23 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +475 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +363 -32
  35. diffusers/models/model_loading_utils.py +177 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_outputs.py +14 -0
  39. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  40. diffusers/models/modeling_utils.py +175 -99
  41. diffusers/models/normalization.py +2 -1
  42. diffusers/models/resnet.py +18 -23
  43. diffusers/models/transformer_temporal.py +3 -3
  44. diffusers/models/transformers/__init__.py +3 -0
  45. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  46. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  47. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  48. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  49. diffusers/models/transformers/prior_transformer.py +7 -7
  50. diffusers/models/transformers/t5_film_transformer.py +17 -19
  51. diffusers/models/transformers/transformer_2d.py +292 -184
  52. diffusers/models/transformers/transformer_temporal.py +10 -10
  53. diffusers/models/unets/unet_1d.py +5 -5
  54. diffusers/models/unets/unet_1d_blocks.py +29 -29
  55. diffusers/models/unets/unet_2d.py +6 -6
  56. diffusers/models/unets/unet_2d_blocks.py +137 -128
  57. diffusers/models/unets/unet_2d_condition.py +19 -15
  58. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  59. diffusers/models/unets/unet_3d_blocks.py +79 -77
  60. diffusers/models/unets/unet_3d_condition.py +13 -9
  61. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  62. diffusers/models/unets/unet_kandinsky3.py +1 -1
  63. diffusers/models/unets/unet_motion_model.py +114 -14
  64. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  65. diffusers/models/unets/unet_stable_cascade.py +16 -13
  66. diffusers/models/upsampling.py +17 -20
  67. diffusers/models/vq_model.py +16 -15
  68. diffusers/pipelines/__init__.py +27 -3
  69. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  70. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  71. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  72. diffusers/pipelines/animatediff/__init__.py +2 -0
  73. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  74. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  75. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  76. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  77. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  78. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  79. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  80. diffusers/pipelines/auto_pipeline.py +21 -17
  81. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  82. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  83. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  84. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  85. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  86. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  87. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  88. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  89. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  90. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  91. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  92. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  93. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  94. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  95. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  96. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  97. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  98. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  99. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  100. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  101. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  102. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  103. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  104. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  105. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  106. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  107. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  108. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  109. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  110. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  111. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  112. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  113. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  114. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  115. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  116. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  117. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  118. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  119. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  120. diffusers/pipelines/free_init_utils.py +39 -38
  121. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  122. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  123. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  124. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  125. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  126. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  127. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  128. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  129. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  130. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  131. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  132. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  133. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  134. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  135. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  136. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  137. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  138. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  139. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  140. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  141. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  142. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  143. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  144. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  145. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  146. diffusers/pipelines/marigold/__init__.py +50 -0
  147. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  148. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  149. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  150. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  151. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  152. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  153. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  154. diffusers/pipelines/pipeline_loading_utils.py +269 -23
  155. diffusers/pipelines/pipeline_utils.py +266 -37
  156. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
  158. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  159. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  160. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  161. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  162. diffusers/pipelines/shap_e/renderer.py +1 -1
  163. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  164. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  165. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  166. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  167. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  168. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  169. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  172. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  173. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  174. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  175. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  176. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  177. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  178. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  179. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  180. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  181. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  182. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  183. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  184. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  185. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  186. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  187. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  188. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  189. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  190. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  191. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  192. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  193. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  194. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  195. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  196. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  197. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  198. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  199. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  200. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  201. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  202. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  203. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  204. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  205. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  206. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  207. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  208. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  209. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  210. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  211. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  212. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  213. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  214. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  215. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  216. diffusers/schedulers/__init__.py +2 -2
  217. diffusers/schedulers/deprecated/__init__.py +1 -1
  218. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  219. diffusers/schedulers/scheduling_amused.py +5 -5
  220. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  221. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  222. diffusers/schedulers/scheduling_ddim.py +22 -24
  223. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  224. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  225. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  226. diffusers/schedulers/scheduling_ddpm.py +20 -22
  227. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  228. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  229. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  230. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  231. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  232. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  236. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  237. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  238. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  239. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  240. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  241. diffusers/schedulers/scheduling_ipndm.py +8 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  244. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  245. diffusers/schedulers/scheduling_lcm.py +21 -23
  246. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  247. diffusers/schedulers/scheduling_pndm.py +20 -20
  248. diffusers/schedulers/scheduling_repaint.py +20 -20
  249. diffusers/schedulers/scheduling_sasolver.py +55 -54
  250. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  251. diffusers/schedulers/scheduling_tcd.py +39 -30
  252. diffusers/schedulers/scheduling_unclip.py +15 -15
  253. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  254. diffusers/schedulers/scheduling_utils.py +14 -5
  255. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  256. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  257. diffusers/training_utils.py +56 -1
  258. diffusers/utils/__init__.py +7 -0
  259. diffusers/utils/doc_utils.py +1 -0
  260. diffusers/utils/dummy_pt_objects.py +75 -0
  261. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  262. diffusers/utils/dynamic_modules_utils.py +24 -11
  263. diffusers/utils/hub_utils.py +3 -2
  264. diffusers/utils/import_utils.py +91 -0
  265. diffusers/utils/loading_utils.py +2 -2
  266. diffusers/utils/logging.py +1 -1
  267. diffusers/utils/peft_utils.py +32 -5
  268. diffusers/utils/state_dict_utils.py +11 -2
  269. diffusers/utils/testing_utils.py +71 -6
  270. diffusers/utils/torch_utils.py +1 -0
  271. diffusers/video_processor.py +113 -0
  272. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
  273. diffusers-0.28.1.dist-info/RECORD +419 -0
  274. diffusers-0.27.2.dist-info/RECORD +0 -399
  275. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
  276. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
  277. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
  278. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,336 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...utils import is_torch_version, logging
21
+ from ..attention import BasicTransformerBlock
22
+ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
23
+ from ..modeling_outputs import Transformer2DModelOutput
24
+ from ..modeling_utils import ModelMixin
25
+ from ..normalization import AdaLayerNormSingle
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
32
+ r"""
33
+ A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
34
+ https://arxiv.org/abs/2403.04692).
35
+
36
+ Parameters:
37
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
38
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
39
+ in_channels (int, defaults to 4): The number of channels in the input.
40
+ out_channels (int, optional):
41
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
42
+ input.
43
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
44
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
45
+ norm_num_groups (int, optional, defaults to 32):
46
+ Number of groups for group normalization within Transformer blocks.
47
+ cross_attention_dim (int, optional):
48
+ The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
49
+ attention_bias (bool, optional, defaults to True):
50
+ Configure if the Transformer blocks' attention should contain a bias parameter.
51
+ sample_size (int, defaults to 128):
52
+ The width of the latent images. This parameter is fixed during training.
53
+ patch_size (int, defaults to 2):
54
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
55
+ activation_fn (str, optional, defaults to "gelu-approximate"):
56
+ Activation function to use in feed-forward networks within Transformer blocks.
57
+ num_embeds_ada_norm (int, optional, defaults to 1000):
58
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
59
+ inference.
60
+ upcast_attention (bool, optional, defaults to False):
61
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
62
+ norm_type (str, optional, defaults to "ada_norm_zero"):
63
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
64
+ norm_elementwise_affine (bool, optional, defaults to False):
65
+ If true, enables element-wise affine parameters in the normalization layers.
66
+ norm_eps (float, optional, defaults to 1e-6):
67
+ A small constant added to the denominator in normalization layers to prevent division by zero.
68
+ interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
69
+ use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
70
+ attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
71
+ caption_channels (int, optional, defaults to None):
72
+ Number of channels to use for projecting the caption embeddings.
73
+ use_linear_projection (bool, optional, defaults to False):
74
+ Deprecated argument. Will be removed in a future version.
75
+ num_vector_embeds (bool, optional, defaults to False):
76
+ Deprecated argument. Will be removed in a future version.
77
+ """
78
+
79
+ _supports_gradient_checkpointing = True
80
+ _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
81
+
82
+ @register_to_config
83
+ def __init__(
84
+ self,
85
+ num_attention_heads: int = 16,
86
+ attention_head_dim: int = 72,
87
+ in_channels: int = 4,
88
+ out_channels: Optional[int] = 8,
89
+ num_layers: int = 28,
90
+ dropout: float = 0.0,
91
+ norm_num_groups: int = 32,
92
+ cross_attention_dim: Optional[int] = 1152,
93
+ attention_bias: bool = True,
94
+ sample_size: int = 128,
95
+ patch_size: int = 2,
96
+ activation_fn: str = "gelu-approximate",
97
+ num_embeds_ada_norm: Optional[int] = 1000,
98
+ upcast_attention: bool = False,
99
+ norm_type: str = "ada_norm_single",
100
+ norm_elementwise_affine: bool = False,
101
+ norm_eps: float = 1e-6,
102
+ interpolation_scale: Optional[int] = None,
103
+ use_additional_conditions: Optional[bool] = None,
104
+ caption_channels: Optional[int] = None,
105
+ attention_type: Optional[str] = "default",
106
+ ):
107
+ super().__init__()
108
+
109
+ # Validate inputs.
110
+ if norm_type != "ada_norm_single":
111
+ raise NotImplementedError(
112
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
113
+ )
114
+ elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
115
+ raise ValueError(
116
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
117
+ )
118
+
119
+ # Set some common variables used across the board.
120
+ self.attention_head_dim = attention_head_dim
121
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
122
+ self.out_channels = in_channels if out_channels is None else out_channels
123
+ if use_additional_conditions is None:
124
+ if sample_size == 128:
125
+ use_additional_conditions = True
126
+ else:
127
+ use_additional_conditions = False
128
+ self.use_additional_conditions = use_additional_conditions
129
+
130
+ self.gradient_checkpointing = False
131
+
132
+ # 2. Initialize the position embedding and transformer blocks.
133
+ self.height = self.config.sample_size
134
+ self.width = self.config.sample_size
135
+
136
+ interpolation_scale = (
137
+ self.config.interpolation_scale
138
+ if self.config.interpolation_scale is not None
139
+ else max(self.config.sample_size // 64, 1)
140
+ )
141
+ self.pos_embed = PatchEmbed(
142
+ height=self.config.sample_size,
143
+ width=self.config.sample_size,
144
+ patch_size=self.config.patch_size,
145
+ in_channels=self.config.in_channels,
146
+ embed_dim=self.inner_dim,
147
+ interpolation_scale=interpolation_scale,
148
+ )
149
+
150
+ self.transformer_blocks = nn.ModuleList(
151
+ [
152
+ BasicTransformerBlock(
153
+ self.inner_dim,
154
+ self.config.num_attention_heads,
155
+ self.config.attention_head_dim,
156
+ dropout=self.config.dropout,
157
+ cross_attention_dim=self.config.cross_attention_dim,
158
+ activation_fn=self.config.activation_fn,
159
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
160
+ attention_bias=self.config.attention_bias,
161
+ upcast_attention=self.config.upcast_attention,
162
+ norm_type=norm_type,
163
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
164
+ norm_eps=self.config.norm_eps,
165
+ attention_type=self.config.attention_type,
166
+ )
167
+ for _ in range(self.config.num_layers)
168
+ ]
169
+ )
170
+
171
+ # 3. Output blocks.
172
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
173
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
174
+ self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
175
+
176
+ self.adaln_single = AdaLayerNormSingle(
177
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
178
+ )
179
+ self.caption_projection = None
180
+ if self.config.caption_channels is not None:
181
+ self.caption_projection = PixArtAlphaTextProjection(
182
+ in_features=self.config.caption_channels, hidden_size=self.inner_dim
183
+ )
184
+
185
+ def _set_gradient_checkpointing(self, module, value=False):
186
+ if hasattr(module, "gradient_checkpointing"):
187
+ module.gradient_checkpointing = value
188
+
189
+ def forward(
190
+ self,
191
+ hidden_states: torch.Tensor,
192
+ encoder_hidden_states: Optional[torch.Tensor] = None,
193
+ timestep: Optional[torch.LongTensor] = None,
194
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
195
+ cross_attention_kwargs: Dict[str, Any] = None,
196
+ attention_mask: Optional[torch.Tensor] = None,
197
+ encoder_attention_mask: Optional[torch.Tensor] = None,
198
+ return_dict: bool = True,
199
+ ):
200
+ """
201
+ The [`PixArtTransformer2DModel`] forward method.
202
+
203
+ Args:
204
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
205
+ Input `hidden_states`.
206
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
207
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
208
+ self-attention.
209
+ timestep (`torch.LongTensor`, *optional*):
210
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
211
+ added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
212
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
213
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
214
+ `self.processor` in
215
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
216
+ attention_mask ( `torch.Tensor`, *optional*):
217
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
218
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
219
+ negative values to the attention scores corresponding to "discard" tokens.
220
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
221
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
222
+
223
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
224
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
225
+
226
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
227
+ above. This bias will be added to the cross-attention scores.
228
+ return_dict (`bool`, *optional*, defaults to `True`):
229
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
230
+ tuple.
231
+
232
+ Returns:
233
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
234
+ `tuple` where the first element is the sample tensor.
235
+ """
236
+ if self.use_additional_conditions and added_cond_kwargs is None:
237
+ raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
238
+
239
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
240
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
241
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
242
+ # expects mask of shape:
243
+ # [batch, key_tokens]
244
+ # adds singleton query_tokens dimension:
245
+ # [batch, 1, key_tokens]
246
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
247
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
248
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
249
+ if attention_mask is not None and attention_mask.ndim == 2:
250
+ # assume that mask is expressed as:
251
+ # (1 = keep, 0 = discard)
252
+ # convert mask into a bias that can be added to attention scores:
253
+ # (keep = +0, discard = -10000.0)
254
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
255
+ attention_mask = attention_mask.unsqueeze(1)
256
+
257
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
258
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
259
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
260
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
261
+
262
+ # 1. Input
263
+ batch_size = hidden_states.shape[0]
264
+ height, width = (
265
+ hidden_states.shape[-2] // self.config.patch_size,
266
+ hidden_states.shape[-1] // self.config.patch_size,
267
+ )
268
+ hidden_states = self.pos_embed(hidden_states)
269
+
270
+ timestep, embedded_timestep = self.adaln_single(
271
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
272
+ )
273
+
274
+ if self.caption_projection is not None:
275
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
276
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
277
+
278
+ # 2. Blocks
279
+ for block in self.transformer_blocks:
280
+ if self.training and self.gradient_checkpointing:
281
+
282
+ def create_custom_forward(module, return_dict=None):
283
+ def custom_forward(*inputs):
284
+ if return_dict is not None:
285
+ return module(*inputs, return_dict=return_dict)
286
+ else:
287
+ return module(*inputs)
288
+
289
+ return custom_forward
290
+
291
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
292
+ hidden_states = torch.utils.checkpoint.checkpoint(
293
+ create_custom_forward(block),
294
+ hidden_states,
295
+ attention_mask,
296
+ encoder_hidden_states,
297
+ encoder_attention_mask,
298
+ timestep,
299
+ cross_attention_kwargs,
300
+ None,
301
+ **ckpt_kwargs,
302
+ )
303
+ else:
304
+ hidden_states = block(
305
+ hidden_states,
306
+ attention_mask=attention_mask,
307
+ encoder_hidden_states=encoder_hidden_states,
308
+ encoder_attention_mask=encoder_attention_mask,
309
+ timestep=timestep,
310
+ cross_attention_kwargs=cross_attention_kwargs,
311
+ class_labels=None,
312
+ )
313
+
314
+ # 3. Output
315
+ shift, scale = (
316
+ self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
317
+ ).chunk(2, dim=1)
318
+ hidden_states = self.norm_out(hidden_states)
319
+ # Modulation
320
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
321
+ hidden_states = self.proj_out(hidden_states)
322
+ hidden_states = hidden_states.squeeze(1)
323
+
324
+ # unpatchify
325
+ hidden_states = hidden_states.reshape(
326
+ shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
327
+ )
328
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
329
+ output = hidden_states.reshape(
330
+ shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
331
+ )
332
+
333
+ if not return_dict:
334
+ return (output,)
335
+
336
+ return Transformer2DModelOutput(sample=output)
@@ -26,11 +26,11 @@ class PriorTransformerOutput(BaseOutput):
26
26
  The output of [`PriorTransformer`].
27
27
 
28
28
  Args:
29
- predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
29
+ predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
30
30
  The predicted CLIP image embedding conditioned on the CLIP text embedding input.
31
31
  """
32
32
 
33
- predicted_image_embedding: torch.FloatTensor
33
+ predicted_image_embedding: torch.Tensor
34
34
 
35
35
 
36
36
  class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
@@ -246,8 +246,8 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
246
246
  self,
247
247
  hidden_states,
248
248
  timestep: Union[torch.Tensor, float, int],
249
- proj_embedding: torch.FloatTensor,
250
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
249
+ proj_embedding: torch.Tensor,
250
+ encoder_hidden_states: Optional[torch.Tensor] = None,
251
251
  attention_mask: Optional[torch.BoolTensor] = None,
252
252
  return_dict: bool = True,
253
253
  ):
@@ -255,13 +255,13 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
255
255
  The [`PriorTransformer`] forward method.
256
256
 
257
257
  Args:
258
- hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
258
+ hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
259
259
  The currently predicted image embeddings.
260
260
  timestep (`torch.LongTensor`):
261
261
  Current denoising step.
262
- proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
262
+ proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
263
263
  Projected embedding vector the denoising process is conditioned on.
264
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
264
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
265
265
  Hidden states of the text embeddings the denoising process is conditioned on.
266
266
  attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
267
267
  Text mask for the text embeddings.
@@ -86,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
86
86
  self.post_dropout = nn.Dropout(p=dropout_rate)
87
87
  self.spec_out = nn.Linear(d_model, input_dims, bias=False)
88
88
 
89
- def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
89
+ def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor:
90
90
  mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
91
91
  return mask.unsqueeze(-3)
92
92
 
@@ -195,13 +195,13 @@ class DecoderLayer(nn.Module):
195
195
 
196
196
  def forward(
197
197
  self,
198
- hidden_states: torch.FloatTensor,
199
- conditioning_emb: Optional[torch.FloatTensor] = None,
200
- attention_mask: Optional[torch.FloatTensor] = None,
198
+ hidden_states: torch.Tensor,
199
+ conditioning_emb: Optional[torch.Tensor] = None,
200
+ attention_mask: Optional[torch.Tensor] = None,
201
201
  encoder_hidden_states: Optional[torch.Tensor] = None,
202
202
  encoder_attention_mask: Optional[torch.Tensor] = None,
203
203
  encoder_decoder_position_bias=None,
204
- ) -> Tuple[torch.FloatTensor]:
204
+ ) -> Tuple[torch.Tensor]:
205
205
  hidden_states = self.layer[0](
206
206
  hidden_states,
207
207
  conditioning_emb=conditioning_emb,
@@ -249,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module):
249
249
 
250
250
  def forward(
251
251
  self,
252
- hidden_states: torch.FloatTensor,
253
- conditioning_emb: Optional[torch.FloatTensor] = None,
254
- attention_mask: Optional[torch.FloatTensor] = None,
255
- ) -> torch.FloatTensor:
252
+ hidden_states: torch.Tensor,
253
+ conditioning_emb: Optional[torch.Tensor] = None,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ ) -> torch.Tensor:
256
256
  # pre_self_attention_layer_norm
257
257
  normed_hidden_states = self.layer_norm(hidden_states)
258
258
 
@@ -292,10 +292,10 @@ class T5LayerCrossAttention(nn.Module):
292
292
 
293
293
  def forward(
294
294
  self,
295
- hidden_states: torch.FloatTensor,
296
- key_value_states: Optional[torch.FloatTensor] = None,
297
- attention_mask: Optional[torch.FloatTensor] = None,
298
- ) -> torch.FloatTensor:
295
+ hidden_states: torch.Tensor,
296
+ key_value_states: Optional[torch.Tensor] = None,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ ) -> torch.Tensor:
299
299
  normed_hidden_states = self.layer_norm(hidden_states)
300
300
  attention_output = self.attention(
301
301
  normed_hidden_states,
@@ -328,9 +328,7 @@ class T5LayerFFCond(nn.Module):
328
328
  self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
329
329
  self.dropout = nn.Dropout(dropout_rate)
330
330
 
331
- def forward(
332
- self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
333
- ) -> torch.FloatTensor:
331
+ def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
334
332
  forwarded_states = self.layer_norm(hidden_states)
335
333
  if conditioning_emb is not None:
336
334
  forwarded_states = self.film(forwarded_states, conditioning_emb)
@@ -361,7 +359,7 @@ class T5DenseGatedActDense(nn.Module):
361
359
  self.dropout = nn.Dropout(dropout_rate)
362
360
  self.act = NewGELUActivation()
363
361
 
364
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
362
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
365
363
  hidden_gelu = self.act(self.wi_0(hidden_states))
366
364
  hidden_linear = self.wi_1(hidden_states)
367
365
  hidden_states = hidden_gelu * hidden_linear
@@ -390,7 +388,7 @@ class T5LayerNorm(nn.Module):
390
388
  self.weight = nn.Parameter(torch.ones(hidden_size))
391
389
  self.variance_epsilon = eps
392
390
 
393
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
391
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
394
392
  # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
395
393
  # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
396
394
  # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
@@ -431,7 +429,7 @@ class T5FiLMLayer(nn.Module):
431
429
  super().__init__()
432
430
  self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
433
431
 
434
- def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
432
+ def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor:
435
433
  emb = self.scale_bias(conditioning_emb)
436
434
  scale, shift = torch.chunk(emb, 2, -1)
437
435
  x = x * (1 + scale) + shift