diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -11,20 +11,25 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
-
16
- from typing import Any, Dict, List, Optional, Union
14
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
15
 
18
16
  import torch
19
17
  import torch.nn as nn
18
+ import torch.nn.functional as F
20
19
 
21
20
  from ...configuration_utils import ConfigMixin, register_to_config
22
- from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
- from ...models.attention import JointTransformerBlock
24
- from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
21
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
22
+ from ...models.attention import FeedForward, JointTransformerBlock
23
+ from ...models.attention_processor import (
24
+ Attention,
25
+ AttentionProcessor,
26
+ FusedJointAttnProcessor2_0,
27
+ JointAttnProcessor2_0,
28
+ )
25
29
  from ...models.modeling_utils import ModelMixin
26
- from ...models.normalization import AdaLayerNormContinuous
30
+ from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
27
31
  from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
32
+ from ...utils.torch_utils import maybe_allow_in_graph
28
33
  from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
29
34
  from ..modeling_outputs import Transformer2DModelOutput
30
35
 
@@ -32,7 +37,75 @@ from ..modeling_outputs import Transformer2DModelOutput
32
37
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
38
 
34
39
 
35
- class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
40
+ @maybe_allow_in_graph
41
+ class SD3SingleTransformerBlock(nn.Module):
42
+ r"""
43
+ A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
44
+
45
+ Reference: https://arxiv.org/abs/2403.03206
46
+
47
+ Parameters:
48
+ dim (`int`): The number of channels in the input and output.
49
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
50
+ attention_head_dim (`int`): The number of channels in each head.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ dim: int,
56
+ num_attention_heads: int,
57
+ attention_head_dim: int,
58
+ ):
59
+ super().__init__()
60
+
61
+ self.norm1 = AdaLayerNormZero(dim)
62
+
63
+ if hasattr(F, "scaled_dot_product_attention"):
64
+ processor = JointAttnProcessor2_0()
65
+ else:
66
+ raise ValueError(
67
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
68
+ )
69
+
70
+ self.attn = Attention(
71
+ query_dim=dim,
72
+ dim_head=attention_head_dim,
73
+ heads=num_attention_heads,
74
+ out_dim=dim,
75
+ bias=True,
76
+ processor=processor,
77
+ eps=1e-6,
78
+ )
79
+
80
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
81
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
82
+
83
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
84
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
85
+ # Attention.
86
+ attn_output = self.attn(
87
+ hidden_states=norm_hidden_states,
88
+ encoder_hidden_states=None,
89
+ )
90
+
91
+ # Process attention outputs for the `hidden_states`.
92
+ attn_output = gate_msa.unsqueeze(1) * attn_output
93
+ hidden_states = hidden_states + attn_output
94
+
95
+ norm_hidden_states = self.norm2(hidden_states)
96
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
97
+
98
+ ff_output = self.ff(norm_hidden_states)
99
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
100
+
101
+ hidden_states = hidden_states + ff_output
102
+
103
+ return hidden_states
104
+
105
+
106
+ class SD3Transformer2DModel(
107
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
108
+ ):
36
109
  """
37
110
  The Transformer model introduced in Stable Diffusion 3.
38
111
 
@@ -69,6 +142,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
69
142
  pooled_projection_dim: int = 2048,
70
143
  out_channels: int = 16,
71
144
  pos_embed_max_size: int = 96,
145
+ dual_attention_layers: Tuple[
146
+ int, ...
147
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
148
+ qk_norm: Optional[str] = None,
72
149
  ):
73
150
  super().__init__()
74
151
  default_out_channels = in_channels
@@ -97,6 +174,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
97
174
  num_attention_heads=self.config.num_attention_heads,
98
175
  attention_head_dim=self.config.attention_head_dim,
99
176
  context_pre_only=i == num_layers - 1,
177
+ qk_norm=qk_norm,
178
+ use_dual_attention=True if i in dual_attention_layers else False,
100
179
  )
101
180
  for i in range(self.config.num_layers)
102
181
  ]
@@ -262,6 +341,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
262
341
  block_controlnet_hidden_states: List = None,
263
342
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
264
343
  return_dict: bool = True,
