diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (238) hide show
  1. diffusers/__init__.py +26 -2
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +13 -8
  4. diffusers/dependency_versions_check.py +0 -1
  5. diffusers/dependency_versions_table.py +5 -5
  6. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  7. diffusers/image_processor.py +463 -51
  8. diffusers/loaders/__init__.py +82 -0
  9. diffusers/loaders/ip_adapter.py +159 -0
  10. diffusers/loaders/lora.py +1553 -0
  11. diffusers/loaders/lora_conversion_utils.py +284 -0
  12. diffusers/loaders/single_file.py +637 -0
  13. diffusers/loaders/textual_inversion.py +455 -0
  14. diffusers/loaders/unet.py +828 -0
  15. diffusers/loaders/utils.py +59 -0
  16. diffusers/models/__init__.py +26 -9
  17. diffusers/models/activations.py +9 -6
  18. diffusers/models/attention.py +301 -29
  19. diffusers/models/attention_flax.py +9 -1
  20. diffusers/models/attention_processor.py +378 -6
  21. diffusers/models/autoencoders/__init__.py +5 -0
  22. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
  23. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
  25. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
  26. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
  27. diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
  28. diffusers/models/controlnet.py +59 -39
  29. diffusers/models/controlnet_flax.py +19 -18
  30. diffusers/models/downsampling.py +338 -0
  31. diffusers/models/embeddings.py +112 -29
  32. diffusers/models/embeddings_flax.py +2 -0
  33. diffusers/models/lora.py +131 -1
  34. diffusers/models/modeling_flax_utils.py +14 -8
  35. diffusers/models/modeling_outputs.py +17 -0
  36. diffusers/models/modeling_utils.py +37 -29
  37. diffusers/models/normalization.py +110 -4
  38. diffusers/models/resnet.py +299 -652
  39. diffusers/models/transformer_2d.py +22 -5
  40. diffusers/models/transformer_temporal.py +183 -1
  41. diffusers/models/unet_2d_blocks_flax.py +5 -0
  42. diffusers/models/unet_2d_condition.py +46 -0
  43. diffusers/models/unet_2d_condition_flax.py +13 -13
  44. diffusers/models/unet_3d_blocks.py +957 -173
  45. diffusers/models/unet_3d_condition.py +16 -8
  46. diffusers/models/unet_kandinsky3.py +535 -0
  47. diffusers/models/unet_motion_model.py +48 -33
  48. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  49. diffusers/models/upsampling.py +454 -0
  50. diffusers/models/uvit_2d.py +471 -0
  51. diffusers/models/vae_flax.py +7 -0
  52. diffusers/models/vq_model.py +12 -3
  53. diffusers/optimization.py +16 -9
  54. diffusers/pipelines/__init__.py +137 -76
  55. diffusers/pipelines/amused/__init__.py +62 -0
  56. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  57. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  58. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  59. diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
  60. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  61. diffusers/pipelines/auto_pipeline.py +23 -13
  62. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  63. diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
  64. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
  65. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
  66. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
  67. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
  68. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
  69. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  70. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  71. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  72. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  73. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  74. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  75. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  76. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  77. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  78. diffusers/pipelines/deprecated/__init__.py +153 -0
  79. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  80. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
  81. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
  82. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  83. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  84. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  85. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  86. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  87. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  88. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  89. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  90. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  91. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  92. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  93. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
  94. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  95. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  96. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  97. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  98. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  100. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
  101. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
  102. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
  103. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
  104. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
  105. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
  106. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  107. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  108. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  109. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
  110. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  111. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
  112. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
  113. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
  114. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  115. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  116. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  117. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  118. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  119. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  120. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  122. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  123. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  124. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
  125. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
  126. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
  127. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
  128. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  129. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  130. diffusers/pipelines/onnx_utils.py +8 -5
  131. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  132. diffusers/pipelines/pipeline_flax_utils.py +11 -8
  133. diffusers/pipelines/pipeline_utils.py +63 -42
  134. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
  135. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  136. diffusers/pipelines/stable_diffusion/__init__.py +37 -65
  137. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
  138. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  139. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  140. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  141. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
  142. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  143. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  144. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
  145. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
  146. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
  147. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  151. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  152. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
  153. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  154. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
  155. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  156. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
  157. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  158. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  159. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
  160. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  161. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
  162. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  163. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
  164. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  165. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  166. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
  171. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  172. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
  175. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
  180. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  181. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  182. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  183. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  184. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  185. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  186. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  187. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
  188. diffusers/schedulers/__init__.py +4 -4
  189. diffusers/schedulers/deprecated/__init__.py +50 -0
  190. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  191. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  192. diffusers/schedulers/scheduling_amused.py +162 -0
  193. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  194. diffusers/schedulers/scheduling_ddim.py +1 -3
  195. diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
  196. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  197. diffusers/schedulers/scheduling_ddpm.py +47 -3
  198. diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
  199. diffusers/schedulers/scheduling_deis_multistep.py +28 -6
  200. diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
  201. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
  202. diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
  203. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
  204. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
  205. diffusers/schedulers/scheduling_euler_discrete.py +102 -16
  206. diffusers/schedulers/scheduling_heun_discrete.py +17 -5
  207. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
  208. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
  209. diffusers/schedulers/scheduling_lcm.py +123 -29
  210. diffusers/schedulers/scheduling_lms_discrete.py +3 -3
  211. diffusers/schedulers/scheduling_pndm.py +1 -3
  212. diffusers/schedulers/scheduling_repaint.py +1 -3
  213. diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
  214. diffusers/schedulers/scheduling_utils.py +3 -1
  215. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  216. diffusers/training_utils.py +1 -1
  217. diffusers/utils/__init__.py +1 -2
  218. diffusers/utils/constants.py +10 -12
  219. diffusers/utils/dummy_pt_objects.py +75 -0
  220. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  221. diffusers/utils/dynamic_modules_utils.py +18 -22
  222. diffusers/utils/export_utils.py +8 -3
  223. diffusers/utils/hub_utils.py +24 -36
  224. diffusers/utils/logging.py +11 -11
  225. diffusers/utils/outputs.py +5 -5
  226. diffusers/utils/peft_utils.py +88 -44
  227. diffusers/utils/state_dict_utils.py +8 -0
  228. diffusers/utils/testing_utils.py +199 -1
  229. diffusers/utils/torch_utils.py +4 -4
  230. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
  231. diffusers-0.25.0.dist-info/RECORD +360 -0
  232. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  233. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  234. diffusers/loaders.py +0 -3336
  235. diffusers-0.23.1.dist-info/RECORD +0 -323
  236. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  237. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  238. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -110,7 +110,10 @@ def jax_memory_efficient_attention(
110
110
  )
