diffusers 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (174) hide show
  1. diffusers/__init__.py +11 -1
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +12 -8
  4. diffusers/dependency_versions_table.py +2 -1
  5. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  6. diffusers/image_processor.py +286 -46
  7. diffusers/loaders/ip_adapter.py +11 -9
  8. diffusers/loaders/lora.py +198 -60
  9. diffusers/loaders/single_file.py +24 -18
  10. diffusers/loaders/textual_inversion.py +10 -14
  11. diffusers/loaders/unet.py +130 -37
  12. diffusers/models/__init__.py +18 -12
  13. diffusers/models/activations.py +9 -6
  14. diffusers/models/attention.py +137 -16
  15. diffusers/models/attention_processor.py +133 -46
  16. diffusers/models/autoencoders/__init__.py +5 -0
  17. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +4 -4
  18. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +45 -6
  19. diffusers/models/{autoencoder_kl_temporal_decoder.py → autoencoders/autoencoder_kl_temporal_decoder.py} +8 -8
  20. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +4 -4
  21. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +14 -14
  22. diffusers/models/{vae.py → autoencoders/vae.py} +9 -5
  23. diffusers/models/downsampling.py +338 -0
  24. diffusers/models/embeddings.py +112 -29
  25. diffusers/models/modeling_flax_utils.py +12 -7
  26. diffusers/models/modeling_utils.py +10 -10
  27. diffusers/models/normalization.py +108 -2
  28. diffusers/models/resnet.py +15 -699
  29. diffusers/models/transformer_2d.py +2 -2
  30. diffusers/models/unet_2d_condition.py +37 -0
  31. diffusers/models/{unet_kandi3.py → unet_kandinsky3.py} +105 -159
  32. diffusers/models/upsampling.py +454 -0
  33. diffusers/models/uvit_2d.py +471 -0
  34. diffusers/models/vq_model.py +9 -2
  35. diffusers/pipelines/__init__.py +81 -73
  36. diffusers/pipelines/amused/__init__.py +62 -0
  37. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  38. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  39. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +38 -10
  41. diffusers/pipelines/auto_pipeline.py +17 -13
  42. diffusers/pipelines/controlnet/pipeline_controlnet.py +27 -10
  43. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +47 -5
  44. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +25 -8
  45. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +4 -6
  46. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +26 -10
  47. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +4 -3
  48. diffusers/pipelines/deprecated/__init__.py +153 -0
  49. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  50. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +91 -18
  51. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +91 -18
  52. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  53. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  54. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  55. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  56. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  57. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  58. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  59. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  60. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  61. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  62. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  63. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +4 -4
  64. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  65. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  66. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  67. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  68. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +7 -7
  69. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  70. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +16 -11
  71. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +6 -6
  72. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +11 -11
  73. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +16 -11
  74. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +10 -10
  75. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +13 -13
  76. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  77. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  78. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  79. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +54 -11
  80. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  81. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +6 -6
  82. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +6 -6
  83. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +6 -6
  84. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  85. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  86. diffusers/pipelines/kandinsky3/__init__.py +4 -4
  87. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  88. diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py → pipeline_kandinsky3.py} +172 -35
  89. diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py → pipeline_kandinsky3_img2img.py} +228 -34
  90. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +46 -5
  91. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +47 -6
  92. diffusers/pipelines/onnx_utils.py +8 -5
  93. diffusers/pipelines/pipeline_flax_utils.py +7 -6
  94. diffusers/pipelines/pipeline_utils.py +30 -29
  95. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +51 -2
  96. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  97. diffusers/pipelines/stable_diffusion/__init__.py +1 -72
  98. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +67 -75
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +92 -8
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -8
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +138 -10
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +57 -7
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +6 -0
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  107. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  108. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +5 -2
  109. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  110. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +2 -3
  111. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  112. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +2 -2
  113. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +3 -3
  114. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  115. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +6 -1
  116. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  117. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +50 -7
  118. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  119. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +56 -8
  120. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  121. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  122. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +67 -10
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +97 -15
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +97 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +7 -5
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +12 -9
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +6 -0
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -0
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +5 -0
  131. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +331 -9
  132. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +468 -9
  133. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  134. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  135. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  136. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +4 -0
  137. diffusers/schedulers/__init__.py +2 -0
  138. diffusers/schedulers/scheduling_amused.py +162 -0
  139. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  140. diffusers/schedulers/scheduling_ddim_inverse.py +1 -4
  141. diffusers/schedulers/scheduling_ddpm.py +46 -0
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +46 -0
  143. diffusers/schedulers/scheduling_deis_multistep.py +13 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +13 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +13 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -0
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -1
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -0
  149. diffusers/schedulers/scheduling_euler_discrete.py +62 -3
  150. diffusers/schedulers/scheduling_heun_discrete.py +2 -0
  151. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -0
  152. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -0
  153. diffusers/schedulers/scheduling_lms_discrete.py +2 -0
  154. diffusers/schedulers/scheduling_unipc_multistep.py +13 -1
  155. diffusers/schedulers/scheduling_utils.py +3 -1
  156. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  157. diffusers/training_utils.py +1 -1
  158. diffusers/utils/__init__.py +0 -2
  159. diffusers/utils/constants.py +2 -5
  160. diffusers/utils/dummy_pt_objects.py +30 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  162. diffusers/utils/dynamic_modules_utils.py +14 -18
  163. diffusers/utils/hub_utils.py +24 -36
  164. diffusers/utils/logging.py +1 -1
  165. diffusers/utils/state_dict_utils.py +8 -0
  166. diffusers/utils/testing_utils.py +199 -1
  167. diffusers/utils/torch_utils.py +3 -3
  168. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/METADATA +54 -53
  169. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/RECORD +174 -155
  170. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  172. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  173. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -18,10 +18,9 @@ from pathlib import Path