344
+ skip_layers: Optional[List[int]] = None,
265
345
  ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
266
346
  """
267
347
  The [`SD3Transformer2DModel`] forward method.
@@ -271,11 +351,11 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
271
351
  Input `hidden_states`.
272
352
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
273
353
  Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
274
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
275
- from the embeddings of input conditions.
276
- timestep ( `torch.LongTensor`):
354
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
355
+ Embeddings projected from the embeddings of input conditions.
356
+ timestep (`torch.LongTensor`):
277
357
  Used to indicate denoising step.
278
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
358
+ block_controlnet_hidden_states (`list` of `torch.Tensor`):
279
359
  A list of tensors that if specified are added to the residuals of transformer blocks.
280
360
  joint_attention_kwargs (`dict`, *optional*):
281
361
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
@@ -284,6 +364,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
284
364
  return_dict (`bool`, *optional*, defaults to `True`):
285
365
  Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
286
366
  tuple.
367
+ skip_layers (`list` of `int`, *optional*):
368
+ A list of layer indices to skip during the forward pass.
287
369
 
288
370
  Returns:
289
371
  If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -310,8 +392,17 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
310
392
  temb = self.time_text_embed(timestep, pooled_projections)
311
393
  encoder_hidden_states = self.context_embedder(encoder_hidden_states)
312
394
 
395
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
396
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
397
+ ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
398
+
399
+ joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
400
+
313
401
  for index_block, block in enumerate(self.transformer_blocks):
314
- if self.training and self.gradient_checkpointing:
402
+ # Skip specified layers
403
+ is_skip = True if skip_layers is not None and index_block in skip_layers else False
404
+
405
+ if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
315
406
 
316
407
  def create_custom_forward(module, return_dict=None):
317
408
  def custom_forward(*inputs):
@@ -328,18 +419,21 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
328
419
  hidden_states,
329
420
  encoder_hidden_states,
330
421
  temb,
422
+ joint_attention_kwargs,
331
423
  **ckpt_kwargs,
332
424
  )
333
-
334
- else:
425
+ elif not is_skip:
335
426
  encoder_hidden_states, hidden_states = block(
336
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
427
+ hidden_states=hidden_states,
428
+ encoder_hidden_states=encoder_hidden_states,
429
+ temb=temb,
430
+ joint_attention_kwargs=joint_attention_kwargs,
337
431
  )
338
432
 
339
433
  # controlnet residual
340
434
  if block_controlnet_hidden_states is not None and block.context_pre_only is False:
341
- interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
342
- hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
435
+ interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
436
+ hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
343
437
 
344
438
  hidden_states = self.norm_out(hidden_states, temb)
345
439
  hidden_states = self.proj_out(hidden_states)
@@ -340,7 +340,7 @@ class TransformerSpatioTemporalModel(nn.Module):
340
340
 
341
341
  # 2. Blocks
342
342
  for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
343
- if self.training and self.gradient_checkpointing:
343
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
344
344
  hidden_states = torch.utils.checkpoint.checkpoint(
345
345
  block,
346
346
  hidden_states,
@@ -217,7 +217,7 @@ class MidResTemporalBlock1D(nn.Module):
217
217
  if self.upsample:
218
218
  hidden_states = self.upsample(hidden_states)
219
219
  if self.downsample:
220
- self.downsample = self.downsample(hidden_states)
220
+ hidden_states = self.downsample(hidden_states)
221
221
 
222
222
  return hidden_states
223
223
 
@@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
89
89
  conditioning with `class_embed_type` equal to `None`.
90
90
  """
91
91
 
92
+ _supports_gradient_checkpointing = True
93
+
92
94
  @register_to_config
