diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +26 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +33 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +8 -0
- diffusers/models/activations.py +23 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +475 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +363 -32
- diffusers/models/model_loading_utils.py +177 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +175 -99
- diffusers/models/normalization.py +2 -1
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +292 -184
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +27 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +269 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
- diffusers-0.28.1.dist-info/RECORD +419 -0
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2024 The HuggingFace Inc. team.
|
3
|
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4
|
+
#
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
# you may not use this file except in compliance with the License.
|
7
|
+
# You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
# See the License for the specific language governing permissions and
|
15
|
+
# limitations under the License.
|
16
|
+
|
17
|
+
import importlib
|
18
|
+
import inspect
|
19
|
+
import os
|
20
|
+
from collections import OrderedDict
|
21
|
+
from typing import List, Optional, Union
|
22
|
+
|
23
|
+
import safetensors
|
24
|
+
import torch
|
25
|
+
|
26
|
+
from ..utils import (
|
27
|
+
SAFETENSORS_FILE_EXTENSION,
|
28
|
+
is_accelerate_available,
|
29
|
+
is_torch_version,
|
30
|
+
logging,
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.get_logger(__name__)
|
35
|
+
|
36
|
+
_CLASS_REMAPPING_DICT = {
|
37
|
+
"Transformer2DModel": {
|
38
|
+
"ada_norm_zero": "DiTTransformer2DModel",
|
39
|
+
"ada_norm_single": "PixArtTransformer2DModel",
|
40
|
+
}
|
41
|
+
}
|
42
|
+
|
43
|
+
|
44
|
+
if is_accelerate_available():
|
45
|
+
from accelerate import infer_auto_device_map
|
46
|
+
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
|
47
|
+
|
48
|
+
|
49
|
+
# Adapted from `transformers` (see modeling_utils.py)
|
50
|
+
def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
|
51
|
+
if isinstance(device_map, str):
|
52
|
+
no_split_modules = model._get_no_split_modules(device_map)
|
53
|
+
device_map_kwargs = {"no_split_module_classes": no_split_modules}
|
54
|
+
|
55
|
+
if device_map != "sequential":
|
56
|
+
max_memory = get_balanced_memory(
|
57
|
+
model,
|
58
|
+
dtype=torch_dtype,
|
59
|
+
low_zero=(device_map == "balanced_low_0"),
|
60
|
+
max_memory=max_memory,
|
61
|
+
**device_map_kwargs,
|
62
|
+
)
|
63
|
+
else:
|
64
|
+
max_memory = get_max_memory(max_memory)
|
65
|
+
|
66
|
+
device_map_kwargs["max_memory"] = max_memory
|
67
|
+
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
|
68
|
+
|
69
|
+
return device_map
|
70
|
+
|
71
|
+
|
72
|
+
def _fetch_remapped_cls_from_config(config, old_class):
|
73
|
+
previous_class_name = old_class.__name__
|
74
|
+
remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)
|
75
|
+
|
76
|
+
# Details:
|
77
|
+
# https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
|
78
|
+
if remapped_class_name:
|
79
|
+
# load diffusers library to import compatible and original scheduler
|
80
|
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
81
|
+
remapped_class = getattr(diffusers_library, remapped_class_name)
|
82
|
+
logger.info(
|
83
|
+
f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
|
84
|
+
f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
|
85
|
+
" DOESN'T affect the final results."
|
86
|
+
)
|
87
|
+
return remapped_class
|
88
|
+
else:
|
89
|
+
return old_class
|
90
|
+
|
91
|
+
|
92
|
+
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
93
|
+
"""
|
94
|
+
Reads a checkpoint file, returning properly formatted errors if they arise.
|
95
|
+
"""
|
96
|
+
try:
|
97
|
+
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
98
|
+
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
99
|
+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
100
|
+
else:
|
101
|
+
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
102
|
+
return torch.load(
|
103
|
+
checkpoint_file,
|
104
|
+
map_location="cpu",
|
105
|
+
**weights_only_kwarg,
|
106
|
+
)
|
107
|
+
except Exception as e:
|
108
|
+
try:
|
109
|
+
with open(checkpoint_file) as f:
|
110
|
+
if f.read().startswith("version"):
|
111
|
+
raise OSError(
|
112
|
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
113
|
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
114
|
+
"you cloned."
|
115
|
+
)
|
116
|
+
else:
|
117
|
+
raise ValueError(
|
118
|
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
119
|
+
"model. Make sure you have saved the model properly."
|
120
|
+
) from e
|
121
|
+
except (UnicodeDecodeError, ValueError):
|
122
|
+
raise OSError(
|
123
|
+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
|
124
|
+
)
|
125
|
+
|
126
|
+
|
127
|
+
def load_model_dict_into_meta(
|
128
|
+
model,
|
129
|
+
state_dict: OrderedDict,
|
130
|
+
device: Optional[Union[str, torch.device]] = None,
|
131
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
132
|
+
model_name_or_path: Optional[str] = None,
|
133
|
+
) -> List[str]:
|
134
|
+
device = device or torch.device("cpu")
|
135
|
+
dtype = dtype or torch.float32
|
136
|
+
|
137
|
+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
138
|
+
|
139
|
+
unexpected_keys = []
|
140
|
+
empty_state_dict = model.state_dict()
|
141
|
+
for param_name, param in state_dict.items():
|
142
|
+
if param_name not in empty_state_dict:
|
143
|
+
unexpected_keys.append(param_name)
|
144
|
+
continue
|
145
|
+
|
146
|
+
if empty_state_dict[param_name].shape != param.shape:
|
147
|
+
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
148
|
+
raise ValueError(
|
149
|
+
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
150
|
+
)
|
151
|
+
|
152
|
+
if accepts_dtype:
|
153
|
+
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
|
154
|
+
else:
|
155
|
+
set_module_tensor_to_device(model, param_name, device, value=param)
|
156
|
+
return unexpected_keys
|
157
|
+
|
158
|
+
|
159
|
+
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
160
|
+
# Convert old format to new format if needed from a PyTorch state_dict
|
161
|
+
# copy state_dict so _load_from_state_dict can modify it
|
162
|
+
state_dict = state_dict.copy()
|
163
|
+
error_msgs = []
|
164
|
+
|
165
|
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
166
|
+
# so we need to apply the function recursively.
|
167
|
+
def load(module: torch.nn.Module, prefix: str = ""):
|
168
|
+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
169
|
+
module._load_from_state_dict(*args)
|
170
|
+
|
171
|
+
for name, child in module._modules.items():
|
172
|
+
if child is not None:
|
173
|
+
load(child, prefix + name + ".")
|
174
|
+
|
175
|
+
load(model_to_load)
|
176
|
+
|
177
|
+
return error_msgs
|
@@ -12,7 +12,8 @@
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
|
-
"""
|
15
|
+
"""PyTorch - Flax general utilities."""
|
16
|
+
|
16
17
|
import re
|
17
18
|
|
18
19
|
import jax.numpy as jnp
|
@@ -245,9 +245,9 @@ class FlaxModelMixin(PushToHubMixin):
|
|
245
245
|
force_download (`bool`, *optional*, defaults to `False`):
|
246
246
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
247
247
|
cached versions if they exist.
|
248
|
-
resume_download
|
249
|
-
|
250
|
-
|
248
|
+
resume_download:
|
249
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
250
|
+
of Diffusers.
|
251
251
|
proxies (`Dict[str, str]`, *optional*):
|
252
252
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
253
253
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -296,7 +296,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
296
296
|
cache_dir = kwargs.pop("cache_dir", None)
|
297
297
|
force_download = kwargs.pop("force_download", False)
|
298
298
|
from_pt = kwargs.pop("from_pt", False)
|
299
|
-
resume_download = kwargs.pop("resume_download",
|
299
|
+
resume_download = kwargs.pop("resume_download", None)
|
300
300
|
proxies = kwargs.pop("proxies", None)
|
301
301
|
local_files_only = kwargs.pop("local_files_only", False)
|
302
302
|
token = kwargs.pop("token", None)
|
@@ -15,3 +15,17 @@ class AutoencoderKLOutput(BaseOutput):
|
|
15
15
|
"""
|
16
16
|
|
17
17
|
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class Transformer2DModelOutput(BaseOutput):
|
22
|
+
"""
|
23
|
+
The output of [`Transformer2DModel`].
|
24
|
+
|
25
|
+
Args:
|
26
|
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
27
|
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
28
|
+
distributions for the unnoised latent pixels.
|
29
|
+
"""
|
30
|
+
|
31
|
+
sample: "torch.Tensor" # noqa: F821
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
|
-
"""
|
15
|
+
"""PyTorch - Flax general utilities."""
|
16
16
|
|
17
17
|
from pickle import UnpicklingError
|
18
18
|
|
@@ -20,6 +20,7 @@ import os
|
|
20
20
|
import re
|
21
21
|
from collections import OrderedDict
|
22
22
|
from functools import partial
|
23
|
+
from pathlib import Path
|
23
24
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
24
25
|
|
25
26
|
import safetensors
|
@@ -32,7 +33,6 @@ from .. import __version__
|
|
32
33
|
from ..utils import (
|
33
34
|
CONFIG_NAME,
|
34
35
|
FLAX_WEIGHTS_NAME,
|
35
|
-
SAFETENSORS_FILE_EXTENSION,
|
36
36
|
SAFETENSORS_WEIGHTS_NAME,
|
37
37
|
WEIGHTS_NAME,
|
38
38
|
_add_variant,
|
@@ -42,7 +42,17 @@ from ..utils import (
|
|
42
42
|
is_torch_version,
|
43
43
|
logging,
|
44
44
|
)
|
45
|
-
from ..utils.hub_utils import
|
45
|
+
from ..utils.hub_utils import (
|
46
|
+
PushToHubMixin,
|
47
|
+
load_or_create_model_card,
|
48
|
+
populate_model_card,
|
49
|
+
)
|
50
|
+
from .model_loading_utils import (
|
51
|
+
_determine_device_map,
|
52
|
+
_load_state_dict_into_model,
|
53
|
+
load_model_dict_into_meta,
|
54
|
+
load_state_dict,
|
55
|
+
)
|
46
56
|
|
47
57
|
|
48
58
|
logger = logging.get_logger(__name__)
|
@@ -56,8 +66,6 @@ else:
|
|
56
66
|
|
57
67
|
if is_accelerate_available():
|
58
68
|
import accelerate
|
59
|
-
from accelerate.utils import set_module_tensor_to_device
|
60
|
-
from accelerate.utils.versions import is_torch_version
|
61
69
|
|
62
70
|
|
63
71
|
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
@@ -98,89 +106,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
|
98
106
|
return first_tuple[1].dtype
|
99
107
|
|
100
108
|
|
101
|
-
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
102
|
-
"""
|
103
|
-
Reads a checkpoint file, returning properly formatted errors if they arise.
|
104
|
-
"""
|
105
|
-
try:
|
106
|
-
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
107
|
-
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
108
|
-
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
109
|
-
else:
|
110
|
-
return torch.load(checkpoint_file, map_location="cpu")
|
111
|
-
except Exception as e:
|
112
|
-
try:
|
113
|
-
with open(checkpoint_file) as f:
|
114
|
-
if f.read().startswith("version"):
|
115
|
-
raise OSError(
|
116
|
-
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
117
|
-
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
118
|
-
"you cloned."
|
119
|
-
)
|
120
|
-
else:
|
121
|
-
raise ValueError(
|
122
|
-
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
123
|
-
"model. Make sure you have saved the model properly."
|
124
|
-
) from e
|
125
|
-
except (UnicodeDecodeError, ValueError):
|
126
|
-
raise OSError(
|
127
|
-
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
|
128
|
-
)
|
129
|
-
|
130
|
-
|
131
|
-
def load_model_dict_into_meta(
|
132
|
-
model,
|
133
|
-
state_dict: OrderedDict,
|
134
|
-
device: Optional[Union[str, torch.device]] = None,
|
135
|
-
dtype: Optional[Union[str, torch.dtype]] = None,
|
136
|
-
model_name_or_path: Optional[str] = None,
|
137
|
-
) -> List[str]:
|
138
|
-
device = device or torch.device("cpu")
|
139
|
-
dtype = dtype or torch.float32
|
140
|
-
|
141
|
-
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
142
|
-
|
143
|
-
unexpected_keys = []
|
144
|
-
empty_state_dict = model.state_dict()
|
145
|
-
for param_name, param in state_dict.items():
|
146
|
-
if param_name not in empty_state_dict:
|
147
|
-
unexpected_keys.append(param_name)
|
148
|
-
continue
|
149
|
-
|
150
|
-
if empty_state_dict[param_name].shape != param.shape:
|
151
|
-
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
152
|
-
raise ValueError(
|
153
|
-
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
154
|
-
)
|
155
|
-
|
156
|
-
if accepts_dtype:
|
157
|
-
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
|
158
|
-
else:
|
159
|
-
set_module_tensor_to_device(model, param_name, device, value=param)
|
160
|
-
return unexpected_keys
|
161
|
-
|
162
|
-
|
163
|
-
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
164
|
-
# Convert old format to new format if needed from a PyTorch state_dict
|
165
|
-
# copy state_dict so _load_from_state_dict can modify it
|
166
|
-
state_dict = state_dict.copy()
|
167
|
-
error_msgs = []
|
168
|
-
|
169
|
-
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
170
|
-
# so we need to apply the function recursively.
|
171
|
-
def load(module: torch.nn.Module, prefix: str = ""):
|
172
|
-
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
173
|
-
module._load_from_state_dict(*args)
|
174
|
-
|
175
|
-
for name, child in module._modules.items():
|
176
|
-
if child is not None:
|
177
|
-
load(child, prefix + name + ".")
|
178
|
-
|
179
|
-
load(model_to_load)
|
180
|
-
|
181
|
-
return error_msgs
|
182
|
-
|
183
|
-
|
184
109
|
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
185
110
|
r"""
|
186
111
|
Base class for all models.
|
@@ -195,6 +120,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
195
120
|
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
196
121
|
_supports_gradient_checkpointing = False
|
197
122
|
_keys_to_ignore_on_load_unexpected = None
|
123
|
+
_no_split_modules = None
|
198
124
|
|
199
125
|
def __init__(self):
|
200
126
|
super().__init__()
|
@@ -241,6 +167,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
241
167
|
if self._supports_gradient_checkpointing:
|
242
168
|
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
243
169
|
|
170
|
+
def set_use_npu_flash_attention(self, valid: bool) -> None:
|
171
|
+
r"""
|
172
|
+
Set the switch for the npu flash attention.
|
173
|
+
"""
|
174
|
+
|
175
|
+
def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
|
176
|
+
if hasattr(module, "set_use_npu_flash_attention"):
|
177
|
+
module.set_use_npu_flash_attention(valid)
|
178
|
+
|
179
|
+
for child in module.children():
|
180
|
+
fn_recursive_set_npu_flash_attention(child)
|
181
|
+
|
182
|
+
for module in self.children():
|
183
|
+
if isinstance(module, torch.nn.Module):
|
184
|
+
fn_recursive_set_npu_flash_attention(module)
|
185
|
+
|
186
|
+
def enable_npu_flash_attention(self) -> None:
|
187
|
+
r"""
|
188
|
+
Enable npu flash attention from torch_npu
|
189
|
+
|
190
|
+
"""
|
191
|
+
self.set_use_npu_flash_attention(True)
|
192
|
+
|
193
|
+
def disable_npu_flash_attention(self) -> None:
|
194
|
+
r"""
|
195
|
+
disable npu flash attention from torch_npu
|
196
|
+
|
197
|
+
"""
|
198
|
+
self.set_use_npu_flash_attention(False)
|
199
|
+
|
244
200
|
def set_use_memory_efficient_attention_xformers(
|
245
201
|
self, valid: bool, attention_op: Optional[Callable] = None
|
246
202
|
) -> None:
|
@@ -367,18 +323,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
367
323
|
# Save the model
|
368
324
|
if safe_serialization:
|
369
325
|
safetensors.torch.save_file(
|
370
|
-
state_dict,
|
326
|
+
state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"}
|
371
327
|
)
|
372
328
|
else:
|
373
|
-
torch.save(state_dict,
|
329
|
+
torch.save(state_dict, Path(save_directory, weights_name).as_posix())
|
374
330
|
|
375
|
-
logger.info(f"Model weights saved in {
|
331
|
+
logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
|
376
332
|
|
377
333
|
if push_to_hub:
|
378
334
|
# Create a new empty model card and eventually tag it
|
379
335
|
model_card = load_or_create_model_card(repo_id, token=token)
|
380
336
|
model_card = populate_model_card(model_card)
|
381
|
-
model_card.save(
|
337
|
+
model_card.save(Path(save_directory, "README.md").as_posix())
|
382
338
|
|
383
339
|
self._upload_folder(
|
384
340
|
save_directory,
|
@@ -415,9 +371,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
415
371
|
force_download (`bool`, *optional*, defaults to `False`):
|
416
372
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
417
373
|
cached versions if they exist.
|
418
|
-
resume_download
|
419
|
-
|
420
|
-
|
374
|
+
resume_download:
|
375
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
376
|
+
of Diffusers.
|
421
377
|
proxies (`Dict[str, str]`, *optional*):
|
422
378
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
423
379
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -499,7 +455,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
499
455
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
500
456
|
force_download = kwargs.pop("force_download", False)
|
501
457
|
from_flax = kwargs.pop("from_flax", False)
|
502
|
-
resume_download = kwargs.pop("resume_download",
|
458
|
+
resume_download = kwargs.pop("resume_download", None)
|
503
459
|
proxies = kwargs.pop("proxies", None)
|
504
460
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
505
461
|
local_files_only = kwargs.pop("local_files_only", None)
|
@@ -554,6 +510,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
554
510
|
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
555
511
|
)
|
556
512
|
|
513
|
+
# change device_map into a map if we passed an int, a str or a torch.device
|
514
|
+
if isinstance(device_map, torch.device):
|
515
|
+
device_map = {"": device_map}
|
516
|
+
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
517
|
+
try:
|
518
|
+
device_map = {"": torch.device(device_map)}
|
519
|
+
except RuntimeError:
|
520
|
+
raise ValueError(
|
521
|
+
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
|
522
|
+
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
|
523
|
+
)
|
524
|
+
elif isinstance(device_map, int):
|
525
|
+
if device_map < 0:
|
526
|
+
raise ValueError(
|
527
|
+
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
|
528
|
+
)
|
529
|
+
else:
|
530
|
+
device_map = {"": device_map}
|
531
|
+
|
532
|
+
if device_map is not None:
|
533
|
+
if low_cpu_mem_usage is None:
|
534
|
+
low_cpu_mem_usage = True
|
535
|
+
elif not low_cpu_mem_usage:
|
536
|
+
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
537
|
+
|
538
|
+
if low_cpu_mem_usage:
|
539
|
+
if device_map is not None and not is_torch_version(">=", "1.10"):
|
540
|
+
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
|
541
|
+
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
|
542
|
+
|
557
543
|
# Load config if we don't provide a configuration
|
558
544
|
config_path = pretrained_model_name_or_path
|
559
545
|
|
@@ -576,10 +562,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
576
562
|
token=token,
|
577
563
|
revision=revision,
|
578
564
|
subfolder=subfolder,
|
579
|
-
device_map=device_map,
|
580
|
-
max_memory=max_memory,
|
581
|
-
offload_folder=offload_folder,
|
582
|
-
offload_state_dict=offload_state_dict,
|
583
565
|
user_agent=user_agent,
|
584
566
|
**kwargs,
|
585
567
|
)
|
@@ -684,6 +666,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
684
666
|
else: # else let accelerate handle loading and dispatching.
|
685
667
|
# Load weights and dispatch according to the device_map
|
686
668
|
# by default the device_map is None and the weights are loaded on the CPU
|
669
|
+
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
|
687
670
|
try:
|
688
671
|
accelerate.load_checkpoint_and_dispatch(
|
689
672
|
model,
|
@@ -693,6 +676,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
693
676
|
offload_folder=offload_folder,
|
694
677
|
offload_state_dict=offload_state_dict,
|
695
678
|
dtype=torch_dtype,
|
679
|
+
force_hooks=True,
|
680
|
+
strict=True,
|
696
681
|
)
|
697
682
|
except AttributeError as e:
|
698
683
|
# When using accelerate loading, we do not have the ability to load the state
|
@@ -873,6 +858,45 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
873
858
|
|
874
859
|
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
875
860
|
|
861
|
+
@classmethod
|
862
|
+
def _get_signature_keys(cls, obj):
|
863
|
+
parameters = inspect.signature(obj.__init__).parameters
|
864
|
+
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
865
|
+
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
866
|
+
expected_modules = set(required_parameters.keys()) - {"self"}
|
867
|
+
|
868
|
+
return expected_modules, optional_parameters
|
869
|
+
|
870
|
+
# Adapted from `transformers` modeling_utils.py
|
871
|
+
def _get_no_split_modules(self, device_map: str):
|
872
|
+
"""
|
873
|
+
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
874
|
+
get the underlying `_no_split_modules`.
|
875
|
+
|
876
|
+
Args:
|
877
|
+
device_map (`str`):
|
878
|
+
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
879
|
+
|
880
|
+
Returns:
|
881
|
+
`List[str]`: List of modules that should not be split
|
882
|
+
"""
|
883
|
+
_no_split_modules = set()
|
884
|
+
modules_to_check = [self]
|
885
|
+
while len(modules_to_check) > 0:
|
886
|
+
module = modules_to_check.pop(-1)
|
887
|
+
# if the module does not appear in _no_split_modules, we also check the children
|
888
|
+
if module.__class__.__name__ not in _no_split_modules:
|
889
|
+
if isinstance(module, ModelMixin):
|
890
|
+
if module._no_split_modules is None:
|
891
|
+
raise ValueError(
|
892
|
+
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
893
|
+
"class needs to implement the `_no_split_modules` attribute."
|
894
|
+
)
|
895
|
+
else:
|
896
|
+
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
897
|
+
modules_to_check += list(module.children())
|
898
|
+
return list(_no_split_modules)
|
899
|
+
|
876
900
|
@property
|
877
901
|
def device(self) -> torch.device:
|
878
902
|
"""
|
@@ -1019,3 +1043,55 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1019
1043
|
del module.key
|
1020
1044
|
del module.value
|
1021
1045
|
del module.proj_attn
|
1046
|
+
|
1047
|
+
|
1048
|
+
class LegacyModelMixin(ModelMixin):
|
1049
|
+
r"""
|
1050
|
+
A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
|
1051
|
+
pipeline-specific classes (like `DiTTransformer2DModel`).
|
1052
|
+
"""
|
1053
|
+
|
1054
|
+
@classmethod
|
1055
|
+
@validate_hf_hub_args
|
1056
|
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
1057
|
+
# To prevent depedency import problem.
|
1058
|
+
from .model_loading_utils import _fetch_remapped_cls_from_config
|
1059
|
+
|
1060
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
1061
|
+
force_download = kwargs.pop("force_download", False)
|
1062
|
+
resume_download = kwargs.pop("resume_download", None)
|
1063
|
+
proxies = kwargs.pop("proxies", None)
|
1064
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
1065
|
+
token = kwargs.pop("token", None)
|
1066
|
+
revision = kwargs.pop("revision", None)
|
1067
|
+
subfolder = kwargs.pop("subfolder", None)
|
1068
|
+
|
1069
|
+
# Load config if we don't provide a configuration
|
1070
|
+
config_path = pretrained_model_name_or_path
|
1071
|
+
|
1072
|
+
user_agent = {
|
1073
|
+
"diffusers": __version__,
|
1074
|
+
"file_type": "model",
|
1075
|
+
"framework": "pytorch",
|
1076
|
+
}
|
1077
|
+
|
1078
|
+
# load config
|
1079
|
+
config, _, _ = cls.load_config(
|
1080
|
+
config_path,
|
1081
|
+
cache_dir=cache_dir,
|
1082
|
+
return_unused_kwargs=True,
|
1083
|
+
return_commit_hash=True,
|
1084
|
+
force_download=force_download,
|
1085
|
+
resume_download=resume_download,
|
1086
|
+
proxies=proxies,
|
1087
|
+
local_files_only=local_files_only,
|
1088
|
+
token=token,
|
1089
|
+
revision=revision,
|
1090
|
+
subfolder=subfolder,
|
1091
|
+
user_agent=user_agent,
|
1092
|
+
**kwargs,
|
1093
|
+
)
|
1094
|
+
# resolve remapping
|
1095
|
+
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
1096
|
+
|
1097
|
+
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
@@ -176,7 +176,8 @@ class AdaLayerNormContinuous(nn.Module):
|
|
176
176
|
raise ValueError(f"unknown norm_type {norm_type}")
|
177
177
|
|
178
178
|
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
179
|
-
|
179
|
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
180
|
+
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
|
180
181
|
scale, shift = torch.chunk(emb, 2, dim=1)
|
181
182
|
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
182
183
|
return x
|