diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,14 @@
|
|
14
14
|
# See the License for the specific language governing permissions and
|
15
15
|
# limitations under the License.
|
16
16
|
|
17
|
+
import copy
|
17
18
|
import inspect
|
18
19
|
import itertools
|
19
20
|
import json
|
20
21
|
import os
|
21
22
|
import re
|
22
23
|
from collections import OrderedDict
|
23
|
-
from functools import partial
|
24
|
+
from functools import partial, wraps
|
24
25
|
from pathlib import Path
|
25
26
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
26
27
|
|
@@ -31,6 +32,8 @@ from huggingface_hub.utils import validate_hf_hub_args
|
|
31
32
|
from torch import Tensor, nn
|
32
33
|
|
33
34
|
from .. import __version__
|
35
|
+
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
36
|
+
from ..quantizers.quantization_config import QuantizationMethod
|
34
37
|
from ..utils import (
|
35
38
|
CONFIG_NAME,
|
36
39
|
FLAX_WEIGHTS_NAME,
|
@@ -43,6 +46,8 @@ from ..utils import (
|
|
43
46
|
_get_model_file,
|
44
47
|
deprecate,
|
45
48
|
is_accelerate_available,
|
49
|
+
is_bitsandbytes_available,
|
50
|
+
is_bitsandbytes_version,
|
46
51
|
is_torch_version,
|
47
52
|
logging,
|
48
53
|
)
|
@@ -54,7 +59,9 @@ from ..utils.hub_utils import (
|
|
54
59
|
from .model_loading_utils import (
|
55
60
|
_determine_device_map,
|
56
61
|
_fetch_index_file,
|
62
|
+
_fetch_index_file_legacy,
|
57
63
|
_load_state_dict_into_model,
|
64
|
+
_merge_sharded_checkpoints,
|
58
65
|
load_model_dict_into_meta,
|
59
66
|
load_state_dict,
|
60
67
|
)
|
@@ -92,25 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
|
92
99
|
|
93
100
|
|
94
101
|
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
if
|
102
|
-
return
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
102
|
+
"""
|
103
|
+
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
|
104
|
+
"""
|
105
|
+
last_dtype = None
|
106
|
+
for param in parameter.parameters():
|
107
|
+
last_dtype = param.dtype
|
108
|
+
if param.is_floating_point():
|
109
|
+
return param.dtype
|
110
|
+
|
111
|
+
for buffer in parameter.buffers():
|
112
|
+
last_dtype = buffer.dtype
|
113
|
+
if buffer.is_floating_point():
|
114
|
+
return buffer.dtype
|
115
|
+
|
116
|
+
if last_dtype is not None:
|
117
|
+
# if no floating dtype was found return whatever the first dtype is
|
118
|
+
return last_dtype
|
119
|
+
|
120
|
+
# For nn.DataParallel compatibility in PyTorch > 1.5
|
121
|
+
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
122
|
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
123
|
+
return tuples
|
124
|
+
|
125
|
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
126
|
+
last_tuple = None
|
127
|
+
for tuple in gen:
|
128
|
+
last_tuple = tuple
|
129
|
+
if tuple[1].is_floating_point():
|
130
|
+
return tuple[1].dtype
|
131
|
+
|
132
|
+
if last_tuple is not None:
|
133
|
+
# fallback to the last dtype
|
134
|
+
return last_tuple[1].dtype
|
114
135
|
|
115
136
|
|
116
137
|
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
@@ -128,6 +149,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
128
149
|
_supports_gradient_checkpointing = False
|
129
150
|
_keys_to_ignore_on_load_unexpected = None
|
130
151
|
_no_split_modules = None
|
152
|
+
_keep_in_fp32_modules = None
|
131
153
|
|
132
154
|
def __init__(self):
|
133
155
|
super().__init__()
|
@@ -204,6 +226,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
204
226
|
"""
|
205
227
|
self.set_use_npu_flash_attention(False)
|
206
228
|
|
229
|
+
def set_use_xla_flash_attention(
|
230
|
+
self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None
|
231
|
+
) -> None:
|
232
|
+
# Recursively walk through all the children.
|
233
|
+
# Any children which exposes the set_use_xla_flash_attention method
|
234
|
+
# gets the message
|
235
|
+
def fn_recursive_set_flash_attention(module: torch.nn.Module):
|
236
|
+
if hasattr(module, "set_use_xla_flash_attention"):
|
237
|
+
module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec)
|
238
|
+
|
239
|
+
for child in module.children():
|
240
|
+
fn_recursive_set_flash_attention(child)
|
241
|
+
|
242
|
+
for module in self.children():
|
243
|
+
if isinstance(module, torch.nn.Module):
|
244
|
+
fn_recursive_set_flash_attention(module)
|
245
|
+
|
246
|
+
def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
|
247
|
+
r"""
|
248
|
+
Enable the flash attention pallals kernel for torch_xla.
|
249
|
+
"""
|
250
|
+
self.set_use_xla_flash_attention(True, partition_spec)
|
251
|
+
|
252
|
+
def disable_xla_flash_attention(self):
|
253
|
+
r"""
|
254
|
+
Disable the flash attention pallals kernel for torch_xla.
|
255
|
+
"""
|
256
|
+
self.set_use_xla_flash_attention(False)
|
257
|
+
|
207
258
|
def set_use_memory_efficient_attention_xformers(
|
208
259
|
self, valid: bool, attention_op: Optional[Callable] = None
|
209
260
|
) -> None:
|
@@ -311,19 +362,30 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
311
362
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
312
363
|
return
|
313
364
|
|
365
|
+
hf_quantizer = getattr(self, "hf_quantizer", None)
|
366
|
+
if hf_quantizer is not None:
|
367
|
+
quantization_serializable = (
|
368
|
+
hf_quantizer is not None
|
369
|
+
and isinstance(hf_quantizer, DiffusersQuantizer)
|
370
|
+
and hf_quantizer.is_serializable
|
371
|
+
)
|
372
|
+
if not quantization_serializable:
|
373
|
+
raise ValueError(
|
374
|
+
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
|
375
|
+
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
376
|
+
)
|
377
|
+
|
314
378
|
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
315
379
|
weights_name = _add_variant(weights_name, variant)
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
else:
|
320
|
-
raise ValueError(f"Invalid {weights_name} provided.")
|
380
|
+
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
381
|
+
".safetensors", "{suffix}.safetensors"
|
382
|
+
)
|
321
383
|
|
322
384
|
os.makedirs(save_directory, exist_ok=True)
|
323
385
|
|
324
386
|
if push_to_hub:
|
325
387
|
commit_message = kwargs.pop("commit_message", None)
|
326
|
-
private = kwargs.pop("private",
|
388
|
+
private = kwargs.pop("private", None)
|
327
389
|
create_pr = kwargs.pop("create_pr", False)
|
328
390
|
token = kwargs.pop("token", None)
|
329
391
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
@@ -407,6 +469,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
407
469
|
create_pr=create_pr,
|
408
470
|
)
|
409
471
|
|
472
|
+
def dequantize(self):
|
473
|
+
"""
|
474
|
+
Potentially dequantize the model in case it has been quantized by a quantization method that support
|
475
|
+
dequantization.
|
476
|
+
"""
|
477
|
+
hf_quantizer = getattr(self, "hf_quantizer", None)
|
478
|
+
|
479
|
+
if hf_quantizer is None:
|
480
|
+
raise ValueError("You need to first quantize your model in order to dequantize it")
|
481
|
+
|
482
|
+
return hf_quantizer.dequantize(self)
|
483
|
+
|
410
484
|
@classmethod
|
411
485
|
@validate_hf_hub_args
|
412
486
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
@@ -529,6 +603,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
529
603
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
530
604
|
variant = kwargs.pop("variant", None)
|
531
605
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
606
|
+
quantization_config = kwargs.pop("quantization_config", None)
|
532
607
|
|
533
608
|
allow_pickle = False
|
534
609
|
if use_safetensors is None:
|
@@ -623,26 +698,87 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
623
698
|
user_agent=user_agent,
|
624
699
|
**kwargs,
|
625
700
|
)
|
701
|
+
# no in-place modification of the original config.
|
702
|
+
config = copy.deepcopy(config)
|
703
|
+
|
704
|
+
# determine initial quantization config.
|
705
|
+
#######################################
|
706
|
+
pre_quantized = "quantization_config" in config and config["quantization_config"] is not None
|
707
|
+
if pre_quantized or quantization_config is not None:
|
708
|
+
if pre_quantized:
|
709
|
+
config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs(
|
710
|
+
config["quantization_config"], quantization_config
|
711
|
+
)
|
712
|
+
else:
|
713
|
+
config["quantization_config"] = quantization_config
|
714
|
+
hf_quantizer = DiffusersAutoQuantizer.from_config(
|
715
|
+
config["quantization_config"], pre_quantized=pre_quantized
|
716
|
+
)
|
717
|
+
else:
|
718
|
+
hf_quantizer = None
|
719
|
+
|
720
|
+
if hf_quantizer is not None:
|
721
|
+
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
|
722
|
+
if is_bnb_quantization_method and device_map is not None:
|
723
|
+
raise NotImplementedError(
|
724
|
+
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
|
725
|
+
)
|
726
|
+
|
727
|
+
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
|
728
|
+
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
729
|
+
|
730
|
+
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
731
|
+
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
|
732
|
+
|
733
|
+
# Force-set to `True` for more mem efficiency
|
734
|
+
if low_cpu_mem_usage is None:
|
735
|
+
low_cpu_mem_usage = True
|
736
|
+
logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
|
737
|
+
elif not low_cpu_mem_usage:
|
738
|
+
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
|
739
|
+
|
740
|
+
# Check if `_keep_in_fp32_modules` is not None
|
741
|
+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
742
|
+
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
743
|
+
)
|
744
|
+
if use_keep_in_fp32_modules:
|
745
|
+
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
746
|
+
if not isinstance(keep_in_fp32_modules, list):
|
747
|
+
keep_in_fp32_modules = [keep_in_fp32_modules]
|
748
|
+
|
749
|
+
if low_cpu_mem_usage is None:
|
750
|
+
low_cpu_mem_usage = True
|
751
|
+
logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
|
752
|
+
elif not low_cpu_mem_usage:
|
753
|
+
raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
|
754
|
+
else:
|
755
|
+
keep_in_fp32_modules = []
|
756
|
+
#######################################
|
626
757
|
|
627
758
|
# Determine if we're loading from a directory of sharded checkpoints.
|
628
759
|
is_sharded = False
|
629
760
|
index_file = None
|
630
761
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
631
|
-
|
632
|
-
is_local
|
633
|
-
pretrained_model_name_or_path
|
634
|
-
subfolder
|
635
|
-
use_safetensors
|
636
|
-
cache_dir
|
637
|
-
variant
|
638
|
-
force_download
|
639
|
-
proxies
|
640
|
-
local_files_only
|
641
|
-
token
|
642
|
-
revision
|
643
|
-
user_agent
|
644
|
-
commit_hash
|
645
|
-
|
762
|
+
index_file_kwargs = {
|
763
|
+
"is_local": is_local,
|
764
|
+
"pretrained_model_name_or_path": pretrained_model_name_or_path,
|
765
|
+
"subfolder": subfolder or "",
|
766
|
+
"use_safetensors": use_safetensors,
|
767
|
+
"cache_dir": cache_dir,
|
768
|
+
"variant": variant,
|
769
|
+
"force_download": force_download,
|
770
|
+
"proxies": proxies,
|
771
|
+
"local_files_only": local_files_only,
|
772
|
+
"token": token,
|
773
|
+
"revision": revision,
|
774
|
+
"user_agent": user_agent,
|
775
|
+
"commit_hash": commit_hash,
|
776
|
+
}
|
777
|
+
index_file = _fetch_index_file(**index_file_kwargs)
|
778
|
+
# In case the index file was not found we still have to consider the legacy format.
|
779
|
+
# this becomes applicable when the variant is not None.
|
780
|
+
if variant is not None and (index_file is None or not os.path.exists(index_file)):
|
781
|
+
index_file = _fetch_index_file_legacy(**index_file_kwargs)
|
646
782
|
if index_file is not None and index_file.is_file():
|
647
783
|
is_sharded = True
|
648
784
|
|
@@ -684,6 +820,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
684
820
|
revision=revision,
|
685
821
|
subfolder=subfolder or "",
|
686
822
|
)
|
823
|
+
if hf_quantizer is not None and is_bnb_quantization_method:
|
824
|
+
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
|
825
|
+
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
826
|
+
is_sharded = False
|
687
827
|
|
688
828
|
elif use_safetensors and not is_sharded:
|
689
829
|
try:
|
@@ -729,13 +869,27 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
729
869
|
with accelerate.init_empty_weights():
|
730
870
|
model = cls.from_config(config, **unused_kwargs)
|
731
871
|
|
872
|
+
if hf_quantizer is not None:
|
873
|
+
hf_quantizer.preprocess_model(
|
874
|
+
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
|
875
|
+
)
|
876
|
+
|
732
877
|
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
733
878
|
if device_map is None and not is_sharded:
|
734
|
-
|
879
|
+
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
|
880
|
+
# It would error out during the `validate_environment()` call above in the absence of cuda.
|
881
|
+
if hf_quantizer is None:
|
882
|
+
param_device = "cpu"
|
883
|
+
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
884
|
+
else:
|
885
|
+
param_device = torch.device(torch.cuda.current_device())
|
735
886
|
state_dict = load_state_dict(model_file, variant=variant)
|
736
887
|
model._convert_deprecated_attention_blocks(state_dict)
|
888
|
+
|
737
889
|
# move the params from meta device to cpu
|
738
890
|
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
891
|
+
if hf_quantizer is not None:
|
892
|
+
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
|
739
893
|
if len(missing_keys) > 0:
|
740
894
|
raise ValueError(
|
741
895
|
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
@@ -750,6 +904,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
750
904
|
device=param_device,
|
751
905
|
dtype=torch_dtype,
|
752
906
|
model_name_or_path=pretrained_model_name_or_path,
|
907
|
+
hf_quantizer=hf_quantizer,
|
908
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
753
909
|
)
|
754
910
|
|
755
911
|
if cls._keys_to_ignore_on_load_unexpected is not None:
|
@@ -765,7 +921,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
765
921
|
# Load weights and dispatch according to the device_map
|
766
922
|
# by default the device_map is None and the weights are loaded on the CPU
|
767
923
|
force_hook = True
|
768
|
-
device_map = _determine_device_map(
|
924
|
+
device_map = _determine_device_map(
|
925
|
+
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
|
926
|
+
)
|
769
927
|
if device_map is None and is_sharded:
|
770
928
|
# we load the parameters on the cpu
|
771
929
|
device_map = {"": "cpu"}
|
@@ -843,14 +1001,25 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
843
1001
|
"error_msgs": error_msgs,
|
844
1002
|
}
|
845
1003
|
|
1004
|
+
if hf_quantizer is not None:
|
1005
|
+
hf_quantizer.postprocess_model(model)
|
1006
|
+
model.hf_quantizer = hf_quantizer
|
1007
|
+
|
846
1008
|
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
847
1009
|
raise ValueError(
|
848
1010
|
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
849
1011
|
)
|
850
|
-
|
1012
|
+
# When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
|
1013
|
+
# completely lose the effectivity of `use_keep_in_fp32_modules`.
|
1014
|
+
elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
|
851
1015
|
model = model.to(torch_dtype)
|
852
1016
|
|
853
|
-
|
1017
|
+
if hf_quantizer is not None:
|
1018
|
+
# We also make sure to purge `_pre_quantization_dtype` when we serialize
|
1019
|
+
# the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable.
|
1020
|
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype)
|
1021
|
+
else:
|
1022
|
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
854
1023
|
|
855
1024
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
856
1025
|
model.eval()
|
@@ -859,6 +1028,76 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
859
1028
|
|
860
1029
|
return model
|
861
1030
|
|
1031
|
+
# Adapted from `transformers`.
|
1032
|
+
@wraps(torch.nn.Module.cuda)
|
1033
|
+
def cuda(self, *args, **kwargs):
|
1034
|
+
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
1035
|
+
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
1036
|
+
if getattr(self, "is_loaded_in_8bit", False):
|
1037
|
+
raise ValueError(
|
1038
|
+
"Calling `cuda()` is not supported for `8-bit` quantized models. "
|
1039
|
+
" Please use the model as it is, since the model has already been set to the correct devices."
|
1040
|
+
)
|
1041
|
+
elif is_bitsandbytes_version("<", "0.43.2"):
|
1042
|
+
raise ValueError(
|
1043
|
+
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
1044
|
+
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
1045
|
+
)
|
1046
|
+
return super().cuda(*args, **kwargs)
|
1047
|
+
|
1048
|
+
# Adapted from `transformers`.
|
1049
|
+
@wraps(torch.nn.Module.to)
|
1050
|
+
def to(self, *args, **kwargs):
|
1051
|
+
dtype_present_in_args = "dtype" in kwargs
|
1052
|
+
|
1053
|
+
if not dtype_present_in_args:
|
1054
|
+
for arg in args:
|
1055
|
+
if isinstance(arg, torch.dtype):
|
1056
|
+
dtype_present_in_args = True
|
1057
|
+
break
|
1058
|
+
|
1059
|
+
if getattr(self, "is_quantized", False):
|
1060
|
+
if dtype_present_in_args:
|
1061
|
+
raise ValueError(
|
1062
|
+
"Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
|
1063
|
+
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
|
1064
|
+
)
|
1065
|
+
|
1066
|
+
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
1067
|
+
if getattr(self, "is_loaded_in_8bit", False):
|
1068
|
+
raise ValueError(
|
1069
|
+
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
1070
|
+
" model has already been set to the correct devices and casted to the correct `dtype`."
|
1071
|
+
)
|
1072
|
+
elif is_bitsandbytes_version("<", "0.43.2"):
|
1073
|
+
raise ValueError(
|
1074
|
+
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
1075
|
+
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
1076
|
+
)
|
1077
|
+
return super().to(*args, **kwargs)
|
1078
|
+
|
1079
|
+
# Taken from `transformers`.
|
1080
|
+
def half(self, *args):
|
1081
|
+
# Checks if the model is quantized
|
1082
|
+
if getattr(self, "is_quantized", False):
|
1083
|
+
raise ValueError(
|
1084
|
+
"`.half()` is not supported for quantized model. Please use the model as it is, since the"
|
1085
|
+
" model has already been cast to the correct `dtype`."
|
1086
|
+
)
|
1087
|
+
else:
|
1088
|
+
return super().half(*args)
|
1089
|
+
|
1090
|
+
# Taken from `transformers`.
|
1091
|
+
def float(self, *args):
|
1092
|
+
# Checks if the model is quantized
|
1093
|
+
if getattr(self, "is_quantized", False):
|
1094
|
+
raise ValueError(
|
1095
|
+
"`.float()` is not supported for quantized model. Please use the model as it is, since the"
|
1096
|
+
" model has already been cast to the correct `dtype`."
|
1097
|
+
)
|
1098
|
+
else:
|
1099
|
+
return super().float(*args)
|
1100
|
+
|
862
1101
|
@classmethod
|
863
1102
|
def _load_pretrained_model(
|
864
1103
|
cls,
|
@@ -1041,19 +1280,63 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1041
1280
|
859520964
|
1042
1281
|
```
|
1043
1282
|
"""
|
1283
|
+
is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
|
1284
|
+
|
1285
|
+
if is_loaded_in_4bit:
|
1286
|
+
if is_bitsandbytes_available():
|
1287
|
+
import bitsandbytes as bnb
|
1288
|
+
else:
|
1289
|
+
raise ValueError(
|
1290
|
+
"bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
|
1291
|
+
" make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
|
1292
|
+
)
|
1044
1293
|
|
1045
1294
|
if exclude_embeddings:
|
1046
1295
|
embedding_param_names = [
|
1047
|
-
f"{name}.weight"
|
1048
|
-
for name, module_type in self.named_modules()
|
1049
|
-
if isinstance(module_type, torch.nn.Embedding)
|
1296
|
+
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
|
1050
1297
|
]
|
1051
|
-
|
1298
|
+
total_parameters = [
|
1052
1299
|
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
1053
1300
|
]
|
1054
|
-
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
1055
1301
|
else:
|
1056
|
-
|
1302
|
+
total_parameters = list(self.parameters())
|
1303
|
+
|
1304
|
+
total_numel = []
|
1305
|
+
|
1306
|
+
for param in total_parameters:
|
1307
|
+
if param.requires_grad or not only_trainable:
|
1308
|
+
# For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
|
1309
|
+
# used for the 4bit quantization (uint8 tensors are stored)
|
1310
|
+
if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
|
1311
|
+
if hasattr(param, "element_size"):
|
1312
|
+
num_bytes = param.element_size()
|
1313
|
+
elif hasattr(param, "quant_storage"):
|
1314
|
+
num_bytes = param.quant_storage.itemsize
|
1315
|
+
else:
|
1316
|
+
num_bytes = 1
|
1317
|
+
total_numel.append(param.numel() * 2 * num_bytes)
|
1318
|
+
else:
|
1319
|
+
total_numel.append(param.numel())
|
1320
|
+
|
1321
|
+
return sum(total_numel)
|
1322
|
+
|
1323
|
+
def get_memory_footprint(self, return_buffers=True):
|
1324
|
+
r"""
|
1325
|
+
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
|
1326
|
+
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
|
1327
|
+
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
|
1328
|
+
|
1329
|
+
Arguments:
|
1330
|
+
return_buffers (`bool`, *optional*, defaults to `True`):
|
1331
|
+
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
|
1332
|
+
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
|
1333
|
+
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
|
1334
|
+
"""
|
1335
|
+
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
|
1336
|
+
if return_buffers:
|
1337
|
+
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
|
1338
|
+
mem = mem + mem_bufs
|
1339
|
+
return mem
|
1057
1340
|
|
1058
1341
|
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
1059
1342
|
deprecated_attention_block_paths = []
|