111
111
 
112
112
  _, res = jax.lax.scan(
113
- f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
113
+ f=chunk_scanner,
114
+ init=0,
115
+ xs=None,
116
+ length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
114
117
  )
115
118
 
116
119
  return jnp.concatenate(res, axis=-3) # fuse the chunked result back
@@ -138,6 +141,7 @@ class FlaxAttention(nn.Module):
138
141
  Parameters `dtype`
139
142
 
140
143
  """
144
+
141
145
  query_dim: int
142
146
  heads: int = 8
143
147
  dim_head: int = 64
@@ -262,6 +266,7 @@ class FlaxBasicTransformerBlock(nn.Module):
262
266
  Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
263
267
  enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
264
268
  """
269
+
265
270
  dim: int
266
271
  n_heads: int
267
272
  d_head: int
@@ -347,6 +352,7 @@ class FlaxTransformer2DModel(nn.Module):
347
352
  Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
348
353
  enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
349
354
  """
355
+
350
356
  in_channels: int
351
357
  n_heads: int
352
358
  d_head: int
@@ -442,6 +448,7 @@ class FlaxFeedForward(nn.Module):
442
448
  dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
443
449
  Parameters `dtype`
444
450
  """
451
+
445
452
  dim: int
446
453
  dropout: float = 0.0
