diffusers 0.28.2__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +9 -1
- diffusers/commands/env.py +1 -5
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +2 -1
- diffusers/loaders/__init__.py +2 -2
- diffusers/loaders/lora.py +406 -140
- diffusers/loaders/lora_conversion_utils.py +7 -1
- diffusers/loaders/single_file.py +1 -1
- diffusers/loaders/single_file_model.py +5 -0
- diffusers/loaders/single_file_utils.py +242 -2
- diffusers/loaders/unet.py +307 -272
- diffusers/models/__init__.py +5 -3
- diffusers/models/attention.py +125 -1
- diffusers/models/attention_processor.py +169 -1
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +17 -6
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet_xs.py +6 -6
- diffusers/models/embeddings.py +112 -84
- diffusers/models/model_loading_utils.py +55 -0
- diffusers/models/modeling_utils.py +128 -17
- diffusers/models/normalization.py +11 -6
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/dual_transformer_2d.py +5 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
- diffusers/models/transformers/prior_transformer.py +5 -5
- diffusers/models/transformers/transformer_2d.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +344 -0
- diffusers/models/transformers/transformer_temporal.py +12 -10
- diffusers/models/unets/unet_1d.py +3 -3
- diffusers/models/unets/unet_2d.py +3 -3
- diffusers/models/unets/unet_2d_condition.py +4 -15
- diffusers/models/unets/unet_3d_condition.py +5 -17
- diffusers/models/unets/unet_i2vgen_xl.py +4 -4
- diffusers/models/unets/unet_motion_model.py +4 -4
- diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
- diffusers/models/vq_model.py +8 -165
- diffusers/pipelines/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
- diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
- diffusers/pipelines/pia/pipeline_pia.py +4 -3
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +886 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +923 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
- diffusers/schedulers/__init__.py +2 -0
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
- diffusers/schedulers/scheduling_edm_euler.py +2 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/training_utils.py +4 -4
- diffusers/utils/__init__.py +3 -0
- diffusers/utils/constants.py +2 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +30 -0
- diffusers/utils/dynamic_modules_utils.py +15 -13
- diffusers/utils/hub_utils.py +106 -0
- diffusers/utils/import_utils.py +0 -1
- diffusers/utils/logging.py +3 -1
- diffusers/utils/state_dict_utils.py +2 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/METADATA +45 -45
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/RECORD +108 -111
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/WHEEL +1 -1
- diffusers/models/dual_transformer_2d.py +0 -20
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/LICENSE +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from dataclasses import dataclass
|
15
|
+
from typing import Optional, Tuple, Union
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
|
20
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
21
|
+
from ...utils import BaseOutput
|
22
|
+
from ...utils.accelerate_utils import apply_forward_hook
|
23
|
+
from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
|
24
|
+
from ..modeling_utils import ModelMixin
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class VQEncoderOutput(BaseOutput):
|
29
|
+
"""
|
30
|
+
Output of VQModel encoding method.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
34
|
+
The encoded output sample from the last layer of the model.
|
35
|
+
"""
|
36
|
+
|
37
|
+
latents: torch.Tensor
|
38
|
+
|
39
|
+
|
40
|
+
class VQModel(ModelMixin, ConfigMixin):
|
41
|
+
r"""
|
42
|
+
A VQ-VAE model for decoding latent representations.
|
43
|
+
|
44
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
45
|
+
for all models (such as downloading or saving).
|
46
|
+
|
47
|
+
Parameters:
|
48
|
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
49
|
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
50
|
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
51
|
+
Tuple of downsample block types.
|
52
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
53
|
+
Tuple of upsample block types.
|
54
|
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
55
|
+
Tuple of block output channels.
|
56
|
+
layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
|
57
|
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
58
|
+
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
59
|
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
60
|
+
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
61
|
+
norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
|
62
|
+
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
|
63
|
+
scaling_factor (`float`, *optional*, defaults to `0.18215`):
|
64
|
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
65
|
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
66
|
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
67
|
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
68
|
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
69
|
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
70
|
+
norm_type (`str`, *optional*, defaults to `"group"`):
|
71
|
+
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
|
72
|
+
"""
|
73
|
+
|
74
|
+
@register_to_config
|
75
|
+
def __init__(
|
76
|
+
self,
|
77
|
+
in_channels: int = 3,
|
78
|
+
out_channels: int = 3,
|
79
|
+
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
80
|
+
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
81
|
+
block_out_channels: Tuple[int, ...] = (64,),
|
82
|
+
layers_per_block: int = 1,
|
83
|
+
act_fn: str = "silu",
|
84
|
+
latent_channels: int = 3,
|
85
|
+
sample_size: int = 32,
|
86
|
+
num_vq_embeddings: int = 256,
|
87
|
+
norm_num_groups: int = 32,
|
88
|
+
vq_embed_dim: Optional[int] = None,
|
89
|
+
scaling_factor: float = 0.18215,
|
90
|
+
norm_type: str = "group", # group, spatial
|
91
|
+
mid_block_add_attention=True,
|
92
|
+
lookup_from_codebook=False,
|
93
|
+
force_upcast=False,
|
94
|
+
):
|
95
|
+
super().__init__()
|
96
|
+
|
97
|
+
# pass init params to Encoder
|
98
|
+
self.encoder = Encoder(
|
99
|
+
in_channels=in_channels,
|
100
|
+
out_channels=latent_channels,
|
101
|
+
down_block_types=down_block_types,
|
102
|
+
block_out_channels=block_out_channels,
|
103
|
+
layers_per_block=layers_per_block,
|
104
|
+
act_fn=act_fn,
|
105
|
+
norm_num_groups=norm_num_groups,
|
106
|
+
double_z=False,
|
107
|
+
mid_block_add_attention=mid_block_add_attention,
|
108
|
+
)
|
109
|
+
|
110
|
+
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
|
111
|
+
|
112
|
+
self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
|
113
|
+
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
|
114
|
+
self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
|
115
|
+
|
116
|
+
# pass init params to Decoder
|
117
|
+
self.decoder = Decoder(
|
118
|
+
in_channels=latent_channels,
|
119
|
+
out_channels=out_channels,
|
120
|
+
up_block_types=up_block_types,
|
121
|
+
block_out_channels=block_out_channels,
|
122
|
+
layers_per_block=layers_per_block,
|
123
|
+
act_fn=act_fn,
|
124
|
+
norm_num_groups=norm_num_groups,
|
125
|
+
norm_type=norm_type,
|
126
|
+
mid_block_add_attention=mid_block_add_attention,
|
127
|
+
)
|
128
|
+
|
129
|
+
@apply_forward_hook
|
130
|
+
def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
|
131
|
+
h = self.encoder(x)
|
132
|
+
h = self.quant_conv(h)
|
133
|
+
|
134
|
+
if not return_dict:
|
135
|
+
return (h,)
|
136
|
+
|
137
|
+
return VQEncoderOutput(latents=h)
|
138
|
+
|
139
|
+
@apply_forward_hook
|
140
|
+
def decode(
|
141
|
+
self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
|
142
|
+
) -> Union[DecoderOutput, torch.Tensor]:
|
143
|
+
# also go through quantization layer
|
144
|
+
if not force_not_quantize:
|
145
|
+
quant, commit_loss, _ = self.quantize(h)
|
146
|
+
elif self.config.lookup_from_codebook:
|
147
|
+
quant = self.quantize.get_codebook_entry(h, shape)
|
148
|
+
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
149
|
+
else:
|
150
|
+
quant = h
|
151
|
+
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
152
|
+
quant2 = self.post_quant_conv(quant)
|
153
|
+
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
|
154
|
+
|
155
|
+
if not return_dict:
|
156
|
+
return dec, commit_loss
|
157
|
+
|
158
|
+
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|
159
|
+
|
160
|
+
def forward(
|
161
|
+
self, sample: torch.Tensor, return_dict: bool = True
|
162
|
+
) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
|
163
|
+
r"""
|
164
|
+
The [`VQModel`] forward method.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
sample (`torch.Tensor`): Input sample.
|
168
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
169
|
+
Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
[`~models.vq_model.VQEncoderOutput`] or `tuple`:
|
173
|
+
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
|
174
|
+
is returned.
|
175
|
+
"""
|
176
|
+
|
177
|
+
h = self.encode(sample).latents
|
178
|
+
dec = self.decode(h)
|
179
|
+
|
180
|
+
if not return_dict:
|
181
|
+
return dec.sample, dec.commit_loss
|
182
|
+
return dec
|
@@ -851,8 +851,8 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
851
851
|
if hasattr(module, "gradient_checkpointing"):
|
852
852
|
module.gradient_checkpointing = value
|
853
853
|
|
854
|
-
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
|
855
854
|
@property
|
855
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
856
856
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
857
857
|
r"""
|
858
858
|
Returns:
|
@@ -911,7 +911,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
911
911
|
for name, module in self.named_children():
|
912
912
|
fn_recursive_attn_processor(name, module, processor)
|
913
913
|
|
914
|
-
#
|
914
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
915
915
|
def set_default_attn_processor(self):
|
916
916
|
"""
|
917
917
|
Disables custom attention processors and sets the default attention implementation.
|
@@ -927,7 +927,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
927
927
|
|
928
928
|
self.set_attn_processor(processor)
|
929
929
|
|
930
|
-
#
|
930
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
931
931
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
932
932
|
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
933
933
|
|
@@ -952,7 +952,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
952
952
|
setattr(upsample_block, "b1", b1)
|
953
953
|
setattr(upsample_block, "b2", b2)
|
954
954
|
|
955
|
-
#
|
955
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
956
956
|
def disable_freeu(self):
|
957
957
|
"""Disables the FreeU mechanism."""
|
958
958
|
freeu_keys = {"s1", "s2", "b1", "b2"}
|
@@ -961,7 +961,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
961
961
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
962
962
|
setattr(upsample_block, k, None)
|
963
963
|
|
964
|
-
#
|
964
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
965
965
|
def fuse_qkv_projections(self):
|
966
966
|
"""
|
967
967
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
@@ -985,7 +985,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
985
985
|
if isinstance(module, Attention):
|
986
986
|
module.fuse_projections(fuse=True)
|
987
987
|
|
988
|
-
#
|
988
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
989
989
|
def unfuse_qkv_projections(self):
|
990
990
|
"""Disables the fused QKV projection if enabled.
|
991
991
|
|
diffusers/models/embeddings.py
CHANGED
@@ -123,7 +123,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
123
123
|
|
124
124
|
|
125
125
|
class PatchEmbed(nn.Module):
|
126
|
-
"""2D Image to Patch Embedding"""
|
126
|
+
"""2D Image to Patch Embedding with support for SD3 cropping."""
|
127
127
|
|
128
128
|
def __init__(
|
129
129
|
self,
|
@@ -137,12 +137,14 @@ class PatchEmbed(nn.Module):
|
|
137
137
|
bias=True,
|
138
138
|
interpolation_scale=1,
|
139
139
|
pos_embed_type="sincos",
|
140
|
+
pos_embed_max_size=None, # For SD3 cropping
|
140
141
|
):
|
141
142
|
super().__init__()
|
142
143
|
|
143
144
|
num_patches = (height // patch_size) * (width // patch_size)
|
144
145
|
self.flatten = flatten
|
145
146
|
self.layer_norm = layer_norm
|
147
|
+
self.pos_embed_max_size = pos_embed_max_size
|
146
148
|
|
147
149
|
self.proj = nn.Conv2d(
|
148
150
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
@@ -153,26 +155,55 @@ class PatchEmbed(nn.Module):
|
|
153
155
|
self.norm = None
|
154
156
|
|
155
157
|
self.patch_size = patch_size
|
156
|
-
# See:
|
157
|
-
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
|
158
158
|
self.height, self.width = height // patch_size, width // patch_size
|
159
159
|
self.base_size = height // patch_size
|
160
160
|
self.interpolation_scale = interpolation_scale
|
161
|
+
|
162
|
+
# Calculate positional embeddings based on max size or default
|
163
|
+
if pos_embed_max_size:
|
164
|
+
grid_size = pos_embed_max_size
|
165
|
+
else:
|
166
|
+
grid_size = int(num_patches**0.5)
|
167
|
+
|
161
168
|
if pos_embed_type is None:
|
162
169
|
self.pos_embed = None
|
163
170
|
elif pos_embed_type == "sincos":
|
164
171
|
pos_embed = get_2d_sincos_pos_embed(
|
165
|
-
embed_dim,
|
166
|
-
int(num_patches**0.5),
|
167
|
-
base_size=self.base_size,
|
168
|
-
interpolation_scale=self.interpolation_scale,
|
172
|
+
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
169
173
|
)
|
170
|
-
|
174
|
+
persistent = True if pos_embed_max_size else False
|
175
|
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
|
171
176
|
else:
|
172
177
|
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
173
178
|
|
179
|
+
def cropped_pos_embed(self, height, width):
|
180
|
+
"""Crops positional embeddings for SD3 compatibility."""
|
181
|
+
if self.pos_embed_max_size is None:
|
182
|
+
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
183
|
+
|
184
|
+
height = height // self.patch_size
|
185
|
+
width = width // self.patch_size
|
186
|
+
if height > self.pos_embed_max_size:
|
187
|
+
raise ValueError(
|
188
|
+
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
189
|
+
)
|
190
|
+
if width > self.pos_embed_max_size:
|
191
|
+
raise ValueError(
|
192
|
+
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
193
|
+
)
|
194
|
+
|
195
|
+
top = (self.pos_embed_max_size - height) // 2
|
196
|
+
left = (self.pos_embed_max_size - width) // 2
|
197
|
+
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
198
|
+
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
199
|
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
200
|
+
return spatial_pos_embed
|
201
|
+
|
174
202
|
def forward(self, latent):
|
175
|
-
|
203
|
+
if self.pos_embed_max_size is not None:
|
204
|
+
height, width = latent.shape[-2:]
|
205
|
+
else:
|
206
|
+
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
176
207
|
|
177
208
|
latent = self.proj(latent)
|
178
209
|
if self.flatten:
|
@@ -181,20 +212,20 @@ class PatchEmbed(nn.Module):
|
|
181
212
|
latent = self.norm(latent)
|
182
213
|
if self.pos_embed is None:
|
183
214
|
return latent.to(latent.dtype)
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
if self.height != height or self.width != width:
|
188
|
-
pos_embed = get_2d_sincos_pos_embed(
|
189
|
-
embed_dim=self.pos_embed.shape[-1],
|
190
|
-
grid_size=(height, width),
|
191
|
-
base_size=self.base_size,
|
192
|
-
interpolation_scale=self.interpolation_scale,
|
193
|
-
)
|
194
|
-
pos_embed = torch.from_numpy(pos_embed)
|
195
|
-
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
|
215
|
+
# Interpolate or crop positional embeddings as needed
|
216
|
+
if self.pos_embed_max_size:
|
217
|
+
pos_embed = self.cropped_pos_embed(height, width)
|
196
218
|
else:
|
197
|
-
|
219
|
+
if self.height != height or self.width != width:
|
220
|
+
pos_embed = get_2d_sincos_pos_embed(
|
221
|
+
embed_dim=self.pos_embed.shape[-1],
|
222
|
+
grid_size=(height, width),
|
223
|
+
base_size=self.base_size,
|
224
|
+
interpolation_scale=self.interpolation_scale,
|
225
|
+
)
|
226
|
+
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
|
227
|
+
else:
|
228
|
+
pos_embed = self.pos_embed
|
198
229
|
|
199
230
|
return (latent + pos_embed).to(latent.dtype)
|
200
231
|
|
@@ -626,6 +657,25 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
|
|
626
657
|
return conditioning
|
627
658
|
|
628
659
|
|
660
|
+
class CombinedTimestepTextProjEmbeddings(nn.Module):
|
661
|
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
662
|
+
super().__init__()
|
663
|
+
|
664
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
665
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
666
|
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
667
|
+
|
668
|
+
def forward(self, timestep, pooled_projection):
|
669
|
+
timesteps_proj = self.time_proj(timestep)
|
670
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
671
|
+
|
672
|
+
pooled_projections = self.text_embedder(pooled_projection)
|
673
|
+
|
674
|
+
conditioning = timesteps_emb + pooled_projections
|
675
|
+
|
676
|
+
return conditioning
|
677
|
+
|
678
|
+
|
629
679
|
class HunyuanDiTAttentionPool(nn.Module):
|
630
680
|
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
631
681
|
|
@@ -1001,6 +1051,8 @@ class PixArtAlphaTextProjection(nn.Module):
|
|
1001
1051
|
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
1002
1052
|
if act_fn == "gelu_tanh":
|
1003
1053
|
self.act_1 = nn.GELU(approximate="tanh")
|
1054
|
+
elif act_fn == "silu":
|
1055
|
+
self.act_1 = nn.SiLU()
|
1004
1056
|
elif act_fn == "silu_fp32":
|
1005
1057
|
self.act_1 = FP32SiLU()
|
1006
1058
|
else:
|
@@ -1014,6 +1066,39 @@ class PixArtAlphaTextProjection(nn.Module):
|
|
1014
1066
|
return hidden_states
|
1015
1067
|
|
1016
1068
|
|
1069
|
+
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
1070
|
+
def __init__(
|
1071
|
+
self,
|
1072
|
+
embed_dims: int = 768,
|
1073
|
+
dim_head: int = 64,
|
1074
|
+
heads: int = 16,
|
1075
|
+
ffn_ratio: float = 4,
|
1076
|
+
) -> None:
|
1077
|
+
super().__init__()
|
1078
|
+
from .attention import FeedForward
|
1079
|
+
|
1080
|
+
self.ln0 = nn.LayerNorm(embed_dims)
|
1081
|
+
self.ln1 = nn.LayerNorm(embed_dims)
|
1082
|
+
self.attn = Attention(
|
1083
|
+
query_dim=embed_dims,
|
1084
|
+
dim_head=dim_head,
|
1085
|
+
heads=heads,
|
1086
|
+
out_bias=False,
|
1087
|
+
)
|
1088
|
+
self.ff = nn.Sequential(
|
1089
|
+
nn.LayerNorm(embed_dims),
|
1090
|
+
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
1091
|
+
)
|
1092
|
+
|
1093
|
+
def forward(self, x, latents, residual):
|
1094
|
+
encoder_hidden_states = self.ln0(x)
|
1095
|
+
latents = self.ln1(latents)
|
1096
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
1097
|
+
latents = self.attn(latents, encoder_hidden_states) + residual
|
1098
|
+
latents = self.ff(latents) + latents
|
1099
|
+
return latents
|
1100
|
+
|
1101
|
+
|
1017
1102
|
class IPAdapterPlusImageProjection(nn.Module):
|
1018
1103
|
"""Resampler of IP-Adapter Plus.
|
1019
1104
|
|
@@ -1042,8 +1127,6 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
1042
1127
|
ffn_ratio: float = 4,
|
1043
1128
|
) -> None:
|
1044
1129
|
super().__init__()
|
1045
|
-
from .attention import FeedForward # Lazy import to avoid circular import
|
1046
|
-
|
1047
1130
|
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
|
1048
1131
|
|
1049
1132
|
self.proj_in = nn.Linear(embed_dims, hidden_dims)
|
@@ -1051,26 +1134,9 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
1051
1134
|
self.proj_out = nn.Linear(hidden_dims, output_dims)
|
1052
1135
|
self.norm_out = nn.LayerNorm(output_dims)
|
1053
1136
|
|
1054
|
-
self.layers = nn.ModuleList(
|
1055
|
-
|
1056
|
-
|
1057
|
-
nn.ModuleList(
|
1058
|
-
[
|
1059
|
-
nn.LayerNorm(hidden_dims),
|
1060
|
-
nn.LayerNorm(hidden_dims),
|
1061
|
-
Attention(
|
1062
|
-
query_dim=hidden_dims,
|
1063
|
-
dim_head=dim_head,
|
1064
|
-
heads=heads,
|
1065
|
-
out_bias=False,
|
1066
|
-
),
|
1067
|
-
nn.Sequential(
|
1068
|
-
nn.LayerNorm(hidden_dims),
|
1069
|
-
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
1070
|
-
),
|
1071
|
-
]
|
1072
|
-
)
|
1073
|
-
)
|
1137
|
+
self.layers = nn.ModuleList(
|
1138
|
+
[IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
1139
|
+
)
|
1074
1140
|
|
1075
1141
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1076
1142
|
"""Forward pass.
|
@@ -1084,52 +1150,14 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
1084
1150
|
|
1085
1151
|
x = self.proj_in(x)
|
1086
1152
|
|
1087
|
-
for
|
1153
|
+
for block in self.layers:
|
1088
1154
|
residual = latents
|
1089
|
-
|
1090
|
-
encoder_hidden_states = ln0(x)
|
1091
|
-
latents = ln1(latents)
|
1092
|
-
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
1093
|
-
latents = attn(latents, encoder_hidden_states) + residual
|
1094
|
-
latents = ff(latents) + latents
|
1155
|
+
latents = block(x, latents, residual)
|
1095
1156
|
|
1096
1157
|
latents = self.proj_out(latents)
|
1097
1158
|
return self.norm_out(latents)
|
1098
1159
|
|
1099
1160
|
|
1100
|
-
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
1101
|
-
def __init__(
|
1102
|
-
self,
|
1103
|
-
embed_dims: int = 768,
|
1104
|
-
dim_head: int = 64,
|
1105
|
-
heads: int = 16,
|
1106
|
-
ffn_ratio: float = 4,
|
1107
|
-
) -> None:
|
1108
|
-
super().__init__()
|
1109
|
-
from .attention import FeedForward
|
1110
|
-
|
1111
|
-
self.ln0 = nn.LayerNorm(embed_dims)
|
1112
|
-
self.ln1 = nn.LayerNorm(embed_dims)
|
1113
|
-
self.attn = Attention(
|
1114
|
-
query_dim=embed_dims,
|
1115
|
-
dim_head=dim_head,
|
1116
|
-
heads=heads,
|
1117
|
-
out_bias=False,
|
1118
|
-
)
|
1119
|
-
self.ff = nn.Sequential(
|
1120
|
-
nn.LayerNorm(embed_dims),
|
1121
|
-
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
1122
|
-
)
|
1123
|
-
|
1124
|
-
def forward(self, x, latents, residual):
|
1125
|
-
encoder_hidden_states = self.ln0(x)
|
1126
|
-
latents = self.ln1(latents)
|
1127
|
-
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
1128
|
-
latents = self.attn(latents, encoder_hidden_states) + residual
|
1129
|
-
latents = self.ff(latents) + latents
|
1130
|
-
return latents
|
1131
|
-
|
1132
|
-
|
1133
1161
|
class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
1134
1162
|
"""FacePerceiverResampler of IP-Adapter Plus.
|
1135
1163
|
|
@@ -18,13 +18,19 @@ import importlib
|
|
18
18
|
import inspect
|
19
19
|
import os
|
20
20
|
from collections import OrderedDict
|
21
|
+
from pathlib import Path
|
21
22
|
from typing import List, Optional, Union
|
22
23
|
|
23
24
|
import safetensors
|
24
25
|
import torch
|
26
|
+
from huggingface_hub.utils import EntryNotFoundError
|
25
27
|
|
26
28
|
from ..utils import (
|
29
|
+
SAFE_WEIGHTS_INDEX_NAME,
|
27
30
|
SAFETENSORS_FILE_EXTENSION,
|
31
|
+
WEIGHTS_INDEX_NAME,
|
32
|
+
_add_variant,
|
33
|
+
_get_model_file,
|
28
34
|
is_accelerate_available,
|
29
35
|
is_torch_version,
|
30
36
|
logging,
|
@@ -175,3 +181,52 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
|
|
175
181
|
load(model_to_load)
|
176
182
|
|
177
183
|
return error_msgs
|
184
|
+
|
185
|
+
|
186
|
+
def _fetch_index_file(
|
187
|
+
is_local,
|
188
|
+
pretrained_model_name_or_path,
|
189
|
+
subfolder,
|
190
|
+
use_safetensors,
|
191
|
+
cache_dir,
|
192
|
+
variant,
|
193
|
+
force_download,
|
194
|
+
resume_download,
|
195
|
+
proxies,
|
196
|
+
local_files_only,
|
197
|
+
token,
|
198
|
+
revision,
|
199
|
+
user_agent,
|
200
|
+
commit_hash,
|
201
|
+
):
|
202
|
+
if is_local:
|
203
|
+
index_file = Path(
|
204
|
+
pretrained_model_name_or_path,
|
205
|
+
subfolder or "",
|
206
|
+
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
|
207
|
+
)
|
208
|
+
else:
|
209
|
+
index_file_in_repo = Path(
|
210
|
+
subfolder or "",
|
211
|
+
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
|
212
|
+
).as_posix()
|
213
|
+
try:
|
214
|
+
index_file = _get_model_file(
|
215
|
+
pretrained_model_name_or_path,
|
216
|
+
weights_name=index_file_in_repo,
|
217
|
+
cache_dir=cache_dir,
|
218
|
+
force_download=force_download,
|
219
|
+
resume_download=resume_download,
|
220
|
+
proxies=proxies,
|
221
|
+
local_files_only=local_files_only,
|
222
|
+
token=token,
|
223
|
+
revision=revision,
|
224
|
+
subfolder=subfolder,
|
225
|
+
user_agent=user_agent,
|
226
|
+
commit_hash=commit_hash,
|
227
|
+
)
|
228
|
+
index_file = Path(index_file)
|
229
|
+
except (EntryNotFoundError, EnvironmentError):
|
230
|
+
index_file = None
|
231
|
+
|
232
|
+
return index_file
|