18
18
  import requests
19
19
  import torch
20
20
  from huggingface_hub import hf_hub_download
21
+ from huggingface_hub.utils import validate_hf_hub_args
21
22
 
22
23
  from ..utils import (
23
- DIFFUSERS_CACHE,
24
- HF_HUB_OFFLINE,
25
24
  deprecate,
26
25
  is_accelerate_available,
27
26
  is_omegaconf_available,
@@ -52,6 +51,7 @@ class FromSingleFileMixin:
52
51
  return cls.from_single_file(*args, **kwargs)
53
52
 
54
53
  @classmethod
54
+ @validate_hf_hub_args
55
55
  def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
56
56
  r"""
57
57
  Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
@@ -81,7 +81,7 @@ class FromSingleFileMixin:
81
81
  local_files_only (`bool`, *optional*, defaults to `False`):
82
82
  Whether to only load local model weights and configuration files or not. If set to `True`, the model
83
83
  won't be downloaded from the Hub.
84
- use_auth_token (`str` or *bool*, *optional*):
84
+ token (`str` or *bool*, *optional*):
85
85
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
86
86
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
87
87
  revision (`str`, *optional*, defaults to `"main"`):
@@ -154,12 +154,12 @@ class FromSingleFileMixin:
154
154
 
155
155
  original_config_file = kwargs.pop("original_config_file", None)
156
156
  config_files = kwargs.pop("config_files", None)
157
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
157
+ cache_dir = kwargs.pop("cache_dir", None)
158
158
  resume_download = kwargs.pop("resume_download", False)
159
159
  force_download = kwargs.pop("force_download", False)
160
160
  proxies = kwargs.pop("proxies", None)
161
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
162
- use_auth_token = kwargs.pop("use_auth_token", None)
161
+ local_files_only = kwargs.pop("local_files_only", None)
162
+ token = kwargs.pop("token", None)
163
163
  revision = kwargs.pop("revision", None)
164
164
  extract_ema = kwargs.pop("extract_ema", False)
165
165
  image_size = kwargs.pop("image_size", None)
@@ -169,10 +169,12 @@ class FromSingleFileMixin:
169
169
  load_safety_checker = kwargs.pop("load_safety_checker", True)
170
170
  prediction_type = kwargs.pop("prediction_type", None)
171
171
  text_encoder = kwargs.pop("text_encoder", None)
172
+ text_encoder_2 = kwargs.pop("text_encoder_2", None)
172
173
  vae = kwargs.pop("vae", None)
173
174
  controlnet = kwargs.pop("controlnet", None)
174
175
  adapter = kwargs.pop("adapter", None)
175
176
  tokenizer = kwargs.pop("tokenizer", None)
177
+ tokenizer_2 = kwargs.pop("tokenizer_2", None)
176
178
 
177
179
  torch_dtype = kwargs.pop("torch_dtype", None)
178
180
 
@@ -253,7 +255,7 @@ class FromSingleFileMixin:
253
255
  resume_download=resume_download,
254
256
  proxies=proxies,
255
257
  local_files_only=local_files_only,
256
- use_auth_token=use_auth_token,
258
+ token=token,
257
259
  revision=revision,
258
260
  force_download=force_download,
259
261
  )
@@ -274,15 +276,17 @@ class FromSingleFileMixin:
274
276
  load_safety_checker=load_safety_checker,
275
277
  prediction_type=prediction_type,
276
278
  text_encoder=text_encoder,
279
+ text_encoder_2=text_encoder_2,
277
280
  vae=vae,
278
281
  tokenizer=tokenizer,
282
+ tokenizer_2=tokenizer_2,
279
283
  original_config_file=original_config_file,
280
284
  config_files=config_files,
281
285
  local_files_only=local_files_only,
282
286
  )
283
287
 
284
288
  if torch_dtype is not None:
285
- pipe.to(torch_dtype=torch_dtype)
289
+ pipe.to(dtype=torch_dtype)
286
290
 
287
291
  return pipe
288
292
 
@@ -293,6 +297,7 @@ class FromOriginalVAEMixin:
293
297
  """
294
298
 
295
299
  @classmethod
300
+ @validate_hf_hub_args
296
301
  def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
297
302
  r"""
298
303
  Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
@@ -322,7 +327,7 @@ class FromOriginalVAEMixin:
322
327
  local_files_only (`bool`, *optional*, defaults to `False`):
323
328
  Whether to only load local model weights and configuration files or not. If set to True, the model
324
329
  won't be downloaded from the Hub.
325
- use_auth_token (`str` or *bool*, *optional*):
330
+ token (`str` or *bool*, *optional*):
326
331
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
327
332
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
328
333
  revision (`str`, *optional*, defaults to `"main"`):
@@ -379,12 +384,12 @@ class FromOriginalVAEMixin:
379
384
  )
380
385
 
381
386
  config_file = kwargs.pop("config_file", None)
382
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
387
+ cache_dir = kwargs.pop("cache_dir", None)
383
388
  resume_download = kwargs.pop("resume_download", False)
384
389
  force_download = kwargs.pop("force_download", False)
385
390
  proxies = kwargs.pop("proxies", None)
386
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
387
- use_auth_token = kwargs.pop("use_auth_token", None)
391
+ local_files_only = kwargs.pop("local_files_only", None)
392
+ token = kwargs.pop("token", None)
388
393
  revision = kwargs.pop("revision", None)
389
394
  image_size = kwargs.pop("image_size", None)
390
395
  scaling_factor = kwargs.pop("scaling_factor", None)
@@ -425,7 +430,7 @@ class FromOriginalVAEMixin:
425
430
  resume_download=resume_download,
426
431
  proxies=proxies,
427
432
  local_files_only=local_files_only,
428
- use_auth_token=use_auth_token,
433
+ token=token,
429
434
  revision=revision,
430
435
  force_download=force_download,
431
436
  )
@@ -490,6 +495,7 @@ class FromOriginalControlnetMixin:
490
495
  """
491
496
 
492
497
  @classmethod
498
+ @validate_hf_hub_args
493
499
  def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
494
500
  r"""
495
501
  Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
@@ -519,7 +525,7 @@ class FromOriginalControlnetMixin:
519
525
  local_files_only (`bool`, *optional*, defaults to `False`):
520
526
  Whether to only load local model weights and configuration files or not. If set to True, the model
521
527
  won't be downloaded from the Hub.
522
- use_auth_token (`str` or *bool*, *optional*):
528
+ token (`str` or *bool*, *optional*):
523
529
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
524
530
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
525
531
  revision (`str`, *optional*, defaults to `"main"`):
@@ -555,12 +561,12 @@ class FromOriginalControlnetMixin:
555
561
  from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
556
562
 
557
563
  config_file = kwargs.pop("config_file", None)
558
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
564
+ cache_dir = kwargs.pop("cache_dir", None)
559
565
  resume_download = kwargs.pop("resume_download", False)
560
566
  force_download = kwargs.pop("force_download", False)
561
567
  proxies = kwargs.pop("proxies", None)
562
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
563
- use_auth_token = kwargs.pop("use_auth_token", None)
568
+ local_files_only = kwargs.pop("local_files_only", None)
569
+ token = kwargs.pop("token", None)
564
570
  num_in_channels = kwargs.pop("num_in_channels", None)
565
571
  use_linear_projection = kwargs.pop("use_linear_projection", None)
566
572
  revision = kwargs.pop("revision", None)
@@ -603,7 +609,7 @@ class FromOriginalControlnetMixin:
603
609
  resume_download=resume_download,
604
610
  proxies=proxies,
605
611
  local_files_only=local_files_only,
606
- use_auth_token=use_auth_token,
612
+ token=token,
607
613
  revision=revision,
608
614
  force_download=force_download,
609
615
  )
@@ -15,16 +15,10 @@ from typing import Dict, List, Optional, Union
15
15
 
16
16
  import safetensors
17
17
  import torch
18
+ from huggingface_hub.utils import validate_hf_hub_args
18
19
  from torch import nn
19
20
 
20
- from ..utils import (
21
- DIFFUSERS_CACHE,
22
- HF_HUB_OFFLINE,
23
- _get_model_file,
24
- is_accelerate_available,
25
- is_transformers_available,
26
- logging,
27
- )
21
+ from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
28
22
 
29
23
 
30
24
  if is_transformers_available():
@@ -39,13 +33,14 @@ TEXT_INVERSION_NAME = "learned_embeds.bin"
39
33
  TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
40
34
 
41
35
 
36
+ @validate_hf_hub_args
42
37
  def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
43
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
38
+ cache_dir = kwargs.pop("cache_dir", None)
44
39
  force_download = kwargs.pop("force_download", False)
45
40
  resume_download = kwargs.pop("resume_download", False)
46
41
  proxies = kwargs.pop("proxies", None)
47
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
48
- use_auth_token = kwargs.pop("use_auth_token", None)
42
+ local_files_only = kwargs.pop("local_files_only", None)
43
+ token = kwargs.pop("token", None)
49
44
  revision = kwargs.pop("revision", None)
50
45
  subfolder = kwargs.pop("subfolder", None)
51
46
  weight_name = kwargs.pop("weight_name", None)
@@ -79,7 +74,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
79
74
  resume_download=resume_download,
80
75
  proxies=proxies,
81
76
  local_files_only=local_files_only,
82
- use_auth_token=use_auth_token,
77
+ token=token,
83
78
  revision=revision,
84
79
  subfolder=subfolder,
85
80
  user_agent=user_agent,
@@ -100,7 +95,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
100
95
  resume_download=resume_download,
101
96
  proxies=proxies,
102
97
  local_files_only=local_files_only,
103
- use_auth_token=use_auth_token,
98
+ token=token,
104
99
  revision=revision,
105
100
  subfolder=subfolder,
106
101
  user_agent=user_agent,
@@ -267,6 +262,7 @@ class TextualInversionLoaderMixin:
267
262
 
268
263
  return all_tokens, all_embeddings
269
264
 
265
+ @validate_hf_hub_args
270
266
  def load_textual_inversion(
271
267
  self,
272
268
  pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
@@ -320,7 +316,7 @@ class TextualInversionLoaderMixin:
320
316
  local_files_only (`bool`, *optional*, defaults to `False`):
321
317
  Whether to only load local model weights and configuration files or not. If set to `True`, the model
322
318
  won't be downloaded from the Hub.
323
- use_auth_token (`str` or *bool*, *optional*):
319
+ token (`str` or *bool*, *optional*):
324
320
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
325
321
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
326
322
  revision (`str`, *optional*, defaults to `"main"`):
diffusers/loaders/unet.py CHANGED
@@ -11,21 +11,22 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import inspect
14
15
  import os
15
16
  from collections import defaultdict
16
17
  from contextlib import nullcontext
18
+ from functools import partial
17
19
  from typing import Callable, Dict, List, Optional, Union
18
20
 
19
21
  import safetensors
20
22
  import torch
21
23
  import torch.nn.functional as F
24
+ from huggingface_hub.utils import validate_hf_hub_args
22
25
  from torch import nn
23
26
 
24
- from ..models.embeddings import ImageProjection
27
+ from ..models.embeddings import ImageProjection, MLPProjection, Resampler
25
28
  from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
26
29
  from ..utils import (
27
- DIFFUSERS_CACHE,
28
- HF_HUB_OFFLINE,
29
30
  USE_PEFT_BACKEND,
30
31
  _get_model_file,
31
32
  delete_adapter_layers,
@@ -62,6 +63,7 @@ class UNet2DConditionLoadersMixin:
62
63
  text_encoder_name = TEXT_ENCODER_NAME
63
64
  unet_name = UNET_NAME
64
65
 
66
+ @validate_hf_hub_args
65
67
  def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
66
68
  r"""
67
69
  Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
@@ -95,7 +97,7 @@ class UNet2DConditionLoadersMixin:
95
97
  local_files_only (`bool`, *optional*, defaults to `False`):
96
98
  Whether to only load local model weights and configuration files or not. If set to `True`, the model
97
99
  won't be downloaded from the Hub.
98
- use_auth_token (`str` or *bool*, *optional*):
100
+ token (`str` or *bool*, *optional*):
99
101
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
100
102
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
101
103
  low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
@@ -130,12 +132,12 @@ class UNet2DConditionLoadersMixin:
130
132
  from ..models.attention_processor import CustomDiffusionAttnProcessor
131
133
  from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
132
134
 
133
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
135
+ cache_dir = kwargs.pop("cache_dir", None)
134
136
  force_download = kwargs.pop("force_download", False)
135
137
  resume_download = kwargs.pop("resume_download", False)
136
138
  proxies = kwargs.pop("proxies", None)
137
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
138
- use_auth_token = kwargs.pop("use_auth_token", None)
139
+ local_files_only = kwargs.pop("local_files_only", None)
140
+ token = kwargs.pop("token", None)
139
141
  revision = kwargs.pop("revision", None)
140
142
  subfolder = kwargs.pop("subfolder", None)
141
143
  weight_name = kwargs.pop("weight_name", None)
@@ -184,7 +186,7 @@ class UNet2DConditionLoadersMixin:
184
186
  resume_download=resume_download,
185
187
  proxies=proxies,
186
188
  local_files_only=local_files_only,
187
- use_auth_token=use_auth_token,
189
+ token=token,
188
190
  revision=revision,
189
191
  subfolder=subfolder,
190
192
  user_agent=user_agent,
@@ -204,7 +206,7 @@ class UNet2DConditionLoadersMixin:
204
206
  resume_download=resume_download,
205
207
  proxies=proxies,
206
208
  local_files_only=local_files_only,
207
- use_auth_token=use_auth_token,
209
+ token=token,
208
210
  revision=revision,
209
211
  subfolder=subfolder,
210
212
  user_agent=user_agent,
@@ -504,22 +506,43 @@ class UNet2DConditionLoadersMixin:
504
506
  save_function(state_dict, os.path.join(save_directory, weight_name))
505
507
  logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
506
508
 
507
- def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
509
+ def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
508
510
  self.lora_scale = lora_scale
509
511
  self._safe_fusing = safe_fusing
510
- self.apply(self._fuse_lora_apply)
512
+ self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
511
513
 
512
- def _fuse_lora_apply(self, module):
514
+ def _fuse_lora_apply(self, module, adapter_names=None):
513
515
  if not USE_PEFT_BACKEND:
514
516
  if hasattr(module, "_fuse_lora"):
515
517
  module._fuse_lora(self.lora_scale, self._safe_fusing)
518
+
519
+ if adapter_names is not None:
520
+ raise ValueError(
521
+ "The `adapter_names` argument is not supported in your environment. Please switch"
522
+ " to PEFT backend to use this argument by installing latest PEFT and transformers."
523
+ " `pip install -U peft transformers`"
524
+ )
516
525
  else:
517
526
  from peft.tuners.tuners_utils import BaseTunerLayer
518
527
 
528
+ merge_kwargs = {"safe_merge": self._safe_fusing}
529
+
519
530
  if isinstance(module, BaseTunerLayer):
520
531
  if self.lora_scale != 1.0:
521
532
  module.scale_layer(self.lora_scale)
522
- module.merge(safe_merge=self._safe_fusing)
533
+
534
+ # For BC with prevous PEFT versions, we need to check the signature
535
+ # of the `merge` method to see if it supports the `adapter_names` argument.
536
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
537
+ if "adapter_names" in supported_merge_kwargs:
538
+ merge_kwargs["adapter_names"] = adapter_names
539
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
540
+ raise ValueError(
541
+ "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
542
+ " to the latest version of PEFT. `pip install -U peft`"
543
+ )
544
+
545
+ module.merge(**merge_kwargs)
523
546
 
524
547
  def unfuse_lora(self):
525
548
  self.apply(self._unfuse_lora_apply)
@@ -664,6 +687,80 @@ class UNet2DConditionLoadersMixin:
664
687
  if hasattr(self, "peft_config"):
665
688
  self.peft_config.pop(adapter_name, None)
666
689
 
690
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
691
+ updated_state_dict = {}
692
+ image_projection = None
693
+
694
+ if "proj.weight" in state_dict:
695
+ # IP-Adapter
696
+ num_image_text_embeds = 4
697
+ clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
698
+ cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
699
+
700
+ image_projection = ImageProjection(
701
+ cross_attention_dim=cross_attention_dim,
702
+ image_embed_dim=clip_embeddings_dim,
703
+ num_image_text_embeds=num_image_text_embeds,
704
+ )
705
+
706
+ for key, value in state_dict.items():
707
+ diffusers_name = key.replace("proj", "image_embeds")
708
+ updated_state_dict[diffusers_name] = value
709
+
710
+ elif "proj.3.weight" in state_dict:
711
+ # IP-Adapter Full
712
+ clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
713
+ cross_attention_dim = state_dict["proj.3.weight"].shape[0]
714
+
715
+ image_projection = MLPProjection(
716
+ cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
717
+ )
718
+
719
+ for key, value in state_dict.items():
720
+ diffusers_name = key.replace("proj.0", "ff.net.0.proj")
721
+ diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
722
+ diffusers_name = diffusers_name.replace("proj.3", "norm")
723
+ updated_state_dict[diffusers_name] = value
724
+
725
+ else:
726
+ # IP-Adapter Plus
727
+ num_image_text_embeds = state_dict["latents"].shape[1]
728
+ embed_dims = state_dict["proj_in.weight"].shape[1]
729
+ output_dims = state_dict["proj_out.weight"].shape[0]
730
+ hidden_dims = state_dict["latents"].shape[2]
731
+ heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
732
+
733
+ image_projection = Resampler(
734
+ embed_dims=embed_dims,
735
+ output_dims=output_dims,
736
+ hidden_dims=hidden_dims,
737
+ heads=heads,
738
+ num_queries=num_image_text_embeds,
739
+ )
740
+
741
+ for key, value in state_dict.items():
742
+ diffusers_name = key.replace("0.to", "2.to")
743
+ diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
744
+ diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
745
+ diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
746
+ diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
747
+
748
+ if "norm1" in diffusers_name:
749
+ updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
750
+ elif "norm2" in diffusers_name:
751
+ updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
752
+ elif "to_kv" in diffusers_name:
753
+ v_chunk = value.chunk(2, dim=0)
754
+ updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
755
+ updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
756
+ elif "to_out" in diffusers_name:
757
+ updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
758
+ else:
759
+ updated_state_dict[diffusers_name] = value
760
+
761
+ image_projection.load_state_dict(updated_state_dict)
762
+ return image_projection
763
+
667
764
  def _load_ip_adapter_weights(self, state_dict):
668
765
  from ..models.attention_processor import (
669
766
  AttnProcessor,
@@ -672,6 +769,20 @@ class UNet2DConditionLoadersMixin:
672
769
  IPAdapterAttnProcessor2_0,
673
770
  )
674
771
 
772
+ if "proj.weight" in state_dict["image_proj"]:
773
+ # IP-Adapter
774
+ num_image_text_embeds = 4
775
+ elif "proj.3.weight" in state_dict["image_proj"]:
776
+ # IP-Adapter Full Face
777
+ num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
778
+ else:
779
+ # IP-Adapter Plus
780
+ num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
781
+
782
+ # Set encoder_hid_proj after loading ip_adapter weights,
783
+ # because `Resampler` also has `attn_processors`.
784
+ self.encoder_hid_proj = None
785
+
675
786
  # set ip-adapter cross-attention processors & load state_dict
676
787
  attn_procs = {}
677
788
  key_id = 1
@@ -695,7 +806,10 @@ class UNet2DConditionLoadersMixin:
695
806
  IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
696
807
  )