447
454
  dtype: jnp.dtype = jnp.float32
@@ -471,6 +478,7 @@ class FlaxGEGLU(nn.Module):
471
478
  dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
472
479
  Parameters `dtype`
473
480
  """
481
+
474
482
  dim: int
475
483
  dropout: float = 0.0
476
484
  dtype: jnp.dtype = jnp.float32
@@ -109,15 +109,19 @@ class Attention(nn.Module):
109
109
  residual_connection: bool = False,
110
110
  _from_deprecated_attn_block: bool = False,
111
111
  processor: Optional["AttnProcessor"] = None,
112
+ out_dim: int = None,
112
113
  ):
113
114
  super().__init__()
114
- self.inner_dim = dim_head * heads
115
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
116
+ self.query_dim = query_dim
115
117
  self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
116
118
  self.upcast_attention = upcast_attention
117
119
  self.upcast_softmax = upcast_softmax
118
120
  self.rescale_output_factor = rescale_output_factor
119
121
  self.residual_connection = residual_connection
120
122
  self.dropout = dropout
123
+ self.fused_projections = False
124
+ self.out_dim = out_dim if out_dim is not None else query_dim
121
125
 
122
126
  # we make use of this private variable to know whether this class is loaded
123
127
  # with an deprecated state dict so that we can convert it on the fly
@@ -126,7 +130,7 @@ class Attention(nn.Module):
126
130
  self.scale_qk = scale_qk
127
131
  self.scale = dim_head**-0.5 if self.scale_qk else 1.0
128
132
 
129
- self.heads = heads
133
+ self.heads = out_dim // dim_head if out_dim is not None else heads
130
134
  # for slice_size > 0 the attention score computation
131
135
  # is split across the batch axis to save memory
132
136
  # You can set slice_size with `set_attention_slice`
@@ -178,6 +182,7 @@ class Attention(nn.Module):
178
182
  else:
179
183
  linear_cls = LoRACompatibleLinear
180
184
 
185
+ self.linear_cls = linear_cls
181
186
  self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
182
187
 
183
188
  if not self.only_cross_attention:
@@ -193,7 +198,7 @@ class Attention(nn.Module):
193
198
  self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
194
199
 
195
200
  self.to_out = nn.ModuleList([])
196
- self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
201
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
197
202
  self.to_out.append(nn.Dropout(dropout))
198
203
 
199
204
  # set attention processor
@@ -690,6 +695,32 @@ class Attention(nn.Module):
690
695
 
691
696
  return encoder_hidden_states
692
697
 
698
+ @torch.no_grad()
699
+ def fuse_projections(self, fuse=True):
700
+ is_cross_attention = self.cross_attention_dim != self.query_dim
701
+ device = self.to_q.weight.data.device
702
+ dtype = self.to_q.weight.data.dtype
703
+
704
+ if not is_cross_attention:
705
+ # fetch weight matrices.
706
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
707
+ in_features = concatenated_weights.shape[1]
708
+ out_features = concatenated_weights.shape[0]
709
+
710
+ # create a new single projection layer and copy over the weights.
711
+ self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
712
+ self.to_qkv.weight.copy_(concatenated_weights)
713
+
714
+ else:
715
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
716
+ in_features = concatenated_weights.shape[1]
717
+ out_features = concatenated_weights.shape[0]
718
+
719
+ self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
720
+ self.to_kv.weight.copy_(concatenated_weights)
721
+
722
+ self.fused_projections = fuse
723
+
693
724
 
694
725
  class AttnProcessor:
695
726
  r"""
@@ -1182,9 +1213,6 @@ class AttnProcessor2_0:
1182
1213
  scale: float = 1.0,