93
95
  def __init__(
94
96
  self,
@@ -97,6 +99,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
97
99
  out_channels: int = 3,
98
100
  center_input_sample: bool = False,
99
101
  time_embedding_type: str = "positional",
102
+ time_embedding_dim: Optional[int] = None,
100
103
  freq_shift: int = 0,
101
104
  flip_sin_to_cos: bool = True,
102
105
  down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
@@ -122,7 +125,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
122
125
  super().__init__()
123
126
 
124
127
  self.sample_size = sample_size
125
- time_embed_dim = block_out_channels[0] * 4
128
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
126
129
 
127
130
  # Check inputs
128
131
  if len(down_block_types) != len(up_block_types):
@@ -240,6 +243,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
240
243
  self.conv_act = nn.SiLU()
241
244
  self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
242
245
 
246
+ def _set_gradient_checkpointing(self, module, value=False):
247
+ if hasattr(module, "gradient_checkpointing"):
248
+ module.gradient_checkpointing = value
249
+
243
250
  def forward(
244
251
  self,
245
252
  sample: torch.Tensor,
@@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module):
731
731
  self.attentions = nn.ModuleList(attentions)
732
732
  self.resnets = nn.ModuleList(resnets)
733
733
 
734
+ self.gradient_checkpointing = False
735
+
734
736
  def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
735
737
  hidden_states = self.resnets[0](hidden_states, temb)
736
738
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
737
- if attn is not None:
738
- hidden_states = attn(hidden_states, temb=temb)
739
- hidden_states = resnet(hidden_states, temb)
739
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
740
+
741
+ def create_custom_forward(module, return_dict=None):
742
+ def custom_forward(*inputs):
743
+ if return_dict is not None:
744
+ return module(*inputs, return_dict=return_dict)
745
+ else:
746
+ return module(*inputs)
747
+
748
+ return custom_forward
749
+
750
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
751
+ if attn is not None:
752
+ hidden_states = attn(hidden_states, temb=temb)
753
+ hidden_states = torch.utils.checkpoint.checkpoint(
754
+ create_custom_forward(resnet),
755
+ hidden_states,
756
+ temb,
757
+ **ckpt_kwargs,
758
+ )
759
+ else:
760
+ if attn is not None:
761
+ hidden_states = attn(hidden_states, temb=temb)
762
+ hidden_states = resnet(hidden_states, temb)
740
763
 
741
764
  return hidden_states
742
765
 
@@ -859,7 +882,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
859
882
 
860
883
  hidden_states = self.resnets[0](hidden_states, temb)
861
884
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
862
- if self.training and self.gradient_checkpointing:
885
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
863
886
 
864
887
  def create_custom_forward(module, return_dict=None):
865
888
  def custom_forward(*inputs):
@@ -1116,6 +1139,8 @@ class AttnDownBlock2D(nn.Module):
1116
1139
  else:
1117
1140
  self.downsamplers = None
1118
1141
 
1142
+ self.gradient_checkpointing = False
1143
+
1119
1144
  def forward(
1120
1145
  self,
1121
1146
  hidden_states: torch.Tensor,
@@ -1130,9 +1155,30 @@ class AttnDownBlock2D(nn.Module):
1130
1155
  output_states = ()
1131
1156
 
1132
1157
  for resnet, attn in zip(self.resnets, self.attentions):
1133
- hidden_states = resnet(hidden_states, temb)
1134
- hidden_states = attn(hidden_states, **cross_attention_kwargs)
1135
- output_states = output_states + (hidden_states,)
1158
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1159
+
1160
+ def create_custom_forward(module, return_dict=None):
1161
+ def custom_forward(*inputs):
1162
+ if return_dict is not None:
1163
+ return module(*inputs, return_dict=return_dict)
1164
+ else:
1165
+ return module(*inputs)
1166
+
1167
+ return custom_forward
1168
+
1169
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1170
+ hidden_states = torch.utils.checkpoint.checkpoint(
1171
+ create_custom_forward(resnet),
1172
+ hidden_states,
1173
+ temb,
1174
+ **ckpt_kwargs,
1175
+ )
1176
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
1177
+ output_states = output_states + (hidden_states,)
1178
+ else:
1179
+ hidden_states = resnet(hidden_states, temb)
1180
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
1181
+ output_states = output_states + (hidden_states,)
1136
1182
 
1137
1183
  if self.downsamplers is not None:
1138
1184
  for downsampler in self.downsamplers:
@@ -1257,7 +1303,7 @@ class CrossAttnDownBlock2D(nn.Module):
1257
1303
  blocks = list(zip(self.resnets, self.attentions))
1258
1304
 
1259
1305
  for i, (resnet, attn) in enumerate(blocks):
1260
- if self.training and self.gradient_checkpointing:
1306
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1261
1307
 
1262
1308
  def create_custom_forward(module, return_dict=None):
1263
1309
  def custom_forward(*inputs):
@@ -1371,7 +1417,7 @@ class DownBlock2D(nn.Module):
1371
1417
  output_states = ()
1372
1418
 
1373
1419
  for resnet in self.resnets:
1374
- if self.training and self.gradient_checkpointing:
1420
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1375
1421
 
1376
1422
  def create_custom_forward(module):
1377
1423
  def custom_forward(*inputs):
@@ -1859,7 +1905,7 @@ class ResnetDownsampleBlock2D(nn.Module):
1859
1905
  output_states = ()
1860
1906
 
1861
1907
  for resnet in self.resnets:
1862
- if self.training and self.gradient_checkpointing:
1908
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1863
1909
 
1864
1910
  def create_custom_forward(module):
1865
1911
  def custom_forward(*inputs):
@@ -2011,7 +2057,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
2011
2057
  mask = attention_mask
2012
2058
 
2013
2059
  for resnet, attn in zip(self.resnets, self.attentions):
2014
- if self.training and self.gradient_checkpointing:
2060
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2015
2061
 
2016
2062
  def create_custom_forward(module, return_dict=None):
2017
2063
  def custom_forward(*inputs):
@@ -2106,7 +2152,7 @@ class KDownBlock2D(nn.Module):
2106
2152
  output_states = ()
2107
2153
 
2108
2154
  for resnet in self.resnets:
2109
- if self.training and self.gradient_checkpointing:
2155
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2110
2156
 
2111
2157
  def create_custom_forward(module):
2112
2158
  def custom_forward(*inputs):
@@ -2215,7 +2261,7 @@ class KCrossAttnDownBlock2D(nn.Module):
2215
2261
  output_states = ()
2216
2262
 
2217
2263
  for resnet, attn in zip(self.resnets, self.attentions):
2218
- if self.training and self.gradient_checkpointing:
2264
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2219
2265
 
2220
2266
  def create_custom_forward(module, return_dict=None):
2221
2267
  def custom_forward(*inputs):
@@ -2354,6 +2400,7 @@ class AttnUpBlock2D(nn.Module):
2354
2400
  else:
2355
2401
  self.upsamplers = None
2356
2402
 
2403
+ self.gradient_checkpointing = False
2357
2404
  self.resolution_idx = resolution_idx
2358
2405
 
2359
2406
  def forward(
@@ -2375,8 +2422,28 @@ class AttnUpBlock2D(nn.Module):
2375
2422
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
2376
2423
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2377
2424
 
2378
- hidden_states = resnet(hidden_states, temb)
2379
- hidden_states = attn(hidden_states)
2425
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2426
+
2427
+ def create_custom_forward(module, return_dict=None):
2428
+ def custom_forward(*inputs):
2429
+ if return_dict is not None:
2430
+ return module(*inputs, return_dict=return_dict)
2431
+ else:
2432
+ return module(*inputs)
2433
+
2434
+ return custom_forward
2435
+
2436
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2437
+ hidden_states = torch.utils.checkpoint.checkpoint(
2438
+ create_custom_forward(resnet),
2439
+ hidden_states,
2440
+ temb,
2441
+ **ckpt_kwargs,
2442
+ )
2443
+ hidden_states = attn(hidden_states)
2444
+ else:
2445
+ hidden_states = resnet(hidden_states, temb)
2446
+ hidden_states = attn(hidden_states)
2380
2447
 
2381
2448
  if self.upsamplers is not None:
2382
2449
  for upsampler in self.upsamplers:
@@ -2520,7 +2587,7 @@ class CrossAttnUpBlock2D(nn.Module):
2520
2587
 
2521
2588
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2522
2589
 
2523
- if self.training and self.gradient_checkpointing:
2590
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2524
2591
 
2525
2592
  def create_custom_forward(module, return_dict=None):
2526
2593
  def custom_forward(*inputs):
@@ -2653,7 +2720,7 @@ class UpBlock2D(nn.Module):
2653
2720
 
2654
2721
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2655
2722
 
2656
- if self.training and self.gradient_checkpointing:
2723
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2657
2724
 
2658
2725
  def create_custom_forward(module):
2659
2726
  def custom_forward(*inputs):
@@ -3183,7 +3250,7 @@ class ResnetUpsampleBlock2D(nn.Module):
3183
3250
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
3184
3251
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
3185
3252
 
3186
- if self.training and self.gradient_checkpointing:
3253
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3187
3254
 
3188
3255
  def create_custom_forward(module):
3189
3256
  def custom_forward(*inputs):
@@ -3341,7 +3408,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
3341
3408
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
3342
3409
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
3343
3410
 
3344
- if self.training and self.gradient_checkpointing:
3411
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3345
3412
 
3346
3413
  def create_custom_forward(module, return_dict=None):
3347
3414
  def custom_forward(*inputs):
@@ -3444,7 +3511,7 @@ class KUpBlock2D(nn.Module):
3444
3511
  hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
3445
3512
 
3446
3513
  for resnet in self.resnets:
3447
- if self.training and self.gradient_checkpointing:
3514
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3448
3515
 
3449
3516
  def create_custom_forward(module):
3450
3517
  def custom_forward(*inputs):
@@ -3572,7 +3639,7 @@ class KCrossAttnUpBlock2D(nn.Module):
3572
3639
  hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
3573
3640
 
3574
3641
  for resnet, attn in zip(self.resnets, self.attentions):
3575
- if self.training and self.gradient_checkpointing:
3642
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3576
3643
 
3577
3644
  def create_custom_forward(module, return_dict=None):
3578
3645
  def custom_forward(*inputs):
@@ -170,7 +170,7 @@ class UNet2DConditionModel(
170
170
  @register_to_config
171
171
  def __init__(
172
172
  self,
173
- sample_size: Optional[int] = None,
173
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
174
174
  in_channels: int = 4,
175
175
  out_channels: int = 4,
176
176
  center_input_sample: bool = False,
@@ -463,7 +463,6 @@ class UNet2DConditionModel(
463
463
  dropout=dropout,
464
464
  )
465
465
  self.up_blocks.append(up_block)
466
- prev_output_channel = output_channel
467
466
 
468
467
  # out
469
468
  if norm_num_groups is not None:
@@ -599,7 +598,7 @@ class UNet2DConditionModel(
599
598
  )
600
599
  elif encoder_hid_dim_type is not None:
601
600
  raise ValueError(
602
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
601
+ f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'."
603
602
  )
604
603
  else:
605
604
  self.encoder_hid_proj = None
@@ -679,7 +678,9 @@ class UNet2DConditionModel(
679
678
  # Kandinsky 2.2 ControlNet
680
679
  self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
681
680
  elif addition_embed_type is not None:
682
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
681
+ raise ValueError(
682
+ f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'."
683
+ )
683
684
 
684
685
  def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
685
686
  if attention_type in ["gated", "gated-text-image"]:
@@ -990,7 +991,7 @@ class UNet2DConditionModel(
990
991
  image_embs = added_cond_kwargs.get("image_embeds")
991
992
  aug_emb = self.add_embedding(image_embs)
992
993
  elif self.config.addition_embed_type == "image_hint":
993
- # Kandinsky 2.2 - style
994
+ # Kandinsky 2.2 ControlNet - style
994
995
  if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
995
996
  raise ValueError(
996
997
  f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
@@ -1009,7 +1010,7 @@ class UNet2DConditionModel(
1009
1010
  # Kandinsky 2.1 - style
1010
1011
  if "image_embeds" not in added_cond_kwargs:
1011
1012
  raise ValueError(
1012
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1013
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1013
1014
  )
1014
1015
 
1015
1016
  image_embeds = added_cond_kwargs.get("image_embeds")
@@ -1018,14 +1019,14 @@ class UNet2DConditionModel(
1018
1019
  # Kandinsky 2.2 - style
1019
1020
  if "image_embeds" not in added_cond_kwargs:
1020
1021
  raise ValueError(
1021
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1022
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1022
1023
  )
1023
1024
  image_embeds = added_cond_kwargs.get("image_embeds")
1024
1025
  encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1025
1026
  elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1026
1027
  if "image_embeds" not in added_cond_kwargs:
1027
1028
  raise ValueError(
1028
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1029
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1029
1030
  )
1030
1031
 
1031
1032
  if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
@@ -1140,7 +1141,6 @@ class UNet2DConditionModel(
1140
1141
  # 1. time
1141
1142
  t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1142
1143
  emb = self.time_embedding(t_emb, timestep_cond)
1143
- aug_emb = None
1144
1144
 
1145
1145
  class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1146
1146
  if class_emb is not None:
@@ -1078,7 +1078,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
1078
1078
  )
1079
1079
 
1080
1080
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
1081
- if self.training and self.gradient_checkpointing: # TODO
1081
+ if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
1082
1082
 
1083
1083
  def create_custom_forward(module, return_dict=None):
1084
1084
  def custom_forward(*inputs):
@@ -1168,7 +1168,7 @@ class DownBlockSpatioTemporal(nn.Module):
1168
1168
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1169
1169
  output_states = ()
1170
1170
  for resnet in self.resnets:
1171
- if self.training and self.gradient_checkpointing:
1171
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1172
1172
 
1173
1173
  def create_custom_forward(module):
1174
1174
  def custom_forward(*inputs):
@@ -1281,7 +1281,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
1281
1281
 
1282
1282
  blocks = list(zip(self.resnets, self.attentions))
1283
1283
  for resnet, attn in blocks:
1284
- if self.training and self.gradient_checkpointing: # TODO
1284
+ if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
1285
1285
 
1286
1286
  def create_custom_forward(module, return_dict=None):
1287
1287
  def custom_forward(*inputs):
@@ -1375,6 +1375,7 @@ class UpBlockSpatioTemporal(nn.Module):
1375
1375
  res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1376
1376
  temb: Optional[torch.Tensor] = None,
1377
1377
  image_only_indicator: Optional[torch.Tensor] = None,
1378
+ upsample_size: Optional[int] = None,
1378
1379
  ) -> torch.Tensor:
1379
1380
  for resnet in self.resnets:
1380
1381
  # pop res hidden states
@@ -1383,7 +1384,7 @@ class UpBlockSpatioTemporal(nn.Module):
1383
1384
 
1384
1385
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1385
1386
 
1386
- if self.training and self.gradient_checkpointing:
1387
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1387
1388
 
1388
1389
  def create_custom_forward(module):
1389
1390
  def custom_forward(*inputs):
@@ -1415,7 +1416,7 @@ class UpBlockSpatioTemporal(nn.Module):
1415
1416
 
1416
1417
  if self.upsamplers is not None:
1417
1418
  for upsampler in self.upsamplers:
1418
- hidden_states = upsampler(hidden_states)
1419
+ hidden_states = upsampler(hidden_states, upsample_size)
1419
1420
 
1420
1421
  return hidden_states
1421
1422
 
@@ -1485,6 +1486,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
1485
1486
  temb: Optional[torch.Tensor] = None,
1486
1487
  encoder_hidden_states: Optional[torch.Tensor] = None,
1487
1488
  image_only_indicator: Optional[torch.Tensor] = None,
1489
+ upsample_size: Optional[int] = None,
1488
1490
  ) -> torch.Tensor:
1489
1491
  for resnet, attn in zip(self.resnets, self.attentions):
1490
1492
  # pop res hidden states
@@ -1493,7 +1495,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
1493
1495
 
1494
1496
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1495
1497
 
1496
- if self.training and self.gradient_checkpointing: # TODO
1498
+ if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
1497
1499
 
1498
1500
  def create_custom_forward(module, return_dict=None):
1499
1501
  def custom_forward(*inputs):
@@ -1533,6 +1535,6 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
1533
1535
 
1534
1536
  if self.upsamplers is not None:
1535
1537
  for upsampler in self.upsamplers:
1536
- hidden_states = upsampler(hidden_states)
1538
+ hidden_states = upsampler(hidden_states, upsample_size)
1537
1539
 
1538
1540
  return hidden_states