diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (238) hide show
  1. diffusers/__init__.py +26 -2
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +13 -8
  4. diffusers/dependency_versions_check.py +0 -1
  5. diffusers/dependency_versions_table.py +5 -5
  6. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  7. diffusers/image_processor.py +463 -51
  8. diffusers/loaders/__init__.py +82 -0
  9. diffusers/loaders/ip_adapter.py +159 -0
  10. diffusers/loaders/lora.py +1553 -0
  11. diffusers/loaders/lora_conversion_utils.py +284 -0
  12. diffusers/loaders/single_file.py +637 -0
  13. diffusers/loaders/textual_inversion.py +455 -0
  14. diffusers/loaders/unet.py +828 -0
  15. diffusers/loaders/utils.py +59 -0
  16. diffusers/models/__init__.py +26 -9
  17. diffusers/models/activations.py +9 -6
  18. diffusers/models/attention.py +301 -29
  19. diffusers/models/attention_flax.py +9 -1
  20. diffusers/models/attention_processor.py +378 -6
  21. diffusers/models/autoencoders/__init__.py +5 -0
  22. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
  23. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
  25. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
  26. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
  27. diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
  28. diffusers/models/controlnet.py +59 -39
  29. diffusers/models/controlnet_flax.py +19 -18
  30. diffusers/models/downsampling.py +338 -0
  31. diffusers/models/embeddings.py +112 -29
  32. diffusers/models/embeddings_flax.py +2 -0
  33. diffusers/models/lora.py +131 -1
  34. diffusers/models/modeling_flax_utils.py +14 -8
  35. diffusers/models/modeling_outputs.py +17 -0
  36. diffusers/models/modeling_utils.py +37 -29
  37. diffusers/models/normalization.py +110 -4
  38. diffusers/models/resnet.py +299 -652
  39. diffusers/models/transformer_2d.py +22 -5
  40. diffusers/models/transformer_temporal.py +183 -1
  41. diffusers/models/unet_2d_blocks_flax.py +5 -0
  42. diffusers/models/unet_2d_condition.py +46 -0
  43. diffusers/models/unet_2d_condition_flax.py +13 -13
  44. diffusers/models/unet_3d_blocks.py +957 -173
  45. diffusers/models/unet_3d_condition.py +16 -8
  46. diffusers/models/unet_kandinsky3.py +535 -0
  47. diffusers/models/unet_motion_model.py +48 -33
  48. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  49. diffusers/models/upsampling.py +454 -0
  50. diffusers/models/uvit_2d.py +471 -0
  51. diffusers/models/vae_flax.py +7 -0
  52. diffusers/models/vq_model.py +12 -3
  53. diffusers/optimization.py +16 -9
  54. diffusers/pipelines/__init__.py +137 -76
  55. diffusers/pipelines/amused/__init__.py +62 -0
  56. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  57. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  58. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  59. diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
  60. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  61. diffusers/pipelines/auto_pipeline.py +23 -13
  62. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  63. diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
  64. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
  65. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
  66. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
  67. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
  68. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
  69. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  70. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  71. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  72. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  73. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  74. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  75. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  76. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  77. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  78. diffusers/pipelines/deprecated/__init__.py +153 -0
  79. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  80. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
  81. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
  82. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  83. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  84. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  85. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  86. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  87. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  88. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  89. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  90. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  91. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  92. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  93. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
  94. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  95. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  96. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  97. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  98. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  100. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
  101. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
  102. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
  103. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
  104. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
  105. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
  106. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  107. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  108. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  109. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
  110. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  111. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
  112. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
  113. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
  114. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  115. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  116. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  117. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  118. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  119. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  120. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  122. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  123. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  124. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
  125. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
  126. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
  127. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
  128. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  129. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  130. diffusers/pipelines/onnx_utils.py +8 -5
  131. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  132. diffusers/pipelines/pipeline_flax_utils.py +11 -8
  133. diffusers/pipelines/pipeline_utils.py +63 -42
  134. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
  135. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  136. diffusers/pipelines/stable_diffusion/__init__.py +37 -65
  137. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
  138. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  139. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  140. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  141. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
  142. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  143. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  144. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
  145. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
  146. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
  147. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  151. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  152. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
  153. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  154. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
  155. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  156. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
  157. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  158. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  159. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
  160. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  161. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
  162. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  163. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
  164. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  165. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  166. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
  171. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  172. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
  175. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
  180. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  181. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  182. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  183. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  184. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  185. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  186. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  187. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
  188. diffusers/schedulers/__init__.py +4 -4
  189. diffusers/schedulers/deprecated/__init__.py +50 -0
  190. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  191. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  192. diffusers/schedulers/scheduling_amused.py +162 -0
  193. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  194. diffusers/schedulers/scheduling_ddim.py +1 -3
  195. diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
  196. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  197. diffusers/schedulers/scheduling_ddpm.py +47 -3
  198. diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
  199. diffusers/schedulers/scheduling_deis_multistep.py +28 -6
  200. diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
  201. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
  202. diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
  203. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
  204. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
  205. diffusers/schedulers/scheduling_euler_discrete.py +102 -16
  206. diffusers/schedulers/scheduling_heun_discrete.py +17 -5
  207. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
  208. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
  209. diffusers/schedulers/scheduling_lcm.py +123 -29
  210. diffusers/schedulers/scheduling_lms_discrete.py +3 -3
  211. diffusers/schedulers/scheduling_pndm.py +1 -3
  212. diffusers/schedulers/scheduling_repaint.py +1 -3
  213. diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
  214. diffusers/schedulers/scheduling_utils.py +3 -1
  215. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  216. diffusers/training_utils.py +1 -1
  217. diffusers/utils/__init__.py +1 -2
  218. diffusers/utils/constants.py +10 -12
  219. diffusers/utils/dummy_pt_objects.py +75 -0
  220. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  221. diffusers/utils/dynamic_modules_utils.py +18 -22
  222. diffusers/utils/export_utils.py +8 -3
  223. diffusers/utils/hub_utils.py +24 -36
  224. diffusers/utils/logging.py +11 -11
  225. diffusers/utils/outputs.py +5 -5
  226. diffusers/utils/peft_utils.py +88 -44
  227. diffusers/utils/state_dict_utils.py +8 -0
  228. diffusers/utils/testing_utils.py +199 -1
  229. diffusers/utils/torch_utils.py +4 -4
  230. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
  231. diffusers-0.25.0.dist-info/RECORD +360 -0
  232. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  233. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  234. diffusers/loaders.py +0 -3336
  235. diffusers-0.23.1.dist-info/RECORD +0 -323
  236. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  237. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  238. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