1183
1214
  ) -> torch.FloatTensor:
1184
1215
  residual = hidden_states
1185
-
1186
- args = () if USE_PEFT_BACKEND else (scale,)
1187
-
1188
1216
  if attn.spatial_norm is not None:
1189
1217
  hidden_states = attn.spatial_norm(hidden_states, temb)
1190
1218
 
@@ -1251,6 +1279,103 @@ class AttnProcessor2_0:
1251
1279
  return hidden_states
1252
1280
 
1253
1281
 
1282
+ class FusedAttnProcessor2_0:
1283
+ r"""
1284
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1285
+ It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
1286
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
1287
+
1288
+ <Tip warning={true}>
1289
+
1290
+ This API is currently 🧪 experimental in nature and can change in future.
1291
+
1292
+ </Tip>
1293
+ """
1294
+
1295
+ def __init__(self):
1296
+ if not hasattr(F, "scaled_dot_product_attention"):
1297
+ raise ImportError(
1298
+ "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
1299
+ )
1300
+
1301
+ def __call__(
1302
+ self,
1303
+ attn: Attention,
1304
+ hidden_states: torch.FloatTensor,
1305
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1306
+ attention_mask: Optional[torch.FloatTensor] = None,
1307
+ temb: Optional[torch.FloatTensor] = None,
1308
+ scale: float = 1.0,
1309
+ ) -> torch.FloatTensor:
1310
+ residual = hidden_states
1311
+ if attn.spatial_norm is not None:
1312
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1313
+
1314
+ input_ndim = hidden_states.ndim
1315
+
1316
+ if input_ndim == 4:
1317
+ batch_size, channel, height, width = hidden_states.shape
1318
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1319
+
1320
+ batch_size, sequence_length, _ = (
1321
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1322
+ )
1323
+
1324
+ if attention_mask is not None:
1325
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1326
+ # scaled_dot_product_attention expects attention_mask shape to be
1327
+ # (batch, heads, source_length, target_length)
1328
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1329
+
1330
+ if attn.group_norm is not None:
1331
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1332
+
1333
+ args = () if USE_PEFT_BACKEND else (scale,)
1334
+ if encoder_hidden_states is None:
1335
+ qkv = attn.to_qkv(hidden_states, *args)
1336
+ split_size = qkv.shape[-1] // 3
1337
+ query, key, value = torch.split(qkv, split_size, dim=-1)
1338
+ else:
1339
+ if attn.norm_cross:
1340
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1341
+ query = attn.to_q(hidden_states, *args)
1342
+
1343
+ kv = attn.to_kv(encoder_hidden_states, *args)
1344
+ split_size = kv.shape[-1] // 2
1345
+ key, value = torch.split(kv, split_size, dim=-1)
1346
+
1347
+ inner_dim = key.shape[-1]
1348
+ head_dim = inner_dim // attn.heads
1349
+
1350
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1351
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1352
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1353
+
1354
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1355
+ # TODO: add support for attn.scale when we move to Torch 2.1
1356
+ hidden_states = F.scaled_dot_product_attention(
1357
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1358
+ )
1359
+
1360
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1361
+ hidden_states = hidden_states.to(query.dtype)
1362
+
1363
+ # linear proj
1364
+ hidden_states = attn.to_out[0](hidden_states, *args)
1365
+ # dropout
1366
+ hidden_states = attn.to_out[1](hidden_states)
1367
+
1368
+ if input_ndim == 4:
1369
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1370
+
1371
+ if attn.residual_connection:
1372
+ hidden_states = hidden_states + residual
1373
+
1374
+ hidden_states = hidden_states / attn.rescale_output_factor
1375
+
1376
+ return hidden_states
1377
+
1378
+
1254
1379
  class CustomDiffusionXFormersAttnProcessor(nn.Module):
