diffusers 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +11 -1
- diffusers/commands/fp16_safetensors.py +10 -11
- diffusers/configuration_utils.py +12 -8
- diffusers/dependency_versions_table.py +2 -1
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +286 -46
- diffusers/loaders/ip_adapter.py +11 -9
- diffusers/loaders/lora.py +198 -60
- diffusers/loaders/single_file.py +24 -18
- diffusers/loaders/textual_inversion.py +10 -14
- diffusers/loaders/unet.py +130 -37
- diffusers/models/__init__.py +18 -12
- diffusers/models/activations.py +9 -6
- diffusers/models/attention.py +137 -16
- diffusers/models/attention_processor.py +133 -46
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +4 -4
- diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +45 -6
- diffusers/models/{autoencoder_kl_temporal_decoder.py → autoencoders/autoencoder_kl_temporal_decoder.py} +8 -8
- diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +4 -4
- diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +14 -14
- diffusers/models/{vae.py → autoencoders/vae.py} +9 -5
- diffusers/models/downsampling.py +338 -0
- diffusers/models/embeddings.py +112 -29
- diffusers/models/modeling_flax_utils.py +12 -7
- diffusers/models/modeling_utils.py +10 -10
- diffusers/models/normalization.py +108 -2
- diffusers/models/resnet.py +15 -699
- diffusers/models/transformer_2d.py +2 -2
- diffusers/models/unet_2d_condition.py +37 -0
- diffusers/models/{unet_kandi3.py → unet_kandinsky3.py} +105 -159
- diffusers/models/upsampling.py +454 -0
- diffusers/models/uvit_2d.py +471 -0
- diffusers/models/vq_model.py +9 -2
- diffusers/pipelines/__init__.py +81 -73
- diffusers/pipelines/amused/__init__.py +62 -0
- diffusers/pipelines/amused/pipeline_amused.py +328 -0
- diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +38 -10
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/controlnet/pipeline_controlnet.py +27 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +47 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +25 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +4 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +26 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +4 -3
- diffusers/pipelines/deprecated/__init__.py +153 -0
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +91 -18
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +91 -18
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
- diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
- diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
- diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
- diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +4 -4
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
- diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +16 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +11 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +16 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +10 -10
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +13 -13
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +54 -11
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +6 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +6 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +6 -6
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/kandinsky3/__init__.py +4 -4
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
- diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py → pipeline_kandinsky3.py} +172 -35
- diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py → pipeline_kandinsky3_img2img.py} +228 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +46 -5
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +47 -6
- diffusers/pipelines/onnx_utils.py +8 -5
- diffusers/pipelines/pipeline_flax_utils.py +7 -6
- diffusers/pipelines/pipeline_utils.py +30 -29
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +51 -2
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/__init__.py +1 -72
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +67 -75
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +92 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +138 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +57 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +6 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +5 -2
- diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +2 -3
- diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +2 -2
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +6 -1
- diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +50 -7
- diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +56 -8
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
- diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +67 -10
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +97 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +97 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +7 -5
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +12 -9
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +6 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +5 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +331 -9
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +468 -9
- diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +4 -0
- diffusers/schedulers/__init__.py +2 -0
- diffusers/schedulers/scheduling_amused.py +162 -0
- diffusers/schedulers/scheduling_consistency_models.py +2 -0
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -4
- diffusers/schedulers/scheduling_ddpm.py +46 -0
- diffusers/schedulers/scheduling_ddpm_parallel.py +46 -0
- diffusers/schedulers/scheduling_deis_multistep.py +13 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +13 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +13 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -1
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -0
- diffusers/schedulers/scheduling_euler_discrete.py +62 -3
- diffusers/schedulers/scheduling_heun_discrete.py +2 -0
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -0
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -0
- diffusers/schedulers/scheduling_lms_discrete.py +2 -0
- diffusers/schedulers/scheduling_unipc_multistep.py +13 -1
- diffusers/schedulers/scheduling_utils.py +3 -1
- diffusers/schedulers/scheduling_utils_flax.py +3 -1
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +0 -2
- diffusers/utils/constants.py +2 -5
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
- diffusers/utils/dynamic_modules_utils.py +14 -18
- diffusers/utils/hub_utils.py +24 -36
- diffusers/utils/logging.py +1 -1
- diffusers/utils/state_dict_utils.py +8 -0
- diffusers/utils/testing_utils.py +199 -1
- diffusers/utils/torch_utils.py +3 -3
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/METADATA +54 -53
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/RECORD +174 -155
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
- /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
- {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -24,13 +24,17 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
|
|
24
24
|
from flax.serialization import from_bytes, to_bytes
|
25
25
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
26
26
|
from huggingface_hub import create_repo, hf_hub_download
|
27
|
-
from huggingface_hub.utils import
|
27
|
+
from huggingface_hub.utils import (
|
28
|
+
EntryNotFoundError,
|
29
|
+
RepositoryNotFoundError,
|
30
|
+
RevisionNotFoundError,
|
31
|
+
validate_hf_hub_args,
|
32
|
+
)
|
28
33
|
from requests import HTTPError
|
29
34
|
|
30
35
|
from .. import __version__, is_torch_available
|
31
36
|
from ..utils import (
|
32
37
|
CONFIG_NAME,
|
33
|
-
DIFFUSERS_CACHE,
|
34
38
|
FLAX_WEIGHTS_NAME,
|
35
39
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
36
40
|
WEIGHTS_NAME,
|
@@ -197,6 +201,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
197
201
|
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
198
202
|
|
199
203
|
@classmethod
|
204
|
+
@validate_hf_hub_args
|
200
205
|
def from_pretrained(
|
201
206
|
cls,
|
202
207
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
@@ -288,13 +293,13 @@ class FlaxModelMixin(PushToHubMixin):
|
|
288
293
|
```
|
289
294
|
"""
|
290
295
|
config = kwargs.pop("config", None)
|
291
|
-
cache_dir = kwargs.pop("cache_dir",
|
296
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
292
297
|
force_download = kwargs.pop("force_download", False)
|
293
298
|
from_pt = kwargs.pop("from_pt", False)
|
294
299
|
resume_download = kwargs.pop("resume_download", False)
|
295
300
|
proxies = kwargs.pop("proxies", None)
|
296
301
|
local_files_only = kwargs.pop("local_files_only", False)
|
297
|
-
|
302
|
+
token = kwargs.pop("token", None)
|
298
303
|
revision = kwargs.pop("revision", None)
|
299
304
|
subfolder = kwargs.pop("subfolder", None)
|
300
305
|
|
@@ -314,7 +319,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
314
319
|
resume_download=resume_download,
|
315
320
|
proxies=proxies,
|
316
321
|
local_files_only=local_files_only,
|
317
|
-
|
322
|
+
token=token,
|
318
323
|
revision=revision,
|
319
324
|
subfolder=subfolder,
|
320
325
|
**kwargs,
|
@@ -359,7 +364,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
359
364
|
proxies=proxies,
|
360
365
|
resume_download=resume_download,
|
361
366
|
local_files_only=local_files_only,
|
362
|
-
|
367
|
+
token=token,
|
363
368
|
user_agent=user_agent,
|
364
369
|
subfolder=subfolder,
|
365
370
|
revision=revision,
|
@@ -369,7 +374,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
369
374
|
raise EnvironmentError(
|
370
375
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
371
376
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
372
|
-
"token having permission to this repo with `
|
377
|
+
"token having permission to this repo with `token` or log in with `huggingface-cli "
|
373
378
|
"login`."
|
374
379
|
)
|
375
380
|
except RevisionNotFoundError:
|
@@ -25,14 +25,13 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
|
25
25
|
import safetensors
|
26
26
|
import torch
|
27
27
|
from huggingface_hub import create_repo
|
28
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
28
29
|
from torch import Tensor, nn
|
29
30
|
|
30
31
|
from .. import __version__
|
31
32
|
from ..utils import (
|
32
33
|
CONFIG_NAME,
|
33
|
-
DIFFUSERS_CACHE,
|
34
34
|
FLAX_WEIGHTS_NAME,
|
35
|
-
HF_HUB_OFFLINE,
|
36
35
|
MIN_PEFT_VERSION,
|
37
36
|
SAFETENSORS_WEIGHTS_NAME,
|
38
37
|
WEIGHTS_NAME,
|
@@ -535,6 +534,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
535
534
|
)
|
536
535
|
|
537
536
|
@classmethod
|
537
|
+
@validate_hf_hub_args
|
538
538
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
539
539
|
r"""
|
540
540
|
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
@@ -571,7 +571,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
571
571
|
local_files_only(`bool`, *optional*, defaults to `False`):
|
572
572
|
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
573
573
|
won't be downloaded from the Hub.
|
574
|
-
|
574
|
+
token (`str` or *bool*, *optional*):
|
575
575
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
576
576
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
577
577
|
revision (`str`, *optional*, defaults to `"main"`):
|
@@ -640,15 +640,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
640
640
|
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
641
641
|
```
|
642
642
|
"""
|
643
|
-
cache_dir = kwargs.pop("cache_dir",
|
643
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
644
644
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
645
645
|
force_download = kwargs.pop("force_download", False)
|
646
646
|
from_flax = kwargs.pop("from_flax", False)
|
647
647
|
resume_download = kwargs.pop("resume_download", False)
|
648
648
|
proxies = kwargs.pop("proxies", None)
|
649
649
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
650
|
-
local_files_only = kwargs.pop("local_files_only",
|
651
|
-
|
650
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
651
|
+
token = kwargs.pop("token", None)
|
652
652
|
revision = kwargs.pop("revision", None)
|
653
653
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
654
654
|
subfolder = kwargs.pop("subfolder", None)
|
@@ -718,7 +718,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
718
718
|
resume_download=resume_download,
|
719
719
|
proxies=proxies,
|
720
720
|
local_files_only=local_files_only,
|
721
|
-
|
721
|
+
token=token,
|
722
722
|
revision=revision,
|
723
723
|
subfolder=subfolder,
|
724
724
|
device_map=device_map,
|
@@ -740,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
740
740
|
resume_download=resume_download,
|
741
741
|
proxies=proxies,
|
742
742
|
local_files_only=local_files_only,
|
743
|
-
|
743
|
+
token=token,
|
744
744
|
revision=revision,
|
745
745
|
subfolder=subfolder,
|
746
746
|
user_agent=user_agent,
|
@@ -763,7 +763,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
763
763
|
resume_download=resume_download,
|
764
764
|
proxies=proxies,
|
765
765
|
local_files_only=local_files_only,
|
766
|
-
|
766
|
+
token=token,
|
767
767
|
revision=revision,
|
768
768
|
subfolder=subfolder,
|
769
769
|
user_agent=user_agent,
|
@@ -782,7 +782,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
782
782
|
resume_download=resume_download,
|
783
783
|
proxies=proxies,
|
784
784
|
local_files_only=local_files_only,
|
785
|
-
|
785
|
+
token=token,
|
786
786
|
revision=revision,
|
787
787
|
subfolder=subfolder,
|
788
788
|
user_agent=user_agent,
|
@@ -13,14 +13,16 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import numbers
|
16
17
|
from typing import Dict, Optional, Tuple
|
17
18
|
|
18
19
|
import torch
|
19
20
|
import torch.nn as nn
|
20
21
|
import torch.nn.functional as F
|
21
22
|
|
23
|
+
from ..utils import is_torch_version
|
22
24
|
from .activations import get_activation
|
23
|
-
from .embeddings import CombinedTimestepLabelEmbeddings,
|
25
|
+
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
24
26
|
|
25
27
|
|
26
28
|
class AdaLayerNorm(nn.Module):
|
@@ -91,7 +93,7 @@ class AdaLayerNormSingle(nn.Module):
|
|
91
93
|
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
92
94
|
super().__init__()
|
93
95
|
|
94
|
-
self.emb =
|
96
|
+
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
95
97
|
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
96
98
|
)
|
97
99
|
|
@@ -146,3 +148,107 @@ class AdaGroupNorm(nn.Module):
|
|
146
148
|
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
147
149
|
x = x * (1 + scale) + shift
|
148
150
|
return x
|
151
|
+
|
152
|
+
|
153
|
+
class AdaLayerNormContinuous(nn.Module):
|
154
|
+
def __init__(
|
155
|
+
self,
|
156
|
+
embedding_dim: int,
|
157
|
+
conditioning_embedding_dim: int,
|
158
|
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
159
|
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
160
|
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
161
|
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
162
|
+
# set `elementwise_affine` to False.
|
163
|
+
elementwise_affine=True,
|
164
|
+
eps=1e-5,
|
165
|
+
bias=True,
|
166
|
+
norm_type="layer_norm",
|
167
|
+
):
|
168
|
+
super().__init__()
|
169
|
+
self.silu = nn.SiLU()
|
170
|
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
171
|
+
if norm_type == "layer_norm":
|
172
|
+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
173
|
+
elif norm_type == "rms_norm":
|
174
|
+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
175
|
+
else:
|
176
|
+
raise ValueError(f"unknown norm_type {norm_type}")
|
177
|
+
|
178
|
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
179
|
+
emb = self.linear(self.silu(conditioning_embedding))
|
180
|
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
181
|
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
182
|
+
return x
|
183
|
+
|
184
|
+
|
185
|
+
if is_torch_version(">=", "2.1.0"):
|
186
|
+
LayerNorm = nn.LayerNorm
|
187
|
+
else:
|
188
|
+
# Has optional bias parameter compared to torch layer norm
|
189
|
+
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
190
|
+
class LayerNorm(nn.Module):
|
191
|
+
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
192
|
+
super().__init__()
|
193
|
+
|
194
|
+
self.eps = eps
|
195
|
+
|
196
|
+
if isinstance(dim, numbers.Integral):
|
197
|
+
dim = (dim,)
|
198
|
+
|
199
|
+
self.dim = torch.Size(dim)
|
200
|
+
|
201
|
+
if elementwise_affine:
|
202
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
203
|
+
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
204
|
+
else:
|
205
|
+
self.weight = None
|
206
|
+
self.bias = None
|
207
|
+
|
208
|
+
def forward(self, input):
|
209
|
+
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
210
|
+
|
211
|
+
|
212
|
+
class RMSNorm(nn.Module):
|
213
|
+
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
214
|
+
super().__init__()
|
215
|
+
|
216
|
+
self.eps = eps
|
217
|
+
|
218
|
+
if isinstance(dim, numbers.Integral):
|
219
|
+
dim = (dim,)
|
220
|
+
|
221
|
+
self.dim = torch.Size(dim)
|
222
|
+
|
223
|
+
if elementwise_affine:
|
224
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
225
|
+
else:
|
226
|
+
self.weight = None
|
227
|
+
|
228
|
+
def forward(self, hidden_states):
|
229
|
+
input_dtype = hidden_states.dtype
|
230
|
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
231
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
232
|
+
|
233
|
+
if self.weight is not None:
|
234
|
+
# convert into half-precision if necessary
|
235
|
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
236
|
+
hidden_states = hidden_states.to(self.weight.dtype)
|
237
|
+
hidden_states = hidden_states * self.weight
|
238
|
+
else:
|
239
|
+
hidden_states = hidden_states.to(input_dtype)
|
240
|
+
|
241
|
+
return hidden_states
|
242
|
+
|
243
|
+
|
244
|
+
class GlobalResponseNorm(nn.Module):
|
245
|
+
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
246
|
+
def __init__(self, dim):
|
247
|
+
super().__init__()
|
248
|
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
249
|
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
250
|
+
|
251
|
+
def forward(self, x):
|
252
|
+
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
253
|
+
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
254
|
+
return self.gamma * (x * nx) + self.beta + x
|