diffusers/models/lora.py CHANGED
@@ -12,19 +12,60 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
16
+ # IMPORTANT: #
17
+ ###################################################################
18
+ # ----------------------------------------------------------------#
19
+ # This file is deprecated and will be removed soon #
20
+ # (as soon as PEFT will become a required dependency for LoRA) #
21
+ # ----------------------------------------------------------------#
22
+ ###################################################################
23
+
15
24
  from typing import Optional, Tuple, Union
16
25
 
17
26
  import torch
18
27
  import torch.nn.functional as F
19
28
  from torch import nn
20
29
 
21
- from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
22
30
  from ..utils import logging
31
+ from ..utils.import_utils import is_transformers_available
32
+
33
+
34
+ if is_transformers_available():
35
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
23
36
 
24
37
 
25
38
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
39
 
27
40
 
41
+ def text_encoder_attn_modules(text_encoder):
42
+ attn_modules = []
43
+
44
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
45
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
46
+ name = f"text_model.encoder.layers.{i}.self_attn"
47
+ mod = layer.self_attn
48
+ attn_modules.append((name, mod))
49
+ else:
50
+ raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
51
+
52
+ return attn_modules
53
+
54
+
55
+ def text_encoder_mlp_modules(text_encoder):
56
+ mlp_modules = []
57
+
58
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
59
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
60
+ mlp_mod = layer.mlp
61
+ name = f"text_model.encoder.layers.{i}.mlp"
62
+ mlp_modules.append((name, mlp_mod))
63
+ else:
64
+ raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
65
+
66
+ return mlp_modules
67
+
68
+
28
69
  def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
29
70
  for _, attn_module in text_encoder_attn_modules(text_encoder):
30
71
  if isinstance(attn_module.q_proj, PatchedLoraProjection):
@@ -39,6 +80,95 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
39
80
  mlp_module.fc2.lora_scale = lora_scale
40
81
 
41
82
 