1255
1380
  r"""
1256
1381
  Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
@@ -1975,6 +2100,250 @@ class LoRAAttnAddedKVProcessor(nn.Module):
1975
2100
  return attn.processor(attn, hidden_states, *args, **kwargs)
1976
2101
 
1977
2102
 
2103
+ class IPAdapterAttnProcessor(nn.Module):
2104
+ r"""
2105
+ Attention processor for IP-Adapater.
2106
+
2107
+ Args:
2108
+ hidden_size (`int`):
2109
+ The hidden size of the attention layer.
2110
+ cross_attention_dim (`int`):
2111
+ The number of channels in the `encoder_hidden_states`.
2112
+ num_tokens (`int`, defaults to 4):
2113
+ The context length of the image features.
2114
+ scale (`float`, defaults to 1.0):
2115
+ the weight scale of image prompt.
2116
+ """
2117
+
2118
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
2119
+ super().__init__()
2120
+
2121
+ self.hidden_size = hidden_size
2122
+ self.cross_attention_dim = cross_attention_dim
2123
+ self.num_tokens = num_tokens
2124
+ self.scale = scale
2125
+
2126
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
2127
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
2128
+
2129
+ def __call__(
2130
+ self,
2131
+ attn,
2132
+ hidden_states,
2133
+ encoder_hidden_states=None,
2134
+ attention_mask=None,
2135
+ temb=None,
2136
+ scale=1.0,
2137
+ ):
2138
+ if scale != 1.0:
2139
+ logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.")
2140
+ residual = hidden_states
2141
+
2142
+ if attn.spatial_norm is not None:
2143
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2144
+
2145
+ input_ndim = hidden_states.ndim
2146
+
2147
+ if input_ndim == 4:
2148
+ batch_size, channel, height, width = hidden_states.shape
2149
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2150
+
2151
+ batch_size, sequence_length, _ = (
2152
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2153
+ )
2154
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2155
+
2156
+ if attn.group_norm is not None:
2157
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2158
+
2159
+ query = attn.to_q(hidden_states)
2160
+
2161
+ if encoder_hidden_states is None:
2162
+ encoder_hidden_states = hidden_states
2163
+ elif attn.norm_cross:
2164
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2165
+
2166
+ # split hidden states
2167
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
2168
+ encoder_hidden_states, ip_hidden_states = (
2169
+ encoder_hidden_states[:, :end_pos, :],
2170
+ encoder_hidden_states[:, end_pos:, :],
2171
+ )
2172
+
2173
+ key = attn.to_k(encoder_hidden_states)
2174
+ value = attn.to_v(encoder_hidden_states)
2175
+
2176
+ query = attn.head_to_batch_dim(query)
2177
+ key = attn.head_to_batch_dim(key)
2178
+ value = attn.head_to_batch_dim(value)
2179
+
2180
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
2181
+ hidden_states = torch.bmm(attention_probs, value)
2182
+ hidden_states = attn.batch_to_head_dim(hidden_states)
2183
+
2184
+ # for ip-adapter
2185
+ ip_key = self.to_k_ip(ip_hidden_states)
2186
+ ip_value = self.to_v_ip(ip_hidden_states)
2187
+
2188
+ ip_key = attn.head_to_batch_dim(ip_key)
2189
+ ip_value = attn.head_to_batch_dim(ip_value)
2190
+
2191
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
2192
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
2193
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
2194
+
2195
+ hidden_states = hidden_states + self.scale * ip_hidden_states
2196
+
2197
+ # linear proj
2198
+ hidden_states = attn.to_out[0](hidden_states)
2199
+ # dropout
2200
+ hidden_states = attn.to_out[1](hidden_states)
2201
+
2202
+ if input_ndim == 4:
2203
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2204
+
2205
+ if attn.residual_connection:
2206
+ hidden_states = hidden_states + residual
2207
+
2208
+ hidden_states = hidden_states / attn.rescale_output_factor
2209
+
2210
+ return hidden_states
2211
+
2212
+
2213
+ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2214
+ r"""
2215
+ Attention processor for IP-Adapater for PyTorch 2.0.
2216
+
2217
+ Args:
2218
+ hidden_size (`int`):
2219
+ The hidden size of the attention layer.
2220
+ cross_attention_dim (`int`):
2221
+ The number of channels in the `encoder_hidden_states`.
2222
+ num_tokens (`int`, defaults to 4):
2223
+ The context length of the image features.
2224
+ scale (`float`, defaults to 1.0):
2225
+ the weight scale of image prompt.
2226
+ """
2227
+
2228
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
2229
+ super().__init__()
2230
+
2231
+ if not hasattr(F, "scaled_dot_product_attention"):
2232
+ raise ImportError(
2233
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2234
+ )
2235
+
2236
+ self.hidden_size = hidden_size
2237
+ self.cross_attention_dim = cross_attention_dim
2238
+ self.num_tokens = num_tokens
2239
+ self.scale = scale
2240
+
2241
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
2242
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
2243
+
2244
+ def __call__(
2245
+ self,
2246
+ attn,
2247
+ hidden_states,
2248
+ encoder_hidden_states=None,
2249
+ attention_mask=None,
2250
+ temb=None,
2251
+ scale=1.0,
2252
+ ):
2253
+ if scale != 1.0:
2254
+ logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.")
2255
+ residual = hidden_states
2256
+
2257
+ if attn.spatial_norm is not None:
2258
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2259
+
2260
+ input_ndim = hidden_states.ndim
2261
+
2262
+ if input_ndim == 4:
2263
+ batch_size, channel, height, width = hidden_states.shape
2264
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2265
+
2266
+ batch_size, sequence_length, _ = (
2267
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2268
+ )
2269
+
2270
+ if attention_mask is not None:
2271
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2272
+ # scaled_dot_product_attention expects attention_mask shape to be
2273
+ # (batch, heads, source_length, target_length)
2274
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2275
+
2276
+ if attn.group_norm is not None:
2277
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2278
+
2279
+ query = attn.to_q(hidden_states)
2280
+
2281
+ if encoder_hidden_states is None:
2282
+ encoder_hidden_states = hidden_states
2283
+ elif attn.norm_cross:
2284
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2285
+
2286
+ # split hidden states
2287
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
2288
+ encoder_hidden_states, ip_hidden_states = (
2289
+ encoder_hidden_states[:, :end_pos, :],
2290
+ encoder_hidden_states[:, end_pos:, :],
2291
+ )
2292
+
2293
+ key = attn.to_k(encoder_hidden_states)
2294
+ value = attn.to_v(encoder_hidden_states)
2295
+
2296
+ inner_dim = key.shape[-1]
2297
+ head_dim = inner_dim // attn.heads
2298
+
2299
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2300
+
2301
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2302
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2303
+
2304
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2305
+ # TODO: add support for attn.scale when we move to Torch 2.1
2306
+ hidden_states = F.scaled_dot_product_attention(
2307
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2308
+ )
2309
+
2310
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2311
+ hidden_states = hidden_states.to(query.dtype)
2312
+
2313
+ # for ip-adapter
2314
+ ip_key = self.to_k_ip(ip_hidden_states)
2315
+ ip_value = self.to_v_ip(ip_hidden_states)
2316
+
2317
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2318
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2319
+
2320
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2321
+ # TODO: add support for attn.scale when we move to Torch 2.1
2322
+ ip_hidden_states = F.scaled_dot_product_attention(
2323
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2324
+ )
2325
+
2326
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2327
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
2328
+
2329
+ hidden_states = hidden_states + self.scale * ip_hidden_states
2330
+
2331
+ # linear proj
2332
+ hidden_states = attn.to_out[0](hidden_states)
2333
+ # dropout
2334
+ hidden_states = attn.to_out[1](hidden_states)
2335
+
2336
+ if input_ndim == 4:
2337
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2338
+
2339
+ if attn.residual_connection:
2340
+ hidden_states = hidden_states + residual
2341
+
2342
+ hidden_states = hidden_states / attn.rescale_output_factor
2343
+
2344
+ return hidden_states
2345
+
2346
+
1978
2347
  LORA_ATTENTION_PROCESSORS = (
1979
2348
  LoRAAttnProcessor,
1980
2349
  LoRAAttnProcessor2_0,
@@ -1998,11 +2367,14 @@ CROSS_ATTENTION_PROCESSORS = (
1998
2367
  LoRAAttnProcessor,
1999
2368
  LoRAAttnProcessor2_0,
2000
2369
  LoRAXFormersAttnProcessor,
2370
+ IPAdapterAttnProcessor,
2371
+ IPAdapterAttnProcessor2_0,
2001
2372
  )
2002
2373
 
2003
2374
  AttentionProcessor = Union[
2004
2375
  AttnProcessor,
2005
2376
  AttnProcessor2_0,
2377
+ FusedAttnProcessor2_0,
2006
2378
  XFormersAttnProcessor,
2007
2379
  SlicedAttnProcessor,
2008
2380
  AttnAddedKVProcessor,
@@ -0,0 +1,5 @@
1
+ from .autoencoder_asym_kl import AsymmetricAutoencoderKL
2
+ from .autoencoder_kl import AutoencoderKL
3
+ from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
4
+ from .autoencoder_tiny import AutoencoderTiny
5
+ from .consistency_decoder_vae import ConsistencyDecoderVAE
@@ -16,10 +16,10 @@ from typing import Optional, Tuple, Union
16
16
  import torch
17
17
  import torch.nn as nn
18
18
 
19
- from ..configuration_utils import ConfigMixin, register_to_config
20
- from ..utils.accelerate_utils import apply_forward_hook
21
- from .autoencoder_kl import AutoencoderKLOutput
22
- from .modeling_utils import ModelMixin
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...utils.accelerate_utils import apply_forward_hook
21
+ from ..modeling_outputs import AutoencoderKLOutput
22
+ from ..modeling_utils import ModelMixin
23
23
  from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
24
24
 
25
25
 
@@ -65,11 +65,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
65
65
  self,
66
66
  in_channels: int = 3,
67
67
  out_channels: int = 3,
68
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
69
- down_block_out_channels: Tuple[int] = (64,),
68
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
69
+ down_block_out_channels: Tuple[int, ...] = (64,),
70
70
  layers_per_down_block: int = 1,
71
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
72
- up_block_out_channels: Tuple[int] = (64,),
71
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
72
+ up_block_out_channels: Tuple[int, ...] = (64,),
73
73
  layers_per_up_block: int = 1,
74
74
  act_fn: str = "silu",
75
75
  latent_channels: int = 4,
@@ -108,8 +108,13 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
108
108
  self.use_slicing = False
109
109
  self.use_tiling = False
110
110
 
111
+ self.register_to_config(block_out_channels=up_block_out_channels)
112
+ self.register_to_config(force_upcast=False)
113
+
111
114
  @apply_forward_hook
112
- def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
115
+ def encode(
116
+ self, x: torch.FloatTensor, return_dict: bool = True
117
+ ) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
113
118
  h = self.encoder(x)
114
119
  moments = self.quant_conv(h)
115
120
  posterior = DiagonalGaussianDistribution(moments)
@@ -125,7 +130,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
125
130
  image: Optional[torch.FloatTensor] = None,
126
131
  mask: Optional[torch.FloatTensor] = None,
127
132
  return_dict: bool = True,
128
- ) -> Union[DecoderOutput, torch.FloatTensor]:
133
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
129
134
  z = self.post_quant_conv(z)
130
135
  dec = self.decoder(z, image, mask)
131
136
 
@@ -142,7 +147,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
142
147
  image: Optional[torch.FloatTensor] = None,
143
148
  mask: Optional[torch.FloatTensor] = None,
144
149
  return_dict: bool = True,
145
- ) -> Union[DecoderOutput, torch.FloatTensor]:
150
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
146
151
  decoded = self._decode(z, image, mask).sample
147
152
 
148
153
  if not return_dict:
@@ -157,7 +162,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
157
162
  sample_posterior: bool = False,
158
163
  return_dict: bool = True,
159
164
  generator: Optional[torch.Generator] = None,
160
- ) -> Union[DecoderOutput, torch.FloatTensor]:
165
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
161
166
  r"""