697
808
  attn_procs[name] = attn_processor_class(
698
- hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
809
+ hidden_size=hidden_size,
810
+ cross_attention_dim=cross_attention_dim,
811
+ scale=1.0,
812
+ num_tokens=num_image_text_embeds,
699
813
  ).to(dtype=self.dtype, device=self.device)
700
814
 
701
815
  value_dict = {}
@@ -707,29 +821,8 @@ class UNet2DConditionLoadersMixin:
707
821
 
708
822
  self.set_attn_processor(attn_procs)
709
823
 
710
- # create image projection layers.
711
- clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
712
- cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
713
-
714
- image_projection = ImageProjection(
715
- cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
716
- )
717
- image_projection.to(dtype=self.dtype, device=self.device)
718
-
719
- # load image projection layer weights
720
- image_proj_state_dict = {}
721
- image_proj_state_dict.update(
722
- {
723
- "image_embeds.weight": state_dict["image_proj"]["proj.weight"],
724
- "image_embeds.bias": state_dict["image_proj"]["proj.bias"],
725
- "norm.weight": state_dict["image_proj"]["norm.weight"],
726
- "norm.bias": state_dict["image_proj"]["norm.bias"],
727
- }
728
- )
729
-
730
- image_projection.load_state_dict(image_proj_state_dict)
824
+ # convert IP-Adapter Image Projection layers to diffusers
825
+ image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
731
826
 