83
+ class PatchedLoraProjection(torch.nn.Module):
84
+ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
85
+ super().__init__()
86
+ from ..models.lora import LoRALinearLayer
87
+
88
+ self.regular_linear_layer = regular_linear_layer
89
+
90
+ device = self.regular_linear_layer.weight.device
91
+
92
+ if dtype is None:
93
+ dtype = self.regular_linear_layer.weight.dtype
94
+
95
+ self.lora_linear_layer = LoRALinearLayer(
96
+ self.regular_linear_layer.in_features,
97
+ self.regular_linear_layer.out_features,
98
+ network_alpha=network_alpha,
99
+ device=device,
100
+ dtype=dtype,
101
+ rank=rank,
102
+ )
103
+
104
+ self.lora_scale = lora_scale
105
+
106
+ # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
107
+ # when saving the whole text encoder model and when LoRA is unloaded or fused
108
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
109
+ if self.lora_linear_layer is None:
110
+ return self.regular_linear_layer.state_dict(
111
+ *args, destination=destination, prefix=prefix, keep_vars=keep_vars
112
+ )
113
+
114
+ return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
115
+
116
+ def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
117
+ if self.lora_linear_layer is None:
118
+ return
119
+
120
+ dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
121
+
122
+ w_orig = self.regular_linear_layer.weight.data.float()
123
+ w_up = self.lora_linear_layer.up.weight.data.float()
124
+ w_down = self.lora_linear_layer.down.weight.data.float()
125
+
126
+ if self.lora_linear_layer.network_alpha is not None:
127
+ w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
128
+
129
+ fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
130
+
131
+ if safe_fusing and torch.isnan(fused_weight).any().item():
132
+ raise ValueError(
133
+ "This LoRA weight seems to be broken. "
134
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
135
+ "LoRA weights will not be fused."
136
+ )
137
+
138
+ self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
139
+
140
+ # we can drop the lora layer now
141
+ self.lora_linear_layer = None
142
+
143
+ # offload the up and down matrices to CPU to not blow the memory
144
+ self.w_up = w_up.cpu()
145
+ self.w_down = w_down.cpu()
146
+ self.lora_scale = lora_scale
147
+
148
+ def _unfuse_lora(self):
149
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
150
+ return
151
+
152
+ fused_weight = self.regular_linear_layer.weight.data
153
+ dtype, device = fused_weight.dtype, fused_weight.device
154
+
155
+ w_up = self.w_up.to(device=device).float()
156
+ w_down = self.w_down.to(device).float()
157
+
158
+ unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
159
+ self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
160
+
161
+ self.w_up = None
162
+ self.w_down = None
163
+
164
+ def forward(self, input):
165
+ if self.lora_scale is None:
166
+ self.lora_scale = 1.0
167
+ if self.lora_linear_layer is None:
168
+ return self.regular_linear_layer(input)
169
+ return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
170
+
171
+
42
172
  class LoRALinearLayer(nn.Module):
43
173
  r"""
44
174
  A linear layer that is used with LoRA.
@@ -24,13 +24,17 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
24
24
  from flax.serialization import from_bytes, to_bytes
25
25
  from flax.traverse_util import flatten_dict, unflatten_dict
26
26
  from huggingface_hub import create_repo, hf_hub_download
27
- from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
27
+ from huggingface_hub.utils import (
28
+ EntryNotFoundError,
29
+ RepositoryNotFoundError,
30
+ RevisionNotFoundError,
31
+ validate_hf_hub_args,
32
+ )
28
33
  from requests import HTTPError
29
34
 
30
35
  from .. import __version__, is_torch_available
31
36
  from ..utils import (
32
37
  CONFIG_NAME,
33
- DIFFUSERS_CACHE,
34
38
  FLAX_WEIGHTS_NAME,
35
39
  HUGGINGFACE_CO_RESOLVE_ENDPOINT,
36
40
  WEIGHTS_NAME,
@@ -52,6 +56,7 @@ class FlaxModelMixin(PushToHubMixin):
52
56
 
53
57
  - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
54
58
  """
59
+
55
60
  config_name = CONFIG_NAME
56
61
  _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
57
62
  _flax_internal_args = ["name", "parent", "dtype"]
@@ -196,6 +201,7 @@ class FlaxModelMixin(PushToHubMixin):
196
201
  raise NotImplementedError(f"init_weights method has to be implemented for {self}")