162
167
  Args:
163
168
  sample (`torch.FloatTensor`): Input sample.
@@ -11,41 +11,27 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from dataclasses import dataclass
15
14
  from typing import Dict, Optional, Tuple, Union
16
15
 
17
16
  import torch
18
17
  import torch.nn as nn
19
18
 
20
- from ..configuration_utils import ConfigMixin, register_to_config
21
- from ..loaders import FromOriginalVAEMixin
22
- from ..utils import BaseOutput
23
- from ..utils.accelerate_utils import apply_forward_hook
24
- from .attention_processor import (
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...loaders import FromOriginalVAEMixin
21
+ from ...utils.accelerate_utils import apply_forward_hook
22
+ from ..attention_processor import (
25
23
  ADDED_KV_ATTENTION_PROCESSORS,
26
24
  CROSS_ATTENTION_PROCESSORS,
25
+ Attention,
27
26
  AttentionProcessor,
28
27
  AttnAddedKVProcessor,
29
28
  AttnProcessor,
30
29
  )
31
- from .modeling_utils import ModelMixin
30
+ from ..modeling_outputs import AutoencoderKLOutput
31
+ from ..modeling_utils import ModelMixin
32
32
  from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
33
33
 
34
34
 
35
- @dataclass
36
- class AutoencoderKLOutput(BaseOutput):
37
- """
38
- Output of AutoencoderKL encoding method.
39
-
40
- Args:
41
- latent_dist (`DiagonalGaussianDistribution`):
42
- Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
43
- `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
44
- """
45
-
46
- latent_dist: "DiagonalGaussianDistribution"
47
-
48
-
49
35
  class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
50
36
  r"""
51
37
  A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -322,13 +308,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
322
308
 
323
309
  return DecoderOutput(sample=decoded)
324
310
 
325
- def blend_v(self, a, b, blend_extent):
311
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
326
312
  blend_extent = min(a.shape[2], b.shape[2], blend_extent)
327
313
  for y in range(blend_extent):
328
314
  b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
329
315
  return b
330
316
 
331
- def blend_h(self, a, b, blend_extent):
317
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
332
318
  blend_extent = min(a.shape[3], b.shape[3], blend_extent)
333
319
  for x in range(blend_extent):
334
320
  b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
@@ -463,3 +449,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
463
449
  return (dec,)
464
450
 
465
451
  return DecoderOutput(sample=dec)
452
+
453
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
454
+ def fuse_qkv_projections(self):
455
+ """
456
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
457
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
458
+
459
+ <Tip warning={true}>
460
+
461
+ This API is 🧪 experimental.
462
+
463
+ </Tip>
464
+ """
465
+ self.original_attn_processors = None
466
+
467
+ for _, attn_processor in self.attn_processors.items():
468
+ if "Added" in str(attn_processor.__class__.__name__):
469
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
470
+
471
+ self.original_attn_processors = self.attn_processors
472
+
473
+ for module in self.modules():
474
+ if isinstance(module, Attention):
475
+ module.fuse_projections(fuse=True)
476
+
477
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
478
+ def unfuse_qkv_projections(self):
479
+ """Disables the fused QKV projection if enabled.
480
+
481
+ <Tip warning={true}>
482
+
483
+ This API is 🧪 experimental.
484
+
485
+ </Tip>
486
+
487
+ """
488
+ if self.original_attn_processors is not None:
489
+ self.set_attn_processor(self.original_attn_processors)