diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +72 -26
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,499 @@
|
|
1
|
+
# Copyright 2024 The Genmo team and The HuggingFace Team.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from typing import Any, Dict, Optional, Tuple
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
|
21
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
+
from ...loaders import PeftAdapterMixin
|
23
|
+
from ...loaders.single_file_model import FromOriginalModelMixin
|
24
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
25
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
26
|
+
from ..attention import FeedForward
|
27
|
+
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
28
|
+
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
29
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
30
|
+
from ..modeling_utils import ModelMixin
|
31
|
+
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35
|
+
|
36
|
+
|
37
|
+
class MochiModulatedRMSNorm(nn.Module):
|
38
|
+
def __init__(self, eps: float):
|
39
|
+
super().__init__()
|
40
|
+
|
41
|
+
self.eps = eps
|
42
|
+
self.norm = RMSNorm(0, eps, False)
|
43
|
+
|
44
|
+
def forward(self, hidden_states, scale=None):
|
45
|
+
hidden_states_dtype = hidden_states.dtype
|
46
|
+
hidden_states = hidden_states.to(torch.float32)
|
47
|
+
|
48
|
+
hidden_states = self.norm(hidden_states)
|
49
|
+
|
50
|
+
if scale is not None:
|
51
|
+
hidden_states = hidden_states * scale
|
52
|
+
|
53
|
+
hidden_states = hidden_states.to(hidden_states_dtype)
|
54
|
+
|
55
|
+
return hidden_states
|
56
|
+
|
57
|
+
|
58
|
+
class MochiLayerNormContinuous(nn.Module):
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
embedding_dim: int,
|
62
|
+
conditioning_embedding_dim: int,
|
63
|
+
eps=1e-5,
|
64
|
+
bias=True,
|
65
|
+
):
|
66
|
+
super().__init__()
|
67
|
+
|
68
|
+
# AdaLN
|
69
|
+
self.silu = nn.SiLU()
|
70
|
+
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
71
|
+
self.norm = MochiModulatedRMSNorm(eps=eps)
|
72
|
+
|
73
|
+
def forward(
|
74
|
+
self,
|
75
|
+
x: torch.Tensor,
|
76
|
+
conditioning_embedding: torch.Tensor,
|
77
|
+
) -> torch.Tensor:
|
78
|
+
input_dtype = x.dtype
|
79
|
+
|
80
|
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
81
|
+
scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
82
|
+
x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))
|
83
|
+
|
84
|
+
return x.to(input_dtype)
|
85
|
+
|
86
|
+
|
87
|
+
class MochiRMSNormZero(nn.Module):
|
88
|
+
r"""
|
89
|
+
Adaptive RMS Norm used in Mochi.
|
90
|
+
|
91
|
+
Parameters:
|
92
|
+
embedding_dim (`int`): The size of each embedding vector.
|
93
|
+
"""
|
94
|
+
|
95
|
+
def __init__(
|
96
|
+
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
|
97
|
+
) -> None:
|
98
|
+
super().__init__()
|
99
|
+
|
100
|
+
self.silu = nn.SiLU()
|
101
|
+
self.linear = nn.Linear(embedding_dim, hidden_dim)
|
102
|
+
self.norm = RMSNorm(0, eps, False)
|
103
|
+
|
104
|
+
def forward(
|
105
|
+
self, hidden_states: torch.Tensor, emb: torch.Tensor
|
106
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
107
|
+
hidden_states_dtype = hidden_states.dtype
|
108
|
+
|
109
|
+
emb = self.linear(self.silu(emb))
|
110
|
+
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
111
|
+
hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32))
|
112
|
+
hidden_states = hidden_states.to(hidden_states_dtype)
|
113
|
+
|
114
|
+
return hidden_states, gate_msa, scale_mlp, gate_mlp
|
115
|
+
|
116
|
+
|
117
|
+
@maybe_allow_in_graph
|
118
|
+
class MochiTransformerBlock(nn.Module):
|
119
|
+
r"""
|
120
|
+
Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
121
|
+
|
122
|
+
Args:
|
123
|
+
dim (`int`):
|
124
|
+
The number of channels in the input and output.
|
125
|
+
num_attention_heads (`int`):
|
126
|
+
The number of heads to use for multi-head attention.
|
127
|
+
attention_head_dim (`int`):
|
128
|
+
The number of channels in each head.
|
129
|
+
qk_norm (`str`, defaults to `"rms_norm"`):
|
130
|
+
The normalization layer to use.
|
131
|
+
activation_fn (`str`, defaults to `"swiglu"`):
|
132
|
+
Activation function to use in feed-forward.
|
133
|
+
context_pre_only (`bool`, defaults to `False`):
|
134
|
+
Whether or not to process context-related conditions with additional layers.
|
135
|
+
eps (`float`, defaults to `1e-6`):
|
136
|
+
Epsilon value for normalization layers.
|
137
|
+
"""
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
dim: int,
|
142
|
+
num_attention_heads: int,
|
143
|
+
attention_head_dim: int,
|
144
|
+
pooled_projection_dim: int,
|
145
|
+
qk_norm: str = "rms_norm",
|
146
|
+
activation_fn: str = "swiglu",
|
147
|
+
context_pre_only: bool = False,
|
148
|
+
eps: float = 1e-6,
|
149
|
+
) -> None:
|
150
|
+
super().__init__()
|
151
|
+
|
152
|
+
self.context_pre_only = context_pre_only
|
153
|
+
self.ff_inner_dim = (4 * dim * 2) // 3
|
154
|
+
self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
|
155
|
+
|
156
|
+
self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False)
|
157
|
+
|
158
|
+
if not context_pre_only:
|
159
|
+
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
|
160
|
+
else:
|
161
|
+
self.norm1_context = MochiLayerNormContinuous(
|
162
|
+
embedding_dim=pooled_projection_dim,
|
163
|
+
conditioning_embedding_dim=dim,
|
164
|
+
eps=eps,
|
165
|
+
)
|
166
|
+
|
167
|
+
self.attn1 = MochiAttention(
|
168
|
+
query_dim=dim,
|
169
|
+
heads=num_attention_heads,
|
170
|
+
dim_head=attention_head_dim,
|
171
|
+
bias=False,
|
172
|
+
added_kv_proj_dim=pooled_projection_dim,
|
173
|
+
added_proj_bias=False,
|
174
|
+
out_dim=dim,
|
175
|
+
out_context_dim=pooled_projection_dim,
|
176
|
+
context_pre_only=context_pre_only,
|
177
|
+
processor=MochiAttnProcessor2_0(),
|
178
|
+
eps=1e-5,
|
179
|
+
)
|
180
|
+
|
181
|
+
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
|
182
|
+
self.norm2 = MochiModulatedRMSNorm(eps=eps)
|
183
|
+
self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
|
184
|
+
|
185
|
+
self.norm3 = MochiModulatedRMSNorm(eps)
|
186
|
+
self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
|
187
|
+
|
188
|
+
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
|
189
|
+
self.ff_context = None
|
190
|
+
if not context_pre_only:
|
191
|
+
self.ff_context = FeedForward(
|
192
|
+
pooled_projection_dim,
|
193
|
+
inner_dim=self.ff_context_inner_dim,
|
194
|
+
activation_fn=activation_fn,
|
195
|
+
bias=False,
|
196
|
+
)
|
197
|
+
|
198
|
+
self.norm4 = MochiModulatedRMSNorm(eps=eps)
|
199
|
+
self.norm4_context = MochiModulatedRMSNorm(eps=eps)
|
200
|
+
|
201
|
+
def forward(
|
202
|
+
self,
|
203
|
+
hidden_states: torch.Tensor,
|
204
|
+
encoder_hidden_states: torch.Tensor,
|
205
|
+
temb: torch.Tensor,
|
206
|
+
encoder_attention_mask: torch.Tensor,
|
207
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
208
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
209
|
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
210
|
+
|
211
|
+
if not self.context_pre_only:
|
212
|
+
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
|
213
|
+
encoder_hidden_states, temb
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
217
|
+
|
218
|
+
attn_hidden_states, context_attn_hidden_states = self.attn1(
|
219
|
+
hidden_states=norm_hidden_states,
|
220
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
221
|
+
image_rotary_emb=image_rotary_emb,
|
222
|
+
attention_mask=encoder_attention_mask,
|
223
|
+
)
|
224
|
+
|
225
|
+
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
|
226
|
+
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
|
227
|
+
ff_output = self.ff(norm_hidden_states)
|
228
|
+
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
|
229
|
+
|
230
|
+
if not self.context_pre_only:
|
231
|
+
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
232
|
+
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
|
233
|
+
)
|
234
|
+
norm_encoder_hidden_states = self.norm3_context(
|
235
|
+
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
|
236
|
+
)
|
237
|
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
238
|
+
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
|
239
|
+
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
|
240
|
+
)
|
241
|
+
|
242
|
+
return hidden_states, encoder_hidden_states
|
243
|
+
|
244
|
+
|
245
|
+
class MochiRoPE(nn.Module):
|
246
|
+
r"""
|
247
|
+
RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
248
|
+
|
249
|
+
Args:
|
250
|
+
base_height (`int`, defaults to `192`):
|
251
|
+
Base height used to compute interpolation scale for rotary positional embeddings.
|
252
|
+
base_width (`int`, defaults to `192`):
|
253
|
+
Base width used to compute interpolation scale for rotary positional embeddings.
|
254
|
+
"""
|
255
|
+
|
256
|
+
def __init__(self, base_height: int = 192, base_width: int = 192) -> None:
|
257
|
+
super().__init__()
|
258
|
+
|
259
|
+
self.target_area = base_height * base_width
|
260
|
+
|
261
|
+
def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
|
262
|
+
edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
|
263
|
+
return (edges[:-1] + edges[1:]) / 2
|
264
|
+
|
265
|
+
def _get_positions(
|
266
|
+
self,
|
267
|
+
num_frames: int,
|
268
|
+
height: int,
|
269
|
+
width: int,
|
270
|
+
device: Optional[torch.device] = None,
|
271
|
+
dtype: Optional[torch.dtype] = None,
|
272
|
+
) -> torch.Tensor:
|
273
|
+
scale = (self.target_area / (height * width)) ** 0.5
|
274
|
+
|
275
|
+
t = torch.arange(num_frames, device=device, dtype=dtype)
|
276
|
+
h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
|
277
|
+
w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
|
278
|
+
|
279
|
+
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
280
|
+
|
281
|
+
positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3)
|
282
|
+
return positions
|
283
|
+
|
284
|
+
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
285
|
+
with torch.autocast(freqs.device.type, torch.float32):
|
286
|
+
# Always run ROPE freqs computation in FP32
|
287
|
+
freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32))
|
288
|
+
|
289
|
+
freqs_cos = torch.cos(freqs)
|
290
|
+
freqs_sin = torch.sin(freqs)
|
291
|
+
return freqs_cos, freqs_sin
|
292
|
+
|
293
|
+
def forward(
|
294
|
+
self,
|
295
|
+
pos_frequencies: torch.Tensor,
|
296
|
+
num_frames: int,
|
297
|
+
height: int,
|
298
|
+
width: int,
|
299
|
+
device: Optional[torch.device] = None,
|
300
|
+
dtype: Optional[torch.dtype] = None,
|
301
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
302
|
+
pos = self._get_positions(num_frames, height, width, device, dtype)
|
303
|
+
rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
|
304
|
+
return rope_cos, rope_sin
|
305
|
+
|
306
|
+
|
307
|
+
@maybe_allow_in_graph
|
308
|
+
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
309
|
+
r"""
|
310
|
+
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
311
|
+
|
312
|
+
Args:
|
313
|
+
patch_size (`int`, defaults to `2`):
|
314
|
+
The size of the patches to use in the patch embedding layer.
|
315
|
+
num_attention_heads (`int`, defaults to `24`):
|
316
|
+
The number of heads to use for multi-head attention.
|
317
|
+
attention_head_dim (`int`, defaults to `128`):
|
318
|
+
The number of channels in each head.
|
319
|
+
num_layers (`int`, defaults to `48`):
|
320
|
+
The number of layers of Transformer blocks to use.
|
321
|
+
in_channels (`int`, defaults to `12`):
|
322
|
+
The number of channels in the input.
|
323
|
+
out_channels (`int`, *optional*, defaults to `None`):
|
324
|
+
The number of channels in the output.
|
325
|
+
qk_norm (`str`, defaults to `"rms_norm"`):
|
326
|
+
The normalization layer to use.
|
327
|
+
text_embed_dim (`int`, defaults to `4096`):
|
328
|
+
Input dimension of text embeddings from the text encoder.
|
329
|
+
time_embed_dim (`int`, defaults to `256`):
|
330
|
+
Output dimension of timestep embeddings.
|
331
|
+
activation_fn (`str`, defaults to `"swiglu"`):
|
332
|
+
Activation function to use in feed-forward.
|
333
|
+
max_sequence_length (`int`, defaults to `256`):
|
334
|
+
The maximum sequence length of text embeddings supported.
|
335
|
+
"""
|
336
|
+
|
337
|
+
_supports_gradient_checkpointing = True
|
338
|
+
_no_split_modules = ["MochiTransformerBlock"]
|
339
|
+
|
340
|
+
@register_to_config
|
341
|
+
def __init__(
|
342
|
+
self,
|
343
|
+
patch_size: int = 2,
|
344
|
+
num_attention_heads: int = 24,
|
345
|
+
attention_head_dim: int = 128,
|
346
|
+
num_layers: int = 48,
|
347
|
+
pooled_projection_dim: int = 1536,
|
348
|
+
in_channels: int = 12,
|
349
|
+
out_channels: Optional[int] = None,
|
350
|
+
qk_norm: str = "rms_norm",
|
351
|
+
text_embed_dim: int = 4096,
|
352
|
+
time_embed_dim: int = 256,
|
353
|
+
activation_fn: str = "swiglu",
|
354
|
+
max_sequence_length: int = 256,
|
355
|
+
) -> None:
|
356
|
+
super().__init__()
|
357
|
+
|
358
|
+
inner_dim = num_attention_heads * attention_head_dim
|
359
|
+
out_channels = out_channels or in_channels
|
360
|
+
|
361
|
+
self.patch_embed = PatchEmbed(
|
362
|
+
patch_size=patch_size,
|
363
|
+
in_channels=in_channels,
|
364
|
+
embed_dim=inner_dim,
|
365
|
+
pos_embed_type=None,
|
366
|
+
)
|
367
|
+
|
368
|
+
self.time_embed = MochiCombinedTimestepCaptionEmbedding(
|
369
|
+
embedding_dim=inner_dim,
|
370
|
+
pooled_projection_dim=pooled_projection_dim,
|
371
|
+
text_embed_dim=text_embed_dim,
|
372
|
+
time_embed_dim=time_embed_dim,
|
373
|
+
num_attention_heads=8,
|
374
|
+
)
|
375
|
+
|
376
|
+
self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0))
|
377
|
+
self.rope = MochiRoPE()
|
378
|
+
|
379
|
+
self.transformer_blocks = nn.ModuleList(
|
380
|
+
[
|
381
|
+
MochiTransformerBlock(
|
382
|
+
dim=inner_dim,
|
383
|
+
num_attention_heads=num_attention_heads,
|
384
|
+
attention_head_dim=attention_head_dim,
|
385
|
+
pooled_projection_dim=pooled_projection_dim,
|
386
|
+
qk_norm=qk_norm,
|
387
|
+
activation_fn=activation_fn,
|
388
|
+
context_pre_only=i == num_layers - 1,
|
389
|
+
)
|
390
|
+
for i in range(num_layers)
|
391
|
+
]
|
392
|
+
)
|
393
|
+
|
394
|
+
self.norm_out = AdaLayerNormContinuous(
|
395
|
+
inner_dim,
|
396
|
+
inner_dim,
|
397
|
+
elementwise_affine=False,
|
398
|
+
eps=1e-6,
|
399
|
+
norm_type="layer_norm",
|
400
|
+
)
|
401
|
+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
402
|
+
|
403
|
+
self.gradient_checkpointing = False
|
404
|
+
|
405
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
406
|
+
if hasattr(module, "gradient_checkpointing"):
|
407
|
+
module.gradient_checkpointing = value
|
408
|
+
|
409
|
+
def forward(
|
410
|
+
self,
|
411
|
+
hidden_states: torch.Tensor,
|
412
|
+
encoder_hidden_states: torch.Tensor,
|
413
|
+
timestep: torch.LongTensor,
|
414
|
+
encoder_attention_mask: torch.Tensor,
|
415
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
416
|
+
return_dict: bool = True,
|
417
|
+
) -> torch.Tensor:
|
418
|
+
if attention_kwargs is not None:
|
419
|
+
attention_kwargs = attention_kwargs.copy()
|
420
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
421
|
+
else:
|
422
|
+
lora_scale = 1.0
|
423
|
+
|
424
|
+
if USE_PEFT_BACKEND:
|
425
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
426
|
+
scale_lora_layers(self, lora_scale)
|
427
|
+
else:
|
428
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
429
|
+
logger.warning(
|
430
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
431
|
+
)
|
432
|
+
|
433
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
434
|
+
p = self.config.patch_size
|
435
|
+
|
436
|
+
post_patch_height = height // p
|
437
|
+
post_patch_width = width // p
|
438
|
+
|
439
|
+
temb, encoder_hidden_states = self.time_embed(
|
440
|
+
timestep,
|
441
|
+
encoder_hidden_states,
|
442
|
+
encoder_attention_mask,
|
443
|
+
hidden_dtype=hidden_states.dtype,
|
444
|
+
)
|
445
|
+
|
446
|
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
447
|
+
hidden_states = self.patch_embed(hidden_states)
|
448
|
+
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
|
449
|
+
|
450
|
+
image_rotary_emb = self.rope(
|
451
|
+
self.pos_frequencies,
|
452
|
+
num_frames,
|
453
|
+
post_patch_height,
|
454
|
+
post_patch_width,
|
455
|
+
device=hidden_states.device,
|
456
|
+
dtype=torch.float32,
|
457
|
+
)
|
458
|
+
|
459
|
+
for i, block in enumerate(self.transformer_blocks):
|
460
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
461
|
+
|
462
|
+
def create_custom_forward(module):
|
463
|
+
def custom_forward(*inputs):
|
464
|
+
return module(*inputs)
|
465
|
+
|
466
|
+
return custom_forward
|
467
|
+
|
468
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
469
|
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
470
|
+
create_custom_forward(block),
|
471
|
+
hidden_states,
|
472
|
+
encoder_hidden_states,
|
473
|
+
temb,
|
474
|
+
encoder_attention_mask,
|
475
|
+
image_rotary_emb,
|
476
|
+
**ckpt_kwargs,
|
477
|
+
)
|
478
|
+
else:
|
479
|
+
hidden_states, encoder_hidden_states = block(
|
480
|
+
hidden_states=hidden_states,
|
481
|
+
encoder_hidden_states=encoder_hidden_states,
|
482
|
+
temb=temb,
|
483
|
+
encoder_attention_mask=encoder_attention_mask,
|
484
|
+
image_rotary_emb=image_rotary_emb,
|
485
|
+
)
|
486
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
487
|
+
hidden_states = self.proj_out(hidden_states)
|
488
|
+
|
489
|
+
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
|
490
|
+
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
|
491
|
+
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
|
492
|
+
|
493
|
+
if USE_PEFT_BACKEND:
|
494
|
+
# remove `lora_scale` from each PEFT layer
|
495
|
+
unscale_lora_layers(self, lora_scale)
|
496
|
+
|
497
|
+
if not return_dict:
|
498
|
+
return (output,)
|
499
|
+
return Transformer2DModelOutput(sample=output)
|
@@ -11,20 +11,25 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
|
-
|
16
14
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
17
15
|
|
18
16
|
import torch
|
19
17
|
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
20
19
|
|
21
20
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
-
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
23
|
-
from ...models.attention import JointTransformerBlock
|
24
|
-
from ...models.attention_processor import
|
21
|
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
22
|
+
from ...models.attention import FeedForward, JointTransformerBlock
|
23
|
+
from ...models.attention_processor import (
|
24
|
+
Attention,
|
25
|
+
AttentionProcessor,
|
26
|
+
FusedJointAttnProcessor2_0,
|
27
|
+
JointAttnProcessor2_0,
|
28
|
+
)
|
25
29
|
from ...models.modeling_utils import ModelMixin
|
26
|
-
from ...models.normalization import AdaLayerNormContinuous
|
30
|
+
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
27
31
|
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
32
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
28
33
|
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
29
34
|
from ..modeling_outputs import Transformer2DModelOutput
|
30
35
|
|
@@ -32,7 +37,75 @@ from ..modeling_outputs import Transformer2DModelOutput
|
|
32
37
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
33
38
|
|
34
39
|
|
35
|
-
|
40
|
+
@maybe_allow_in_graph
|
41
|
+
class SD3SingleTransformerBlock(nn.Module):
|
42
|
+
r"""
|
43
|
+
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
|
44
|
+
|
45
|
+
Reference: https://arxiv.org/abs/2403.03206
|
46
|
+
|
47
|
+
Parameters:
|
48
|
+
dim (`int`): The number of channels in the input and output.
|
49
|
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
50
|
+
attention_head_dim (`int`): The number of channels in each head.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
dim: int,
|
56
|
+
num_attention_heads: int,
|
57
|
+
attention_head_dim: int,
|
58
|
+
):
|
59
|
+
super().__init__()
|
60
|
+
|
61
|
+
self.norm1 = AdaLayerNormZero(dim)
|
62
|
+
|
63
|
+
if hasattr(F, "scaled_dot_product_attention"):
|
64
|
+
processor = JointAttnProcessor2_0()
|
65
|
+
else:
|
66
|
+
raise ValueError(
|
67
|
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
68
|
+
)
|
69
|
+
|
70
|
+
self.attn = Attention(
|
71
|
+
query_dim=dim,
|
72
|
+
dim_head=attention_head_dim,
|
73
|
+
heads=num_attention_heads,
|
74
|
+
out_dim=dim,
|
75
|
+
bias=True,
|
76
|
+
processor=processor,
|
77
|
+
eps=1e-6,
|
78
|
+
)
|
79
|
+
|
80
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
81
|
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
82
|
+
|
83
|
+
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
|
84
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
85
|
+
# Attention.
|
86
|
+
attn_output = self.attn(
|
87
|
+
hidden_states=norm_hidden_states,
|
88
|
+
encoder_hidden_states=None,
|
89
|
+
)
|
90
|
+
|
91
|
+
# Process attention outputs for the `hidden_states`.
|
92
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
93
|
+
hidden_states = hidden_states + attn_output
|
94
|
+
|
95
|
+
norm_hidden_states = self.norm2(hidden_states)
|
96
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
97
|
+
|
98
|
+
ff_output = self.ff(norm_hidden_states)
|
99
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
100
|
+
|
101
|
+
hidden_states = hidden_states + ff_output
|
102
|
+
|
103
|
+
return hidden_states
|
104
|
+
|
105
|
+
|
106
|
+
class SD3Transformer2DModel(
|
107
|
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
|
108
|
+
):
|
36
109
|
"""
|
37
110
|
The Transformer model introduced in Stable Diffusion 3.
|
38
111
|
|
@@ -268,6 +341,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
268
341
|
block_controlnet_hidden_states: List = None,
|
269
342
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
270
343
|
return_dict: bool = True,
|
344
|
+
skip_layers: Optional[List[int]] = None,
|
271
345
|
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
272
346
|
"""
|
273
347
|
The [`SD3Transformer2DModel`] forward method.
|
@@ -277,11 +351,11 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
277
351
|
Input `hidden_states`.
|
278
352
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
279
353
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
280
|
-
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
|
281
|
-
from the embeddings of input conditions.
|
282
|
-
timestep (
|
354
|
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
|
355
|
+
Embeddings projected from the embeddings of input conditions.
|
356
|
+
timestep (`torch.LongTensor`):
|
283
357
|
Used to indicate denoising step.
|
284
|
-
block_controlnet_hidden_states
|
358
|
+
block_controlnet_hidden_states (`list` of `torch.Tensor`):
|
285
359
|
A list of tensors that if specified are added to the residuals of transformer blocks.
|
286
360
|
joint_attention_kwargs (`dict`, *optional*):
|
287
361
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
@@ -290,6 +364,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
290
364
|
return_dict (`bool`, *optional*, defaults to `True`):
|
291
365
|
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
292
366
|
tuple.
|
367
|
+
skip_layers (`list` of `int`, *optional*):
|
368
|
+
A list of layer indices to skip during the forward pass.
|
293
369
|
|
294
370
|
Returns:
|
295
371
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
@@ -316,8 +392,17 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
316
392
|
temb = self.time_text_embed(timestep, pooled_projections)
|
317
393
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
318
394
|
|
395
|
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
396
|
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
397
|
+
ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
|
398
|
+
|
399
|
+
joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
|
400
|
+
|
319
401
|
for index_block, block in enumerate(self.transformer_blocks):
|
320
|
-
|
402
|
+
# Skip specified layers
|
403
|
+
is_skip = True if skip_layers is not None and index_block in skip_layers else False
|
404
|
+
|
405
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
|
321
406
|
|
322
407
|
def create_custom_forward(module, return_dict=None):
|
323
408
|
def custom_forward(*inputs):
|
@@ -334,18 +419,21 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
334
419
|
hidden_states,
|
335
420
|
encoder_hidden_states,
|
336
421
|
temb,
|
422
|
+
joint_attention_kwargs,
|
337
423
|
**ckpt_kwargs,
|
338
424
|
)
|
339
|
-
|
340
|
-
else:
|
425
|
+
elif not is_skip:
|
341
426
|
encoder_hidden_states, hidden_states = block(
|
342
|
-
hidden_states=hidden_states,
|
427
|
+
hidden_states=hidden_states,
|
428
|
+
encoder_hidden_states=encoder_hidden_states,
|
429
|
+
temb=temb,
|
430
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
343
431
|
)
|
344
432
|
|
345
433
|
# controlnet residual
|
346
434
|
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
|
347
|
-
interval_control = len(self.transformer_blocks)
|
348
|
-
hidden_states = hidden_states + block_controlnet_hidden_states[index_block
|
435
|
+
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
|
436
|
+
hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
|
349
437
|
|
350
438
|
hidden_states = self.norm_out(hidden_states, temb)
|
351
439
|
hidden_states = self.proj_out(hidden_states)
|
@@ -340,7 +340,7 @@ class TransformerSpatioTemporalModel(nn.Module):
|
|
340
340
|
|
341
341
|
# 2. Blocks
|
342
342
|
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
343
|
-
if
|
343
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
344
344
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
345
345
|
block,
|
346
346
|
hidden_states,
|