diffusers 0.28.2__py3-none-any.whl → 0.29.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +15 -1
- diffusers/commands/env.py +1 -5
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +2 -1
- diffusers/loaders/__init__.py +2 -2
- diffusers/loaders/lora.py +406 -140
- diffusers/loaders/lora_conversion_utils.py +7 -1
- diffusers/loaders/single_file.py +13 -1
- diffusers/loaders/single_file_model.py +15 -8
- diffusers/loaders/single_file_utils.py +267 -17
- diffusers/loaders/unet.py +307 -272
- diffusers/models/__init__.py +7 -3
- diffusers/models/attention.py +125 -1
- diffusers/models/attention_processor.py +169 -1
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +17 -6
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet_sd3.py +418 -0
- diffusers/models/controlnet_xs.py +6 -6
- diffusers/models/embeddings.py +112 -84
- diffusers/models/model_loading_utils.py +55 -0
- diffusers/models/modeling_utils.py +138 -20
- diffusers/models/normalization.py +11 -6
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/dual_transformer_2d.py +5 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
- diffusers/models/transformers/prior_transformer.py +5 -5
- diffusers/models/transformers/transformer_2d.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +353 -0
- diffusers/models/transformers/transformer_temporal.py +12 -10
- diffusers/models/unets/unet_1d.py +3 -3
- diffusers/models/unets/unet_2d.py +3 -3
- diffusers/models/unets/unet_2d_condition.py +4 -15
- diffusers/models/unets/unet_3d_condition.py +5 -17
- diffusers/models/unets/unet_i2vgen_xl.py +4 -4
- diffusers/models/unets/unet_motion_model.py +4 -4
- diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
- diffusers/models/vq_model.py +8 -165
- diffusers/pipelines/__init__.py +11 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
- diffusers/pipelines/auto_pipeline.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +53 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1062 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
- diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
- diffusers/pipelines/pia/pipeline_pia.py +4 -3
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +904 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +941 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
- diffusers/schedulers/__init__.py +2 -0
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
- diffusers/schedulers/scheduling_edm_euler.py +2 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/training_utils.py +4 -4
- diffusers/utils/__init__.py +3 -0
- diffusers/utils/constants.py +2 -0
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
- diffusers/utils/dynamic_modules_utils.py +15 -13
- diffusers/utils/hub_utils.py +106 -0
- diffusers/utils/import_utils.py +0 -1
- diffusers/utils/logging.py +3 -1
- diffusers/utils/state_dict_utils.py +2 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/METADATA +3 -3
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/RECORD +112 -112
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/WHEEL +1 -1
- diffusers/models/dual_transformer_2d.py +0 -20
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/LICENSE +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/top_level.txt +0 -0
diffusers/models/embeddings.py
CHANGED
@@ -123,7 +123,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
123
123
|
|
124
124
|
|
125
125
|
class PatchEmbed(nn.Module):
|
126
|
-
"""2D Image to Patch Embedding"""
|
126
|
+
"""2D Image to Patch Embedding with support for SD3 cropping."""
|
127
127
|
|
128
128
|
def __init__(
|
129
129
|
self,
|
@@ -137,12 +137,14 @@ class PatchEmbed(nn.Module):
|
|
137
137
|
bias=True,
|
138
138
|
interpolation_scale=1,
|
139
139
|
pos_embed_type="sincos",
|
140
|
+
pos_embed_max_size=None, # For SD3 cropping
|
140
141
|
):
|
141
142
|
super().__init__()
|
142
143
|
|
143
144
|
num_patches = (height // patch_size) * (width // patch_size)
|
144
145
|
self.flatten = flatten
|
145
146
|
self.layer_norm = layer_norm
|
147
|
+
self.pos_embed_max_size = pos_embed_max_size
|
146
148
|
|
147
149
|
self.proj = nn.Conv2d(
|
148
150
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
@@ -153,26 +155,55 @@ class PatchEmbed(nn.Module):
|
|
153
155
|
self.norm = None
|
154
156
|
|
155
157
|
self.patch_size = patch_size
|
156
|
-
# See:
|
157
|
-
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
|
158
158
|
self.height, self.width = height // patch_size, width // patch_size
|
159
159
|
self.base_size = height // patch_size
|
160
160
|
self.interpolation_scale = interpolation_scale
|
161
|
+
|
162
|
+
# Calculate positional embeddings based on max size or default
|
163
|
+
if pos_embed_max_size:
|
164
|
+
grid_size = pos_embed_max_size
|
165
|
+
else:
|
166
|
+
grid_size = int(num_patches**0.5)
|
167
|
+
|
161
168
|
if pos_embed_type is None:
|
162
169
|
self.pos_embed = None
|
163
170
|
elif pos_embed_type == "sincos":
|
164
171
|
pos_embed = get_2d_sincos_pos_embed(
|
165
|
-
embed_dim,
|
166
|
-
int(num_patches**0.5),
|
167
|
-
base_size=self.base_size,
|
168
|
-
interpolation_scale=self.interpolation_scale,
|
172
|
+
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
169
173
|
)
|
170
|
-
|
174
|
+
persistent = True if pos_embed_max_size else False
|
175
|
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
|
171
176
|
else:
|
172
177
|
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
173
178
|
|
179
|
+
def cropped_pos_embed(self, height, width):
|
180
|
+
"""Crops positional embeddings for SD3 compatibility."""
|
181
|
+
if self.pos_embed_max_size is None:
|
182
|
+
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
183
|
+
|
184
|
+
height = height // self.patch_size
|
185
|
+
width = width // self.patch_size
|
186
|
+
if height > self.pos_embed_max_size:
|
187
|
+
raise ValueError(
|
188
|
+
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
189
|
+
)
|
190
|
+
if width > self.pos_embed_max_size:
|
191
|
+
raise ValueError(
|
192
|
+
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
193
|
+
)
|
194
|
+
|
195
|
+
top = (self.pos_embed_max_size - height) // 2
|
196
|
+
left = (self.pos_embed_max_size - width) // 2
|
197
|
+
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
198
|
+
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
199
|
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
200
|
+
return spatial_pos_embed
|
201
|
+
|
174
202
|
def forward(self, latent):
|
175
|
-
|
203
|
+
if self.pos_embed_max_size is not None:
|
204
|
+
height, width = latent.shape[-2:]
|
205
|
+
else:
|
206
|
+
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
176
207
|
|
177
208
|
latent = self.proj(latent)
|
178
209
|
if self.flatten:
|
@@ -181,20 +212,20 @@ class PatchEmbed(nn.Module):
|
|
181
212
|
latent = self.norm(latent)
|
182
213
|
if self.pos_embed is None:
|
183
214
|
return latent.to(latent.dtype)
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
if self.height != height or self.width != width:
|
188
|
-
pos_embed = get_2d_sincos_pos_embed(
|
189
|
-
embed_dim=self.pos_embed.shape[-1],
|
190
|
-
grid_size=(height, width),
|
191
|
-
base_size=self.base_size,
|
192
|
-
interpolation_scale=self.interpolation_scale,
|
193
|
-
)
|
194
|
-
pos_embed = torch.from_numpy(pos_embed)
|
195
|
-
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
|
215
|
+
# Interpolate or crop positional embeddings as needed
|
216
|
+
if self.pos_embed_max_size:
|
217
|
+
pos_embed = self.cropped_pos_embed(height, width)
|
196
218
|
else:
|
197
|
-
|
219
|
+
if self.height != height or self.width != width:
|
220
|
+
pos_embed = get_2d_sincos_pos_embed(
|
221
|
+
embed_dim=self.pos_embed.shape[-1],
|
222
|
+
grid_size=(height, width),
|
223
|
+
base_size=self.base_size,
|
224
|
+
interpolation_scale=self.interpolation_scale,
|
225
|
+
)
|
226
|
+
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
|
227
|
+
else:
|
228
|
+
pos_embed = self.pos_embed
|
198
229
|
|
199
230
|
return (latent + pos_embed).to(latent.dtype)
|
200
231
|
|
@@ -626,6 +657,25 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
|
|
626
657
|
return conditioning
|
627
658
|
|
628
659
|
|
660
|
+
class CombinedTimestepTextProjEmbeddings(nn.Module):
|
661
|
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
662
|
+
super().__init__()
|
663
|
+
|
664
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
665
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
666
|
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
667
|
+
|
668
|
+
def forward(self, timestep, pooled_projection):
|
669
|
+
timesteps_proj = self.time_proj(timestep)
|
670
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
671
|
+
|
672
|
+
pooled_projections = self.text_embedder(pooled_projection)
|
673
|
+
|
674
|
+
conditioning = timesteps_emb + pooled_projections
|
675
|
+
|
676
|
+
return conditioning
|
677
|
+
|
678
|
+
|
629
679
|
class HunyuanDiTAttentionPool(nn.Module):
|
630
680
|
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
631
681
|
|
@@ -1001,6 +1051,8 @@ class PixArtAlphaTextProjection(nn.Module):
|
|
1001
1051
|
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
1002
1052
|
if act_fn == "gelu_tanh":
|
1003
1053
|
self.act_1 = nn.GELU(approximate="tanh")
|
1054
|
+
elif act_fn == "silu":
|
1055
|
+
self.act_1 = nn.SiLU()
|
1004
1056
|
elif act_fn == "silu_fp32":
|
1005
1057
|
self.act_1 = FP32SiLU()
|
1006
1058
|
else:
|
@@ -1014,6 +1066,39 @@ class PixArtAlphaTextProjection(nn.Module):
|
|
1014
1066
|
return hidden_states
|
1015
1067
|
|
1016
1068
|
|
1069
|
+
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
1070
|
+
def __init__(
|
1071
|
+
self,
|
1072
|
+
embed_dims: int = 768,
|
1073
|
+
dim_head: int = 64,
|
1074
|
+
heads: int = 16,
|
1075
|
+
ffn_ratio: float = 4,
|
1076
|
+
) -> None:
|
1077
|
+
super().__init__()
|
1078
|
+
from .attention import FeedForward
|
1079
|
+
|
1080
|
+
self.ln0 = nn.LayerNorm(embed_dims)
|
1081
|
+
self.ln1 = nn.LayerNorm(embed_dims)
|
1082
|
+
self.attn = Attention(
|
1083
|
+
query_dim=embed_dims,
|
1084
|
+
dim_head=dim_head,
|
1085
|
+
heads=heads,
|
1086
|
+
out_bias=False,
|
1087
|
+
)
|
1088
|
+
self.ff = nn.Sequential(
|
1089
|
+
nn.LayerNorm(embed_dims),
|
1090
|
+
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
1091
|
+
)
|
1092
|
+
|
1093
|
+
def forward(self, x, latents, residual):
|
1094
|
+
encoder_hidden_states = self.ln0(x)
|
1095
|
+
latents = self.ln1(latents)
|
1096
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
1097
|
+
latents = self.attn(latents, encoder_hidden_states) + residual
|
1098
|
+
latents = self.ff(latents) + latents
|
1099
|
+
return latents
|
1100
|
+
|
1101
|
+
|
1017
1102
|
class IPAdapterPlusImageProjection(nn.Module):
|
1018
1103
|
"""Resampler of IP-Adapter Plus.
|
1019
1104
|
|
@@ -1042,8 +1127,6 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
1042
1127
|
ffn_ratio: float = 4,
|
1043
1128
|
) -> None:
|
1044
1129
|
super().__init__()
|
1045
|
-
from .attention import FeedForward # Lazy import to avoid circular import
|
1046
|
-
|
1047
1130
|
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
|
1048
1131
|
|
1049
1132
|
self.proj_in = nn.Linear(embed_dims, hidden_dims)
|
@@ -1051,26 +1134,9 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
1051
1134
|
self.proj_out = nn.Linear(hidden_dims, output_dims)
|
1052
1135
|
self.norm_out = nn.LayerNorm(output_dims)
|
1053
1136
|
|
1054
|
-
self.layers = nn.ModuleList(
|
1055
|
-
|
1056
|
-
|
1057
|
-
nn.ModuleList(
|
1058
|
-
[
|
1059
|
-
nn.LayerNorm(hidden_dims),
|
1060
|
-
nn.LayerNorm(hidden_dims),
|
1061
|
-
Attention(
|
1062
|
-
query_dim=hidden_dims,
|
1063
|
-
dim_head=dim_head,
|
1064
|
-
heads=heads,
|
1065
|
-
out_bias=False,
|
1066
|
-
),
|
1067
|
-
nn.Sequential(
|
1068
|
-
nn.LayerNorm(hidden_dims),
|
1069
|
-
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
1070
|
-
),
|
1071
|
-
]
|
1072
|
-
)
|
1073
|
-
)
|
1137
|
+
self.layers = nn.ModuleList(
|
1138
|
+
[IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
1139
|
+
)
|
1074
1140
|
|
1075
1141
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1076
1142
|
"""Forward pass.
|
@@ -1084,52 +1150,14 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
1084
1150
|
|
1085
1151
|
x = self.proj_in(x)
|
1086
1152
|
|
1087
|
-
for
|
1153
|
+
for block in self.layers:
|
1088
1154
|
residual = latents
|
1089
|
-
|
1090
|
-
encoder_hidden_states = ln0(x)
|
1091
|
-
latents = ln1(latents)
|
1092
|
-
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
1093
|
-
latents = attn(latents, encoder_hidden_states) + residual
|
1094
|
-
latents = ff(latents) + latents
|
1155
|
+
latents = block(x, latents, residual)
|
1095
1156
|
|
1096
1157
|
latents = self.proj_out(latents)
|
1097
1158
|
return self.norm_out(latents)
|
1098
1159
|
|
1099
1160
|
|
1100
|
-
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
1101
|
-
def __init__(
|
1102
|
-
self,
|
1103
|
-
embed_dims: int = 768,
|
1104
|
-
dim_head: int = 64,
|
1105
|
-
heads: int = 16,
|
1106
|
-
ffn_ratio: float = 4,
|
1107
|
-
) -> None:
|
1108
|
-
super().__init__()
|
1109
|
-
from .attention import FeedForward
|
1110
|
-
|
1111
|
-
self.ln0 = nn.LayerNorm(embed_dims)
|
1112
|
-
self.ln1 = nn.LayerNorm(embed_dims)
|
1113
|
-
self.attn = Attention(
|
1114
|
-
query_dim=embed_dims,
|
1115
|
-
dim_head=dim_head,
|
1116
|
-
heads=heads,
|
1117
|
-
out_bias=False,
|
1118
|
-
)
|
1119
|
-
self.ff = nn.Sequential(
|
1120
|
-
nn.LayerNorm(embed_dims),
|
1121
|
-
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
1122
|
-
)
|
1123
|
-
|
1124
|
-
def forward(self, x, latents, residual):
|
1125
|
-
encoder_hidden_states = self.ln0(x)
|
1126
|
-
latents = self.ln1(latents)
|
1127
|
-
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
1128
|
-
latents = self.attn(latents, encoder_hidden_states) + residual
|
1129
|
-
latents = self.ff(latents) + latents
|
1130
|
-
return latents
|
1131
|
-
|
1132
|
-
|
1133
1161
|
class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
1134
1162
|
"""FacePerceiverResampler of IP-Adapter Plus.
|
1135
1163
|
|
@@ -18,13 +18,19 @@ import importlib
|
|
18
18
|
import inspect
|
19
19
|
import os
|
20
20
|
from collections import OrderedDict
|
21
|
+
from pathlib import Path
|
21
22
|
from typing import List, Optional, Union
|
22
23
|
|
23
24
|
import safetensors
|
24
25
|
import torch
|
26
|
+
from huggingface_hub.utils import EntryNotFoundError
|
25
27
|
|
26
28
|
from ..utils import (
|
29
|
+
SAFE_WEIGHTS_INDEX_NAME,
|
27
30
|
SAFETENSORS_FILE_EXTENSION,
|
31
|
+
WEIGHTS_INDEX_NAME,
|
32
|
+
_add_variant,
|
33
|
+
_get_model_file,
|
28
34
|
is_accelerate_available,
|
29
35
|
is_torch_version,
|
30
36
|
logging,
|
@@ -175,3 +181,52 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
|
|
175
181
|
load(model_to_load)
|
176
182
|
|
177
183
|
return error_msgs
|
184
|
+
|
185
|
+
|
186
|
+
def _fetch_index_file(
|
187
|
+
is_local,
|
188
|
+
pretrained_model_name_or_path,
|
189
|
+
subfolder,
|
190
|
+
use_safetensors,
|
191
|
+
cache_dir,
|
192
|
+
variant,
|
193
|
+
force_download,
|
194
|
+
resume_download,
|
195
|
+
proxies,
|
196
|
+
local_files_only,
|
197
|
+
token,
|
198
|
+
revision,
|
199
|
+
user_agent,
|
200
|
+
commit_hash,
|
201
|
+
):
|
202
|
+
if is_local:
|
203
|
+
index_file = Path(
|
204
|
+
pretrained_model_name_or_path,
|
205
|
+
subfolder or "",
|
206
|
+
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
|
207
|
+
)
|
208
|
+
else:
|
209
|
+
index_file_in_repo = Path(
|
210
|
+
subfolder or "",
|
211
|
+
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
|
212
|
+
).as_posix()
|
213
|
+
try:
|
214
|
+
index_file = _get_model_file(
|
215
|
+
pretrained_model_name_or_path,
|
216
|
+
weights_name=index_file_in_repo,
|
217
|
+
cache_dir=cache_dir,
|
218
|
+
force_download=force_download,
|
219
|
+
resume_download=resume_download,
|
220
|
+
proxies=proxies,
|
221
|
+
local_files_only=local_files_only,
|
222
|
+
token=token,
|
223
|
+
revision=revision,
|
224
|
+
subfolder=subfolder,
|
225
|
+
user_agent=user_agent,
|
226
|
+
commit_hash=commit_hash,
|
227
|
+
)
|
228
|
+
index_file = Path(index_file)
|
229
|
+
except (EntryNotFoundError, EnvironmentError):
|
230
|
+
index_file = None
|
231
|
+
|
232
|
+
return index_file
|
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
import inspect
|
18
18
|
import itertools
|
19
|
+
import json
|
19
20
|
import os
|
20
21
|
import re
|
21
22
|
from collections import OrderedDict
|
@@ -25,7 +26,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
|
25
26
|
|
26
27
|
import safetensors
|
27
28
|
import torch
|
28
|
-
from huggingface_hub import create_repo
|
29
|
+
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
|
29
30
|
from huggingface_hub.utils import validate_hf_hub_args
|
30
31
|
from torch import Tensor, nn
|
31
32
|
|
@@ -33,9 +34,12 @@ from .. import __version__
|
|
33
34
|
from ..utils import (
|
34
35
|
CONFIG_NAME,
|
35
36
|
FLAX_WEIGHTS_NAME,
|
37
|
+
SAFE_WEIGHTS_INDEX_NAME,
|
36
38
|
SAFETENSORS_WEIGHTS_NAME,
|
39
|
+
WEIGHTS_INDEX_NAME,
|
37
40
|
WEIGHTS_NAME,
|
38
41
|
_add_variant,
|
42
|
+
_get_checkpoint_shard_files,
|
39
43
|
_get_model_file,
|
40
44
|
deprecate,
|
41
45
|
is_accelerate_available,
|
@@ -49,6 +53,7 @@ from ..utils.hub_utils import (
|
|
49
53
|
)
|
50
54
|
from .model_loading_utils import (
|
51
55
|
_determine_device_map,
|
56
|
+
_fetch_index_file,
|
52
57
|
_load_state_dict_into_model,
|
53
58
|
load_model_dict_into_meta,
|
54
59
|
load_state_dict,
|
@@ -57,6 +62,8 @@ from .model_loading_utils import (
|
|
57
62
|
|
58
63
|
logger = logging.get_logger(__name__)
|
59
64
|
|
65
|
+
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
66
|
+
|
60
67
|
|
61
68
|
if is_torch_version(">=", "1.9.0"):
|
62
69
|
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
@@ -263,6 +270,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
263
270
|
save_function: Optional[Callable] = None,
|
264
271
|
safe_serialization: bool = True,
|
265
272
|
variant: Optional[str] = None,
|
273
|
+
max_shard_size: Union[int, str] = "10GB",
|
266
274
|
push_to_hub: bool = False,
|
267
275
|
**kwargs,
|
268
276
|
):
|
@@ -285,6 +293,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
285
293
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
286
294
|
variant (`str`, *optional*):
|
287
295
|
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
296
|
+
max_shard_size (`int` or `str`, defaults to `"10GB"`):
|
297
|
+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
298
|
+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
299
|
+
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
|
300
|
+
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
|
301
|
+
This is to establish a common default size for this argument across different libraries in the Hugging
|
302
|
+
Face ecosystem (`transformers`, and `accelerate`, for example).
|
288
303
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
289
304
|
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
290
305
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
@@ -296,6 +311,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
296
311
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
297
312
|
return
|
298
313
|
|
314
|
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
315
|
+
weights_name = _add_variant(weights_name, variant)
|
316
|
+
weight_name_split = weights_name.split(".")
|
317
|
+
if len(weight_name_split) in [2, 3]:
|
318
|
+
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
|
319
|
+
else:
|
320
|
+
raise ValueError(f"Invalid {weights_name} provided.")
|
321
|
+
|
299
322
|
os.makedirs(save_directory, exist_ok=True)
|
300
323
|
|
301
324
|
if push_to_hub:
|
@@ -317,18 +340,58 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
317
340
|
# Save the model
|
318
341
|
state_dict = model_to_save.state_dict()
|
319
342
|
|
320
|
-
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
321
|
-
weights_name = _add_variant(weights_name, variant)
|
322
|
-
|
323
343
|
# Save the model
|
324
|
-
|
325
|
-
|
326
|
-
|
344
|
+
state_dict_split = split_torch_state_dict_into_shards(
|
345
|
+
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
346
|
+
)
|
347
|
+
|
348
|
+
# Clean the folder from a previous save
|
349
|
+
if is_main_process:
|
350
|
+
for filename in os.listdir(save_directory):
|
351
|
+
if filename in state_dict_split.filename_to_tensors.keys():
|
352
|
+
continue
|
353
|
+
full_filename = os.path.join(save_directory, filename)
|
354
|
+
if not os.path.isfile(full_filename):
|
355
|
+
continue
|
356
|
+
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
357
|
+
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
358
|
+
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
359
|
+
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
360
|
+
if (
|
361
|
+
filename.startswith(weights_without_ext)
|
362
|
+
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
363
|
+
):
|
364
|
+
os.remove(full_filename)
|
365
|
+
|
366
|
+
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
367
|
+
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
368
|
+
filepath = os.path.join(save_directory, filename)
|
369
|
+
if safe_serialization:
|
370
|
+
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
371
|
+
# joyfulness), but for now this enough.
|
372
|
+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
373
|
+
else:
|
374
|
+
torch.save(shard, filepath)
|
375
|
+
|
376
|
+
if state_dict_split.is_sharded:
|
377
|
+
index = {
|
378
|
+
"metadata": state_dict_split.metadata,
|
379
|
+
"weight_map": state_dict_split.tensor_to_filename,
|
380
|
+
}
|
381
|
+
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
382
|
+
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
383
|
+
# Save the index as well
|
384
|
+
with open(save_index_file, "w", encoding="utf-8") as f:
|
385
|
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
386
|
+
f.write(content)
|
387
|
+
logger.info(
|
388
|
+
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
389
|
+
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
390
|
+
f"index located at {save_index_file}."
|
327
391
|
)
|
328
392
|
else:
|
329
|
-
|
330
|
-
|
331
|
-
logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
|
393
|
+
path_to_weights = os.path.join(save_directory, weights_name)
|
394
|
+
logger.info(f"Model weights saved in {path_to_weights}")
|
332
395
|
|
333
396
|
if push_to_hub:
|
334
397
|
# Create a new empty model card and eventually tag it
|
@@ -399,7 +462,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
399
462
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
400
463
|
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
401
464
|
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
402
|
-
same device.
|
465
|
+
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
403
466
|
|
404
467
|
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
405
468
|
more information about each option see [designing a device
|
@@ -566,6 +629,32 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
566
629
|
**kwargs,
|
567
630
|
)
|
568
631
|
|
632
|
+
# Determine if we're loading from a directory of sharded checkpoints.
|
633
|
+
is_sharded = False
|
634
|
+
index_file = None
|
635
|
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
636
|
+
index_file = _fetch_index_file(
|
637
|
+
is_local=is_local,
|
638
|
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
639
|
+
subfolder=subfolder or "",
|
640
|
+
use_safetensors=use_safetensors,
|
641
|
+
cache_dir=cache_dir,
|
642
|
+
variant=variant,
|
643
|
+
force_download=force_download,
|
644
|
+
resume_download=resume_download,
|
645
|
+
proxies=proxies,
|
646
|
+
local_files_only=local_files_only,
|
647
|
+
token=token,
|
648
|
+
revision=revision,
|
649
|
+
user_agent=user_agent,
|
650
|
+
commit_hash=commit_hash,
|
651
|
+
)
|
652
|
+
if index_file is not None and index_file.is_file():
|
653
|
+
is_sharded = True
|
654
|
+
|
655
|
+
if is_sharded and from_flax:
|
656
|
+
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
|
657
|
+
|
569
658
|
# load model
|
570
659
|
model_file = None
|
571
660
|
if from_flax:
|
@@ -590,7 +679,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
590
679
|
|
591
680
|
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
592
681
|
else:
|
593
|
-
if
|
682
|
+
if is_sharded:
|
683
|
+
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
684
|
+
pretrained_model_name_or_path,
|
685
|
+
index_file,
|
686
|
+
cache_dir=cache_dir,
|
687
|
+
proxies=proxies,
|
688
|
+
resume_download=resume_download,
|
689
|
+
local_files_only=local_files_only,
|
690
|
+
token=token,
|
691
|
+
user_agent=user_agent,
|
692
|
+
revision=revision,
|
693
|
+
subfolder=subfolder or "",
|
694
|
+
)
|
695
|
+
|
696
|
+
elif use_safetensors and not is_sharded:
|
594
697
|
try:
|
595
698
|
model_file = _get_model_file(
|
596
699
|
pretrained_model_name_or_path,
|
@@ -606,11 +709,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
606
709
|
user_agent=user_agent,
|
607
710
|
commit_hash=commit_hash,
|
608
711
|
)
|
712
|
+
|
609
713
|
except IOError as e:
|
714
|
+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
610
715
|
if not allow_pickle:
|
611
|
-
raise
|
612
|
-
|
613
|
-
|
716
|
+
raise
|
717
|
+
logger.warning(
|
718
|
+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
719
|
+
)
|
720
|
+
|
721
|
+
if model_file is None and not is_sharded:
|
614
722
|
model_file = _get_model_file(
|
615
723
|
pretrained_model_name_or_path,
|
616
724
|
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
@@ -632,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
632
740
|
model = cls.from_config(config, **unused_kwargs)
|
633
741
|
|
634
742
|
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
635
|
-
if device_map is None:
|
743
|
+
if device_map is None and not is_sharded:
|
636
744
|
param_device = "cpu"
|
637
745
|
state_dict = load_state_dict(model_file, variant=variant)
|
638
746
|
model._convert_deprecated_attention_blocks(state_dict)
|
@@ -666,17 +774,22 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
666
774
|
else: # else let accelerate handle loading and dispatching.
|
667
775
|
# Load weights and dispatch according to the device_map
|
668
776
|
# by default the device_map is None and the weights are loaded on the CPU
|
777
|
+
force_hook = True
|
669
778
|
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
|
779
|
+
if device_map is None and is_sharded:
|
780
|
+
# we load the parameters on the cpu
|
781
|
+
device_map = {"": "cpu"}
|
782
|
+
force_hook = False
|
670
783
|
try:
|
671
784
|
accelerate.load_checkpoint_and_dispatch(
|
672
785
|
model,
|
673
|
-
model_file,
|
786
|
+
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
674
787
|
device_map,
|
675
788
|
max_memory=max_memory,
|
676
789
|
offload_folder=offload_folder,
|
677
790
|
offload_state_dict=offload_state_dict,
|
678
791
|
dtype=torch_dtype,
|
679
|
-
force_hooks=
|
792
|
+
force_hooks=force_hook,
|
680
793
|
strict=True,
|
681
794
|
)
|
682
795
|
except AttributeError as e:
|
@@ -700,12 +813,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
700
813
|
model._temp_convert_self_to_deprecated_attention_blocks()
|
701
814
|
accelerate.load_checkpoint_and_dispatch(
|
702
815
|
model,
|
703
|
-
model_file,
|
816
|
+
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
704
817
|
device_map,
|
705
818
|
max_memory=max_memory,
|
706
819
|
offload_folder=offload_folder,
|
707
820
|
offload_state_dict=offload_state_dict,
|
708
821
|
dtype=torch_dtype,
|
822
|
+
force_hook=force_hook,
|
823
|
+
strict=True,
|
709
824
|
)
|
710
825
|
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
711
826
|
else:
|
@@ -1057,6 +1172,9 @@ class LegacyModelMixin(ModelMixin):
|
|
1057
1172
|
# To prevent depedency import problem.
|
1058
1173
|
from .model_loading_utils import _fetch_remapped_cls_from_config
|
1059
1174
|
|
1175
|
+
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
|
1176
|
+
kwargs_copy = kwargs.copy()
|
1177
|
+
|
1060
1178
|
cache_dir = kwargs.pop("cache_dir", None)
|
1061
1179
|
force_download = kwargs.pop("force_download", False)
|
1062
1180
|
resume_download = kwargs.pop("resume_download", None)
|
@@ -1094,4 +1212,4 @@ class LegacyModelMixin(ModelMixin):
|
|
1094
1212
|
# resolve remapping
|
1095
1213
|
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
1096
1214
|
|
1097
|
-
return remapped_class.from_pretrained(pretrained_model_name_or_path, **
|
1215
|
+
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|