197
202
 
198
203
  @classmethod
204
+ @validate_hf_hub_args
199
205
  def from_pretrained(
200
206
  cls,
201
207
  pretrained_model_name_or_path: Union[str, os.PathLike],
@@ -287,13 +293,13 @@ class FlaxModelMixin(PushToHubMixin):
287
293
  ```
288
294
  """
289
295
  config = kwargs.pop("config", None)
290
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
296
+ cache_dir = kwargs.pop("cache_dir", None)
291
297
  force_download = kwargs.pop("force_download", False)
292
298
  from_pt = kwargs.pop("from_pt", False)
293
299
  resume_download = kwargs.pop("resume_download", False)
294
300
  proxies = kwargs.pop("proxies", None)
295
301
  local_files_only = kwargs.pop("local_files_only", False)
296
- use_auth_token = kwargs.pop("use_auth_token", None)
302
+ token = kwargs.pop("token", None)
297
303
  revision = kwargs.pop("revision", None)
298
304
  subfolder = kwargs.pop("subfolder", None)
299
305
 
@@ -313,7 +319,7 @@ class FlaxModelMixin(PushToHubMixin):
313
319
  resume_download=resume_download,
314
320
  proxies=proxies,
315
321
  local_files_only=local_files_only,
316
- use_auth_token=use_auth_token,
322
+ token=token,
317
323
  revision=revision,
318
324
  subfolder=subfolder,
319
325
  **kwargs,
@@ -358,7 +364,7 @@ class FlaxModelMixin(PushToHubMixin):
358
364
  proxies=proxies,
359
365
  resume_download=resume_download,
360
366
  local_files_only=local_files_only,
361
- use_auth_token=use_auth_token,
367
+ token=token,
362
368
  user_agent=user_agent,
363
369
  subfolder=subfolder,
364
370
  revision=revision,
@@ -368,7 +374,7 @@ class FlaxModelMixin(PushToHubMixin):
368
374
  raise EnvironmentError(
369
375
  f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
370
376
  "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
371
- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
377
+ "token having permission to this repo with `token` or log in with `huggingface-cli "
372
378
  "login`."
373
379
  )
374
380
  except RevisionNotFoundError:
@@ -436,7 +442,7 @@ class FlaxModelMixin(PushToHubMixin):
436
442
  # make sure all arrays are stored as jnp.ndarray
437
443
  # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
438
444
  # https://github.com/google/flax/issues/1261
439
- state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
445
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
440
446
 
441
447
  # flatten dicts
442
448
  state = flatten_dict(state)
@@ -0,0 +1,17 @@
1
+ from dataclasses import dataclass
2
+
3
+ from ..utils import BaseOutput
4
+
5
+
6
+ @dataclass
7
+ class AutoencoderKLOutput(BaseOutput):
8
+ """
9
+ Output of AutoencoderKL encoding method.
10
+
11
+ Args:
12
+ latent_dist (`DiagonalGaussianDistribution`):
13
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
14
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
15
+ """
16
+
17
+ latent_dist: "DiagonalGaussianDistribution" # noqa: F821
@@ -18,20 +18,20 @@ import inspect
18
18
  import itertools
19
19
  import os
20
20
  import re
21
+ from collections import OrderedDict
21
22
  from functools import partial
22
23
  from typing import Any, Callable, List, Optional, Tuple, Union
23
24
 
24
25
  import safetensors
25
26
  import torch
26
27
  from huggingface_hub import create_repo
27
- from torch import Tensor, device, nn
28
+ from huggingface_hub.utils import validate_hf_hub_args
29
+ from torch import Tensor, nn
28
30
 
29
31
  from .. import __version__
30
32
  from ..utils import (
31
33
  CONFIG_NAME,
32
- DIFFUSERS_CACHE,
33
34
  FLAX_WEIGHTS_NAME,
34
- HF_HUB_OFFLINE,
35
35
  MIN_PEFT_VERSION,
36
36
  SAFETENSORS_WEIGHTS_NAME,
37
37
  WEIGHTS_NAME,
@@ -61,7 +61,7 @@ if is_accelerate_available():
61
61
  from accelerate.utils.versions import is_torch_version
62
62
 
63
63
 
64
- def get_parameter_device(parameter: torch.nn.Module):
64
+ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
65
65
  try:
66
66
  parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
67
67
  return next(parameters_and_buffers).device
@@ -77,7 +77,7 @@ def get_parameter_device(parameter: torch.nn.Module):
77
77
  return first_tuple[1].device
78
78
 
79
79
 
80
- def get_parameter_dtype(parameter: torch.nn.Module):
80
+ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
81
81
  try:
82
82
  params = tuple(parameter.parameters())
83
83
  if len(params) > 0:
@@ -130,7 +130,13 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
130
130
  )
