diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
752
752
|
condition = self.controlnet_cond_embedding(cond)
|
753
753
|
feat_seq = torch.mean(condition, dim=(2, 3))
|
754
754
|
feat_seq = feat_seq + self.task_embedding[control_idx]
|
755
|
-
if from_multi:
|
755
|
+
if from_multi or len(control_type_idx) == 1:
|
756
756
|
inputs.append(feat_seq.unsqueeze(1))
|
757
757
|
condition_list.append(condition)
|
758
758
|
else:
|
@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
772
772
|
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
|
773
773
|
alpha = self.spatial_ch_projs(x[:, idx])
|
774
774
|
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
775
|
-
if from_multi:
|
775
|
+
if from_multi or len(control_type_idx) == 1:
|
776
776
|
controlnet_cond_fuser += condition + alpha
|
777
777
|
else:
|
778
778
|
controlnet_cond_fuser += condition + alpha * scale
|
@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
819
819
|
# 6. scaling
|
820
820
|
if guess_mode and not self.config.global_pool_conditions:
|
821
821
|
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
822
|
-
if from_multi:
|
822
|
+
if from_multi or len(control_type_idx) == 1:
|
823
823
|
scales = scales * conditioning_scale[0]
|
824
824
|
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
825
825
|
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
826
|
-
elif from_multi:
|
826
|
+
elif from_multi or len(control_type_idx) == 1:
|
827
827
|
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
|
828
828
|
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
|
829
829
|
|
diffusers/models/embeddings.py
CHANGED
@@ -319,7 +319,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
|
|
319
319
|
return emb
|
320
320
|
|
321
321
|
|
322
|
-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
322
|
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
|
323
323
|
"""
|
324
324
|
This function generates 1D positional embeddings from a grid.
|
325
325
|
|
@@ -352,6 +352,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
|
352
352
|
emb_cos = torch.cos(out) # (M, D/2)
|
353
353
|
|
354
354
|
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
|
355
|
+
|
356
|
+
# flip sine and cosine embeddings
|
357
|
+
if flip_sin_to_cos:
|
358
|
+
emb = torch.cat([emb[:, embed_dim // 2 :], emb[:, : embed_dim // 2]], dim=1)
|
359
|
+
|
355
360
|
return emb
|
356
361
|
|
357
362
|
|
@@ -1176,6 +1181,7 @@ def apply_rotary_emb(
|
|
1176
1181
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
1177
1182
|
use_real: bool = True,
|
1178
1183
|
use_real_unbind_dim: int = -1,
|
1184
|
+
sequence_dim: int = 2,
|
1179
1185
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1180
1186
|
"""
|
1181
1187
|
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
@@ -1193,8 +1199,15 @@ def apply_rotary_emb(
|
|
1193
1199
|
"""
|
1194
1200
|
if use_real:
|
1195
1201
|
cos, sin = freqs_cis # [S, D]
|
1196
|
-
|
1197
|
-
|
1202
|
+
if sequence_dim == 2:
|
1203
|
+
cos = cos[None, None, :, :]
|
1204
|
+
sin = sin[None, None, :, :]
|
1205
|
+
elif sequence_dim == 1:
|
1206
|
+
cos = cos[None, :, None, :]
|
1207
|
+
sin = sin[None, :, None, :]
|
1208
|
+
else:
|
1209
|
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
1210
|
+
|
1198
1211
|
cos, sin = cos.to(x.device), sin.to(x.device)
|
1199
1212
|
|
1200
1213
|
if use_real_unbind_dim == -1:
|
@@ -1238,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
|
|
1238
1251
|
return x
|
1239
1252
|
|
1240
1253
|
|
1241
|
-
class FluxPosEmbed(nn.Module):
|
1242
|
-
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
1243
|
-
def __init__(self, theta: int, axes_dim: List[int]):
|
1244
|
-
super().__init__()
|
1245
|
-
self.theta = theta
|
1246
|
-
self.axes_dim = axes_dim
|
1247
|
-
|
1248
|
-
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
1249
|
-
n_axes = ids.shape[-1]
|
1250
|
-
cos_out = []
|
1251
|
-
sin_out = []
|
1252
|
-
pos = ids.float()
|
1253
|
-
is_mps = ids.device.type == "mps"
|
1254
|
-
is_npu = ids.device.type == "npu"
|
1255
|
-
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
1256
|
-
for i in range(n_axes):
|
1257
|
-
cos, sin = get_1d_rotary_pos_embed(
|
1258
|
-
self.axes_dim[i],
|
1259
|
-
pos[:, i],
|
1260
|
-
theta=self.theta,
|
1261
|
-
repeat_interleave_real=True,
|
1262
|
-
use_real=True,
|
1263
|
-
freqs_dtype=freqs_dtype,
|
1264
|
-
)
|
1265
|
-
cos_out.append(cos)
|
1266
|
-
sin_out.append(sin)
|
1267
|
-
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
1268
|
-
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
1269
|
-
return freqs_cos, freqs_sin
|
1270
|
-
|
1271
|
-
|
1272
1254
|
class TimestepEmbedding(nn.Module):
|
1273
1255
|
def __init__(
|
1274
1256
|
self,
|
@@ -2619,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module):
|
|
2619
2601
|
projected_image_embeds.append(image_embed)
|
2620
2602
|
|
2621
2603
|
return projected_image_embeds
|
2604
|
+
|
2605
|
+
|
2606
|
+
class FluxPosEmbed(nn.Module):
|
2607
|
+
def __new__(cls, *args, **kwargs):
|
2608
|
+
deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
|
2609
|
+
deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
|
2610
|
+
|
2611
|
+
from .transformers.transformer_flux import FluxPosEmbed
|
2612
|
+
|
2613
|
+
return FluxPosEmbed(*args, **kwargs)
|
@@ -14,11 +14,13 @@
|
|
14
14
|
# See the License for the specific language governing permissions and
|
15
15
|
# limitations under the License.
|
16
16
|
|
17
|
+
import functools
|
17
18
|
import importlib
|
18
19
|
import inspect
|
19
20
|
import os
|
20
21
|
from array import array
|
21
|
-
from collections import OrderedDict
|
22
|
+
from collections import OrderedDict, defaultdict
|
23
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
22
24
|
from pathlib import Path
|
23
25
|
from typing import Dict, List, Optional, Union
|
24
26
|
from zipfile import is_zipfile
|
@@ -30,6 +32,7 @@ from huggingface_hub.utils import EntryNotFoundError
|
|
30
32
|
|
31
33
|
from ..quantizers import DiffusersQuantizer
|
32
34
|
from ..utils import (
|
35
|
+
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
|
33
36
|
GGUF_FILE_EXTENSION,
|
34
37
|
SAFE_WEIGHTS_INDEX_NAME,
|
35
38
|
SAFETENSORS_FILE_EXTENSION,
|
@@ -38,6 +41,7 @@ from ..utils import (
|
|
38
41
|
_get_model_file,
|
39
42
|
deprecate,
|
40
43
|
is_accelerate_available,
|
44
|
+
is_accelerate_version,
|
41
45
|
is_gguf_available,
|
42
46
|
is_torch_available,
|
43
47
|
is_torch_version,
|
@@ -252,6 +256,10 @@ def load_model_dict_into_meta(
|
|
252
256
|
param = param.to(dtype)
|
253
257
|
set_module_kwargs["dtype"] = dtype
|
254
258
|
|
259
|
+
if is_accelerate_version(">", "1.8.1"):
|
260
|
+
set_module_kwargs["non_blocking"] = True
|
261
|
+
set_module_kwargs["clear_cache"] = False
|
262
|
+
|
255
263
|
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
256
264
|
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
257
265
|
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
@@ -304,6 +312,161 @@ def load_model_dict_into_meta(
|
|
304
312
|
return offload_index, state_dict_index
|
305
313
|
|
306
314
|
|
315
|
+
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
|
316
|
+
"""
|
317
|
+
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
|
318
|
+
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
|
319
|
+
parameters.
|
320
|
+
|
321
|
+
"""
|
322
|
+
if model_to_load.device.type == "meta":
|
323
|
+
return False
|
324
|
+
|
325
|
+
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
|
326
|
+
return False
|
327
|
+
|
328
|
+
# Some models explicitly do not support param buffer assignment
|
329
|
+
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
|
330
|
+
logger.debug(
|
331
|
+
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
|
332
|
+
)
|
333
|
+
return False
|
334
|
+
|
335
|
+
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
|
336
|
+
first_key = next(iter(model_to_load.state_dict().keys()))
|
337
|
+
if start_prefix + first_key in state_dict:
|
338
|
+
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
|
339
|
+
|
340
|
+
return False
|
341
|
+
|
342
|
+
|
343
|
+
def _load_shard_file(
|
344
|
+
shard_file,
|
345
|
+
model,
|
346
|
+
model_state_dict,
|
347
|
+
device_map=None,
|
348
|
+
dtype=None,
|
349
|
+
hf_quantizer=None,
|
350
|
+
keep_in_fp32_modules=None,
|
351
|
+
dduf_entries=None,
|
352
|
+
loaded_keys=None,
|
353
|
+
unexpected_keys=None,
|
354
|
+
offload_index=None,
|
355
|
+
offload_folder=None,
|
356
|
+
state_dict_index=None,
|
357
|
+
state_dict_folder=None,
|
358
|
+
ignore_mismatched_sizes=False,
|
359
|
+
low_cpu_mem_usage=False,
|
360
|
+
):
|
361
|
+
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
362
|
+
mismatched_keys = _find_mismatched_keys(
|
363
|
+
state_dict,
|
364
|
+
model_state_dict,
|
365
|
+
loaded_keys,
|
366
|
+
ignore_mismatched_sizes,
|
367
|
+
)
|
368
|
+
error_msgs = []
|
369
|
+
if low_cpu_mem_usage:
|
370
|
+
offload_index, state_dict_index = load_model_dict_into_meta(
|
371
|
+
model,
|
372
|
+
state_dict,
|
373
|
+
device_map=device_map,
|
374
|
+
dtype=dtype,
|
375
|
+
hf_quantizer=hf_quantizer,
|
376
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
377
|
+
unexpected_keys=unexpected_keys,
|
378
|
+
offload_folder=offload_folder,
|
379
|
+
offload_index=offload_index,
|
380
|
+
state_dict_index=state_dict_index,
|
381
|
+
state_dict_folder=state_dict_folder,
|
382
|
+
)
|
383
|
+
else:
|
384
|
+
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
|
385
|
+
|
386
|
+
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
|
387
|
+
return offload_index, state_dict_index, mismatched_keys, error_msgs
|
388
|
+
|
389
|
+
|
390
|
+
def _load_shard_files_with_threadpool(
|
391
|
+
shard_files,
|
392
|
+
model,
|
393
|
+
model_state_dict,
|
394
|
+
device_map=None,
|
395
|
+
dtype=None,
|
396
|
+
hf_quantizer=None,
|
397
|
+
keep_in_fp32_modules=None,
|
398
|
+
dduf_entries=None,
|
399
|
+
loaded_keys=None,
|
400
|
+
unexpected_keys=None,
|
401
|
+
offload_index=None,
|
402
|
+
offload_folder=None,
|
403
|
+
state_dict_index=None,
|
404
|
+
state_dict_folder=None,
|
405
|
+
ignore_mismatched_sizes=False,
|
406
|
+
low_cpu_mem_usage=False,
|
407
|
+
):
|
408
|
+
# Do not spawn anymore workers than you need
|
409
|
+
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
|
410
|
+
|
411
|
+
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
|
412
|
+
|
413
|
+
error_msgs = []
|
414
|
+
mismatched_keys = []
|
415
|
+
|
416
|
+
load_one = functools.partial(
|
417
|
+
_load_shard_file,
|
418
|
+
model=model,
|
419
|
+
model_state_dict=model_state_dict,
|
420
|
+
device_map=device_map,
|
421
|
+
dtype=dtype,
|
422
|
+
hf_quantizer=hf_quantizer,
|
423
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
424
|
+
dduf_entries=dduf_entries,
|
425
|
+
loaded_keys=loaded_keys,
|
426
|
+
unexpected_keys=unexpected_keys,
|
427
|
+
offload_index=offload_index,
|
428
|
+
offload_folder=offload_folder,
|
429
|
+
state_dict_index=state_dict_index,
|
430
|
+
state_dict_folder=state_dict_folder,
|
431
|
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
432
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
433
|
+
)
|
434
|
+
|
435
|
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
436
|
+
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
|
437
|
+
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
|
438
|
+
for future in as_completed(futures):
|
439
|
+
result = future.result()
|
440
|
+
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
|
441
|
+
error_msgs += _error_msgs
|
442
|
+
mismatched_keys += _mismatched_keys
|
443
|
+
pbar.update(1)
|
444
|
+
|
445
|
+
return offload_index, state_dict_index, mismatched_keys, error_msgs
|
446
|
+
|
447
|
+
|
448
|
+
def _find_mismatched_keys(
|
449
|
+
state_dict,
|
450
|
+
model_state_dict,
|
451
|
+
loaded_keys,
|
452
|
+
ignore_mismatched_sizes,
|
453
|
+
):
|
454
|
+
mismatched_keys = []
|
455
|
+
if ignore_mismatched_sizes:
|
456
|
+
for checkpoint_key in loaded_keys:
|
457
|
+
model_key = checkpoint_key
|
458
|
+
# If the checkpoint is sharded, we may not have the key here.
|
459
|
+
if checkpoint_key not in state_dict:
|
460
|
+
continue
|
461
|
+
|
462
|
+
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
|
463
|
+
mismatched_keys.append(
|
464
|
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
465
|
+
)
|
466
|
+
del state_dict[checkpoint_key]
|
467
|
+
return mismatched_keys
|
468
|
+
|
469
|
+
|
307
470
|
def _load_state_dict_into_model(
|
308
471
|
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
309
472
|
) -> List[str]:
|
@@ -520,3 +683,72 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
|
520
683
|
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
|
521
684
|
|
522
685
|
return parsed_parameters
|
686
|
+
|
687
|
+
|
688
|
+
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
|
689
|
+
mismatched_keys = []
|
690
|
+
if not ignore_mismatched_sizes:
|
691
|
+
return mismatched_keys
|
692
|
+
for checkpoint_key in loaded_keys:
|
693
|
+
model_key = checkpoint_key
|
694
|
+
# If the checkpoint is sharded, we may not have the key here.
|
695
|
+
if checkpoint_key not in state_dict:
|
696
|
+
continue
|
697
|
+
|
698
|
+
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
|
699
|
+
mismatched_keys.append(
|
700
|
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
701
|
+
)
|
702
|
+
del state_dict[checkpoint_key]
|
703
|
+
return mismatched_keys
|
704
|
+
|
705
|
+
|
706
|
+
def _expand_device_map(device_map, param_names):
|
707
|
+
"""
|
708
|
+
Expand a device map to return the correspondence parameter name to device.
|
709
|
+
"""
|
710
|
+
new_device_map = {}
|
711
|
+
for module, device in device_map.items():
|
712
|
+
new_device_map.update(
|
713
|
+
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
714
|
+
)
|
715
|
+
return new_device_map
|
716
|
+
|
717
|
+
|
718
|
+
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
|
719
|
+
def _caching_allocator_warmup(
|
720
|
+
model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
|
721
|
+
) -> None:
|
722
|
+
"""
|
723
|
+
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
724
|
+
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
|
725
|
+
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
|
726
|
+
very large margin.
|
727
|
+
"""
|
728
|
+
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
|
729
|
+
|
730
|
+
# Keep only accelerator devices
|
731
|
+
accelerator_device_map = {
|
732
|
+
param: torch.device(device)
|
733
|
+
for param, device in expanded_device_map.items()
|
734
|
+
if str(device) not in ["cpu", "disk"]
|
735
|
+
}
|
736
|
+
if not accelerator_device_map:
|
737
|
+
return
|
738
|
+
|
739
|
+
elements_per_device = defaultdict(int)
|
740
|
+
for param_name, device in accelerator_device_map.items():
|
741
|
+
try:
|
742
|
+
p = model.get_parameter(param_name)
|
743
|
+
except AttributeError:
|
744
|
+
try:
|
745
|
+
p = model.get_buffer(param_name)
|
746
|
+
except AttributeError:
|
747
|
+
raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
|
748
|
+
# TODO: account for TP when needed.
|
749
|
+
elements_per_device[device] += p.numel()
|
750
|
+
|
751
|
+
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
752
|
+
for device, elem_count in elements_per_device.items():
|
753
|
+
warmup_elems = max(1, elem_count // factor)
|
754
|
+
_ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)
|
@@ -369,8 +369,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
369
369
|
raise EnvironmentError(
|
370
370
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
371
371
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
372
|
-
"token having permission to this repo with `token` or log in with `
|
373
|
-
"login`."
|
372
|
+
"token having permission to this repo with `token` or log in with `hf auth login`."
|
374
373
|
)
|
375
374
|
except RevisionNotFoundError:
|
376
375
|
raise EnvironmentError(
|