732
827
  self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
733
828
  self.config.encoder_hid_dim_type = "ip_image_proj"
734
-
735
- delete_adapter_layers
@@ -26,13 +26,14 @@ _import_structure = {}
26
26
 
27
27
  if is_torch_available():
28
28
  _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
29
- _import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
30
- _import_structure["autoencoder_kl"] = ["AutoencoderKL"]
31
- _import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
32
- _import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
33
- _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
29
+ _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
30
+ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
31
+ _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
32
+ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
33
+ _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
34
34
  _import_structure["controlnet"] = ["ControlNetModel"]
35
35
  _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
36
+ _import_structure["embeddings"] = ["ImageProjection"]
36
37
  _import_structure["modeling_utils"] = ["ModelMixin"]
37
38
  _import_structure["prior_transformer"] = ["PriorTransformer"]
38
39
  _import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
@@ -42,9 +43,10 @@ if is_torch_available():
42
43
  _import_structure["unet_2d"] = ["UNet2DModel"]
43
44
  _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
44
45
  _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
45
- _import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
46
+ _import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
46
47
  _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
47
48
  _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
49
+ _import_structure["uvit_2d"] = ["UVit2DModel"]
48
50
  _import_structure["vq_model"] = ["VQModel"]
49
51
 
50
52
  if is_flax_available():
