diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +72 -26
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -156,9 +156,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
156
156
|
|
157
157
|
# define temporal positional embedding
|
158
158
|
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
|
159
|
-
inner_dim, torch.arange(0, video_length).unsqueeze(1)
|
159
|
+
inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
|
160
160
|
) # 1152 hidden size
|
161
|
-
self.register_buffer("temp_pos_embed",
|
161
|
+
self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
|
162
162
|
|
163
163
|
self.gradient_checkpointing = False
|
164
164
|
|
@@ -238,7 +238,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
238
238
|
for i, (spatial_block, temp_block) in enumerate(
|
239
239
|
zip(self.transformer_blocks, self.temporal_transformer_blocks)
|
240
240
|
):
|
241
|
-
if
|
241
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
242
242
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
243
243
|
spatial_block,
|
244
244
|
hidden_states,
|
@@ -271,7 +271,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
271
271
|
if i == 0 and num_frame > 1:
|
272
272
|
hidden_states = hidden_states + self.temp_pos_embed
|
273
273
|
|
274
|
-
if
|
274
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
275
275
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
276
276
|
temp_block,
|
277
277
|
hidden_states,
|
@@ -386,7 +386,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
|
386
386
|
|
387
387
|
# 2. Blocks
|
388
388
|
for block in self.transformer_blocks:
|
389
|
-
if
|
389
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
390
390
|
|
391
391
|
def create_custom_forward(module, return_dict=None):
|
392
392
|
def custom_forward(*inputs):
|
@@ -0,0 +1,488 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch import nn
|
19
|
+
|
20
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
21
|
+
from ...loaders import PeftAdapterMixin
|
22
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
23
|
+
from ..attention_processor import (
|
24
|
+
Attention,
|
25
|
+
AttentionProcessor,
|
26
|
+
AttnProcessor2_0,
|
27
|
+
SanaLinearAttnProcessor2_0,
|
28
|
+
)
|
29
|
+
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
30
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
31
|
+
from ..modeling_utils import ModelMixin
|
32
|
+
from ..normalization import AdaLayerNormSingle, RMSNorm
|
33
|
+
|
34
|
+
|
35
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36
|
+
|
37
|
+
|
38
|
+
class GLUMBConv(nn.Module):
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
in_channels: int,
|
42
|
+
out_channels: int,
|
43
|
+
expand_ratio: float = 4,
|
44
|
+
norm_type: Optional[str] = None,
|
45
|
+
residual_connection: bool = True,
|
46
|
+
) -> None:
|
47
|
+
super().__init__()
|
48
|
+
|
49
|
+
hidden_channels = int(expand_ratio * in_channels)
|
50
|
+
self.norm_type = norm_type
|
51
|
+
self.residual_connection = residual_connection
|
52
|
+
|
53
|
+
self.nonlinearity = nn.SiLU()
|
54
|
+
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
55
|
+
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
56
|
+
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
57
|
+
|
58
|
+
self.norm = None
|
59
|
+
if norm_type == "rms_norm":
|
60
|
+
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
61
|
+
|
62
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
63
|
+
if self.residual_connection:
|
64
|
+
residual = hidden_states
|
65
|
+
|
66
|
+
hidden_states = self.conv_inverted(hidden_states)
|
67
|
+
hidden_states = self.nonlinearity(hidden_states)
|
68
|
+
|
69
|
+
hidden_states = self.conv_depth(hidden_states)
|
70
|
+
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
71
|
+
hidden_states = hidden_states * self.nonlinearity(gate)
|
72
|
+
|
73
|
+
hidden_states = self.conv_point(hidden_states)
|
74
|
+
|
75
|
+
if self.norm_type == "rms_norm":
|
76
|
+
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
77
|
+
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
78
|
+
|
79
|
+
if self.residual_connection:
|
80
|
+
hidden_states = hidden_states + residual
|
81
|
+
|
82
|
+
return hidden_states
|
83
|
+
|
84
|
+
|
85
|
+
class SanaTransformerBlock(nn.Module):
|
86
|
+
r"""
|
87
|
+
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
|
88
|
+
"""
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
dim: int = 2240,
|
93
|
+
num_attention_heads: int = 70,
|
94
|
+
attention_head_dim: int = 32,
|
95
|
+
dropout: float = 0.0,
|
96
|
+
num_cross_attention_heads: Optional[int] = 20,
|
97
|
+
cross_attention_head_dim: Optional[int] = 112,
|
98
|
+
cross_attention_dim: Optional[int] = 2240,
|
99
|
+
attention_bias: bool = True,
|
100
|
+
norm_elementwise_affine: bool = False,
|
101
|
+
norm_eps: float = 1e-6,
|
102
|
+
attention_out_bias: bool = True,
|
103
|
+
mlp_ratio: float = 2.5,
|
104
|
+
) -> None:
|
105
|
+
super().__init__()
|
106
|
+
|
107
|
+
# 1. Self Attention
|
108
|
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
|
109
|
+
self.attn1 = Attention(
|
110
|
+
query_dim=dim,
|
111
|
+
heads=num_attention_heads,
|
112
|
+
dim_head=attention_head_dim,
|
113
|
+
dropout=dropout,
|
114
|
+
bias=attention_bias,
|
115
|
+
cross_attention_dim=None,
|
116
|
+
processor=SanaLinearAttnProcessor2_0(),
|
117
|
+
)
|
118
|
+
|
119
|
+
# 2. Cross Attention
|
120
|
+
if cross_attention_dim is not None:
|
121
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
122
|
+
self.attn2 = Attention(
|
123
|
+
query_dim=dim,
|
124
|
+
cross_attention_dim=cross_attention_dim,
|
125
|
+
heads=num_cross_attention_heads,
|
126
|
+
dim_head=cross_attention_head_dim,
|
127
|
+
dropout=dropout,
|
128
|
+
bias=True,
|
129
|
+
out_bias=attention_out_bias,
|
130
|
+
processor=AttnProcessor2_0(),
|
131
|
+
)
|
132
|
+
|
133
|
+
# 3. Feed-forward
|
134
|
+
self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
|
135
|
+
|
136
|
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
137
|
+
|
138
|
+
def forward(
|
139
|
+
self,
|
140
|
+
hidden_states: torch.Tensor,
|
141
|
+
attention_mask: Optional[torch.Tensor] = None,
|
142
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
143
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
144
|
+
timestep: Optional[torch.LongTensor] = None,
|
145
|
+
height: int = None,
|
146
|
+
width: int = None,
|
147
|
+
) -> torch.Tensor:
|
148
|
+
batch_size = hidden_states.shape[0]
|
149
|
+
|
150
|
+
# 1. Modulation
|
151
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
152
|
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
153
|
+
).chunk(6, dim=1)
|
154
|
+
|
155
|
+
# 2. Self Attention
|
156
|
+
norm_hidden_states = self.norm1(hidden_states)
|
157
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
158
|
+
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
|
159
|
+
|
160
|
+
attn_output = self.attn1(norm_hidden_states)
|
161
|
+
hidden_states = hidden_states + gate_msa * attn_output
|
162
|
+
|
163
|
+
# 3. Cross Attention
|
164
|
+
if self.attn2 is not None:
|
165
|
+
attn_output = self.attn2(
|
166
|
+
hidden_states,
|
167
|
+
encoder_hidden_states=encoder_hidden_states,
|
168
|
+
attention_mask=encoder_attention_mask,
|
169
|
+
)
|
170
|
+
hidden_states = attn_output + hidden_states
|
171
|
+
|
172
|
+
# 4. Feed-forward
|
173
|
+
norm_hidden_states = self.norm2(hidden_states)
|
174
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
175
|
+
|
176
|
+
norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2)
|
177
|
+
ff_output = self.ff(norm_hidden_states)
|
178
|
+
ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)
|
179
|
+
hidden_states = hidden_states + gate_mlp * ff_output
|
180
|
+
|
181
|
+
return hidden_states
|
182
|
+
|
183
|
+
|
184
|
+
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
185
|
+
r"""
|
186
|
+
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
|
187
|
+
|
188
|
+
Args:
|
189
|
+
in_channels (`int`, defaults to `32`):
|
190
|
+
The number of channels in the input.
|
191
|
+
out_channels (`int`, *optional*, defaults to `32`):
|
192
|
+
The number of channels in the output.
|
193
|
+
num_attention_heads (`int`, defaults to `70`):
|
194
|
+
The number of heads to use for multi-head attention.
|
195
|
+
attention_head_dim (`int`, defaults to `32`):
|
196
|
+
The number of channels in each head.
|
197
|
+
num_layers (`int`, defaults to `20`):
|
198
|
+
The number of layers of Transformer blocks to use.
|
199
|
+
num_cross_attention_heads (`int`, *optional*, defaults to `20`):
|
200
|
+
The number of heads to use for cross-attention.
|
201
|
+
cross_attention_head_dim (`int`, *optional*, defaults to `112`):
|
202
|
+
The number of channels in each head for cross-attention.
|
203
|
+
cross_attention_dim (`int`, *optional*, defaults to `2240`):
|
204
|
+
The number of channels in the cross-attention output.
|
205
|
+
caption_channels (`int`, defaults to `2304`):
|
206
|
+
The number of channels in the caption embeddings.
|
207
|
+
mlp_ratio (`float`, defaults to `2.5`):
|
208
|
+
The expansion ratio to use in the GLUMBConv layer.
|
209
|
+
dropout (`float`, defaults to `0.0`):
|
210
|
+
The dropout probability.
|
211
|
+
attention_bias (`bool`, defaults to `False`):
|
212
|
+
Whether to use bias in the attention layer.
|
213
|
+
sample_size (`int`, defaults to `32`):
|
214
|
+
The base size of the input latent.
|
215
|
+
patch_size (`int`, defaults to `1`):
|
216
|
+
The size of the patches to use in the patch embedding layer.
|
217
|
+
norm_elementwise_affine (`bool`, defaults to `False`):
|
218
|
+
Whether to use elementwise affinity in the normalization layer.
|
219
|
+
norm_eps (`float`, defaults to `1e-6`):
|
220
|
+
The epsilon value for the normalization layer.
|
221
|
+
"""
|
222
|
+
|
223
|
+
_supports_gradient_checkpointing = True
|
224
|
+
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
|
225
|
+
|
226
|
+
@register_to_config
|
227
|
+
def __init__(
|
228
|
+
self,
|
229
|
+
in_channels: int = 32,
|
230
|
+
out_channels: Optional[int] = 32,
|
231
|
+
num_attention_heads: int = 70,
|
232
|
+
attention_head_dim: int = 32,
|
233
|
+
num_layers: int = 20,
|
234
|
+
num_cross_attention_heads: Optional[int] = 20,
|
235
|
+
cross_attention_head_dim: Optional[int] = 112,
|
236
|
+
cross_attention_dim: Optional[int] = 2240,
|
237
|
+
caption_channels: int = 2304,
|
238
|
+
mlp_ratio: float = 2.5,
|
239
|
+
dropout: float = 0.0,
|
240
|
+
attention_bias: bool = False,
|
241
|
+
sample_size: int = 32,
|
242
|
+
patch_size: int = 1,
|
243
|
+
norm_elementwise_affine: bool = False,
|
244
|
+
norm_eps: float = 1e-6,
|
245
|
+
interpolation_scale: Optional[int] = None,
|
246
|
+
) -> None:
|
247
|
+
super().__init__()
|
248
|
+
|
249
|
+
out_channels = out_channels or in_channels
|
250
|
+
inner_dim = num_attention_heads * attention_head_dim
|
251
|
+
|
252
|
+
# 1. Patch Embedding
|
253
|
+
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
|
254
|
+
self.patch_embed = PatchEmbed(
|
255
|
+
height=sample_size,
|
256
|
+
width=sample_size,
|
257
|
+
patch_size=patch_size,
|
258
|
+
in_channels=in_channels,
|
259
|
+
embed_dim=inner_dim,
|
260
|
+
interpolation_scale=interpolation_scale,
|
261
|
+
)
|
262
|
+
|
263
|
+
# 2. Additional condition embeddings
|
264
|
+
self.time_embed = AdaLayerNormSingle(inner_dim)
|
265
|
+
|
266
|
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
267
|
+
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
268
|
+
|
269
|
+
# 3. Transformer blocks
|
270
|
+
self.transformer_blocks = nn.ModuleList(
|
271
|
+
[
|
272
|
+
SanaTransformerBlock(
|
273
|
+
inner_dim,
|
274
|
+
num_attention_heads,
|
275
|
+
attention_head_dim,
|
276
|
+
dropout=dropout,
|
277
|
+
num_cross_attention_heads=num_cross_attention_heads,
|
278
|
+
cross_attention_head_dim=cross_attention_head_dim,
|
279
|
+
cross_attention_dim=cross_attention_dim,
|
280
|
+
attention_bias=attention_bias,
|
281
|
+
norm_elementwise_affine=norm_elementwise_affine,
|
282
|
+
norm_eps=norm_eps,
|
283
|
+
mlp_ratio=mlp_ratio,
|
284
|
+
)
|
285
|
+
for _ in range(num_layers)
|
286
|
+
]
|
287
|
+
)
|
288
|
+
|
289
|
+
# 4. Output blocks
|
290
|
+
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
291
|
+
|
292
|
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
293
|
+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
294
|
+
|
295
|
+
self.gradient_checkpointing = False
|
296
|
+
|
297
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
298
|
+
if hasattr(module, "gradient_checkpointing"):
|
299
|
+
module.gradient_checkpointing = value
|
300
|
+
|
301
|
+
@property
|
302
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
303
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
304
|
+
r"""
|
305
|
+
Returns:
|
306
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
307
|
+
indexed by its weight name.
|
308
|
+
"""
|
309
|
+
# set recursively
|
310
|
+
processors = {}
|
311
|
+
|
312
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
313
|
+
if hasattr(module, "get_processor"):
|
314
|
+
processors[f"{name}.processor"] = module.get_processor()
|
315
|
+
|
316
|
+
for sub_name, child in module.named_children():
|
317
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
318
|
+
|
319
|
+
return processors
|
320
|
+
|
321
|
+
for name, module in self.named_children():
|
322
|
+
fn_recursive_add_processors(name, module, processors)
|
323
|
+
|
324
|
+
return processors
|
325
|
+
|
326
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
327
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
328
|
+
r"""
|
329
|
+
Sets the attention processor to use to compute attention.
|
330
|
+
|
331
|
+
Parameters:
|
332
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
333
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
334
|
+
for **all** `Attention` layers.
|
335
|
+
|
336
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
337
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
338
|
+
|
339
|
+
"""
|
340
|
+
count = len(self.attn_processors.keys())
|
341
|
+
|
342
|
+
if isinstance(processor, dict) and len(processor) != count:
|
343
|
+
raise ValueError(
|
344
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
345
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
346
|
+
)
|
347
|
+
|
348
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
349
|
+
if hasattr(module, "set_processor"):
|
350
|
+
if not isinstance(processor, dict):
|
351
|
+
module.set_processor(processor)
|
352
|
+
else:
|
353
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
354
|
+
|
355
|
+
for sub_name, child in module.named_children():
|
356
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
357
|
+
|
358
|
+
for name, module in self.named_children():
|
359
|
+
fn_recursive_attn_processor(name, module, processor)
|
360
|
+
|
361
|
+
def forward(
|
362
|
+
self,
|
363
|
+
hidden_states: torch.Tensor,
|
364
|
+
encoder_hidden_states: torch.Tensor,
|
365
|
+
timestep: torch.LongTensor,
|
366
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
367
|
+
attention_mask: Optional[torch.Tensor] = None,
|
368
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
369
|
+
return_dict: bool = True,
|
370
|
+
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
371
|
+
if attention_kwargs is not None:
|
372
|
+
attention_kwargs = attention_kwargs.copy()
|
373
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
374
|
+
else:
|
375
|
+
lora_scale = 1.0
|
376
|
+
|
377
|
+
if USE_PEFT_BACKEND:
|
378
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
379
|
+
scale_lora_layers(self, lora_scale)
|
380
|
+
else:
|
381
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
382
|
+
logger.warning(
|
383
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
384
|
+
)
|
385
|
+
|
386
|
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
387
|
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
388
|
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
389
|
+
# expects mask of shape:
|
390
|
+
# [batch, key_tokens]
|
391
|
+
# adds singleton query_tokens dimension:
|
392
|
+
# [batch, 1, key_tokens]
|
393
|
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
394
|
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
395
|
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
396
|
+
if attention_mask is not None and attention_mask.ndim == 2:
|
397
|
+
# assume that mask is expressed as:
|
398
|
+
# (1 = keep, 0 = discard)
|
399
|
+
# convert mask into a bias that can be added to attention scores:
|
400
|
+
# (keep = +0, discard = -10000.0)
|
401
|
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
402
|
+
attention_mask = attention_mask.unsqueeze(1)
|
403
|
+
|
404
|
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
405
|
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
406
|
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
407
|
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
408
|
+
|
409
|
+
# 1. Input
|
410
|
+
batch_size, num_channels, height, width = hidden_states.shape
|
411
|
+
p = self.config.patch_size
|
412
|
+
post_patch_height, post_patch_width = height // p, width // p
|
413
|
+
|
414
|
+
hidden_states = self.patch_embed(hidden_states)
|
415
|
+
|
416
|
+
timestep, embedded_timestep = self.time_embed(
|
417
|
+
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
418
|
+
)
|
419
|
+
|
420
|
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
421
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
422
|
+
|
423
|
+
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
|
424
|
+
|
425
|
+
# 2. Transformer blocks
|
426
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
427
|
+
|
428
|
+
def create_custom_forward(module, return_dict=None):
|
429
|
+
def custom_forward(*inputs):
|
430
|
+
if return_dict is not None:
|
431
|
+
return module(*inputs, return_dict=return_dict)
|
432
|
+
else:
|
433
|
+
return module(*inputs)
|
434
|
+
|
435
|
+
return custom_forward
|
436
|
+
|
437
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
438
|
+
|
439
|
+
for block in self.transformer_blocks:
|
440
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
441
|
+
create_custom_forward(block),
|
442
|
+
hidden_states,
|
443
|
+
attention_mask,
|
444
|
+
encoder_hidden_states,
|
445
|
+
encoder_attention_mask,
|
446
|
+
timestep,
|
447
|
+
post_patch_height,
|
448
|
+
post_patch_width,
|
449
|
+
**ckpt_kwargs,
|
450
|
+
)
|
451
|
+
|
452
|
+
else:
|
453
|
+
for block in self.transformer_blocks:
|
454
|
+
hidden_states = block(
|
455
|
+
hidden_states,
|
456
|
+
attention_mask,
|
457
|
+
encoder_hidden_states,
|
458
|
+
encoder_attention_mask,
|
459
|
+
timestep,
|
460
|
+
post_patch_height,
|
461
|
+
post_patch_width,
|
462
|
+
)
|
463
|
+
|
464
|
+
# 3. Normalization
|
465
|
+
shift, scale = (
|
466
|
+
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
|
467
|
+
).chunk(2, dim=1)
|
468
|
+
hidden_states = self.norm_out(hidden_states)
|
469
|
+
|
470
|
+
# 4. Modulation
|
471
|
+
hidden_states = hidden_states * (1 + scale) + shift
|
472
|
+
hidden_states = self.proj_out(hidden_states)
|
473
|
+
|
474
|
+
# 5. Unpatchify
|
475
|
+
hidden_states = hidden_states.reshape(
|
476
|
+
batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1
|
477
|
+
)
|
478
|
+
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
|
479
|
+
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
|
480
|
+
|
481
|
+
if USE_PEFT_BACKEND:
|
482
|
+
# remove `lora_scale` from each PEFT layer
|
483
|
+
unscale_lora_layers(self, lora_scale)
|
484
|
+
|
485
|
+
if not return_dict:
|
486
|
+
return (output,)
|
487
|
+
|
488
|
+
return Transformer2DModelOutput(sample=output)
|
@@ -414,7 +414,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
|
414
414
|
attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
|
415
415
|
|
416
416
|
for block in self.transformer_blocks:
|
417
|
-
if
|
417
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
418
418
|
|
419
419
|
def create_custom_forward(module, return_dict=None):
|
420
420
|
def custom_forward(*inputs):
|
@@ -415,7 +415,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
|
415
415
|
|
416
416
|
# 2. Blocks
|
417
417
|
for block in self.transformer_blocks:
|
418
|
-
if
|
418
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
419
419
|
|
420
420
|
def create_custom_forward(module, return_dict=None):
|
421
421
|
def custom_forward(*inputs):
|