131
131
 
132
132
 
133
- def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
133
+ def load_model_dict_into_meta(
134
+ model,
135
+ state_dict: OrderedDict,
136
+ device: Optional[Union[str, torch.device]] = None,
137
+ dtype: Optional[Union[str, torch.dtype]] = None,
138
+ model_name_or_path: Optional[str] = None,
139
+ ) -> List[str]:
134
140
  device = device or torch.device("cpu")
135
141
  dtype = dtype or torch.float32
136
142
 
@@ -156,7 +162,7 @@ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_
156
162
  return unexpected_keys
157
163
 
158
164
 
159
- def _load_state_dict_into_model(model_to_load, state_dict):
165
+ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
160
166
  # Convert old format to new format if needed from a PyTorch state_dict
161
167
  # copy state_dict so _load_from_state_dict can modify it
162
168
  state_dict = state_dict.copy()
@@ -164,7 +170,7 @@ def _load_state_dict_into_model(model_to_load, state_dict):
164
170
 
165
171
  # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
166
172
  # so we need to apply the function recursively.
167
- def load(module: torch.nn.Module, prefix=""):
173
+ def load(module: torch.nn.Module, prefix: str = ""):
168
174
  args = (state_dict, prefix, {}, True, [], [], error_msgs)
169
175
  module._load_from_state_dict(*args)
170
176
 
@@ -186,6 +192,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
186
192
 
187
193
  - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
188
194
  """
195
+
189
196
  config_name = CONFIG_NAME
190
197
  _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
191
198
  _supports_gradient_checkpointing = False
@@ -220,7 +227,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
220
227
  """
221
228
  return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
222
229
 
223
- def enable_gradient_checkpointing(self):
230
+ def enable_gradient_checkpointing(self) -> None:
224
231
  """
225
232
  Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
226
233
  *checkpoint activations* in other frameworks).
@@ -229,7 +236,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
229
236
  raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
230
237
  self.apply(partial(self._set_gradient_checkpointing, value=True))
231
238
 
232
- def disable_gradient_checkpointing(self):
239
+ def disable_gradient_checkpointing(self) -> None:
233
240
  """
234
241
  Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
235
242
  *checkpoint activations* in other frameworks).
@@ -254,7 +261,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
254
261
  if isinstance(module, torch.nn.Module):
255
262
  fn_recursive_set_mem_eff(module)
256
263
 
257
- def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
264
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
258
265
  r"""
259
266
  Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
260
267
 
@@ -290,7 +297,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
290
297
  """
291
298
  self.set_use_memory_efficient_attention_xformers(True, attention_op)
292
299
 
293
- def disable_xformers_memory_efficient_attention(self):
300
+ def disable_xformers_memory_efficient_attention(self) -> None:
294
301
  r"""
295
302
  Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
296
303
  """
@@ -447,7 +454,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
447
454
  self,
448
455
  save_directory: Union[str, os.PathLike],
449
456
  is_main_process: bool = True,
450
- save_function: Callable = None,
457
+ save_function: Optional[Callable] = None,
451
458
  safe_serialization: bool = True,
452
459
  variant: Optional[str] = None,
453
460
  push_to_hub: bool = False,
@@ -527,6 +534,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
527
534
  )
528
535
 
529
536
  @classmethod
537
+ @validate_hf_hub_args
530
538
  def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
531
539
  r"""
532
540
  Instantiate a pretrained PyTorch model from a pretrained model configuration.
@@ -563,7 +571,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
563
571
  local_files_only(`bool`, *optional*, defaults to `False`):
564
572
  Whether to only load local model weights and configuration files or not. If set to `True`, the model
565
573
  won't be downloaded from the Hub.
566
- use_auth_token (`str` or *bool*, *optional*):
574
+ token (`str` or *bool*, *optional*):
567
575
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
568
576
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
569
577
  revision (`str`, *optional*, defaults to `"main"`):
@@ -632,15 +640,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
632
640
  You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
633
641
  ```
634
642
  """
