diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -2
- diffusers/commands/fp16_safetensors.py +10 -11
- diffusers/configuration_utils.py +13 -8
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +5 -5
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +463 -51
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +159 -0
- diffusers/loaders/lora.py +1553 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +637 -0
- diffusers/loaders/textual_inversion.py +455 -0
- diffusers/loaders/unet.py +828 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +26 -9
- diffusers/models/activations.py +9 -6
- diffusers/models/attention.py +301 -29
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +378 -6
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
- diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
- diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
- diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/downsampling.py +338 -0
- diffusers/models/embeddings.py +112 -29
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +14 -8
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +37 -29
- diffusers/models/normalization.py +110 -4
- diffusers/models/resnet.py +299 -652
- diffusers/models/transformer_2d.py +22 -5
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +46 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandinsky3.py +535 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/upsampling.py +454 -0
- diffusers/models/uvit_2d.py +471 -0
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +12 -3
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +137 -76
- diffusers/pipelines/amused/__init__.py +62 -0
- diffusers/pipelines/amused/pipeline_amused.py +328 -0
- diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +23 -13
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/deprecated/__init__.py +153 -0
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
- diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
- diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
- diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
- diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
- diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/onnx_utils.py +8 -5
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +11 -8
- diffusers/pipelines/pipeline_utils.py +63 -42
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/__init__.py +37 -65
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
- diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
- diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
- diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
- diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
- diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
- diffusers/schedulers/__init__.py +4 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_amused.py +162 -0
- diffusers/schedulers/scheduling_consistency_models.py +2 -0
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +47 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
- diffusers/schedulers/scheduling_deis_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
- diffusers/schedulers/scheduling_euler_discrete.py +102 -16
- diffusers/schedulers/scheduling_heun_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +3 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
- diffusers/schedulers/scheduling_utils.py +3 -1
- diffusers/schedulers/scheduling_utils_flax.py +3 -1
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +1 -2
- diffusers/utils/constants.py +10 -12
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +18 -22
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/hub_utils.py +24 -36
- diffusers/utils/logging.py +11 -11
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/state_dict_utils.py +8 -0
- diffusers/utils/testing_utils.py +199 -1
- diffusers/utils/torch_utils.py +4 -4
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
- diffusers-0.25.0.dist-info/RECORD +360 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
- diffusers/loaders.py +0 -3336
- diffusers-0.23.1.dist-info/RECORD +0 -323
- /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
diffusers/models/lora.py
CHANGED
@@ -12,19 +12,60 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
|
16
|
+
# IMPORTANT: #
|
17
|
+
###################################################################
|
18
|
+
# ----------------------------------------------------------------#
|
19
|
+
# This file is deprecated and will be removed soon #
|
20
|
+
# (as soon as PEFT will become a required dependency for LoRA) #
|
21
|
+
# ----------------------------------------------------------------#
|
22
|
+
###################################################################
|
23
|
+
|
15
24
|
from typing import Optional, Tuple, Union
|
16
25
|
|
17
26
|
import torch
|
18
27
|
import torch.nn.functional as F
|
19
28
|
from torch import nn
|
20
29
|
|
21
|
-
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
|
22
30
|
from ..utils import logging
|
31
|
+
from ..utils.import_utils import is_transformers_available
|
32
|
+
|
33
|
+
|
34
|
+
if is_transformers_available():
|
35
|
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
23
36
|
|
24
37
|
|
25
38
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
26
39
|
|
27
40
|
|
41
|
+
def text_encoder_attn_modules(text_encoder):
|
42
|
+
attn_modules = []
|
43
|
+
|
44
|
+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
45
|
+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
46
|
+
name = f"text_model.encoder.layers.{i}.self_attn"
|
47
|
+
mod = layer.self_attn
|
48
|
+
attn_modules.append((name, mod))
|
49
|
+
else:
|
50
|
+
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
|
51
|
+
|
52
|
+
return attn_modules
|
53
|
+
|
54
|
+
|
55
|
+
def text_encoder_mlp_modules(text_encoder):
|
56
|
+
mlp_modules = []
|
57
|
+
|
58
|
+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
59
|
+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
60
|
+
mlp_mod = layer.mlp
|
61
|
+
name = f"text_model.encoder.layers.{i}.mlp"
|
62
|
+
mlp_modules.append((name, mlp_mod))
|
63
|
+
else:
|
64
|
+
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
|
65
|
+
|
66
|
+
return mlp_modules
|
67
|
+
|
68
|
+
|
28
69
|
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
29
70
|
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
30
71
|
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
@@ -39,6 +80,95 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
|
39
80
|
mlp_module.fc2.lora_scale = lora_scale
|
40
81
|
|
41
82
|
|
83
|
+
class PatchedLoraProjection(torch.nn.Module):
|
84
|
+
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
|
85
|
+
super().__init__()
|
86
|
+
from ..models.lora import LoRALinearLayer
|
87
|
+
|
88
|
+
self.regular_linear_layer = regular_linear_layer
|
89
|
+
|
90
|
+
device = self.regular_linear_layer.weight.device
|
91
|
+
|
92
|
+
if dtype is None:
|
93
|
+
dtype = self.regular_linear_layer.weight.dtype
|
94
|
+
|
95
|
+
self.lora_linear_layer = LoRALinearLayer(
|
96
|
+
self.regular_linear_layer.in_features,
|
97
|
+
self.regular_linear_layer.out_features,
|
98
|
+
network_alpha=network_alpha,
|
99
|
+
device=device,
|
100
|
+
dtype=dtype,
|
101
|
+
rank=rank,
|
102
|
+
)
|
103
|
+
|
104
|
+
self.lora_scale = lora_scale
|
105
|
+
|
106
|
+
# overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
|
107
|
+
# when saving the whole text encoder model and when LoRA is unloaded or fused
|
108
|
+
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
109
|
+
if self.lora_linear_layer is None:
|
110
|
+
return self.regular_linear_layer.state_dict(
|
111
|
+
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
|
112
|
+
)
|
113
|
+
|
114
|
+
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
115
|
+
|
116
|
+
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
117
|
+
if self.lora_linear_layer is None:
|
118
|
+
return
|
119
|
+
|
120
|
+
dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
|
121
|
+
|
122
|
+
w_orig = self.regular_linear_layer.weight.data.float()
|
123
|
+
w_up = self.lora_linear_layer.up.weight.data.float()
|
124
|
+
w_down = self.lora_linear_layer.down.weight.data.float()
|
125
|
+
|
126
|
+
if self.lora_linear_layer.network_alpha is not None:
|
127
|
+
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
|
128
|
+
|
129
|
+
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
130
|
+
|
131
|
+
if safe_fusing and torch.isnan(fused_weight).any().item():
|
132
|
+
raise ValueError(
|
133
|
+
"This LoRA weight seems to be broken. "
|
134
|
+
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
|
135
|
+
"LoRA weights will not be fused."
|
136
|
+
)
|
137
|
+
|
138
|
+
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
|
139
|
+
|
140
|
+
# we can drop the lora layer now
|
141
|
+
self.lora_linear_layer = None
|
142
|
+
|
143
|
+
# offload the up and down matrices to CPU to not blow the memory
|
144
|
+
self.w_up = w_up.cpu()
|
145
|
+
self.w_down = w_down.cpu()
|
146
|
+
self.lora_scale = lora_scale
|
147
|
+
|
148
|
+
def _unfuse_lora(self):
|
149
|
+
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
|
150
|
+
return
|
151
|
+
|
152
|
+
fused_weight = self.regular_linear_layer.weight.data
|
153
|
+
dtype, device = fused_weight.dtype, fused_weight.device
|
154
|
+
|
155
|
+
w_up = self.w_up.to(device=device).float()
|
156
|
+
w_down = self.w_down.to(device).float()
|
157
|
+
|
158
|
+
unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
159
|
+
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
160
|
+
|
161
|
+
self.w_up = None
|
162
|
+
self.w_down = None
|
163
|
+
|
164
|
+
def forward(self, input):
|
165
|
+
if self.lora_scale is None:
|
166
|
+
self.lora_scale = 1.0
|
167
|
+
if self.lora_linear_layer is None:
|
168
|
+
return self.regular_linear_layer(input)
|
169
|
+
return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
|
170
|
+
|
171
|
+
|
42
172
|
class LoRALinearLayer(nn.Module):
|
43
173
|
r"""
|
44
174
|
A linear layer that is used with LoRA.
|
@@ -24,13 +24,17 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
|
|
24
24
|
from flax.serialization import from_bytes, to_bytes
|
25
25
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
26
26
|
from huggingface_hub import create_repo, hf_hub_download
|
27
|
-
from huggingface_hub.utils import
|
27
|
+
from huggingface_hub.utils import (
|
28
|
+
EntryNotFoundError,
|
29
|
+
RepositoryNotFoundError,
|
30
|
+
RevisionNotFoundError,
|
31
|
+
validate_hf_hub_args,
|
32
|
+
)
|
28
33
|
from requests import HTTPError
|
29
34
|
|
30
35
|
from .. import __version__, is_torch_available
|
31
36
|
from ..utils import (
|
32
37
|
CONFIG_NAME,
|
33
|
-
DIFFUSERS_CACHE,
|
34
38
|
FLAX_WEIGHTS_NAME,
|
35
39
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
36
40
|
WEIGHTS_NAME,
|
@@ -52,6 +56,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
52
56
|
|
53
57
|
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
|
54
58
|
"""
|
59
|
+
|
55
60
|
config_name = CONFIG_NAME
|
56
61
|
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
57
62
|
_flax_internal_args = ["name", "parent", "dtype"]
|
@@ -196,6 +201,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
196
201
|
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
197
202
|
|
198
203
|
@classmethod
|
204
|
+
@validate_hf_hub_args
|
199
205
|
def from_pretrained(
|
200
206
|
cls,
|
201
207
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
@@ -287,13 +293,13 @@ class FlaxModelMixin(PushToHubMixin):
|
|
287
293
|
```
|
288
294
|
"""
|
289
295
|
config = kwargs.pop("config", None)
|
290
|
-
cache_dir = kwargs.pop("cache_dir",
|
296
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
291
297
|
force_download = kwargs.pop("force_download", False)
|
292
298
|
from_pt = kwargs.pop("from_pt", False)
|
293
299
|
resume_download = kwargs.pop("resume_download", False)
|
294
300
|
proxies = kwargs.pop("proxies", None)
|
295
301
|
local_files_only = kwargs.pop("local_files_only", False)
|
296
|
-
|
302
|
+
token = kwargs.pop("token", None)
|
297
303
|
revision = kwargs.pop("revision", None)
|
298
304
|
subfolder = kwargs.pop("subfolder", None)
|
299
305
|
|
@@ -313,7 +319,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
313
319
|
resume_download=resume_download,
|
314
320
|
proxies=proxies,
|
315
321
|
local_files_only=local_files_only,
|
316
|
-
|
322
|
+
token=token,
|
317
323
|
revision=revision,
|
318
324
|
subfolder=subfolder,
|
319
325
|
**kwargs,
|
@@ -358,7 +364,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
358
364
|
proxies=proxies,
|
359
365
|
resume_download=resume_download,
|
360
366
|
local_files_only=local_files_only,
|
361
|
-
|
367
|
+
token=token,
|
362
368
|
user_agent=user_agent,
|
363
369
|
subfolder=subfolder,
|
364
370
|
revision=revision,
|
@@ -368,7 +374,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
368
374
|
raise EnvironmentError(
|
369
375
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
370
376
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
371
|
-
"token having permission to this repo with `
|
377
|
+
"token having permission to this repo with `token` or log in with `huggingface-cli "
|
372
378
|
"login`."
|
373
379
|
)
|
374
380
|
except RevisionNotFoundError:
|
@@ -436,7 +442,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
436
442
|
# make sure all arrays are stored as jnp.ndarray
|
437
443
|
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
438
444
|
# https://github.com/google/flax/issues/1261
|
439
|
-
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.
|
445
|
+
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
|
440
446
|
|
441
447
|
# flatten dicts
|
442
448
|
state = flatten_dict(state)
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from ..utils import BaseOutput
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class AutoencoderKLOutput(BaseOutput):
|
8
|
+
"""
|
9
|
+
Output of AutoencoderKL encoding method.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
latent_dist (`DiagonalGaussianDistribution`):
|
13
|
+
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
14
|
+
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
15
|
+
"""
|
16
|
+
|
17
|
+
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
|
@@ -18,20 +18,20 @@ import inspect
|
|
18
18
|
import itertools
|
19
19
|
import os
|
20
20
|
import re
|
21
|
+
from collections import OrderedDict
|
21
22
|
from functools import partial
|
22
23
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
23
24
|
|
24
25
|
import safetensors
|
25
26
|
import torch
|
26
27
|
from huggingface_hub import create_repo
|
27
|
-
from
|
28
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
29
|
+
from torch import Tensor, nn
|
28
30
|
|
29
31
|
from .. import __version__
|
30
32
|
from ..utils import (
|
31
33
|
CONFIG_NAME,
|
32
|
-
DIFFUSERS_CACHE,
|
33
34
|
FLAX_WEIGHTS_NAME,
|
34
|
-
HF_HUB_OFFLINE,
|
35
35
|
MIN_PEFT_VERSION,
|
36
36
|
SAFETENSORS_WEIGHTS_NAME,
|
37
37
|
WEIGHTS_NAME,
|
@@ -61,7 +61,7 @@ if is_accelerate_available():
|
|
61
61
|
from accelerate.utils.versions import is_torch_version
|
62
62
|
|
63
63
|
|
64
|
-
def get_parameter_device(parameter: torch.nn.Module):
|
64
|
+
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
65
65
|
try:
|
66
66
|
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
67
67
|
return next(parameters_and_buffers).device
|
@@ -77,7 +77,7 @@ def get_parameter_device(parameter: torch.nn.Module):
|
|
77
77
|
return first_tuple[1].device
|
78
78
|
|
79
79
|
|
80
|
-
def get_parameter_dtype(parameter: torch.nn.Module):
|
80
|
+
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
81
81
|
try:
|
82
82
|
params = tuple(parameter.parameters())
|
83
83
|
if len(params) > 0:
|
@@ -130,7 +130,13 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
|
130
130
|
)
|
131
131
|
|
132
132
|
|
133
|
-
def load_model_dict_into_meta(
|
133
|
+
def load_model_dict_into_meta(
|
134
|
+
model,
|
135
|
+
state_dict: OrderedDict,
|
136
|
+
device: Optional[Union[str, torch.device]] = None,
|
137
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
138
|
+
model_name_or_path: Optional[str] = None,
|
139
|
+
) -> List[str]:
|
134
140
|
device = device or torch.device("cpu")
|
135
141
|
dtype = dtype or torch.float32
|
136
142
|
|
@@ -156,7 +162,7 @@ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_
|
|
156
162
|
return unexpected_keys
|
157
163
|
|
158
164
|
|
159
|
-
def _load_state_dict_into_model(model_to_load, state_dict):
|
165
|
+
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
160
166
|
# Convert old format to new format if needed from a PyTorch state_dict
|
161
167
|
# copy state_dict so _load_from_state_dict can modify it
|
162
168
|
state_dict = state_dict.copy()
|
@@ -164,7 +170,7 @@ def _load_state_dict_into_model(model_to_load, state_dict):
|
|
164
170
|
|
165
171
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
166
172
|
# so we need to apply the function recursively.
|
167
|
-
def load(module: torch.nn.Module, prefix=""):
|
173
|
+
def load(module: torch.nn.Module, prefix: str = ""):
|
168
174
|
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
169
175
|
module._load_from_state_dict(*args)
|
170
176
|
|
@@ -186,6 +192,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
186
192
|
|
187
193
|
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
188
194
|
"""
|
195
|
+
|
189
196
|
config_name = CONFIG_NAME
|
190
197
|
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
191
198
|
_supports_gradient_checkpointing = False
|
@@ -220,7 +227,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
220
227
|
"""
|
221
228
|
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
222
229
|
|
223
|
-
def enable_gradient_checkpointing(self):
|
230
|
+
def enable_gradient_checkpointing(self) -> None:
|
224
231
|
"""
|
225
232
|
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
226
233
|
*checkpoint activations* in other frameworks).
|
@@ -229,7 +236,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
229
236
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
230
237
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
231
238
|
|
232
|
-
def disable_gradient_checkpointing(self):
|
239
|
+
def disable_gradient_checkpointing(self) -> None:
|
233
240
|
"""
|
234
241
|
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
235
242
|
*checkpoint activations* in other frameworks).
|
@@ -254,7 +261,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
254
261
|
if isinstance(module, torch.nn.Module):
|
255
262
|
fn_recursive_set_mem_eff(module)
|
256
263
|
|
257
|
-
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
264
|
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
|
258
265
|
r"""
|
259
266
|
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
260
267
|
|
@@ -290,7 +297,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
290
297
|
"""
|
291
298
|
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
292
299
|
|
293
|
-
def disable_xformers_memory_efficient_attention(self):
|
300
|
+
def disable_xformers_memory_efficient_attention(self) -> None:
|
294
301
|
r"""
|
295
302
|
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
296
303
|
"""
|
@@ -447,7 +454,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
447
454
|
self,
|
448
455
|
save_directory: Union[str, os.PathLike],
|
449
456
|
is_main_process: bool = True,
|
450
|
-
save_function: Callable = None,
|
457
|
+
save_function: Optional[Callable] = None,
|
451
458
|
safe_serialization: bool = True,
|
452
459
|
variant: Optional[str] = None,
|
453
460
|
push_to_hub: bool = False,
|
@@ -527,6 +534,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
527
534
|
)
|
528
535
|
|
529
536
|
@classmethod
|
537
|
+
@validate_hf_hub_args
|
530
538
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
531
539
|
r"""
|
532
540
|
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
@@ -563,7 +571,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
563
571
|
local_files_only(`bool`, *optional*, defaults to `False`):
|
564
572
|
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
565
573
|
won't be downloaded from the Hub.
|
566
|
-
|
574
|
+
token (`str` or *bool*, *optional*):
|
567
575
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
568
576
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
569
577
|
revision (`str`, *optional*, defaults to `"main"`):
|
@@ -632,15 +640,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
632
640
|
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
633
641
|
```
|
634
642
|
"""
|
635
|
-
cache_dir = kwargs.pop("cache_dir",
|
643
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
636
644
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
637
645
|
force_download = kwargs.pop("force_download", False)
|
638
646
|
from_flax = kwargs.pop("from_flax", False)
|
639
647
|
resume_download = kwargs.pop("resume_download", False)
|
640
648
|
proxies = kwargs.pop("proxies", None)
|
641
649
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
642
|
-
local_files_only = kwargs.pop("local_files_only",
|
643
|
-
|
650
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
651
|
+
token = kwargs.pop("token", None)
|
644
652
|
revision = kwargs.pop("revision", None)
|
645
653
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
646
654
|
subfolder = kwargs.pop("subfolder", None)
|
@@ -710,7 +718,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
710
718
|
resume_download=resume_download,
|
711
719
|
proxies=proxies,
|
712
720
|
local_files_only=local_files_only,
|
713
|
-
|
721
|
+
token=token,
|
714
722
|
revision=revision,
|
715
723
|
subfolder=subfolder,
|
716
724
|
device_map=device_map,
|
@@ -732,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
732
740
|
resume_download=resume_download,
|
733
741
|
proxies=proxies,
|
734
742
|
local_files_only=local_files_only,
|
735
|
-
|
743
|
+
token=token,
|
736
744
|
revision=revision,
|
737
745
|
subfolder=subfolder,
|
738
746
|
user_agent=user_agent,
|
@@ -755,7 +763,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
755
763
|
resume_download=resume_download,
|
756
764
|
proxies=proxies,
|
757
765
|
local_files_only=local_files_only,
|
758
|
-
|
766
|
+
token=token,
|
759
767
|
revision=revision,
|
760
768
|
subfolder=subfolder,
|
761
769
|
user_agent=user_agent,
|
@@ -774,7 +782,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
774
782
|
resume_download=resume_download,
|
775
783
|
proxies=proxies,
|
776
784
|
local_files_only=local_files_only,
|
777
|
-
|
785
|
+
token=token,
|
778
786
|
revision=revision,
|
779
787
|
subfolder=subfolder,
|
780
788
|
user_agent=user_agent,
|
@@ -910,10 +918,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
910
918
|
def _load_pretrained_model(
|
911
919
|
cls,
|
912
920
|
model,
|
913
|
-
state_dict,
|
921
|
+
state_dict: OrderedDict,
|
914
922
|
resolved_archive_file,
|
915
|
-
pretrained_model_name_or_path,
|
916
|
-
ignore_mismatched_sizes=False,
|
923
|
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
924
|
+
ignore_mismatched_sizes: bool = False,
|
917
925
|
):
|
918
926
|
# Retrieve missing & unexpected_keys
|
919
927
|
model_state_dict = model.state_dict()
|
@@ -1011,7 +1019,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1011
1019
|
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
1012
1020
|
|
1013
1021
|
@property
|
1014
|
-
def device(self) -> device:
|
1022
|
+
def device(self) -> torch.device:
|
1015
1023
|
"""
|
1016
1024
|
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
1017
1025
|
device).
|
@@ -1063,7 +1071,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1063
1071
|
else:
|
1064
1072
|
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
1065
1073
|
|
1066
|
-
def _convert_deprecated_attention_blocks(self, state_dict):
|
1074
|
+
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
1067
1075
|
deprecated_attention_block_paths = []
|
1068
1076
|
|
1069
1077
|
def recursive_find_attn_block(name, module):
|
@@ -1107,7 +1115,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1107
1115
|
if f"{path}.proj_attn.bias" in state_dict:
|
1108
1116
|
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
1109
1117
|
|
1110
|
-
def _temp_convert_self_to_deprecated_attention_blocks(self):
|
1118
|
+
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
1111
1119
|
deprecated_attention_block_modules = []
|
1112
1120
|
|
1113
1121
|
def recursive_find_attn_block(module):
|
@@ -1134,10 +1142,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1134
1142
|
del module.to_v
|
1135
1143
|
del module.to_out
|
1136
1144
|
|
1137
|
-
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
|
1145
|
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
1138
1146
|
deprecated_attention_block_modules = []
|
1139
1147
|
|
1140
|
-
def recursive_find_attn_block(module):
|
1148
|
+
def recursive_find_attn_block(module) -> None:
|
1141
1149
|
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1142
1150
|
deprecated_attention_block_modules.append(module)
|
1143
1151
|
|
@@ -13,14 +13,16 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import numbers
|
16
17
|
from typing import Dict, Optional, Tuple
|
17
18
|
|
18
19
|
import torch
|
19
20
|
import torch.nn as nn
|
20
21
|
import torch.nn.functional as F
|
21
22
|
|
23
|
+
from ..utils import is_torch_version
|
22
24
|
from .activations import get_activation
|
23
|
-
from .embeddings import CombinedTimestepLabelEmbeddings,
|
25
|
+
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
24
26
|
|
25
27
|
|
26
28
|
class AdaLayerNorm(nn.Module):
|
@@ -91,7 +93,7 @@ class AdaLayerNormSingle(nn.Module):
|
|
91
93
|
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
92
94
|
super().__init__()
|
93
95
|
|
94
|
-
self.emb =
|
96
|
+
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
95
97
|
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
96
98
|
)
|
97
99
|
|
@@ -101,8 +103,8 @@ class AdaLayerNormSingle(nn.Module):
|
|
101
103
|
def forward(
|
102
104
|
self,
|
103
105
|
timestep: torch.Tensor,
|
104
|
-
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
105
|
-
batch_size: int = None,
|
106
|
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
107
|
+
batch_size: Optional[int] = None,
|
106
108
|
hidden_dtype: Optional[torch.dtype] = None,
|
107
109
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
108
110
|
# No modulation happening here.
|
@@ -146,3 +148,107 @@ class AdaGroupNorm(nn.Module):
|
|
146
148
|
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
147
149
|
x = x * (1 + scale) + shift
|
148
150
|
return x
|
151
|
+
|
152
|
+
|
153
|
+
class AdaLayerNormContinuous(nn.Module):
|
154
|
+
def __init__(
|
155
|
+
self,
|
156
|
+
embedding_dim: int,
|
157
|
+
conditioning_embedding_dim: int,
|
158
|
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
159
|
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
160
|
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
161
|
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
162
|
+
# set `elementwise_affine` to False.
|
163
|
+
elementwise_affine=True,
|
164
|
+
eps=1e-5,
|
165
|
+
bias=True,
|
166
|
+
norm_type="layer_norm",
|
167
|
+
):
|
168
|
+
super().__init__()
|
169
|
+
self.silu = nn.SiLU()
|
170
|
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
171
|
+
if norm_type == "layer_norm":
|
172
|
+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
173
|
+
elif norm_type == "rms_norm":
|
174
|
+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
175
|
+
else:
|
176
|
+
raise ValueError(f"unknown norm_type {norm_type}")
|
177
|
+
|
178
|
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
179
|
+
emb = self.linear(self.silu(conditioning_embedding))
|
180
|
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
181
|
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
182
|
+
return x
|
183
|
+
|
184
|
+
|
185
|
+
if is_torch_version(">=", "2.1.0"):
|
186
|
+
LayerNorm = nn.LayerNorm
|
187
|
+
else:
|
188
|
+
# Has optional bias parameter compared to torch layer norm
|
189
|
+
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
190
|
+
class LayerNorm(nn.Module):
|
191
|
+
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
192
|
+
super().__init__()
|
193
|
+
|
194
|
+
self.eps = eps
|
195
|
+
|
196
|
+
if isinstance(dim, numbers.Integral):
|
197
|
+
dim = (dim,)
|
198
|
+
|
199
|
+
self.dim = torch.Size(dim)
|
200
|
+
|
201
|
+
if elementwise_affine:
|
202
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
203
|
+
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
204
|
+
else:
|
205
|
+
self.weight = None
|
206
|
+
self.bias = None
|
207
|
+
|
208
|
+
def forward(self, input):
|
209
|
+
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
210
|
+
|
211
|
+
|
212
|
+
class RMSNorm(nn.Module):
|
213
|
+
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
214
|
+
super().__init__()
|
215
|
+
|
216
|
+
self.eps = eps
|
217
|
+
|
218
|
+
if isinstance(dim, numbers.Integral):
|
219
|
+
dim = (dim,)
|
220
|
+
|
221
|
+
self.dim = torch.Size(dim)
|
222
|
+
|
223
|
+
if elementwise_affine:
|
224
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
225
|
+
else:
|
226
|
+
self.weight = None
|
227
|
+
|
228
|
+
def forward(self, hidden_states):
|
229
|
+
input_dtype = hidden_states.dtype
|
230
|
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
231
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
232
|
+
|
233
|
+
if self.weight is not None:
|
234
|
+
# convert into half-precision if necessary
|
235
|
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
236
|
+
hidden_states = hidden_states.to(self.weight.dtype)
|
237
|
+
hidden_states = hidden_states * self.weight
|
238
|
+
else:
|
239
|
+
hidden_states = hidden_states.to(input_dtype)
|
240
|
+
|
241
|
+
return hidden_states
|
242
|
+
|
243
|
+
|
244
|
+
class GlobalResponseNorm(nn.Module):
|
245
|
+
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
246
|
+
def __init__(self, dim):
|
247
|
+
super().__init__()
|
248
|
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
249
|
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
250
|
+
|
251
|
+
def forward(self, x):
|
252
|
+
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
253
|
+
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
254
|
+
return self.gamma * (x * nx) + self.beta + x
|