@@ -56,13 +58,16 @@ if is_flax_available():
56
58
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
57
59
  if is_torch_available():
58
60
  from .adapter import MultiAdapter, T2IAdapter
59
- from .autoencoder_asym_kl import AsymmetricAutoencoderKL
60
- from .autoencoder_kl import AutoencoderKL
61
- from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
62
- from .autoencoder_tiny import AutoencoderTiny
63
- from .consistency_decoder_vae import ConsistencyDecoderVAE
61
+ from .autoencoders import (
62
+ AsymmetricAutoencoderKL,
63
+ AutoencoderKL,
64
+ AutoencoderKLTemporalDecoder,
65
+ AutoencoderTiny,
66
+ ConsistencyDecoderVAE,
67
+ )
64
68
  from .controlnet import ControlNetModel
65
69
  from .dual_transformer_2d import DualTransformer2DModel
70
+ from .embeddings import ImageProjection
66
71
  from .modeling_utils import ModelMixin
67
72
  from .prior_transformer import PriorTransformer
68
73
  from .t5_film_transformer import T5FilmDecoder
@@ -72,9 +77,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
72
77
  from .unet_2d import UNet2DModel
73
78
  from .unet_2d_condition import UNet2DConditionModel
74
79
  from .unet_3d_condition import UNet3DConditionModel
