diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
|
|
29
29
|
"""Returns the positional encoding (same as Tensor2Tensor).
|
30
30
|
|
31
31
|
Args:
|
32
|
-
timesteps
|
33
|
-
|
34
|
-
embedding_dim:
|
35
|
-
|
36
|
-
|
32
|
+
timesteps (`jnp.ndarray` of shape `(N,)`):
|
33
|
+
A 1-D array of N indices, one per batch element. These may be fractional.
|
34
|
+
embedding_dim (`int`):
|
35
|
+
The number of output channels.
|
36
|
+
freq_shift (`float`, *optional*, defaults to `1`):
|
37
|
+
Shift applied to the frequency scaling of the embeddings.
|
38
|
+
min_timescale (`float`, *optional*, defaults to `1`):
|
39
|
+
The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
|
40
|
+
max_timescale (`float`, *optional*, defaults to `1.0e4`):
|
41
|
+
The largest time unit used in the sinusoidal calculation.
|
42
|
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
43
|
+
Whether to flip the order of sinusoidal components to cosine first.
|
44
|
+
scale (`float`, *optional*, defaults to `1.0`):
|
45
|
+
A scaling factor applied to the positional embeddings.
|
46
|
+
|
37
47
|
Returns:
|
38
48
|
a Tensor of timing signals [N, num_channels]
|
39
49
|
"""
|
@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
|
|
61
71
|
|
62
72
|
Args:
|
63
73
|
time_embed_dim (`int`, *optional*, defaults to `32`):
|
64
|
-
|
65
|
-
dtype (
|
66
|
-
|
74
|
+
Time step embedding dimension.
|
75
|
+
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
76
|
+
The data type for the embedding parameters.
|
67
77
|
"""
|
68
78
|
|
69
79
|
time_embed_dim: int = 32
|
@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
|
|
83
93
|
|
84
94
|
Args:
|
85
95
|
dim (`int`, *optional*, defaults to `32`):
|
86
|
-
|
96
|
+
Time step embedding dimension.
|
97
|
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
98
|
+
Whether to flip the sinusoidal function from sine to cosine.
|
99
|
+
freq_shift (`float`, *optional*, defaults to `1`):
|
100
|
+
Frequency shift applied to the sinusoidal embeddings.
|
87
101
|
"""
|
88
102
|
|
89
103
|
dim: int = 32
|
@@ -17,6 +17,7 @@
|
|
17
17
|
import importlib
|
18
18
|
import inspect
|
19
19
|
import os
|
20
|
+
from array import array
|
20
21
|
from collections import OrderedDict
|
21
22
|
from pathlib import Path
|
22
23
|
from typing import List, Optional, Union
|
@@ -26,12 +27,16 @@ import torch
|
|
26
27
|
from huggingface_hub.utils import EntryNotFoundError
|
27
28
|
|
28
29
|
from ..utils import (
|
30
|
+
GGUF_FILE_EXTENSION,
|
29
31
|
SAFE_WEIGHTS_INDEX_NAME,
|
30
32
|
SAFETENSORS_FILE_EXTENSION,
|
31
33
|
WEIGHTS_INDEX_NAME,
|
32
34
|
_add_variant,
|
33
35
|
_get_model_file,
|
36
|
+
deprecate,
|
34
37
|
is_accelerate_available,
|
38
|
+
is_gguf_available,
|
39
|
+
is_torch_available,
|
35
40
|
is_torch_version,
|
36
41
|
logging,
|
37
42
|
)
|
@@ -53,11 +58,36 @@ if is_accelerate_available():
|
|
53
58
|
|
54
59
|
|
55
60
|
# Adapted from `transformers` (see modeling_utils.py)
|
56
|
-
def _determine_device_map(
|
61
|
+
def _determine_device_map(
|
62
|
+
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
|
63
|
+
):
|
57
64
|
if isinstance(device_map, str):
|
65
|
+
special_dtypes = {}
|
66
|
+
if hf_quantizer is not None:
|
67
|
+
special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
|
68
|
+
special_dtypes.update(
|
69
|
+
{
|
70
|
+
name: torch.float32
|
71
|
+
for name, _ in model.named_parameters()
|
72
|
+
if any(m in name for m in keep_in_fp32_modules)
|
73
|
+
}
|
74
|
+
)
|
75
|
+
|
76
|
+
target_dtype = torch_dtype
|
77
|
+
if hf_quantizer is not None:
|
78
|
+
target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
|
79
|
+
|
58
80
|
no_split_modules = model._get_no_split_modules(device_map)
|
59
81
|
device_map_kwargs = {"no_split_module_classes": no_split_modules}
|
60
82
|
|
83
|
+
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
|
84
|
+
device_map_kwargs["special_dtypes"] = special_dtypes
|
85
|
+
elif len(special_dtypes) > 0:
|
86
|
+
logger.warning(
|
87
|
+
"This model has some weights that should be kept in higher precision, you need to upgrade "
|
88
|
+
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
|
89
|
+
)
|
90
|
+
|
61
91
|
if device_map != "sequential":
|
62
92
|
max_memory = get_balanced_memory(
|
63
93
|
model,
|
@@ -69,8 +99,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
|
|
69
99
|
else:
|
70
100
|
max_memory = get_max_memory(max_memory)
|
71
101
|
|
102
|
+
if hf_quantizer is not None:
|
103
|
+
max_memory = hf_quantizer.adjust_max_memory(max_memory)
|
104
|
+
|
72
105
|
device_map_kwargs["max_memory"] = max_memory
|
73
|
-
device_map = infer_auto_device_map(model, dtype=
|
106
|
+
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
|
107
|
+
|
108
|
+
if hf_quantizer is not None:
|
109
|
+
hf_quantizer.validate_environment(device_map=device_map)
|
74
110
|
|
75
111
|
return device_map
|
76
112
|
|
@@ -99,10 +135,16 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
|
99
135
|
"""
|
100
136
|
Reads a checkpoint file, returning properly formatted errors if they arise.
|
101
137
|
"""
|
138
|
+
# TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
|
139
|
+
# when refactoring the _merge_sharded_checkpoints() method later.
|
140
|
+
if isinstance(checkpoint_file, dict):
|
141
|
+
return checkpoint_file
|
102
142
|
try:
|
103
143
|
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
104
144
|
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
105
145
|
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
146
|
+
elif file_extension == GGUF_FILE_EXTENSION:
|
147
|
+
return load_gguf_checkpoint(checkpoint_file)
|
106
148
|
else:
|
107
149
|
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
108
150
|
return torch.load(
|
@@ -136,29 +178,69 @@ def load_model_dict_into_meta(
|
|
136
178
|
device: Optional[Union[str, torch.device]] = None,
|
137
179
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
138
180
|
model_name_or_path: Optional[str] = None,
|
181
|
+
hf_quantizer=None,
|
182
|
+
keep_in_fp32_modules=None,
|
139
183
|
) -> List[str]:
|
140
|
-
device
|
184
|
+
if device is not None and not isinstance(device, (str, torch.device)):
|
185
|
+
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
|
186
|
+
if hf_quantizer is None:
|
187
|
+
device = device or torch.device("cpu")
|
141
188
|
dtype = dtype or torch.float32
|
189
|
+
is_quantized = hf_quantizer is not None
|
142
190
|
|
143
191
|
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
144
|
-
|
145
|
-
unexpected_keys = []
|
146
192
|
empty_state_dict = model.state_dict()
|
193
|
+
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
|
194
|
+
|
147
195
|
for param_name, param in state_dict.items():
|
148
196
|
if param_name not in empty_state_dict:
|
149
|
-
unexpected_keys.append(param_name)
|
150
197
|
continue
|
151
198
|
|
199
|
+
set_module_kwargs = {}
|
200
|
+
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
|
201
|
+
# in int/uint/bool and not cast them.
|
202
|
+
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
|
203
|
+
if torch.is_floating_point(param):
|
204
|
+
if (
|
205
|
+
keep_in_fp32_modules is not None
|
206
|
+
and any(
|
207
|
+
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
|
208
|
+
)
|
209
|
+
and dtype == torch.float16
|
210
|
+
):
|
211
|
+
param = param.to(torch.float32)
|
212
|
+
if accepts_dtype:
|
213
|
+
set_module_kwargs["dtype"] = torch.float32
|
214
|
+
else:
|
215
|
+
param = param.to(dtype)
|
216
|
+
if accepts_dtype:
|
217
|
+
set_module_kwargs["dtype"] = dtype
|
218
|
+
|
219
|
+
# bnb params are flattened.
|
220
|
+
# gguf quants have a different shape based on the type of quantization applied
|
152
221
|
if empty_state_dict[param_name].shape != param.shape:
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
222
|
+
if (
|
223
|
+
is_quantized
|
224
|
+
and hf_quantizer.pre_quantized
|
225
|
+
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
226
|
+
):
|
227
|
+
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
|
228
|
+
else:
|
229
|
+
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
230
|
+
raise ValueError(
|
231
|
+
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
232
|
+
)
|
233
|
+
|
234
|
+
if is_quantized and (
|
235
|
+
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
236
|
+
):
|
237
|
+
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
|
160
238
|
else:
|
161
|
-
|
239
|
+
if accepts_dtype:
|
240
|
+
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
|
241
|
+
else:
|
242
|
+
set_module_tensor_to_device(model, param_name, device, value=param)
|
243
|
+
|
162
244
|
return unexpected_keys
|
163
245
|
|
164
246
|
|
@@ -228,3 +310,171 @@ def _fetch_index_file(
|
|
228
310
|
index_file = None
|
229
311
|
|
230
312
|
return index_file
|
313
|
+
|
314
|
+
|
315
|
+
# Adapted from
|
316
|
+
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
|
317
|
+
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
|
318
|
+
weight_map = sharded_metadata.get("weight_map", None)
|
319
|
+
if weight_map is None:
|
320
|
+
raise KeyError("'weight_map' key not found in the shard index file.")
|
321
|
+
|
322
|
+
# Collect all unique safetensors files from weight_map
|
323
|
+
files_to_load = set(weight_map.values())
|
324
|
+
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
|
325
|
+
merged_state_dict = {}
|
326
|
+
|
327
|
+
# Load tensors from each unique file
|
328
|
+
for file_name in files_to_load:
|
329
|
+
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
|
330
|
+
if not os.path.exists(part_file_path):
|
331
|
+
raise FileNotFoundError(f"Part file {file_name} not found.")
|
332
|
+
|
333
|
+
if is_safetensors:
|
334
|
+
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
|
335
|
+
for tensor_key in f.keys():
|
336
|
+
if tensor_key in weight_map:
|
337
|
+
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
|
338
|
+
else:
|
339
|
+
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
|
340
|
+
|
341
|
+
return merged_state_dict
|
342
|
+
|
343
|
+
|
344
|
+
def _fetch_index_file_legacy(
|
345
|
+
is_local,
|
346
|
+
pretrained_model_name_or_path,
|
347
|
+
subfolder,
|
348
|
+
use_safetensors,
|
349
|
+
cache_dir,
|
350
|
+
variant,
|
351
|
+
force_download,
|
352
|
+
proxies,
|
353
|
+
local_files_only,
|
354
|
+
token,
|
355
|
+
revision,
|
356
|
+
user_agent,
|
357
|
+
commit_hash,
|
358
|
+
):
|
359
|
+
if is_local:
|
360
|
+
index_file = Path(
|
361
|
+
pretrained_model_name_or_path,
|
362
|
+
subfolder or "",
|
363
|
+
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
|
364
|
+
).as_posix()
|
365
|
+
splits = index_file.split(".")
|
366
|
+
split_index = -3 if ".cache" in index_file else -2
|
367
|
+
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
368
|
+
index_file = ".".join(splits)
|
369
|
+
if os.path.exists(index_file):
|
370
|
+
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
371
|
+
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
|
372
|
+
index_file = Path(index_file)
|
373
|
+
else:
|
374
|
+
index_file = None
|
375
|
+
else:
|
376
|
+
if variant is not None:
|
377
|
+
index_file_in_repo = Path(
|
378
|
+
subfolder or "",
|
379
|
+
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
|
380
|
+
).as_posix()
|
381
|
+
splits = index_file_in_repo.split(".")
|
382
|
+
split_index = -2
|
383
|
+
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
384
|
+
index_file_in_repo = ".".join(splits)
|
385
|
+
try:
|
386
|
+
index_file = _get_model_file(
|
387
|
+
pretrained_model_name_or_path,
|
388
|
+
weights_name=index_file_in_repo,
|
389
|
+
cache_dir=cache_dir,
|
390
|
+
force_download=force_download,
|
391
|
+
proxies=proxies,
|
392
|
+
local_files_only=local_files_only,
|
393
|
+
token=token,
|
394
|
+
revision=revision,
|
395
|
+
subfolder=None,
|
396
|
+
user_agent=user_agent,
|
397
|
+
commit_hash=commit_hash,
|
398
|
+
)
|
399
|
+
index_file = Path(index_file)
|
400
|
+
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
401
|
+
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
|
402
|
+
except (EntryNotFoundError, EnvironmentError):
|
403
|
+
index_file = None
|
404
|
+
|
405
|
+
return index_file
|
406
|
+
|
407
|
+
|
408
|
+
def _gguf_parse_value(_value, data_type):
|
409
|
+
if not isinstance(data_type, list):
|
410
|
+
data_type = [data_type]
|
411
|
+
if len(data_type) == 1:
|
412
|
+
data_type = data_type[0]
|
413
|
+
array_data_type = None
|
414
|
+
else:
|
415
|
+
if data_type[0] != 9:
|
416
|
+
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
|
417
|
+
data_type, array_data_type = data_type
|
418
|
+
|
419
|
+
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
|
420
|
+
_value = int(_value[0])
|
421
|
+
elif data_type in [6, 12]:
|
422
|
+
_value = float(_value[0])
|
423
|
+
elif data_type in [7]:
|
424
|
+
_value = bool(_value[0])
|
425
|
+
elif data_type in [8]:
|
426
|
+
_value = array("B", list(_value)).tobytes().decode()
|
427
|
+
elif data_type in [9]:
|
428
|
+
_value = _gguf_parse_value(_value, array_data_type)
|
429
|
+
return _value
|
430
|
+
|
431
|
+
|
432
|
+
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
433
|
+
"""
|
434
|
+
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
|
435
|
+
attributes.
|
436
|
+
|
437
|
+
Args:
|
438
|
+
gguf_checkpoint_path (`str`):
|
439
|
+
The path the to GGUF file to load
|
440
|
+
return_tensors (`bool`, defaults to `True`):
|
441
|
+
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
|
442
|
+
metadata in memory.
|
443
|
+
"""
|
444
|
+
|
445
|
+
if is_gguf_available() and is_torch_available():
|
446
|
+
import gguf
|
447
|
+
from gguf import GGUFReader
|
448
|
+
|
449
|
+
from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
|
450
|
+
else:
|
451
|
+
logger.error(
|
452
|
+
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
|
453
|
+
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
|
454
|
+
)
|
455
|
+
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
|
456
|
+
|
457
|
+
reader = GGUFReader(gguf_checkpoint_path)
|
458
|
+
|
459
|
+
parsed_parameters = {}
|
460
|
+
for tensor in reader.tensors:
|
461
|
+
name = tensor.name
|
462
|
+
quant_type = tensor.tensor_type
|
463
|
+
|
464
|
+
# if the tensor is a torch supported dtype do not use GGUFParameter
|
465
|
+
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
|
466
|
+
if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
|
467
|
+
_supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
|
468
|
+
raise ValueError(
|
469
|
+
(
|
470
|
+
f"{name} has a quantization type: {str(quant_type)} which is unsupported."
|
471
|
+
"\n\nCurrently the following quantization types are supported: \n\n"
|
472
|
+
f"{_supported_quants_str}"
|
473
|
+
"\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
|
474
|
+
)
|
475
|
+
)
|
476
|
+
|
477
|
+
weights = torch.from_numpy(tensor.data.copy())
|
478
|
+
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
|
479
|
+
|
480
|
+
return parsed_parameters
|
@@ -530,7 +530,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
530
530
|
|
531
531
|
if push_to_hub:
|
532
532
|
commit_message = kwargs.pop("commit_message", None)
|
533
|
-
private = kwargs.pop("private",
|
533
|
+
private = kwargs.pop("private", None)
|
534
534
|
create_pr = kwargs.pop("create_pr", False)
|
535
535
|
token = kwargs.pop("token", None)
|
536
536
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|