diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2222 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +1 -12
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +262 -2
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1795 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +319 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +1 -4
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +19 -16
- diffusers/utils/loading_utils.py +76 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,455 @@
|
|
1
|
+
# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
import torch.nn.functional as F
|
21
|
+
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
24
|
+
from ...models.attention import FeedForward
|
25
|
+
from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
|
26
|
+
from ...models.modeling_utils import ModelMixin
|
27
|
+
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
28
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
29
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
30
|
+
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
|
31
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35
|
+
|
36
|
+
|
37
|
+
# YiYi to-do: refactor rope related functions/classes
|
38
|
+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
39
|
+
assert dim % 2 == 0, "The dimension must be even."
|
40
|
+
|
41
|
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
42
|
+
omega = 1.0 / (theta**scale)
|
43
|
+
|
44
|
+
batch_size, seq_length = pos.shape
|
45
|
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
46
|
+
cos_out = torch.cos(out)
|
47
|
+
sin_out = torch.sin(out)
|
48
|
+
|
49
|
+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
50
|
+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
51
|
+
return out.float()
|
52
|
+
|
53
|
+
|
54
|
+
# YiYi to-do: refactor rope related functions/classes
|
55
|
+
class EmbedND(nn.Module):
|
56
|
+
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
57
|
+
super().__init__()
|
58
|
+
self.dim = dim
|
59
|
+
self.theta = theta
|
60
|
+
self.axes_dim = axes_dim
|
61
|
+
|
62
|
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
63
|
+
n_axes = ids.shape[-1]
|
64
|
+
emb = torch.cat(
|
65
|
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
66
|
+
dim=-3,
|
67
|
+
)
|
68
|
+
return emb.unsqueeze(1)
|
69
|
+
|
70
|
+
|
71
|
+
@maybe_allow_in_graph
|
72
|
+
class FluxSingleTransformerBlock(nn.Module):
|
73
|
+
r"""
|
74
|
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
75
|
+
|
76
|
+
Reference: https://arxiv.org/abs/2403.03206
|
77
|
+
|
78
|
+
Parameters:
|
79
|
+
dim (`int`): The number of channels in the input and output.
|
80
|
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
81
|
+
attention_head_dim (`int`): The number of channels in each head.
|
82
|
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
83
|
+
processing of `context` conditions.
|
84
|
+
"""
|
85
|
+
|
86
|
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
87
|
+
super().__init__()
|
88
|
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
89
|
+
|
90
|
+
self.norm = AdaLayerNormZeroSingle(dim)
|
91
|
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
92
|
+
self.act_mlp = nn.GELU(approximate="tanh")
|
93
|
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
94
|
+
|
95
|
+
processor = FluxSingleAttnProcessor2_0()
|
96
|
+
self.attn = Attention(
|
97
|
+
query_dim=dim,
|
98
|
+
cross_attention_dim=None,
|
99
|
+
dim_head=attention_head_dim,
|
100
|
+
heads=num_attention_heads,
|
101
|
+
out_dim=dim,
|
102
|
+
bias=True,
|
103
|
+
processor=processor,
|
104
|
+
qk_norm="rms_norm",
|
105
|
+
eps=1e-6,
|
106
|
+
pre_only=True,
|
107
|
+
)
|
108
|
+
|
109
|
+
def forward(
|
110
|
+
self,
|
111
|
+
hidden_states: torch.FloatTensor,
|
112
|
+
temb: torch.FloatTensor,
|
113
|
+
image_rotary_emb=None,
|
114
|
+
):
|
115
|
+
residual = hidden_states
|
116
|
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
117
|
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
118
|
+
|
119
|
+
attn_output = self.attn(
|
120
|
+
hidden_states=norm_hidden_states,
|
121
|
+
image_rotary_emb=image_rotary_emb,
|
122
|
+
)
|
123
|
+
|
124
|
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
125
|
+
gate = gate.unsqueeze(1)
|
126
|
+
hidden_states = gate * self.proj_out(hidden_states)
|
127
|
+
hidden_states = residual + hidden_states
|
128
|
+
if hidden_states.dtype == torch.float16:
|
129
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
130
|
+
|
131
|
+
return hidden_states
|
132
|
+
|
133
|
+
|
134
|
+
@maybe_allow_in_graph
|
135
|
+
class FluxTransformerBlock(nn.Module):
|
136
|
+
r"""
|
137
|
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
138
|
+
|
139
|
+
Reference: https://arxiv.org/abs/2403.03206
|
140
|
+
|
141
|
+
Parameters:
|
142
|
+
dim (`int`): The number of channels in the input and output.
|
143
|
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
144
|
+
attention_head_dim (`int`): The number of channels in each head.
|
145
|
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
146
|
+
processing of `context` conditions.
|
147
|
+
"""
|
148
|
+
|
149
|
+
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
|
150
|
+
super().__init__()
|
151
|
+
|
152
|
+
self.norm1 = AdaLayerNormZero(dim)
|
153
|
+
|
154
|
+
self.norm1_context = AdaLayerNormZero(dim)
|
155
|
+
|
156
|
+
if hasattr(F, "scaled_dot_product_attention"):
|
157
|
+
processor = FluxAttnProcessor2_0()
|
158
|
+
else:
|
159
|
+
raise ValueError(
|
160
|
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
161
|
+
)
|
162
|
+
self.attn = Attention(
|
163
|
+
query_dim=dim,
|
164
|
+
cross_attention_dim=None,
|
165
|
+
added_kv_proj_dim=dim,
|
166
|
+
dim_head=attention_head_dim,
|
167
|
+
heads=num_attention_heads,
|
168
|
+
out_dim=dim,
|
169
|
+
context_pre_only=False,
|
170
|
+
bias=True,
|
171
|
+
processor=processor,
|
172
|
+
qk_norm=qk_norm,
|
173
|
+
eps=eps,
|
174
|
+
)
|
175
|
+
|
176
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
177
|
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
178
|
+
|
179
|
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
180
|
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
181
|
+
|
182
|
+
# let chunk size default to None
|
183
|
+
self._chunk_size = None
|
184
|
+
self._chunk_dim = 0
|
185
|
+
|
186
|
+
def forward(
|
187
|
+
self,
|
188
|
+
hidden_states: torch.FloatTensor,
|
189
|
+
encoder_hidden_states: torch.FloatTensor,
|
190
|
+
temb: torch.FloatTensor,
|
191
|
+
image_rotary_emb=None,
|
192
|
+
):
|
193
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
194
|
+
|
195
|
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
196
|
+
encoder_hidden_states, emb=temb
|
197
|
+
)
|
198
|
+
|
199
|
+
# Attention.
|
200
|
+
attn_output, context_attn_output = self.attn(
|
201
|
+
hidden_states=norm_hidden_states,
|
202
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
203
|
+
image_rotary_emb=image_rotary_emb,
|
204
|
+
)
|
205
|
+
|
206
|
+
# Process attention outputs for the `hidden_states`.
|
207
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
208
|
+
hidden_states = hidden_states + attn_output
|
209
|
+
|
210
|
+
norm_hidden_states = self.norm2(hidden_states)
|
211
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
212
|
+
|
213
|
+
ff_output = self.ff(norm_hidden_states)
|
214
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
215
|
+
|
216
|
+
hidden_states = hidden_states + ff_output
|
217
|
+
|
218
|
+
# Process attention outputs for the `encoder_hidden_states`.
|
219
|
+
|
220
|
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
221
|
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
222
|
+
|
223
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
224
|
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
225
|
+
|
226
|
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
227
|
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
228
|
+
if encoder_hidden_states.dtype == torch.float16:
|
229
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
230
|
+
|
231
|
+
return encoder_hidden_states, hidden_states
|
232
|
+
|
233
|
+
|
234
|
+
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
235
|
+
"""
|
236
|
+
The Transformer model introduced in Flux.
|
237
|
+
|
238
|
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
239
|
+
|
240
|
+
Parameters:
|
241
|
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
242
|
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
243
|
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
244
|
+
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
245
|
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
246
|
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
247
|
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
248
|
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
249
|
+
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
250
|
+
"""
|
251
|
+
|
252
|
+
_supports_gradient_checkpointing = True
|
253
|
+
|
254
|
+
@register_to_config
|
255
|
+
def __init__(
|
256
|
+
self,
|
257
|
+
patch_size: int = 1,
|
258
|
+
in_channels: int = 64,
|
259
|
+
num_layers: int = 19,
|
260
|
+
num_single_layers: int = 38,
|
261
|
+
attention_head_dim: int = 128,
|
262
|
+
num_attention_heads: int = 24,
|
263
|
+
joint_attention_dim: int = 4096,
|
264
|
+
pooled_projection_dim: int = 768,
|
265
|
+
guidance_embeds: bool = False,
|
266
|
+
axes_dims_rope: List[int] = [16, 56, 56],
|
267
|
+
):
|
268
|
+
super().__init__()
|
269
|
+
self.out_channels = in_channels
|
270
|
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
271
|
+
|
272
|
+
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
|
273
|
+
text_time_guidance_cls = (
|
274
|
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
275
|
+
)
|
276
|
+
self.time_text_embed = text_time_guidance_cls(
|
277
|
+
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
|
278
|
+
)
|
279
|
+
|
280
|
+
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
281
|
+
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
282
|
+
|
283
|
+
self.transformer_blocks = nn.ModuleList(
|
284
|
+
[
|
285
|
+
FluxTransformerBlock(
|
286
|
+
dim=self.inner_dim,
|
287
|
+
num_attention_heads=self.config.num_attention_heads,
|
288
|
+
attention_head_dim=self.config.attention_head_dim,
|
289
|
+
)
|
290
|
+
for i in range(self.config.num_layers)
|
291
|
+
]
|
292
|
+
)
|
293
|
+
|
294
|
+
self.single_transformer_blocks = nn.ModuleList(
|
295
|
+
[
|
296
|
+
FluxSingleTransformerBlock(
|
297
|
+
dim=self.inner_dim,
|
298
|
+
num_attention_heads=self.config.num_attention_heads,
|
299
|
+
attention_head_dim=self.config.attention_head_dim,
|
300
|
+
)
|
301
|
+
for i in range(self.config.num_single_layers)
|
302
|
+
]
|
303
|
+
)
|
304
|
+
|
305
|
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
306
|
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
307
|
+
|
308
|
+
self.gradient_checkpointing = False
|
309
|
+
|
310
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
311
|
+
if hasattr(module, "gradient_checkpointing"):
|
312
|
+
module.gradient_checkpointing = value
|
313
|
+
|
314
|
+
def forward(
|
315
|
+
self,
|
316
|
+
hidden_states: torch.Tensor,
|
317
|
+
encoder_hidden_states: torch.Tensor = None,
|
318
|
+
pooled_projections: torch.Tensor = None,
|
319
|
+
timestep: torch.LongTensor = None,
|
320
|
+
img_ids: torch.Tensor = None,
|
321
|
+
txt_ids: torch.Tensor = None,
|
322
|
+
guidance: torch.Tensor = None,
|
323
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
324
|
+
return_dict: bool = True,
|
325
|
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
326
|
+
"""
|
327
|
+
The [`FluxTransformer2DModel`] forward method.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
331
|
+
Input `hidden_states`.
|
332
|
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
333
|
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
334
|
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
335
|
+
from the embeddings of input conditions.
|
336
|
+
timestep ( `torch.LongTensor`):
|
337
|
+
Used to indicate denoising step.
|
338
|
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
339
|
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
340
|
+
joint_attention_kwargs (`dict`, *optional*):
|
341
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
342
|
+
`self.processor` in
|
343
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
344
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
345
|
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
346
|
+
tuple.
|
347
|
+
|
348
|
+
Returns:
|
349
|
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
350
|
+
`tuple` where the first element is the sample tensor.
|
351
|
+
"""
|
352
|
+
if joint_attention_kwargs is not None:
|
353
|
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
354
|
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
355
|
+
else:
|
356
|
+
lora_scale = 1.0
|
357
|
+
|
358
|
+
if USE_PEFT_BACKEND:
|
359
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
360
|
+
scale_lora_layers(self, lora_scale)
|
361
|
+
else:
|
362
|
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
363
|
+
logger.warning(
|
364
|
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
365
|
+
)
|
366
|
+
hidden_states = self.x_embedder(hidden_states)
|
367
|
+
|
368
|
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
369
|
+
if guidance is not None:
|
370
|
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
371
|
+
else:
|
372
|
+
guidance = None
|
373
|
+
temb = (
|
374
|
+
self.time_text_embed(timestep, pooled_projections)
|
375
|
+
if guidance is None
|
376
|
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
377
|
+
)
|
378
|
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
379
|
+
|
380
|
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
381
|
+
image_rotary_emb = self.pos_embed(ids)
|
382
|
+
|
383
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
384
|
+
if self.training and self.gradient_checkpointing:
|
385
|
+
|
386
|
+
def create_custom_forward(module, return_dict=None):
|
387
|
+
def custom_forward(*inputs):
|
388
|
+
if return_dict is not None:
|
389
|
+
return module(*inputs, return_dict=return_dict)
|
390
|
+
else:
|
391
|
+
return module(*inputs)
|
392
|
+
|
393
|
+
return custom_forward
|
394
|
+
|
395
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
396
|
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
397
|
+
create_custom_forward(block),
|
398
|
+
hidden_states,
|
399
|
+
encoder_hidden_states,
|
400
|
+
temb,
|
401
|
+
image_rotary_emb,
|
402
|
+
**ckpt_kwargs,
|
403
|
+
)
|
404
|
+
|
405
|
+
else:
|
406
|
+
encoder_hidden_states, hidden_states = block(
|
407
|
+
hidden_states=hidden_states,
|
408
|
+
encoder_hidden_states=encoder_hidden_states,
|
409
|
+
temb=temb,
|
410
|
+
image_rotary_emb=image_rotary_emb,
|
411
|
+
)
|
412
|
+
|
413
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
414
|
+
|
415
|
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
416
|
+
if self.training and self.gradient_checkpointing:
|
417
|
+
|
418
|
+
def create_custom_forward(module, return_dict=None):
|
419
|
+
def custom_forward(*inputs):
|
420
|
+
if return_dict is not None:
|
421
|
+
return module(*inputs, return_dict=return_dict)
|
422
|
+
else:
|
423
|
+
return module(*inputs)
|
424
|
+
|
425
|
+
return custom_forward
|
426
|
+
|
427
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
428
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
429
|
+
create_custom_forward(block),
|
430
|
+
hidden_states,
|
431
|
+
temb,
|
432
|
+
image_rotary_emb,
|
433
|
+
**ckpt_kwargs,
|
434
|
+
)
|
435
|
+
|
436
|
+
else:
|
437
|
+
hidden_states = block(
|
438
|
+
hidden_states=hidden_states,
|
439
|
+
temb=temb,
|
440
|
+
image_rotary_emb=image_rotary_emb,
|
441
|
+
)
|
442
|
+
|
443
|
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
444
|
+
|
445
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
446
|
+
output = self.proj_out(hidden_states)
|
447
|
+
|
448
|
+
if USE_PEFT_BACKEND:
|
449
|
+
# remove `lora_scale` from each PEFT layer
|
450
|
+
unscale_lora_layers(self, lora_scale)
|
451
|
+
|
452
|
+
if not return_dict:
|
453
|
+
return (output,)
|
454
|
+
|
455
|
+
return Transformer2DModelOutput(sample=output)
|
@@ -21,7 +21,7 @@ import torch.nn as nn
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
22
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
23
23
|
from ...models.attention import JointTransformerBlock
|
24
|
-
from ...models.attention_processor import Attention, AttentionProcessor
|
24
|
+
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
25
25
|
from ...models.modeling_utils import ModelMixin
|
26
26
|
from ...models.normalization import AdaLayerNormContinuous
|
27
27
|
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
@@ -95,7 +95,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
95
95
|
JointTransformerBlock(
|
96
96
|
dim=self.inner_dim,
|
97
97
|
num_attention_heads=self.config.num_attention_heads,
|
98
|
-
attention_head_dim=self.
|
98
|
+
attention_head_dim=self.config.attention_head_dim,
|
99
99
|
context_pre_only=i == num_layers - 1,
|
100
100
|
)
|
101
101
|
for i in range(self.config.num_layers)
|
@@ -137,6 +137,18 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
137
137
|
for module in self.children():
|
138
138
|
fn_recursive_feed_forward(module, chunk_size, dim)
|
139
139
|
|
140
|
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
141
|
+
def disable_forward_chunking(self):
|
142
|
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
143
|
+
if hasattr(module, "set_chunk_feed_forward"):
|
144
|
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
145
|
+
|
146
|
+
for child in module.children():
|
147
|
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
148
|
+
|
149
|
+
for module in self.children():
|
150
|
+
fn_recursive_feed_forward(module, None, 0)
|
151
|
+
|
140
152
|
@property
|
141
153
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
142
154
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
@@ -150,7 +162,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
150
162
|
|
151
163
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
152
164
|
if hasattr(module, "get_processor"):
|
153
|
-
processors[f"{name}.processor"] = module.get_processor(
|
165
|
+
processors[f"{name}.processor"] = module.get_processor()
|
154
166
|
|
155
167
|
for sub_name, child in module.named_children():
|
156
168
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
@@ -197,7 +209,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
197
209
|
for name, module in self.named_children():
|
198
210
|
fn_recursive_attn_processor(name, module, processor)
|
199
211
|
|
200
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
212
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
|
201
213
|
def fuse_qkv_projections(self):
|
202
214
|
"""
|
203
215
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
@@ -221,6 +233,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
221
233
|
if isinstance(module, Attention):
|
222
234
|
module.fuse_projections(fuse=True)
|
223
235
|
|
236
|
+
self.set_attn_processor(FusedJointAttnProcessor2_0())
|
237
|
+
|
224
238
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
225
239
|
def unfuse_qkv_projections(self):
|
226
240
|
"""Disables the fused QKV projection if enabled.
|
@@ -200,7 +200,7 @@ class MidResTemporalBlock1D(nn.Module):
|
|
200
200
|
|
201
201
|
self.upsample = None
|
202
202
|
if add_upsample:
|
203
|
-
self.upsample =
|
203
|
+
self.upsample = Upsample1D(out_channels, use_conv=True)
|
204
204
|
|
205
205
|
self.downsample = None
|
206
206
|
if add_downsample:
|
@@ -30,6 +30,7 @@ from ..attention_processor import (
|
|
30
30
|
AttentionProcessor,
|
31
31
|
AttnAddedKVProcessor,
|
32
32
|
AttnProcessor,
|
33
|
+
FusedAttnProcessor2_0,
|
33
34
|
)
|
34
35
|
from ..embeddings import (
|
35
36
|
GaussianFourierProjection,
|
@@ -705,7 +706,7 @@ class UNet2DConditionModel(
|
|
705
706
|
|
706
707
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
707
708
|
if hasattr(module, "get_processor"):
|
708
|
-
processors[f"{name}.processor"] = module.get_processor(
|
709
|
+
processors[f"{name}.processor"] = module.get_processor()
|
709
710
|
|
710
711
|
for sub_name, child in module.named_children():
|
711
712
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
@@ -890,6 +891,8 @@ class UNet2DConditionModel(
|
|
890
891
|
if isinstance(module, Attention):
|
891
892
|
module.fuse_projections(fuse=True)
|
892
893
|
|
894
|
+
self.set_attn_processor(FusedAttnProcessor2_0())
|
895
|
+
|
893
896
|
def unfuse_qkv_projections(self):
|
894
897
|
"""Disables the fused QKV projection if enabled.
|
895
898
|
|
@@ -1024,6 +1027,10 @@ class UNet2DConditionModel(
|
|
1024
1027
|
raise ValueError(
|
1025
1028
|
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1026
1029
|
)
|
1030
|
+
|
1031
|
+
if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
|
1032
|
+
encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
|
1033
|
+
|
1027
1034
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
1028
1035
|
image_embeds = self.encoder_hid_proj(image_embeds)
|
1029
1036
|
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|