635
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
643
+ cache_dir = kwargs.pop("cache_dir", None)
636
644
  ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
637
645
  force_download = kwargs.pop("force_download", False)
638
646
  from_flax = kwargs.pop("from_flax", False)
639
647
  resume_download = kwargs.pop("resume_download", False)
640
648
  proxies = kwargs.pop("proxies", None)
641
649
  output_loading_info = kwargs.pop("output_loading_info", False)
642
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
643
- use_auth_token = kwargs.pop("use_auth_token", None)
650
+ local_files_only = kwargs.pop("local_files_only", None)
651
+ token = kwargs.pop("token", None)
644
652
  revision = kwargs.pop("revision", None)
645
653
  torch_dtype = kwargs.pop("torch_dtype", None)
646
654
  subfolder = kwargs.pop("subfolder", None)
@@ -710,7 +718,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
710
718
  resume_download=resume_download,
711
719
  proxies=proxies,
712
720
  local_files_only=local_files_only,
713
- use_auth_token=use_auth_token,
721
+ token=token,
714
722
  revision=revision,
715
723
  subfolder=subfolder,
716
724
  device_map=device_map,
@@ -732,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
732
740
  resume_download=resume_download,
733
741
  proxies=proxies,
734
742
  local_files_only=local_files_only,
735
- use_auth_token=use_auth_token,
743
+ token=token,
736
744
  revision=revision,
737
745
  subfolder=subfolder,
738
746
  user_agent=user_agent,
@@ -755,7 +763,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
755
763
  resume_download=resume_download,
756
764
  proxies=proxies,
757
765
  local_files_only=local_files_only,
758
- use_auth_token=use_auth_token,
766
+ token=token,
759
767
  revision=revision,
760
768
  subfolder=subfolder,
761
769
  user_agent=user_agent,
@@ -774,7 +782,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
774
782
  resume_download=resume_download,
775
783
  proxies=proxies,
776
784
  local_files_only=local_files_only,
777
- use_auth_token=use_auth_token,
785
+ token=token,
778
786
  revision=revision,
779
787
  subfolder=subfolder,
780
788
  user_agent=user_agent,
@@ -910,10 +918,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
910
918
  def _load_pretrained_model(
911
919
  cls,
912
920
  model,
913
- state_dict,
921
+ state_dict: OrderedDict,
914
922
  resolved_archive_file,
915
- pretrained_model_name_or_path,
916
- ignore_mismatched_sizes=False,
923
+ pretrained_model_name_or_path: Union[str, os.PathLike],
924
+ ignore_mismatched_sizes: bool = False,
917
925
  ):
918
926
  # Retrieve missing & unexpected_keys
919
927
  model_state_dict = model.state_dict()
@@ -1011,7 +1019,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1011
1019
  return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1012
1020
 
1013
1021
  @property