75
- from .unet_kandi3 import Kandinsky3UNet
80
+ from .unet_kandinsky3 import Kandinsky3UNet
76
81
  from .unet_motion_model import MotionAdapter, UNetMotionModel
77
82
  from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
83
+ from .uvit_2d import UVit2DModel
78
84
  from .vq_model import VQModel
79
85
 
80
86
  if is_flax_available():
@@ -55,11 +55,12 @@ class GELU(nn.Module):
55
55
  dim_in (`int`): The number of channels in the input.
56
56
  dim_out (`int`): The number of channels in the output.
57
57
  approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
58
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
58
59
  """
59
60
 
60
- def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
61
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
61
62
  super().__init__()
62
- self.proj = nn.Linear(dim_in, dim_out)
63
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
63
64
  self.approximate = approximate
64
65
 
65
66
  def gelu(self, gate: torch.Tensor) -> torch.Tensor:
@@ -81,13 +82,14 @@ class GEGLU(nn.Module):
81
82
  Parameters:
82
83
  dim_in (`int`): The number of channels in the input.
83
84
  dim_out (`int`): The number of channels in the output.
85
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
84
86
  """
85
87
 
86
- def __init__(self, dim_in: int, dim_out: int):
88
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
87
89
  super().__init__()
88
90
  linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
89
91
 
90
- self.proj = linear_cls(dim_in, dim_out * 2)
92
+ self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
91
93
 
92
94
  def gelu(self, gate: torch.Tensor) -> torch.Tensor:
93
95
  if gate.device.type != "mps":
@@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module):
109
111
  Parameters:
110
112
  dim_in (`int`): The number of channels in the input.
111
113
  dim_out (`int`): The number of channels in the output.
114
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
112
115
  """
113
116
 
114
- def __init__(self, dim_in: int, dim_out: int):
117
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
115
118
  super().__init__()
116
- self.proj = nn.Linear(dim_in, dim_out)
119
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
117
120
 
118
121
  def forward(self, x: torch.Tensor) -> torch.Tensor:
119
122
  x = self.proj(x)