diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -156,9 +156,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
156
156
 
157
157
  # define temporal positional embedding
158
158
  temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
159
- inner_dim, torch.arange(0, video_length).unsqueeze(1)
159
+ inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
160
160
  ) # 1152 hidden size
161
- self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
161
+ self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
162
162
 
163
163
  self.gradient_checkpointing = False
164
164
 
@@ -238,7 +238,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
238
238
  for i, (spatial_block, temp_block) in enumerate(
239
239
  zip(self.transformer_blocks, self.temporal_transformer_blocks)
240
240
  ):
241
- if self.training and self.gradient_checkpointing:
241
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
242
242
  hidden_states = torch.utils.checkpoint.checkpoint(
243
243
  spatial_block,
244
244
  hidden_states,
@@ -271,7 +271,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
271
271
  if i == 0 and num_frame > 1:
272
272
  hidden_states = hidden_states + self.temp_pos_embed
273
273
 
274
- if self.training and self.gradient_checkpointing:
274
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
275
275
  hidden_states = torch.utils.checkpoint.checkpoint(
276
276
  temp_block,
277
277
  hidden_states,
@@ -386,7 +386,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
386
386
 
387
387
  # 2. Blocks
388
388
  for block in self.transformer_blocks:
389
- if self.training and self.gradient_checkpointing:
389
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
390
390
 
391
391
  def create_custom_forward(module, return_dict=None):
392
392
  def custom_forward(*inputs):
@@ -0,0 +1,488 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...loaders import PeftAdapterMixin
22
+ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
23
+ from ..attention_processor import (
24
+ Attention,
25
+ AttentionProcessor,
26
+ AttnProcessor2_0,
27
+ SanaLinearAttnProcessor2_0,
28
+ )
29
+ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
30
+ from ..modeling_outputs import Transformer2DModelOutput
31
+ from ..modeling_utils import ModelMixin
32
+ from ..normalization import AdaLayerNormSingle, RMSNorm
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class GLUMBConv(nn.Module):
39
+ def __init__(
40
+ self,
41
+ in_channels: int,
42
+ out_channels: int,
43
+ expand_ratio: float = 4,
44
+ norm_type: Optional[str] = None,
45
+ residual_connection: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ hidden_channels = int(expand_ratio * in_channels)
50
+ self.norm_type = norm_type
51
+ self.residual_connection = residual_connection
52
+
53
+ self.nonlinearity = nn.SiLU()
54
+ self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
55
+ self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
56
+ self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
57
+
58
+ self.norm = None
59
+ if norm_type == "rms_norm":
60
+ self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
61
+
62
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
63
+ if self.residual_connection:
64
+ residual = hidden_states
65
+
66
+ hidden_states = self.conv_inverted(hidden_states)
67
+ hidden_states = self.nonlinearity(hidden_states)
68
+
69
+ hidden_states = self.conv_depth(hidden_states)
70
+ hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
71
+ hidden_states = hidden_states * self.nonlinearity(gate)
72
+
73
+ hidden_states = self.conv_point(hidden_states)
74
+
75
+ if self.norm_type == "rms_norm":
76
+ # move channel to the last dimension so we apply RMSnorm across channel dimension
77
+ hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
78
+
79
+ if self.residual_connection:
80
+ hidden_states = hidden_states + residual
81
+
82
+ return hidden_states
83
+
84
+
85
+ class SanaTransformerBlock(nn.Module):
86
+ r"""
87
+ Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ dim: int = 2240,
93
+ num_attention_heads: int = 70,
94
+ attention_head_dim: int = 32,
95
+ dropout: float = 0.0,
96
+ num_cross_attention_heads: Optional[int] = 20,
97
+ cross_attention_head_dim: Optional[int] = 112,
98
+ cross_attention_dim: Optional[int] = 2240,
99
+ attention_bias: bool = True,
100
+ norm_elementwise_affine: bool = False,
101
+ norm_eps: float = 1e-6,
102
+ attention_out_bias: bool = True,
103
+ mlp_ratio: float = 2.5,
104
+ ) -> None:
105
+ super().__init__()
106
+
107
+ # 1. Self Attention
108
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
109
+ self.attn1 = Attention(
110
+ query_dim=dim,
111
+ heads=num_attention_heads,
112
+ dim_head=attention_head_dim,
113
+ dropout=dropout,
114
+ bias=attention_bias,
115
+ cross_attention_dim=None,
116
+ processor=SanaLinearAttnProcessor2_0(),
117
+ )
118
+
119
+ # 2. Cross Attention
120
+ if cross_attention_dim is not None:
121
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
122
+ self.attn2 = Attention(
123
+ query_dim=dim,
124
+ cross_attention_dim=cross_attention_dim,
125
+ heads=num_cross_attention_heads,
126
+ dim_head=cross_attention_head_dim,
127
+ dropout=dropout,
128
+ bias=True,
129
+ out_bias=attention_out_bias,
130
+ processor=AttnProcessor2_0(),
131
+ )
132
+
133
+ # 3. Feed-forward
134
+ self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
135
+
136
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
137
+
138
+ def forward(
139
+ self,
140
+ hidden_states: torch.Tensor,
141
+ attention_mask: Optional[torch.Tensor] = None,
142
+ encoder_hidden_states: Optional[torch.Tensor] = None,
143
+ encoder_attention_mask: Optional[torch.Tensor] = None,
144
+ timestep: Optional[torch.LongTensor] = None,
145
+ height: int = None,
146
+ width: int = None,
147
+ ) -> torch.Tensor:
148
+ batch_size = hidden_states.shape[0]
149
+
150
+ # 1. Modulation
151
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
152
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
153
+ ).chunk(6, dim=1)
154
+
155
+ # 2. Self Attention
156
+ norm_hidden_states = self.norm1(hidden_states)
157
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
158
+ norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
159
+
160
+ attn_output = self.attn1(norm_hidden_states)
161
+ hidden_states = hidden_states + gate_msa * attn_output
162
+
163
+ # 3. Cross Attention
164
+ if self.attn2 is not None:
165
+ attn_output = self.attn2(
166
+ hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states,
168
+ attention_mask=encoder_attention_mask,
169
+ )
170
+ hidden_states = attn_output + hidden_states
171
+
172
+ # 4. Feed-forward
173
+ norm_hidden_states = self.norm2(hidden_states)
174
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
175
+
176
+ norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2)
177
+ ff_output = self.ff(norm_hidden_states)
178
+ ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)
179
+ hidden_states = hidden_states + gate_mlp * ff_output
180
+
181
+ return hidden_states
182
+
183
+
184
+ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
185
+ r"""
186
+ A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
187
+
188
+ Args:
189
+ in_channels (`int`, defaults to `32`):
190
+ The number of channels in the input.
191
+ out_channels (`int`, *optional*, defaults to `32`):
192
+ The number of channels in the output.
193
+ num_attention_heads (`int`, defaults to `70`):
194
+ The number of heads to use for multi-head attention.
195
+ attention_head_dim (`int`, defaults to `32`):
196
+ The number of channels in each head.
197
+ num_layers (`int`, defaults to `20`):
198
+ The number of layers of Transformer blocks to use.
199
+ num_cross_attention_heads (`int`, *optional*, defaults to `20`):
200
+ The number of heads to use for cross-attention.
201
+ cross_attention_head_dim (`int`, *optional*, defaults to `112`):
202
+ The number of channels in each head for cross-attention.
203
+ cross_attention_dim (`int`, *optional*, defaults to `2240`):
204
+ The number of channels in the cross-attention output.
205
+ caption_channels (`int`, defaults to `2304`):
206
+ The number of channels in the caption embeddings.
207
+ mlp_ratio (`float`, defaults to `2.5`):
208
+ The expansion ratio to use in the GLUMBConv layer.
209
+ dropout (`float`, defaults to `0.0`):
210
+ The dropout probability.
211
+ attention_bias (`bool`, defaults to `False`):
212
+ Whether to use bias in the attention layer.
213
+ sample_size (`int`, defaults to `32`):
214
+ The base size of the input latent.
215
+ patch_size (`int`, defaults to `1`):
216
+ The size of the patches to use in the patch embedding layer.
217
+ norm_elementwise_affine (`bool`, defaults to `False`):
218
+ Whether to use elementwise affinity in the normalization layer.
219
+ norm_eps (`float`, defaults to `1e-6`):
220
+ The epsilon value for the normalization layer.
221
+ """
222
+
223
+ _supports_gradient_checkpointing = True
224
+ _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
225
+
226
+ @register_to_config
227
+ def __init__(
228
+ self,
229
+ in_channels: int = 32,
230
+ out_channels: Optional[int] = 32,
231
+ num_attention_heads: int = 70,
232
+ attention_head_dim: int = 32,
233
+ num_layers: int = 20,
234
+ num_cross_attention_heads: Optional[int] = 20,
235
+ cross_attention_head_dim: Optional[int] = 112,
236
+ cross_attention_dim: Optional[int] = 2240,
237
+ caption_channels: int = 2304,
238
+ mlp_ratio: float = 2.5,
239
+ dropout: float = 0.0,
240
+ attention_bias: bool = False,
241
+ sample_size: int = 32,
242
+ patch_size: int = 1,
243
+ norm_elementwise_affine: bool = False,
244
+ norm_eps: float = 1e-6,
245
+ interpolation_scale: Optional[int] = None,
246
+ ) -> None:
247
+ super().__init__()
248
+
249
+ out_channels = out_channels or in_channels
250
+ inner_dim = num_attention_heads * attention_head_dim
251
+
252
+ # 1. Patch Embedding
253
+ interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
254
+ self.patch_embed = PatchEmbed(
255
+ height=sample_size,
256
+ width=sample_size,
257
+ patch_size=patch_size,
258
+ in_channels=in_channels,
259
+ embed_dim=inner_dim,
260
+ interpolation_scale=interpolation_scale,
261
+ )
262
+
263
+ # 2. Additional condition embeddings
264
+ self.time_embed = AdaLayerNormSingle(inner_dim)
265
+
266
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
267
+ self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
268
+
269
+ # 3. Transformer blocks
270
+ self.transformer_blocks = nn.ModuleList(
271
+ [
272
+ SanaTransformerBlock(
273
+ inner_dim,
274
+ num_attention_heads,
275
+ attention_head_dim,
276
+ dropout=dropout,
277
+ num_cross_attention_heads=num_cross_attention_heads,
278
+ cross_attention_head_dim=cross_attention_head_dim,
279
+ cross_attention_dim=cross_attention_dim,
280
+ attention_bias=attention_bias,
281
+ norm_elementwise_affine=norm_elementwise_affine,
282
+ norm_eps=norm_eps,
283
+ mlp_ratio=mlp_ratio,
284
+ )
285
+ for _ in range(num_layers)
286
+ ]
287
+ )
288
+
289
+ # 4. Output blocks
290
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
291
+
292
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
293
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
294
+
295
+ self.gradient_checkpointing = False
296
+
297
+ def _set_gradient_checkpointing(self, module, value=False):
298
+ if hasattr(module, "gradient_checkpointing"):
299
+ module.gradient_checkpointing = value
300
+
301
+ @property
302
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
303
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
304
+ r"""
305
+ Returns:
306
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
307
+ indexed by its weight name.
308
+ """
309
+ # set recursively
310
+ processors = {}
311
+
312
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
313
+ if hasattr(module, "get_processor"):
314
+ processors[f"{name}.processor"] = module.get_processor()
315
+
316
+ for sub_name, child in module.named_children():
317
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
318
+
319
+ return processors
320
+
321
+ for name, module in self.named_children():
322
+ fn_recursive_add_processors(name, module, processors)
323
+
324
+ return processors
325
+
326
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
327
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
328
+ r"""
329
+ Sets the attention processor to use to compute attention.
330
+
331
+ Parameters:
332
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
333
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
334
+ for **all** `Attention` layers.
335
+
336
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
337
+ processor. This is strongly recommended when setting trainable attention processors.
338
+
339
+ """
340
+ count = len(self.attn_processors.keys())
341
+
342
+ if isinstance(processor, dict) and len(processor) != count:
343
+ raise ValueError(
344
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
345
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
346
+ )
347
+
348
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
349
+ if hasattr(module, "set_processor"):
350
+ if not isinstance(processor, dict):
351
+ module.set_processor(processor)
352
+ else:
353
+ module.set_processor(processor.pop(f"{name}.processor"))
354
+
355
+ for sub_name, child in module.named_children():
356
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
357
+
358
+ for name, module in self.named_children():
359
+ fn_recursive_attn_processor(name, module, processor)
360
+
361
+ def forward(
362
+ self,
363
+ hidden_states: torch.Tensor,
364
+ encoder_hidden_states: torch.Tensor,
365
+ timestep: torch.LongTensor,
366
+ encoder_attention_mask: Optional[torch.Tensor] = None,
367
+ attention_mask: Optional[torch.Tensor] = None,
368
+ attention_kwargs: Optional[Dict[str, Any]] = None,
369
+ return_dict: bool = True,
370
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
371
+ if attention_kwargs is not None:
372
+ attention_kwargs = attention_kwargs.copy()
373
+ lora_scale = attention_kwargs.pop("scale", 1.0)
374
+ else:
375
+ lora_scale = 1.0
376
+
377
+ if USE_PEFT_BACKEND:
378
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
379
+ scale_lora_layers(self, lora_scale)
380
+ else:
381
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
382
+ logger.warning(
383
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
384
+ )
385
+
386
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
387
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
388
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
389
+ # expects mask of shape:
390
+ # [batch, key_tokens]
391
+ # adds singleton query_tokens dimension:
392
+ # [batch, 1, key_tokens]
393
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
394
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
395
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
396
+ if attention_mask is not None and attention_mask.ndim == 2:
397
+ # assume that mask is expressed as:
398
+ # (1 = keep, 0 = discard)
399
+ # convert mask into a bias that can be added to attention scores:
400
+ # (keep = +0, discard = -10000.0)
401
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
402
+ attention_mask = attention_mask.unsqueeze(1)
403
+
404
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
405
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
406
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
407
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
408
+
409
+ # 1. Input
410
+ batch_size, num_channels, height, width = hidden_states.shape
411
+ p = self.config.patch_size
412
+ post_patch_height, post_patch_width = height // p, width // p
413
+
414
+ hidden_states = self.patch_embed(hidden_states)
415
+
416
+ timestep, embedded_timestep = self.time_embed(
417
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
418
+ )
419
+
420
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
421
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
422
+
423
+ encoder_hidden_states = self.caption_norm(encoder_hidden_states)
424
+
425
+ # 2. Transformer blocks
426
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
427
+
428
+ def create_custom_forward(module, return_dict=None):
429
+ def custom_forward(*inputs):
430
+ if return_dict is not None:
431
+ return module(*inputs, return_dict=return_dict)
432
+ else:
433
+ return module(*inputs)
434
+
435
+ return custom_forward
436
+
437
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
438
+
439
+ for block in self.transformer_blocks:
440
+ hidden_states = torch.utils.checkpoint.checkpoint(
441
+ create_custom_forward(block),
442
+ hidden_states,
443
+ attention_mask,
444
+ encoder_hidden_states,
445
+ encoder_attention_mask,
446
+ timestep,
447
+ post_patch_height,
448
+ post_patch_width,
449
+ **ckpt_kwargs,
450
+ )
451
+
452
+ else:
453
+ for block in self.transformer_blocks:
454
+ hidden_states = block(
455
+ hidden_states,
456
+ attention_mask,
457
+ encoder_hidden_states,
458
+ encoder_attention_mask,
459
+ timestep,
460
+ post_patch_height,
461
+ post_patch_width,
462
+ )
463
+
464
+ # 3. Normalization
465
+ shift, scale = (
466
+ self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
467
+ ).chunk(2, dim=1)
468
+ hidden_states = self.norm_out(hidden_states)
469
+
470
+ # 4. Modulation
471
+ hidden_states = hidden_states * (1 + scale) + shift
472
+ hidden_states = self.proj_out(hidden_states)
473
+
474
+ # 5. Unpatchify
475
+ hidden_states = hidden_states.reshape(
476
+ batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1
477
+ )
478
+ hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
479
+ output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
480
+
481
+ if USE_PEFT_BACKEND:
482
+ # remove `lora_scale` from each PEFT layer
483
+ unscale_lora_layers(self, lora_scale)
484
+
485
+ if not return_dict:
486
+ return (output,)
487
+
488
+ return Transformer2DModelOutput(sample=output)
@@ -414,7 +414,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
414
414
  attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
415
415
 
416
416
  for block in self.transformer_blocks:
417
- if self.training and self.gradient_checkpointing:
417
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
418
418
 
419
419
  def create_custom_forward(module, return_dict=None):
420
420
  def custom_forward(*inputs):
@@ -415,7 +415,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
415
415
 
416
416
  # 2. Blocks
417
417
  for block in self.transformer_blocks:
418
- if self.training and self.gradient_checkpointing:
418
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
419
419
 
420
420
  def create_custom_forward(module, return_dict=None):
421
421
  def custom_forward(*inputs):