1014
- def device(self) -> device:
1022
+ def device(self) -> torch.device:
1015
1023
  """
1016
1024
  `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
1017
1025
  device).
@@ -1063,7 +1071,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1063
1071
  else:
1064
1072
  return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
1065
1073
 
1066
- def _convert_deprecated_attention_blocks(self, state_dict):
1074
+ def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
1067
1075
  deprecated_attention_block_paths = []
1068
1076
 
1069
1077
  def recursive_find_attn_block(name, module):
@@ -1107,7 +1115,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1107
1115
  if f"{path}.proj_attn.bias" in state_dict:
1108
1116
  state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
1109
1117
 
1110
- def _temp_convert_self_to_deprecated_attention_blocks(self):
1118
+ def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1111
1119
  deprecated_attention_block_modules = []
1112
1120
 
1113
1121
  def recursive_find_attn_block(module):
@@ -1134,10 +1142,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1134
1142
  del module.to_v
1135
1143
  del module.to_out
1136
1144
 
1137
- def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
1145
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1138
1146
  deprecated_attention_block_modules = []
1139
1147
 
1140
- def recursive_find_attn_block(module):
1148
+ def recursive_find_attn_block(module) -> None:
1141
1149
  if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1142
1150
  deprecated_attention_block_modules.append(module)
1143
1151
 
@@ -13,14 +13,16 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import numbers
16
17
  from typing import Dict, Optional, Tuple
17
18
 
18
19
  import torch
19
20
  import torch.nn as nn
20
21
  import torch.nn.functional as F
21
22
 
23
+ from ..utils import is_torch_version
22
24
  from .activations import get_activation
23
- from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
25
+ from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
24
26
 
25
27
 
26
28
  class AdaLayerNorm(nn.Module):
@@ -91,7 +93,7 @@ class AdaLayerNormSingle(nn.Module):
91
93
  def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
92
94
  super().__init__()
93
95
 
94
- self.emb = CombinedTimestepSizeEmbeddings(
96
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
95
97
  embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
96
98
  )
97
99
 
@@ -101,8 +103,8 @@ class AdaLayerNormSingle(nn.Module):
101
103
  def forward(
102
104
  self,
103
105
  timestep: torch.Tensor,
104
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
105
- batch_size: int = None,
106
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
107
+ batch_size: Optional[int] = None,
106
108
  hidden_dtype: Optional[torch.dtype] = None,
107
109
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
108
110
  # No modulation happening here.
@@ -146,3 +148,107 @@ class AdaGroupNorm(nn.Module):
146
148
  x = F.group_norm(x, self.num_groups, eps=self.eps)
147
149
  x = x * (1 + scale) + shift
148
150
  return x
151
+
152
+
153
+ class AdaLayerNormContinuous(nn.Module):
154
+ def __init__(
155
+ self,
156
+ embedding_dim: int,
157
+ conditioning_embedding_dim: int,
158
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
159
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
160
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
161
+ # However, this is how it was implemented in the original code, and it's rather likely you should
162
+ # set `elementwise_affine` to False.
163
+ elementwise_affine=True,
164
+ eps=1e-5,
165
+ bias=True,
166
+ norm_type="layer_norm",
167
+ ):
168
+ super().__init__()
169
+ self.silu = nn.SiLU()
170
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
171
+ if norm_type == "layer_norm":
172
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
173
+ elif norm_type == "rms_norm":
174
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
175
+ else:
176
+ raise ValueError(f"unknown norm_type {norm_type}")
177
+
178
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
179
+ emb = self.linear(self.silu(conditioning_embedding))
180
+ scale, shift = torch.chunk(emb, 2, dim=1)
181
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
182
+ return x
183
+
184
+
185
+ if is_torch_version(">=", "2.1.0"):
186
+ LayerNorm = nn.LayerNorm
187
+ else:
188
+ # Has optional bias parameter compared to torch layer norm
189
+ # TODO: replace with torch layernorm once min required torch version >= 2.1
190
+ class LayerNorm(nn.Module):
191
+ def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
192
+ super().__init__()
193
+
194
+ self.eps = eps
195
+
196
+ if isinstance(dim, numbers.Integral):
197
+ dim = (dim,)
198
+
199
+ self.dim = torch.Size(dim)
200
+
201
+ if elementwise_affine:
202
+ self.weight = nn.Parameter(torch.ones(dim))
203
+ self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
204
+ else:
205
+ self.weight = None
206
+ self.bias = None
207
+
208
+ def forward(self, input):
209
+ return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
210
+
211
+
212
+ class RMSNorm(nn.Module):
213
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
214
+ super().__init__()
215
+
216
+ self.eps = eps
217
+
218
+ if isinstance(dim, numbers.Integral):
219
+ dim = (dim,)
220
+
221
+ self.dim = torch.Size(dim)
222
+
223
+ if elementwise_affine:
224
+ self.weight = nn.Parameter(torch.ones(dim))
225
+ else:
226
+ self.weight = None
227
+
228
+ def forward(self, hidden_states):
229
+ input_dtype = hidden_states.dtype
230
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
231
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
232
+
233
+ if self.weight is not None:
234
+ # convert into half-precision if necessary
235
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
236
+ hidden_states = hidden_states.to(self.weight.dtype)
237
+ hidden_states = hidden_states * self.weight
238
+ else:
239
+ hidden_states = hidden_states.to(input_dtype)
240
+
241
+ return hidden_states
242
+
243
+
244
+ class GlobalResponseNorm(nn.Module):
245
+ # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
246
+ def __init__(self, dim):
247
+ super().__init__()
248
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
249
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
250
+
251
+ def forward(self, x):
252
+ gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
253
+ nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
254
+ return self.gamma * (x * nx) + self.beta + x