diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,386 @@
|
|
1
|
+
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from typing import Any, Dict, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
|
21
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
+
from ...models.attention import FeedForward
|
23
|
+
from ...models.attention_processor import (
|
24
|
+
Attention,
|
25
|
+
AttentionProcessor,
|
26
|
+
CogVideoXAttnProcessor2_0,
|
27
|
+
)
|
28
|
+
from ...models.modeling_utils import ModelMixin
|
29
|
+
from ...models.normalization import AdaLayerNormContinuous
|
30
|
+
from ...utils import is_torch_version, logging
|
31
|
+
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
|
32
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
33
|
+
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
|
34
|
+
|
35
|
+
|
36
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37
|
+
|
38
|
+
|
39
|
+
class CogView3PlusTransformerBlock(nn.Module):
|
40
|
+
r"""
|
41
|
+
Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
dim (`int`):
|
45
|
+
The number of channels in the input and output.
|
46
|
+
num_attention_heads (`int`):
|
47
|
+
The number of heads to use for multi-head attention.
|
48
|
+
attention_head_dim (`int`):
|
49
|
+
The number of channels in each head.
|
50
|
+
time_embed_dim (`int`):
|
51
|
+
The number of channels in timestep embedding.
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
dim: int = 2560,
|
57
|
+
num_attention_heads: int = 64,
|
58
|
+
attention_head_dim: int = 40,
|
59
|
+
time_embed_dim: int = 512,
|
60
|
+
):
|
61
|
+
super().__init__()
|
62
|
+
|
63
|
+
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
|
64
|
+
|
65
|
+
self.attn1 = Attention(
|
66
|
+
query_dim=dim,
|
67
|
+
heads=num_attention_heads,
|
68
|
+
dim_head=attention_head_dim,
|
69
|
+
out_dim=dim,
|
70
|
+
bias=True,
|
71
|
+
qk_norm="layer_norm",
|
72
|
+
elementwise_affine=False,
|
73
|
+
eps=1e-6,
|
74
|
+
processor=CogVideoXAttnProcessor2_0(),
|
75
|
+
)
|
76
|
+
|
77
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
78
|
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
79
|
+
|
80
|
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
81
|
+
|
82
|
+
def forward(
|
83
|
+
self,
|
84
|
+
hidden_states: torch.Tensor,
|
85
|
+
encoder_hidden_states: torch.Tensor,
|
86
|
+
emb: torch.Tensor,
|
87
|
+
) -> torch.Tensor:
|
88
|
+
text_seq_length = encoder_hidden_states.size(1)
|
89
|
+
|
90
|
+
# norm & modulate
|
91
|
+
(
|
92
|
+
norm_hidden_states,
|
93
|
+
gate_msa,
|
94
|
+
shift_mlp,
|
95
|
+
scale_mlp,
|
96
|
+
gate_mlp,
|
97
|
+
norm_encoder_hidden_states,
|
98
|
+
c_gate_msa,
|
99
|
+
c_shift_mlp,
|
100
|
+
c_scale_mlp,
|
101
|
+
c_gate_mlp,
|
102
|
+
) = self.norm1(hidden_states, encoder_hidden_states, emb)
|
103
|
+
|
104
|
+
# attention
|
105
|
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
106
|
+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
107
|
+
)
|
108
|
+
|
109
|
+
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
|
110
|
+
encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
|
111
|
+
|
112
|
+
# norm & modulate
|
113
|
+
norm_hidden_states = self.norm2(hidden_states)
|
114
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
115
|
+
|
116
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
117
|
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
118
|
+
|
119
|
+
# feed-forward
|
120
|
+
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
121
|
+
ff_output = self.ff(norm_hidden_states)
|
122
|
+
|
123
|
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
|
124
|
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
|
125
|
+
|
126
|
+
if hidden_states.dtype == torch.float16:
|
127
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
128
|
+
if encoder_hidden_states.dtype == torch.float16:
|
129
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
130
|
+
return hidden_states, encoder_hidden_states
|
131
|
+
|
132
|
+
|
133
|
+
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
134
|
+
r"""
|
135
|
+
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
|
136
|
+
Diffusion](https://huggingface.co/papers/2403.05121).
|
137
|
+
|
138
|
+
Args:
|
139
|
+
patch_size (`int`, defaults to `2`):
|
140
|
+
The size of the patches to use in the patch embedding layer.
|
141
|
+
in_channels (`int`, defaults to `16`):
|
142
|
+
The number of channels in the input.
|
143
|
+
num_layers (`int`, defaults to `30`):
|
144
|
+
The number of layers of Transformer blocks to use.
|
145
|
+
attention_head_dim (`int`, defaults to `40`):
|
146
|
+
The number of channels in each head.
|
147
|
+
num_attention_heads (`int`, defaults to `64`):
|
148
|
+
The number of heads to use for multi-head attention.
|
149
|
+
out_channels (`int`, defaults to `16`):
|
150
|
+
The number of channels in the output.
|
151
|
+
text_embed_dim (`int`, defaults to `4096`):
|
152
|
+
Input dimension of text embeddings from the text encoder.
|
153
|
+
time_embed_dim (`int`, defaults to `512`):
|
154
|
+
Output dimension of timestep embeddings.
|
155
|
+
condition_dim (`int`, defaults to `256`):
|
156
|
+
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
|
157
|
+
crop_coords).
|
158
|
+
pos_embed_max_size (`int`, defaults to `128`):
|
159
|
+
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
|
160
|
+
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
|
161
|
+
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
|
162
|
+
patch_size => 128 * 8 * 2 => 2048`.
|
163
|
+
sample_size (`int`, defaults to `128`):
|
164
|
+
The base resolution of input latents. If height/width is not provided during generation, this value is used
|
165
|
+
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
|
166
|
+
"""
|
167
|
+
|
168
|
+
_supports_gradient_checkpointing = True
|
169
|
+
|
170
|
+
@register_to_config
|
171
|
+
def __init__(
|
172
|
+
self,
|
173
|
+
patch_size: int = 2,
|
174
|
+
in_channels: int = 16,
|
175
|
+
num_layers: int = 30,
|
176
|
+
attention_head_dim: int = 40,
|
177
|
+
num_attention_heads: int = 64,
|
178
|
+
out_channels: int = 16,
|
179
|
+
text_embed_dim: int = 4096,
|
180
|
+
time_embed_dim: int = 512,
|
181
|
+
condition_dim: int = 256,
|
182
|
+
pos_embed_max_size: int = 128,
|
183
|
+
sample_size: int = 128,
|
184
|
+
):
|
185
|
+
super().__init__()
|
186
|
+
self.out_channels = out_channels
|
187
|
+
self.inner_dim = num_attention_heads * attention_head_dim
|
188
|
+
|
189
|
+
# CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
|
190
|
+
# Each of these are sincos embeddings of shape 2 * condition_dim
|
191
|
+
self.pooled_projection_dim = 3 * 2 * condition_dim
|
192
|
+
|
193
|
+
self.patch_embed = CogView3PlusPatchEmbed(
|
194
|
+
in_channels=in_channels,
|
195
|
+
hidden_size=self.inner_dim,
|
196
|
+
patch_size=patch_size,
|
197
|
+
text_hidden_size=text_embed_dim,
|
198
|
+
pos_embed_max_size=pos_embed_max_size,
|
199
|
+
)
|
200
|
+
|
201
|
+
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
|
202
|
+
embedding_dim=time_embed_dim,
|
203
|
+
condition_dim=condition_dim,
|
204
|
+
pooled_projection_dim=self.pooled_projection_dim,
|
205
|
+
timesteps_dim=self.inner_dim,
|
206
|
+
)
|
207
|
+
|
208
|
+
self.transformer_blocks = nn.ModuleList(
|
209
|
+
[
|
210
|
+
CogView3PlusTransformerBlock(
|
211
|
+
dim=self.inner_dim,
|
212
|
+
num_attention_heads=num_attention_heads,
|
213
|
+
attention_head_dim=attention_head_dim,
|
214
|
+
time_embed_dim=time_embed_dim,
|
215
|
+
)
|
216
|
+
for _ in range(num_layers)
|
217
|
+
]
|
218
|
+
)
|
219
|
+
|
220
|
+
self.norm_out = AdaLayerNormContinuous(
|
221
|
+
embedding_dim=self.inner_dim,
|
222
|
+
conditioning_embedding_dim=time_embed_dim,
|
223
|
+
elementwise_affine=False,
|
224
|
+
eps=1e-6,
|
225
|
+
)
|
226
|
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
227
|
+
|
228
|
+
self.gradient_checkpointing = False
|
229
|
+
|
230
|
+
@property
|
231
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
232
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
233
|
+
r"""
|
234
|
+
Returns:
|
235
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
236
|
+
indexed by its weight name.
|
237
|
+
"""
|
238
|
+
# set recursively
|
239
|
+
processors = {}
|
240
|
+
|
241
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
242
|
+
if hasattr(module, "get_processor"):
|
243
|
+
processors[f"{name}.processor"] = module.get_processor()
|
244
|
+
|
245
|
+
for sub_name, child in module.named_children():
|
246
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
247
|
+
|
248
|
+
return processors
|
249
|
+
|
250
|
+
for name, module in self.named_children():
|
251
|
+
fn_recursive_add_processors(name, module, processors)
|
252
|
+
|
253
|
+
return processors
|
254
|
+
|
255
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
256
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
257
|
+
r"""
|
258
|
+
Sets the attention processor to use to compute attention.
|
259
|
+
|
260
|
+
Parameters:
|
261
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
262
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
263
|
+
for **all** `Attention` layers.
|
264
|
+
|
265
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
266
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
267
|
+
|
268
|
+
"""
|
269
|
+
count = len(self.attn_processors.keys())
|
270
|
+
|
271
|
+
if isinstance(processor, dict) and len(processor) != count:
|
272
|
+
raise ValueError(
|
273
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
274
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
275
|
+
)
|
276
|
+
|
277
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
278
|
+
if hasattr(module, "set_processor"):
|
279
|
+
if not isinstance(processor, dict):
|
280
|
+
module.set_processor(processor)
|
281
|
+
else:
|
282
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
283
|
+
|
284
|
+
for sub_name, child in module.named_children():
|
285
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
286
|
+
|
287
|
+
for name, module in self.named_children():
|
288
|
+
fn_recursive_attn_processor(name, module, processor)
|
289
|
+
|
290
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
291
|
+
if hasattr(module, "gradient_checkpointing"):
|
292
|
+
module.gradient_checkpointing = value
|
293
|
+
|
294
|
+
def forward(
|
295
|
+
self,
|
296
|
+
hidden_states: torch.Tensor,
|
297
|
+
encoder_hidden_states: torch.Tensor,
|
298
|
+
timestep: torch.LongTensor,
|
299
|
+
original_size: torch.Tensor,
|
300
|
+
target_size: torch.Tensor,
|
301
|
+
crop_coords: torch.Tensor,
|
302
|
+
return_dict: bool = True,
|
303
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
304
|
+
"""
|
305
|
+
The [`CogView3PlusTransformer2DModel`] forward method.
|
306
|
+
|
307
|
+
Args:
|
308
|
+
hidden_states (`torch.Tensor`):
|
309
|
+
Input `hidden_states` of shape `(batch size, channel, height, width)`.
|
310
|
+
encoder_hidden_states (`torch.Tensor`):
|
311
|
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
|
312
|
+
`(batch_size, sequence_len, text_embed_dim)`
|
313
|
+
timestep (`torch.LongTensor`):
|
314
|
+
Used to indicate denoising step.
|
315
|
+
original_size (`torch.Tensor`):
|
316
|
+
CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
|
317
|
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
318
|
+
target_size (`torch.Tensor`):
|
319
|
+
CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
|
320
|
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
321
|
+
crop_coords (`torch.Tensor`):
|
322
|
+
CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
|
323
|
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
324
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
325
|
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
326
|
+
tuple.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
|
330
|
+
The denoised latents using provided inputs as conditioning.
|
331
|
+
"""
|
332
|
+
height, width = hidden_states.shape[-2:]
|
333
|
+
text_seq_length = encoder_hidden_states.shape[1]
|
334
|
+
|
335
|
+
hidden_states = self.patch_embed(
|
336
|
+
hidden_states, encoder_hidden_states
|
337
|
+
) # takes care of adding positional embeddings too.
|
338
|
+
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
|
339
|
+
|
340
|
+
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
341
|
+
hidden_states = hidden_states[:, text_seq_length:]
|
342
|
+
|
343
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
344
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
345
|
+
|
346
|
+
def create_custom_forward(module):
|
347
|
+
def custom_forward(*inputs):
|
348
|
+
return module(*inputs)
|
349
|
+
|
350
|
+
return custom_forward
|
351
|
+
|
352
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
353
|
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
354
|
+
create_custom_forward(block),
|
355
|
+
hidden_states,
|
356
|
+
encoder_hidden_states,
|
357
|
+
emb,
|
358
|
+
**ckpt_kwargs,
|
359
|
+
)
|
360
|
+
else:
|
361
|
+
hidden_states, encoder_hidden_states = block(
|
362
|
+
hidden_states=hidden_states,
|
363
|
+
encoder_hidden_states=encoder_hidden_states,
|
364
|
+
emb=emb,
|
365
|
+
)
|
366
|
+
|
367
|
+
hidden_states = self.norm_out(hidden_states, emb)
|
368
|
+
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
|
369
|
+
|
370
|
+
# unpatchify
|
371
|
+
patch_size = self.config.patch_size
|
372
|
+
height = height // patch_size
|
373
|
+
width = width // patch_size
|
374
|
+
|
375
|
+
hidden_states = hidden_states.reshape(
|
376
|
+
shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
|
377
|
+
)
|
378
|
+
hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
|
379
|
+
output = hidden_states.reshape(
|
380
|
+
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
381
|
+
)
|
382
|
+
|
383
|
+
if not return_dict:
|
384
|
+
return (output,)
|
385
|
+
|
386
|
+
return Transformer2DModelOutput(sample=output)
|