diffusers 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (174) hide show
  1. diffusers/__init__.py +11 -1
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +12 -8
  4. diffusers/dependency_versions_table.py +2 -1
  5. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  6. diffusers/image_processor.py +286 -46
  7. diffusers/loaders/ip_adapter.py +11 -9
  8. diffusers/loaders/lora.py +198 -60
  9. diffusers/loaders/single_file.py +24 -18
  10. diffusers/loaders/textual_inversion.py +10 -14
  11. diffusers/loaders/unet.py +130 -37
  12. diffusers/models/__init__.py +18 -12
  13. diffusers/models/activations.py +9 -6
  14. diffusers/models/attention.py +137 -16
  15. diffusers/models/attention_processor.py +133 -46
  16. diffusers/models/autoencoders/__init__.py +5 -0
  17. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +4 -4
  18. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +45 -6
  19. diffusers/models/{autoencoder_kl_temporal_decoder.py → autoencoders/autoencoder_kl_temporal_decoder.py} +8 -8
  20. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +4 -4
  21. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +14 -14
  22. diffusers/models/{vae.py → autoencoders/vae.py} +9 -5
  23. diffusers/models/downsampling.py +338 -0
  24. diffusers/models/embeddings.py +112 -29
  25. diffusers/models/modeling_flax_utils.py +12 -7
  26. diffusers/models/modeling_utils.py +10 -10
  27. diffusers/models/normalization.py +108 -2
  28. diffusers/models/resnet.py +15 -699
  29. diffusers/models/transformer_2d.py +2 -2
  30. diffusers/models/unet_2d_condition.py +37 -0
  31. diffusers/models/{unet_kandi3.py → unet_kandinsky3.py} +105 -159
  32. diffusers/models/upsampling.py +454 -0
  33. diffusers/models/uvit_2d.py +471 -0
  34. diffusers/models/vq_model.py +9 -2
  35. diffusers/pipelines/__init__.py +81 -73
  36. diffusers/pipelines/amused/__init__.py +62 -0
  37. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  38. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  39. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +38 -10
  41. diffusers/pipelines/auto_pipeline.py +17 -13
  42. diffusers/pipelines/controlnet/pipeline_controlnet.py +27 -10
  43. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +47 -5
  44. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +25 -8
  45. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +4 -6
  46. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +26 -10
  47. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +4 -3
  48. diffusers/pipelines/deprecated/__init__.py +153 -0
  49. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  50. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +91 -18
  51. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +91 -18
  52. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  53. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  54. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  55. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  56. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  57. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  58. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  59. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  60. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  61. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  62. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  63. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +4 -4
  64. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  65. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  66. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  67. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  68. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +7 -7
  69. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  70. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +16 -11
  71. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +6 -6
  72. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +11 -11
  73. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +16 -11
  74. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +10 -10
  75. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +13 -13
  76. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  77. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  78. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  79. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +54 -11
  80. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  81. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +6 -6
  82. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +6 -6
  83. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +6 -6
  84. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  85. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  86. diffusers/pipelines/kandinsky3/__init__.py +4 -4
  87. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  88. diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py → pipeline_kandinsky3.py} +172 -35
  89. diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py → pipeline_kandinsky3_img2img.py} +228 -34
  90. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +46 -5
  91. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +47 -6
  92. diffusers/pipelines/onnx_utils.py +8 -5
  93. diffusers/pipelines/pipeline_flax_utils.py +7 -6
  94. diffusers/pipelines/pipeline_utils.py +30 -29
  95. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +51 -2
  96. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  97. diffusers/pipelines/stable_diffusion/__init__.py +1 -72
  98. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +67 -75
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +92 -8
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -8
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +138 -10
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +57 -7
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +6 -0
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  107. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  108. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +5 -2
  109. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  110. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +2 -3
  111. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  112. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +2 -2
  113. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +3 -3
  114. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  115. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +6 -1
  116. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  117. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +50 -7
  118. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  119. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +56 -8
  120. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  121. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  122. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +67 -10
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +97 -15
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +97 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +7 -5
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +12 -9
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +6 -0
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -0
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +5 -0
  131. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +331 -9
  132. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +468 -9
  133. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  134. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  135. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  136. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +4 -0
  137. diffusers/schedulers/__init__.py +2 -0
  138. diffusers/schedulers/scheduling_amused.py +162 -0
  139. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  140. diffusers/schedulers/scheduling_ddim_inverse.py +1 -4
  141. diffusers/schedulers/scheduling_ddpm.py +46 -0
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +46 -0
  143. diffusers/schedulers/scheduling_deis_multistep.py +13 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +13 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +13 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -0
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -1
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -0
  149. diffusers/schedulers/scheduling_euler_discrete.py +62 -3
  150. diffusers/schedulers/scheduling_heun_discrete.py +2 -0
  151. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -0
  152. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -0
  153. diffusers/schedulers/scheduling_lms_discrete.py +2 -0
  154. diffusers/schedulers/scheduling_unipc_multistep.py +13 -1
  155. diffusers/schedulers/scheduling_utils.py +3 -1
  156. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  157. diffusers/training_utils.py +1 -1
  158. diffusers/utils/__init__.py +0 -2
  159. diffusers/utils/constants.py +2 -5
  160. diffusers/utils/dummy_pt_objects.py +30 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  162. diffusers/utils/dynamic_modules_utils.py +14 -18
  163. diffusers/utils/hub_utils.py +24 -36
  164. diffusers/utils/logging.py +1 -1
  165. diffusers/utils/state_dict_utils.py +8 -0
  166. diffusers/utils/testing_utils.py +199 -1
  167. diffusers/utils/torch_utils.py +3 -3
  168. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/METADATA +54 -53
  169. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/RECORD +174 -155
  170. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  172. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  173. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -24,13 +24,17 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
