diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -11,20 +11,25 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
|
-
|
16
|
-
from typing import Any, Dict, List, Optional, Union
|
14
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17
15
|
|
18
16
|
import torch
|
19
17
|
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
20
19
|
|
21
20
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
-
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
23
|
-
from ...models.attention import JointTransformerBlock
|
24
|
-
from ...models.attention_processor import
|
21
|
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
22
|
+
from ...models.attention import FeedForward, JointTransformerBlock
|
23
|
+
from ...models.attention_processor import (
|
24
|
+
Attention,
|
25
|
+
AttentionProcessor,
|
26
|
+
FusedJointAttnProcessor2_0,
|
27
|
+
JointAttnProcessor2_0,
|
28
|
+
)
|
25
29
|
from ...models.modeling_utils import ModelMixin
|
26
|
-
from ...models.normalization import AdaLayerNormContinuous
|
30
|
+
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
27
31
|
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
32
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
28
33
|
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
29
34
|
from ..modeling_outputs import Transformer2DModelOutput
|
30
35
|
|
@@ -32,7 +37,75 @@ from ..modeling_outputs import Transformer2DModelOutput
|
|
32
37
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
33
38
|
|
34
39
|
|
35
|
-
|
40
|
+
@maybe_allow_in_graph
|
41
|
+
class SD3SingleTransformerBlock(nn.Module):
|
42
|
+
r"""
|
43
|
+
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
|
44
|
+
|
45
|
+
Reference: https://arxiv.org/abs/2403.03206
|
46
|
+
|
47
|
+
Parameters:
|
48
|
+
dim (`int`): The number of channels in the input and output.
|
49
|
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
50
|
+
attention_head_dim (`int`): The number of channels in each head.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
dim: int,
|
56
|
+
num_attention_heads: int,
|
57
|
+
attention_head_dim: int,
|
58
|
+
):
|
59
|
+
super().__init__()
|
60
|
+
|
61
|
+
self.norm1 = AdaLayerNormZero(dim)
|
62
|
+
|
63
|
+
if hasattr(F, "scaled_dot_product_attention"):
|
64
|
+
processor = JointAttnProcessor2_0()
|
65
|
+
else:
|
66
|
+
raise ValueError(
|
67
|
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
68
|
+
)
|
69
|
+
|
70
|
+
self.attn = Attention(
|
71
|
+
query_dim=dim,
|
72
|
+
dim_head=attention_head_dim,
|
73
|
+
heads=num_attention_heads,
|
74
|
+
out_dim=dim,
|
75
|
+
bias=True,
|
76
|
+
processor=processor,
|
77
|
+
eps=1e-6,
|
78
|
+
)
|
79
|
+
|
80
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
81
|
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
82
|
+
|
83
|
+
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
|
84
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
85
|
+
# Attention.
|
86
|
+
attn_output = self.attn(
|
87
|
+
hidden_states=norm_hidden_states,
|
88
|
+
encoder_hidden_states=None,
|
89
|
+
)
|
90
|
+
|
91
|
+
# Process attention outputs for the `hidden_states`.
|
92
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
93
|
+
hidden_states = hidden_states + attn_output
|
94
|
+
|
95
|
+
norm_hidden_states = self.norm2(hidden_states)
|
96
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
97
|
+
|
98
|
+
ff_output = self.ff(norm_hidden_states)
|
99
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
100
|
+
|
101
|
+
hidden_states = hidden_states + ff_output
|
102
|
+
|
103
|
+
return hidden_states
|
104
|
+
|
105
|
+
|
106
|
+
class SD3Transformer2DModel(
|
107
|
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
|
108
|
+
):
|
36
109
|
"""
|
37
110
|
The Transformer model introduced in Stable Diffusion 3.
|
38
111
|
|
@@ -69,6 +142,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
69
142
|
pooled_projection_dim: int = 2048,
|
70
143
|
out_channels: int = 16,
|
71
144
|
pos_embed_max_size: int = 96,
|
145
|
+
dual_attention_layers: Tuple[
|
146
|
+
int, ...
|
147
|
+
] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
|
148
|
+
qk_norm: Optional[str] = None,
|
72
149
|
):
|
73
150
|
super().__init__()
|
74
151
|
default_out_channels = in_channels
|
@@ -97,6 +174,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
97
174
|
num_attention_heads=self.config.num_attention_heads,
|
98
175
|
attention_head_dim=self.config.attention_head_dim,
|
99
176
|
context_pre_only=i == num_layers - 1,
|
177
|
+
qk_norm=qk_norm,
|
178
|
+
use_dual_attention=True if i in dual_attention_layers else False,
|
100
179
|
)
|
101
180
|
for i in range(self.config.num_layers)
|
102
181
|
]
|
@@ -262,6 +341,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
262
341
|
block_controlnet_hidden_states: List = None,
|
263
342
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
264
343
|
return_dict: bool = True,
|
344
|
+
skip_layers: Optional[List[int]] = None,
|
265
345
|
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
266
346
|
"""
|
267
347
|
The [`SD3Transformer2DModel`] forward method.
|
@@ -271,11 +351,11 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
271
351
|
Input `hidden_states`.
|
272
352
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
273
353
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
274
|
-
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
|
275
|
-
from the embeddings of input conditions.
|
276
|
-
timestep (
|
354
|
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
|
355
|
+
Embeddings projected from the embeddings of input conditions.
|
356
|
+
timestep (`torch.LongTensor`):
|
277
357
|
Used to indicate denoising step.
|
278
|
-
block_controlnet_hidden_states
|
358
|
+
block_controlnet_hidden_states (`list` of `torch.Tensor`):
|
279
359
|
A list of tensors that if specified are added to the residuals of transformer blocks.
|
280
360
|
joint_attention_kwargs (`dict`, *optional*):
|
281
361
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
@@ -284,6 +364,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
284
364
|
return_dict (`bool`, *optional*, defaults to `True`):
|
285
365
|
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
286
366
|
tuple.
|
367
|
+
skip_layers (`list` of `int`, *optional*):
|
368
|
+
A list of layer indices to skip during the forward pass.
|
287
369
|
|
288
370
|
Returns:
|
289
371
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
@@ -310,8 +392,17 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
310
392
|
temb = self.time_text_embed(timestep, pooled_projections)
|
311
393
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
312
394
|
|
395
|
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
396
|
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
397
|
+
ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
|
398
|
+
|
399
|
+
joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
|
400
|
+
|
313
401
|
for index_block, block in enumerate(self.transformer_blocks):
|
314
|
-
|
402
|
+
# Skip specified layers
|
403
|
+
is_skip = True if skip_layers is not None and index_block in skip_layers else False
|
404
|
+
|
405
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
|
315
406
|
|
316
407
|
def create_custom_forward(module, return_dict=None):
|
317
408
|
def custom_forward(*inputs):
|
@@ -328,18 +419,21 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
328
419
|
hidden_states,
|
329
420
|
encoder_hidden_states,
|
330
421
|
temb,
|
422
|
+
joint_attention_kwargs,
|
331
423
|
**ckpt_kwargs,
|
332
424
|
)
|
333
|
-
|
334
|
-
else:
|
425
|
+
elif not is_skip:
|
335
426
|
encoder_hidden_states, hidden_states = block(
|
336
|
-
hidden_states=hidden_states,
|
427
|
+
hidden_states=hidden_states,
|
428
|
+
encoder_hidden_states=encoder_hidden_states,
|
429
|
+
temb=temb,
|
430
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
337
431
|
)
|
338
432
|
|
339
433
|
# controlnet residual
|
340
434
|
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
|
341
|
-
interval_control = len(self.transformer_blocks)
|
342
|
-
hidden_states = hidden_states + block_controlnet_hidden_states[index_block
|
435
|
+
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
|
436
|
+
hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
|
343
437
|
|
344
438
|
hidden_states = self.norm_out(hidden_states, temb)
|
345
439
|
hidden_states = self.proj_out(hidden_states)
|
@@ -340,7 +340,7 @@ class TransformerSpatioTemporalModel(nn.Module):
|
|
340
340
|
|
341
341
|
# 2. Blocks
|
342
342
|
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
343
|
-
if
|
343
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
344
344
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
345
345
|
block,
|
346
346
|
hidden_states,
|
@@ -217,7 +217,7 @@ class MidResTemporalBlock1D(nn.Module):
|
|
217
217
|
if self.upsample:
|
218
218
|
hidden_states = self.upsample(hidden_states)
|
219
219
|
if self.downsample:
|
220
|
-
|
220
|
+
hidden_states = self.downsample(hidden_states)
|
221
221
|
|
222
222
|
return hidden_states
|
223
223
|
|
@@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
89
89
|
conditioning with `class_embed_type` equal to `None`.
|
90
90
|
"""
|
91
91
|
|
92
|
+
_supports_gradient_checkpointing = True
|
93
|
+
|
92
94
|
@register_to_config
|
93
95
|
def __init__(
|
94
96
|
self,
|
@@ -97,6 +99,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
97
99
|
out_channels: int = 3,
|
98
100
|
center_input_sample: bool = False,
|
99
101
|
time_embedding_type: str = "positional",
|
102
|
+
time_embedding_dim: Optional[int] = None,
|
100
103
|
freq_shift: int = 0,
|
101
104
|
flip_sin_to_cos: bool = True,
|
102
105
|
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
@@ -122,7 +125,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
122
125
|
super().__init__()
|
123
126
|
|
124
127
|
self.sample_size = sample_size
|
125
|
-
time_embed_dim = block_out_channels[0] * 4
|
128
|
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
126
129
|
|
127
130
|
# Check inputs
|
128
131
|
if len(down_block_types) != len(up_block_types):
|
@@ -240,6 +243,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
240
243
|
self.conv_act = nn.SiLU()
|
241
244
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
242
245
|
|
246
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
247
|
+
if hasattr(module, "gradient_checkpointing"):
|
248
|
+
module.gradient_checkpointing = value
|
249
|
+
|
243
250
|
def forward(
|
244
251
|
self,
|
245
252
|
sample: torch.Tensor,
|
@@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module):
|
|
731
731
|
self.attentions = nn.ModuleList(attentions)
|
732
732
|
self.resnets = nn.ModuleList(resnets)
|
733
733
|
|
734
|
+
self.gradient_checkpointing = False
|
735
|
+
|
734
736
|
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
735
737
|
hidden_states = self.resnets[0](hidden_states, temb)
|
736
738
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
737
|
-
if
|
738
|
-
|
739
|
-
|
739
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
740
|
+
|
741
|
+
def create_custom_forward(module, return_dict=None):
|
742
|
+
def custom_forward(*inputs):
|
743
|
+
if return_dict is not None:
|
744
|
+
return module(*inputs, return_dict=return_dict)
|
745
|
+
else:
|
746
|
+
return module(*inputs)
|
747
|
+
|
748
|
+
return custom_forward
|
749
|
+
|
750
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
751
|
+
if attn is not None:
|
752
|
+
hidden_states = attn(hidden_states, temb=temb)
|
753
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
754
|
+
create_custom_forward(resnet),
|
755
|
+
hidden_states,
|
756
|
+
temb,
|
757
|
+
**ckpt_kwargs,
|
758
|
+
)
|
759
|
+
else:
|
760
|
+
if attn is not None:
|
761
|
+
hidden_states = attn(hidden_states, temb=temb)
|
762
|
+
hidden_states = resnet(hidden_states, temb)
|
740
763
|
|
741
764
|
return hidden_states
|
742
765
|
|
@@ -859,7 +882,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
859
882
|
|
860
883
|
hidden_states = self.resnets[0](hidden_states, temb)
|
861
884
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
862
|
-
if
|
885
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
863
886
|
|
864
887
|
def create_custom_forward(module, return_dict=None):
|
865
888
|
def custom_forward(*inputs):
|
@@ -1116,6 +1139,8 @@ class AttnDownBlock2D(nn.Module):
|
|
1116
1139
|
else:
|
1117
1140
|
self.downsamplers = None
|
1118
1141
|
|
1142
|
+
self.gradient_checkpointing = False
|
1143
|
+
|
1119
1144
|
def forward(
|
1120
1145
|
self,
|
1121
1146
|
hidden_states: torch.Tensor,
|
@@ -1130,9 +1155,30 @@ class AttnDownBlock2D(nn.Module):
|
|
1130
1155
|
output_states = ()
|
1131
1156
|
|
1132
1157
|
for resnet, attn in zip(self.resnets, self.attentions):
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1158
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1159
|
+
|
1160
|
+
def create_custom_forward(module, return_dict=None):
|
1161
|
+
def custom_forward(*inputs):
|
1162
|
+
if return_dict is not None:
|
1163
|
+
return module(*inputs, return_dict=return_dict)
|
1164
|
+
else:
|
1165
|
+
return module(*inputs)
|
1166
|
+
|
1167
|
+
return custom_forward
|
1168
|
+
|
1169
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1170
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1171
|
+
create_custom_forward(resnet),
|
1172
|
+
hidden_states,
|
1173
|
+
temb,
|
1174
|
+
**ckpt_kwargs,
|
1175
|
+
)
|
1176
|
+
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
1177
|
+
output_states = output_states + (hidden_states,)
|
1178
|
+
else:
|
1179
|
+
hidden_states = resnet(hidden_states, temb)
|
1180
|
+
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
1181
|
+
output_states = output_states + (hidden_states,)
|
1136
1182
|
|
1137
1183
|
if self.downsamplers is not None:
|
1138
1184
|
for downsampler in self.downsamplers:
|
@@ -1257,7 +1303,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1257
1303
|
blocks = list(zip(self.resnets, self.attentions))
|
1258
1304
|
|
1259
1305
|
for i, (resnet, attn) in enumerate(blocks):
|
1260
|
-
if
|
1306
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1261
1307
|
|
1262
1308
|
def create_custom_forward(module, return_dict=None):
|
1263
1309
|
def custom_forward(*inputs):
|
@@ -1371,7 +1417,7 @@ class DownBlock2D(nn.Module):
|
|
1371
1417
|
output_states = ()
|
1372
1418
|
|
1373
1419
|
for resnet in self.resnets:
|
1374
|
-
if
|
1420
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1375
1421
|
|
1376
1422
|
def create_custom_forward(module):
|
1377
1423
|
def custom_forward(*inputs):
|
@@ -1859,7 +1905,7 @@ class ResnetDownsampleBlock2D(nn.Module):
|
|
1859
1905
|
output_states = ()
|
1860
1906
|
|
1861
1907
|
for resnet in self.resnets:
|
1862
|
-
if
|
1908
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1863
1909
|
|
1864
1910
|
def create_custom_forward(module):
|
1865
1911
|
def custom_forward(*inputs):
|
@@ -2011,7 +2057,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
2011
2057
|
mask = attention_mask
|
2012
2058
|
|
2013
2059
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2014
|
-
if
|
2060
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2015
2061
|
|
2016
2062
|
def create_custom_forward(module, return_dict=None):
|
2017
2063
|
def custom_forward(*inputs):
|
@@ -2106,7 +2152,7 @@ class KDownBlock2D(nn.Module):
|
|
2106
2152
|
output_states = ()
|
2107
2153
|
|
2108
2154
|
for resnet in self.resnets:
|
2109
|
-
if
|
2155
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2110
2156
|
|
2111
2157
|
def create_custom_forward(module):
|
2112
2158
|
def custom_forward(*inputs):
|
@@ -2215,7 +2261,7 @@ class KCrossAttnDownBlock2D(nn.Module):
|
|
2215
2261
|
output_states = ()
|
2216
2262
|
|
2217
2263
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2218
|
-
if
|
2264
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2219
2265
|
|
2220
2266
|
def create_custom_forward(module, return_dict=None):
|
2221
2267
|
def custom_forward(*inputs):
|
@@ -2354,6 +2400,7 @@ class AttnUpBlock2D(nn.Module):
|
|
2354
2400
|
else:
|
2355
2401
|
self.upsamplers = None
|
2356
2402
|
|
2403
|
+
self.gradient_checkpointing = False
|
2357
2404
|
self.resolution_idx = resolution_idx
|
2358
2405
|
|
2359
2406
|
def forward(
|
@@ -2375,8 +2422,28 @@ class AttnUpBlock2D(nn.Module):
|
|
2375
2422
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
2376
2423
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2377
2424
|
|
2378
|
-
|
2379
|
-
|
2425
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2426
|
+
|
2427
|
+
def create_custom_forward(module, return_dict=None):
|
2428
|
+
def custom_forward(*inputs):
|
2429
|
+
if return_dict is not None:
|
2430
|
+
return module(*inputs, return_dict=return_dict)
|
2431
|
+
else:
|
2432
|
+
return module(*inputs)
|
2433
|
+
|
2434
|
+
return custom_forward
|
2435
|
+
|
2436
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2437
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2438
|
+
create_custom_forward(resnet),
|
2439
|
+
hidden_states,
|
2440
|
+
temb,
|
2441
|
+
**ckpt_kwargs,
|
2442
|
+
)
|
2443
|
+
hidden_states = attn(hidden_states)
|
2444
|
+
else:
|
2445
|
+
hidden_states = resnet(hidden_states, temb)
|
2446
|
+
hidden_states = attn(hidden_states)
|
2380
2447
|
|
2381
2448
|
if self.upsamplers is not None:
|
2382
2449
|
for upsampler in self.upsamplers:
|
@@ -2520,7 +2587,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2520
2587
|
|
2521
2588
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2522
2589
|
|
2523
|
-
if
|
2590
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2524
2591
|
|
2525
2592
|
def create_custom_forward(module, return_dict=None):
|
2526
2593
|
def custom_forward(*inputs):
|
@@ -2653,7 +2720,7 @@ class UpBlock2D(nn.Module):
|
|
2653
2720
|
|
2654
2721
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2655
2722
|
|
2656
|
-
if
|
2723
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2657
2724
|
|
2658
2725
|
def create_custom_forward(module):
|
2659
2726
|
def custom_forward(*inputs):
|
@@ -3183,7 +3250,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
|
3183
3250
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
3184
3251
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
3185
3252
|
|
3186
|
-
if
|
3253
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3187
3254
|
|
3188
3255
|
def create_custom_forward(module):
|
3189
3256
|
def custom_forward(*inputs):
|
@@ -3341,7 +3408,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
3341
3408
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
3342
3409
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
3343
3410
|
|
3344
|
-
if
|
3411
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3345
3412
|
|
3346
3413
|
def create_custom_forward(module, return_dict=None):
|
3347
3414
|
def custom_forward(*inputs):
|
@@ -3444,7 +3511,7 @@ class KUpBlock2D(nn.Module):
|
|
3444
3511
|
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
3445
3512
|
|
3446
3513
|
for resnet in self.resnets:
|
3447
|
-
if
|
3514
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3448
3515
|
|
3449
3516
|
def create_custom_forward(module):
|
3450
3517
|
def custom_forward(*inputs):
|
@@ -3572,7 +3639,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
3572
3639
|
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
3573
3640
|
|
3574
3641
|
for resnet, attn in zip(self.resnets, self.attentions):
|
3575
|
-
if
|
3642
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3576
3643
|
|
3577
3644
|
def create_custom_forward(module, return_dict=None):
|
3578
3645
|
def custom_forward(*inputs):
|
@@ -170,7 +170,7 @@ class UNet2DConditionModel(
|
|
170
170
|
@register_to_config
|
171
171
|
def __init__(
|
172
172
|
self,
|
173
|
-
sample_size: Optional[int] = None,
|
173
|
+
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
174
174
|
in_channels: int = 4,
|
175
175
|
out_channels: int = 4,
|
176
176
|
center_input_sample: bool = False,
|
@@ -463,7 +463,6 @@ class UNet2DConditionModel(
|
|
463
463
|
dropout=dropout,
|
464
464
|
)
|
465
465
|
self.up_blocks.append(up_block)
|
466
|
-
prev_output_channel = output_channel
|
467
466
|
|
468
467
|
# out
|
469
468
|
if norm_num_groups is not None:
|
@@ -599,7 +598,7 @@ class UNet2DConditionModel(
|
|
599
598
|
)
|
600
599
|
elif encoder_hid_dim_type is not None:
|
601
600
|
raise ValueError(
|
602
|
-
f"encoder_hid_dim_type
|
601
|
+
f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'."
|
603
602
|
)
|
604
603
|
else:
|
605
604
|
self.encoder_hid_proj = None
|
@@ -679,7 +678,9 @@ class UNet2DConditionModel(
|
|
679
678
|
# Kandinsky 2.2 ControlNet
|
680
679
|
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
681
680
|
elif addition_embed_type is not None:
|
682
|
-
raise ValueError(
|
681
|
+
raise ValueError(
|
682
|
+
f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'."
|
683
|
+
)
|
683
684
|
|
684
685
|
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
|
685
686
|
if attention_type in ["gated", "gated-text-image"]:
|
@@ -990,7 +991,7 @@ class UNet2DConditionModel(
|
|
990
991
|
image_embs = added_cond_kwargs.get("image_embeds")
|
991
992
|
aug_emb = self.add_embedding(image_embs)
|
992
993
|
elif self.config.addition_embed_type == "image_hint":
|
993
|
-
# Kandinsky 2.2 - style
|
994
|
+
# Kandinsky 2.2 ControlNet - style
|
994
995
|
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
995
996
|
raise ValueError(
|
996
997
|
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
@@ -1009,7 +1010,7 @@ class UNet2DConditionModel(
|
|
1009
1010
|
# Kandinsky 2.1 - style
|
1010
1011
|
if "image_embeds" not in added_cond_kwargs:
|
1011
1012
|
raise ValueError(
|
1012
|
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in
|
1013
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1013
1014
|
)
|
1014
1015
|
|
1015
1016
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
@@ -1018,14 +1019,14 @@ class UNet2DConditionModel(
|
|
1018
1019
|
# Kandinsky 2.2 - style
|
1019
1020
|
if "image_embeds" not in added_cond_kwargs:
|
1020
1021
|
raise ValueError(
|
1021
|
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in
|
1022
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1022
1023
|
)
|
1023
1024
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
1024
1025
|
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1025
1026
|
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
1026
1027
|
if "image_embeds" not in added_cond_kwargs:
|
1027
1028
|
raise ValueError(
|
1028
|
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in
|
1029
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1029
1030
|
)
|
1030
1031
|
|
1031
1032
|
if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
|
@@ -1140,7 +1141,6 @@ class UNet2DConditionModel(
|
|
1140
1141
|
# 1. time
|
1141
1142
|
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
1142
1143
|
emb = self.time_embedding(t_emb, timestep_cond)
|
1143
|
-
aug_emb = None
|
1144
1144
|
|
1145
1145
|
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
1146
1146
|
if class_emb is not None:
|
@@ -1078,7 +1078,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
|
|
1078
1078
|
)
|
1079
1079
|
|
1080
1080
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
1081
|
-
if
|
1081
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
|
1082
1082
|
|
1083
1083
|
def create_custom_forward(module, return_dict=None):
|
1084
1084
|
def custom_forward(*inputs):
|
@@ -1168,7 +1168,7 @@ class DownBlockSpatioTemporal(nn.Module):
|
|
1168
1168
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1169
1169
|
output_states = ()
|
1170
1170
|
for resnet in self.resnets:
|
1171
|
-
if
|
1171
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1172
1172
|
|
1173
1173
|
def create_custom_forward(module):
|
1174
1174
|
def custom_forward(*inputs):
|
@@ -1281,7 +1281,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
|
|
1281
1281
|
|
1282
1282
|
blocks = list(zip(self.resnets, self.attentions))
|
1283
1283
|
for resnet, attn in blocks:
|
1284
|
-
if
|
1284
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
|
1285
1285
|
|
1286
1286
|
def create_custom_forward(module, return_dict=None):
|
1287
1287
|
def custom_forward(*inputs):
|
@@ -1375,6 +1375,7 @@ class UpBlockSpatioTemporal(nn.Module):
|
|
1375
1375
|
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
1376
1376
|
temb: Optional[torch.Tensor] = None,
|
1377
1377
|
image_only_indicator: Optional[torch.Tensor] = None,
|
1378
|
+
upsample_size: Optional[int] = None,
|
1378
1379
|
) -> torch.Tensor:
|
1379
1380
|
for resnet in self.resnets:
|
1380
1381
|
# pop res hidden states
|
@@ -1383,7 +1384,7 @@ class UpBlockSpatioTemporal(nn.Module):
|
|
1383
1384
|
|
1384
1385
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1385
1386
|
|
1386
|
-
if
|
1387
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1387
1388
|
|
1388
1389
|
def create_custom_forward(module):
|
1389
1390
|
def custom_forward(*inputs):
|
@@ -1415,7 +1416,7 @@ class UpBlockSpatioTemporal(nn.Module):
|
|
1415
1416
|
|
1416
1417
|
if self.upsamplers is not None:
|
1417
1418
|
for upsampler in self.upsamplers:
|
1418
|
-
hidden_states = upsampler(hidden_states)
|
1419
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1419
1420
|
|
1420
1421
|
return hidden_states
|
1421
1422
|
|
@@ -1485,6 +1486,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
|
1485
1486
|
temb: Optional[torch.Tensor] = None,
|
1486
1487
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1487
1488
|
image_only_indicator: Optional[torch.Tensor] = None,
|
1489
|
+
upsample_size: Optional[int] = None,
|
1488
1490
|
) -> torch.Tensor:
|
1489
1491
|
for resnet, attn in zip(self.resnets, self.attentions):
|
1490
1492
|
# pop res hidden states
|
@@ -1493,7 +1495,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
|
1493
1495
|
|
1494
1496
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1495
1497
|
|
1496
|
-
if
|
1498
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
|
1497
1499
|
|
1498
1500
|
def create_custom_forward(module, return_dict=None):
|
1499
1501
|
def custom_forward(*inputs):
|
@@ -1533,6 +1535,6 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
|
1533
1535
|
|
1534
1536
|
if self.upsamplers is not None:
|
1535
1537
|
for upsampler in self.upsamplers:
|
1536
|
-
hidden_states = upsampler(hidden_states)
|
1538
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1537
1539
|
|
1538
1540
|
return hidden_states
|