24
24
  from flax.serialization import from_bytes, to_bytes
25
25
  from flax.traverse_util import flatten_dict, unflatten_dict
26
26
  from huggingface_hub import create_repo, hf_hub_download
27
- from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
27
+ from huggingface_hub.utils import (
28
+ EntryNotFoundError,
29
+ RepositoryNotFoundError,
30
+ RevisionNotFoundError,
31
+ validate_hf_hub_args,
32
+ )
28
33
  from requests import HTTPError
29
34
 
30
35
  from .. import __version__, is_torch_available
31
36
  from ..utils import (
32
37
  CONFIG_NAME,
33
- DIFFUSERS_CACHE,
34
38
  FLAX_WEIGHTS_NAME,
35
39
  HUGGINGFACE_CO_RESOLVE_ENDPOINT,
36
40
  WEIGHTS_NAME,
@@ -197,6 +201,7 @@ class FlaxModelMixin(PushToHubMixin):
197
201
  raise NotImplementedError(f"init_weights method has to be implemented for {self}")
198
202
 
199
203
  @classmethod
204
+ @validate_hf_hub_args
200
205
  def from_pretrained(
201
206
  cls,
202
207
  pretrained_model_name_or_path: Union[str, os.PathLike],
@@ -288,13 +293,13 @@ class FlaxModelMixin(PushToHubMixin):
288
293
  ```
289
294
  """
290
295
  config = kwargs.pop("config", None)
291
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
296
+ cache_dir = kwargs.pop("cache_dir", None)
292
297
  force_download = kwargs.pop("force_download", False)
293
298
  from_pt = kwargs.pop("from_pt", False)
294
299
  resume_download = kwargs.pop("resume_download", False)
295
300
  proxies = kwargs.pop("proxies", None)
296
301
  local_files_only = kwargs.pop("local_files_only", False)
297
- use_auth_token = kwargs.pop("use_auth_token", None)
302
+ token = kwargs.pop("token", None)
298
303
  revision = kwargs.pop("revision", None)
299
304
  subfolder = kwargs.pop("subfolder", None)
300
305
 
@@ -314,7 +319,7 @@ class FlaxModelMixin(PushToHubMixin):
314
319
  resume_download=resume_download,
315
320
  proxies=proxies,
316
321
  local_files_only=local_files_only,
317
- use_auth_token=use_auth_token,
322
+ token=token,
318
323
  revision=revision,
319
324
  subfolder=subfolder,
320
325
  **kwargs,
@@ -359,7 +364,7 @@ class FlaxModelMixin(PushToHubMixin):
359
364
  proxies=proxies,
360
365
  resume_download=resume_download,
361
366
  local_files_only=local_files_only,
362
- use_auth_token=use_auth_token,
367
+ token=token,
363
368
  user_agent=user_agent,
364
369
  subfolder=subfolder,
365
370
  revision=revision,
@@ -369,7 +374,7 @@ class FlaxModelMixin(PushToHubMixin):
369
374
  raise EnvironmentError(
370
375
  f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
371
376
  "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
372
- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
377
+ "token having permission to this repo with `token` or log in with `huggingface-cli "
373
378
  "login`."
374
379
  )
375
380
  except RevisionNotFoundError:
@@ -25,14 +25,13 @@ from typing import Any, Callable, List, Optional, Tuple, Union
25
25
  import safetensors
26
26
  import torch
27
27
  from huggingface_hub import create_repo
28
+ from huggingface_hub.utils import validate_hf_hub_args
28
29
  from torch import Tensor, nn
29
30
 
30
31
  from .. import __version__
31
32
  from ..utils import (
32
33
  CONFIG_NAME,
33
- DIFFUSERS_CACHE,
34
34
  FLAX_WEIGHTS_NAME,
35
- HF_HUB_OFFLINE,
36
35
  MIN_PEFT_VERSION,
37
36
  SAFETENSORS_WEIGHTS_NAME,
38
37
  WEIGHTS_NAME,
@@ -535,6 +534,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
535
534
  )
536
535
 
537
536
  @classmethod
537
+ @validate_hf_hub_args
538
538
  def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
539
539
  r"""
540
540
  Instantiate a pretrained PyTorch model from a pretrained model configuration.
@@ -571,7 +571,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
571
571
  local_files_only(`bool`, *optional*, defaults to `False`):
572
572
  Whether to only load local model weights and configuration files or not. If set to `True`, the model
573
573
  won't be downloaded from the Hub.
574
- use_auth_token (`str` or *bool*, *optional*):
574
+ token (`str` or *bool*, *optional*):
575
575
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
576
576
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
577
577
  revision (`str`, *optional*, defaults to `"main"`):
@@ -640,15 +640,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
640
640
  You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
641
641
  ```
642
642
  """
643
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
643
+ cache_dir = kwargs.pop("cache_dir", None)
644
644
  ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
645
645
  force_download = kwargs.pop("force_download", False)
646
646
  from_flax = kwargs.pop("from_flax", False)
647
647
  resume_download = kwargs.pop("resume_download", False)
648
648
  proxies = kwargs.pop("proxies", None)
649
649
  output_loading_info = kwargs.pop("output_loading_info", False)
650
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
651
- use_auth_token = kwargs.pop("use_auth_token", None)
650
+ local_files_only = kwargs.pop("local_files_only", None)
651
+ token = kwargs.pop("token", None)
652
652
  revision = kwargs.pop("revision", None)
653
653
  torch_dtype = kwargs.pop("torch_dtype", None)
654
654
  subfolder = kwargs.pop("subfolder", None)
@@ -718,7 +718,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
718
718
  resume_download=resume_download,
719
719
  proxies=proxies,
720
720
  local_files_only=local_files_only,
721
- use_auth_token=use_auth_token,
721
+ token=token,
722
722
  revision=revision,
723
723
  subfolder=subfolder,
724
724
  device_map=device_map,
@@ -740,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
740
740
  resume_download=resume_download,
741
741
  proxies=proxies,
742
742
  local_files_only=local_files_only,
743
- use_auth_token=use_auth_token,
743
+ token=token,
744
744
  revision=revision,
745
745
  subfolder=subfolder,
746
746
  user_agent=user_agent,
@@ -763,7 +763,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
763
763
  resume_download=resume_download,
764
764
  proxies=proxies,
765
765
  local_files_only=local_files_only,
766
- use_auth_token=use_auth_token,
766
+ token=token,
767
767
  revision=revision,
768
768
  subfolder=subfolder,
769
769
  user_agent=user_agent,
@@ -782,7 +782,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
782
782
  resume_download=resume_download,
783
783
  proxies=proxies,
784
784
  local_files_only=local_files_only,
785
- use_auth_token=use_auth_token,
785
+ token=token,
786
786
  revision=revision,
787
787
  subfolder=subfolder,
788
788
  user_agent=user_agent,
@@ -13,14 +13,16 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import numbers
16
17
  from typing import Dict, Optional, Tuple
17
18
 
18
19
  import torch
19
20
  import torch.nn as nn
20
21
  import torch.nn.functional as F
21
22
 
23
+ from ..utils import is_torch_version
22
24
  from .activations import get_activation
23
- from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
25
+ from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
24
26
 
25
27
 
26
28
  class AdaLayerNorm(nn.Module):
@@ -91,7 +93,7 @@ class AdaLayerNormSingle(nn.Module):
91
93
  def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
92
94
  super().__init__()
93
95
 
94
- self.emb = CombinedTimestepSizeEmbeddings(
96
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
95
97
  embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
96
98
  )
97
99
 
@@ -146,3 +148,107 @@ class AdaGroupNorm(nn.Module):
146
148
  x = F.group_norm(x, self.num_groups, eps=self.eps)
147
149
  x = x * (1 + scale) + shift
148
150
  return x
151
+
152
+
153
+ class AdaLayerNormContinuous(nn.Module):
154
+ def __init__(
155
+ self,
156
+ embedding_dim: int,
157
+ conditioning_embedding_dim: int,
158
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
159
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
160
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
161
+ # However, this is how it was implemented in the original code, and it's rather likely you should
162
+ # set `elementwise_affine` to False.
163
+ elementwise_affine=True,
164
+ eps=1e-5,
165
+ bias=True,
166
+ norm_type="layer_norm",
167
+ ):
168
+ super().__init__()
169
+ self.silu = nn.SiLU()
170
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
171
+ if norm_type == "layer_norm":
172
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
173
+ elif norm_type == "rms_norm":
174
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
175
+ else:
176
+ raise ValueError(f"unknown norm_type {norm_type}")
177
+
178
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
179
+ emb = self.linear(self.silu(conditioning_embedding))
180
+ scale, shift = torch.chunk(emb, 2, dim=1)
181
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
182
+ return x
183
+
184
+
185
+ if is_torch_version(">=", "2.1.0"):
186
+ LayerNorm = nn.LayerNorm
187
+ else:
188
+ # Has optional bias parameter compared to torch layer norm
189
+ # TODO: replace with torch layernorm once min required torch version >= 2.1
190
+ class LayerNorm(nn.Module):
191
+ def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
192
+ super().__init__()
193
+
194
+ self.eps = eps
195
+
196
+ if isinstance(dim, numbers.Integral):
197
+ dim = (dim,)
198
+
199
+ self.dim = torch.Size(dim)
200
+
201
+ if elementwise_affine:
202
+ self.weight = nn.Parameter(torch.ones(dim))
203
+ self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
204
+ else:
205
+ self.weight = None
206
+ self.bias = None
207
+
208
+ def forward(self, input):
209
+ return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
210
+
211
+
212
+ class RMSNorm(nn.Module):
213
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
214
+ super().__init__()
215
+
216
+ self.eps = eps
217
+
218
+ if isinstance(dim, numbers.Integral):
219
+ dim = (dim,)
220
+
221
+ self.dim = torch.Size(dim)
222
+
223
+ if elementwise_affine:
224
+ self.weight = nn.Parameter(torch.ones(dim))
225
+ else:
226
+ self.weight = None
227
+
228
+ def forward(self, hidden_states):
229
+ input_dtype = hidden_states.dtype
230
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
231
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
232
+
233
+ if self.weight is not None:
234
+ # convert into half-precision if necessary
235
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
236
+ hidden_states = hidden_states.to(self.weight.dtype)
237
+ hidden_states = hidden_states * self.weight
238
+ else:
239
+ hidden_states = hidden_states.to(input_dtype)
240
+
241
+ return hidden_states
242
+
243
+
244
+ class GlobalResponseNorm(nn.Module):
245
+ # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
246
+ def __init__(self, dim):
247
+ super().__init__()
248
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
249
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
250
+
251
+ def forward(self, x):
252
+ gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
253
+ nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
254
+ return self.gamma * (x * nx) + self.beta + x