diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (238) hide show
  1. diffusers/__init__.py +26 -2
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +13 -8
  4. diffusers/dependency_versions_check.py +0 -1
  5. diffusers/dependency_versions_table.py +5 -5
  6. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  7. diffusers/image_processor.py +463 -51
  8. diffusers/loaders/__init__.py +82 -0
  9. diffusers/loaders/ip_adapter.py +159 -0
  10. diffusers/loaders/lora.py +1553 -0
  11. diffusers/loaders/lora_conversion_utils.py +284 -0
  12. diffusers/loaders/single_file.py +637 -0
  13. diffusers/loaders/textual_inversion.py +455 -0
  14. diffusers/loaders/unet.py +828 -0
  15. diffusers/loaders/utils.py +59 -0
  16. diffusers/models/__init__.py +26 -9
  17. diffusers/models/activations.py +9 -6
  18. diffusers/models/attention.py +301 -29
  19. diffusers/models/attention_flax.py +9 -1
  20. diffusers/models/attention_processor.py +378 -6
  21. diffusers/models/autoencoders/__init__.py +5 -0
  22. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
  23. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
  25. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
  26. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
  27. diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
  28. diffusers/models/controlnet.py +59 -39
  29. diffusers/models/controlnet_flax.py +19 -18
  30. diffusers/models/downsampling.py +338 -0
  31. diffusers/models/embeddings.py +112 -29
  32. diffusers/models/embeddings_flax.py +2 -0
  33. diffusers/models/lora.py +131 -1
  34. diffusers/models/modeling_flax_utils.py +14 -8
  35. diffusers/models/modeling_outputs.py +17 -0
  36. diffusers/models/modeling_utils.py +37 -29
  37. diffusers/models/normalization.py +110 -4
  38. diffusers/models/resnet.py +299 -652
  39. diffusers/models/transformer_2d.py +22 -5
  40. diffusers/models/transformer_temporal.py +183 -1
  41. diffusers/models/unet_2d_blocks_flax.py +5 -0
  42. diffusers/models/unet_2d_condition.py +46 -0
  43. diffusers/models/unet_2d_condition_flax.py +13 -13
  44. diffusers/models/unet_3d_blocks.py +957 -173
  45. diffusers/models/unet_3d_condition.py +16 -8
  46. diffusers/models/unet_kandinsky3.py +535 -0
  47. diffusers/models/unet_motion_model.py +48 -33
  48. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  49. diffusers/models/upsampling.py +454 -0
  50. diffusers/models/uvit_2d.py +471 -0
  51. diffusers/models/vae_flax.py +7 -0
  52. diffusers/models/vq_model.py +12 -3
  53. diffusers/optimization.py +16 -9
  54. diffusers/pipelines/__init__.py +137 -76
  55. diffusers/pipelines/amused/__init__.py +62 -0
  56. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  57. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  58. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  59. diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
  60. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  61. diffusers/pipelines/auto_pipeline.py +23 -13
  62. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  63. diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
  64. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
  65. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
  66. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
  67. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
  68. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
  69. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  70. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  71. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  72. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  73. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  74. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  75. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  76. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  77. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  78. diffusers/pipelines/deprecated/__init__.py +153 -0
  79. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  80. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
  81. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
  82. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  83. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  84. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  85. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  86. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  87. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  88. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  89. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  90. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  91. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  92. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  93. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
  94. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  95. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  96. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  97. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  98. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  100. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
  101. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
  102. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
  103. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
  104. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
  105. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
  106. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  107. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  108. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  109. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
  110. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  111. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
  112. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
  113. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
  114. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  115. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  116. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  117. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  118. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  119. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  120. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  122. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  123. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  124. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
  125. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
  126. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
  127. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
  128. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  129. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  130. diffusers/pipelines/onnx_utils.py +8 -5
  131. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  132. diffusers/pipelines/pipeline_flax_utils.py +11 -8
  133. diffusers/pipelines/pipeline_utils.py +63 -42
  134. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
  135. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  136. diffusers/pipelines/stable_diffusion/__init__.py +37 -65
  137. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
  138. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  139. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  140. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  141. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
  142. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  143. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  144. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
  145. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
  146. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
  147. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  151. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  152. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
  153. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  154. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
  155. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  156. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
  157. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  158. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  159. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
  160. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  161. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
  162. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  163. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
  164. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  165. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  166. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
  171. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  172. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
  175. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
  180. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  181. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  182. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  183. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  184. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  185. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  186. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  187. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
  188. diffusers/schedulers/__init__.py +4 -4
  189. diffusers/schedulers/deprecated/__init__.py +50 -0
  190. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  191. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  192. diffusers/schedulers/scheduling_amused.py +162 -0
  193. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  194. diffusers/schedulers/scheduling_ddim.py +1 -3
  195. diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
  196. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  197. diffusers/schedulers/scheduling_ddpm.py +47 -3
  198. diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
  199. diffusers/schedulers/scheduling_deis_multistep.py +28 -6
  200. diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
  201. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
  202. diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
  203. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
  204. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
  205. diffusers/schedulers/scheduling_euler_discrete.py +102 -16
  206. diffusers/schedulers/scheduling_heun_discrete.py +17 -5
  207. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
  208. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
  209. diffusers/schedulers/scheduling_lcm.py +123 -29
  210. diffusers/schedulers/scheduling_lms_discrete.py +3 -3
  211. diffusers/schedulers/scheduling_pndm.py +1 -3
  212. diffusers/schedulers/scheduling_repaint.py +1 -3
  213. diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
  214. diffusers/schedulers/scheduling_utils.py +3 -1
  215. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  216. diffusers/training_utils.py +1 -1
  217. diffusers/utils/__init__.py +1 -2
  218. diffusers/utils/constants.py +10 -12
  219. diffusers/utils/dummy_pt_objects.py +75 -0
  220. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  221. diffusers/utils/dynamic_modules_utils.py +18 -22
  222. diffusers/utils/export_utils.py +8 -3
  223. diffusers/utils/hub_utils.py +24 -36
  224. diffusers/utils/logging.py +11 -11
  225. diffusers/utils/outputs.py +5 -5
  226. diffusers/utils/peft_utils.py +88 -44
  227. diffusers/utils/state_dict_utils.py +8 -0
  228. diffusers/utils/testing_utils.py +199 -1
  229. diffusers/utils/torch_utils.py +4 -4
  230. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
  231. diffusers-0.25.0.dist-info/RECORD +360 -0
  232. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  233. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  234. diffusers/loaders.py +0 -3336
  235. diffusers-0.23.1.dist-info/RECORD +0 -323
  236. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  237. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  238. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,402 @@
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...loaders import FromOriginalVAEMixin
21
+ from ...utils import is_torch_version
22
+ from ...utils.accelerate_utils import apply_forward_hook
23
+ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
24
+ from ..modeling_outputs import AutoencoderKLOutput
25
+ from ..modeling_utils import ModelMixin
26
+ from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
27
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
28
+
29
+
30
+ class TemporalDecoder(nn.Module):
31
+ def __init__(
32
+ self,
33
+ in_channels: int = 4,
34
+ out_channels: int = 3,
35
+ block_out_channels: Tuple[int] = (128, 256, 512, 512),
36
+ layers_per_block: int = 2,
37
+ ):
38
+ super().__init__()
39
+ self.layers_per_block = layers_per_block
40
+
41
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
42
+ self.mid_block = MidBlockTemporalDecoder(
43
+ num_layers=self.layers_per_block,
44
+ in_channels=block_out_channels[-1],
45
+ out_channels=block_out_channels[-1],
46
+ attention_head_dim=block_out_channels[-1],
47
+ )
48
+
49
+ # up
50
+ self.up_blocks = nn.ModuleList([])
51
+ reversed_block_out_channels = list(reversed(block_out_channels))
52
+ output_channel = reversed_block_out_channels[0]
53
+ for i in range(len(block_out_channels)):
54
+ prev_output_channel = output_channel
55
+ output_channel = reversed_block_out_channels[i]
56
+
57
+ is_final_block = i == len(block_out_channels) - 1
58
+ up_block = UpBlockTemporalDecoder(
59
+ num_layers=self.layers_per_block + 1,
60
+ in_channels=prev_output_channel,
61
+ out_channels=output_channel,
62
+ add_upsample=not is_final_block,
63
+ )
64
+ self.up_blocks.append(up_block)
65
+ prev_output_channel = output_channel
66
+
67
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6)
68
+
69
+ self.conv_act = nn.SiLU()
70
+ self.conv_out = torch.nn.Conv2d(
71
+ in_channels=block_out_channels[0],
72
+ out_channels=out_channels,
73
+ kernel_size=3,
74
+ padding=1,
75
+ )
76
+
77
+ conv_out_kernel_size = (3, 1, 1)
78
+ padding = [int(k // 2) for k in conv_out_kernel_size]
79
+ self.time_conv_out = torch.nn.Conv3d(
80
+ in_channels=out_channels,
81
+ out_channels=out_channels,
82
+ kernel_size=conv_out_kernel_size,
83
+ padding=padding,
84
+ )
85
+
86
+ self.gradient_checkpointing = False
87
+
88
+ def forward(
89
+ self,
90
+ sample: torch.FloatTensor,
91
+ image_only_indicator: torch.FloatTensor,
92
+ num_frames: int = 1,
93
+ ) -> torch.FloatTensor:
94
+ r"""The forward method of the `Decoder` class."""
95
+
96
+ sample = self.conv_in(sample)
97
+
98
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
99
+ if self.training and self.gradient_checkpointing:
100
+
101
+ def create_custom_forward(module):
102
+ def custom_forward(*inputs):
103
+ return module(*inputs)
104
+
105
+ return custom_forward
106
+
107
+ if is_torch_version(">=", "1.11.0"):
108
+ # middle
109
+ sample = torch.utils.checkpoint.checkpoint(
110
+ create_custom_forward(self.mid_block),
111
+ sample,
112
+ image_only_indicator,
113
+ use_reentrant=False,
114
+ )
115
+ sample = sample.to(upscale_dtype)
116
+
117
+ # up
118
+ for up_block in self.up_blocks:
119
+ sample = torch.utils.checkpoint.checkpoint(
120
+ create_custom_forward(up_block),
121
+ sample,
122
+ image_only_indicator,
123
+ use_reentrant=False,
124
+ )
125
+ else:
126
+ # middle
127
+ sample = torch.utils.checkpoint.checkpoint(
128
+ create_custom_forward(self.mid_block),
129
+ sample,
130
+ image_only_indicator,
131
+ )
132
+ sample = sample.to(upscale_dtype)
133
+
134
+ # up
135
+ for up_block in self.up_blocks:
136
+ sample = torch.utils.checkpoint.checkpoint(
137
+ create_custom_forward(up_block),
138
+ sample,
139
+ image_only_indicator,
140
+ )
141
+ else:
142
+ # middle
143
+ sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
144
+ sample = sample.to(upscale_dtype)
145
+
146
+ # up
147
+ for up_block in self.up_blocks:
148
+ sample = up_block(sample, image_only_indicator=image_only_indicator)
149
+
150
+ # post-process
151
+ sample = self.conv_norm_out(sample)
152
+ sample = self.conv_act(sample)
153
+ sample = self.conv_out(sample)
154
+
155
+ batch_frames, channels, height, width = sample.shape
156
+ batch_size = batch_frames // num_frames
157
+ sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
158
+ sample = self.time_conv_out(sample)
159
+
160
+ sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
161
+
162
+ return sample
163
+
164
+
165
+ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
166
+ r"""
167
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
168
+
169
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
170
+ for all models (such as downloading or saving).
171
+
172
+ Parameters:
173
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
174
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
175
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
176
+ Tuple of downsample block types.
177
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
178
+ Tuple of block output channels.
179
+ layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
180
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
181
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
182
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
183
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
184
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
185
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
186
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
187
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
188
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
189
+ force_upcast (`bool`, *optional*, default to `True`):
190
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
191
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
192
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
193
+ """
194
+
195
+ _supports_gradient_checkpointing = True
196
+
197
+ @register_to_config
198
+ def __init__(
199
+ self,
200
+ in_channels: int = 3,
201
+ out_channels: int = 3,
202
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
203
+ block_out_channels: Tuple[int] = (64,),
204
+ layers_per_block: int = 1,
205
+ latent_channels: int = 4,
206
+ sample_size: int = 32,
207
+ scaling_factor: float = 0.18215,
208
+ force_upcast: float = True,
209
+ ):
210
+ super().__init__()
211
+
212
+ # pass init params to Encoder
213
+ self.encoder = Encoder(
214
+ in_channels=in_channels,
215
+ out_channels=latent_channels,
216
+ down_block_types=down_block_types,
217
+ block_out_channels=block_out_channels,
218
+ layers_per_block=layers_per_block,
219
+ double_z=True,
220
+ )
221
+
222
+ # pass init params to Decoder
223
+ self.decoder = TemporalDecoder(
224
+ in_channels=latent_channels,
225
+ out_channels=out_channels,
226
+ block_out_channels=block_out_channels,
227
+ layers_per_block=layers_per_block,
228
+ )
229
+
230
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
231
+
232
+ sample_size = (
233
+ self.config.sample_size[0]
234
+ if isinstance(self.config.sample_size, (list, tuple))
235
+ else self.config.sample_size
236
+ )
237
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
238
+ self.tile_overlap_factor = 0.25
239
+
240
+ def _set_gradient_checkpointing(self, module, value=False):
241
+ if isinstance(module, (Encoder, TemporalDecoder)):
242
+ module.gradient_checkpointing = value
243
+
244
+ @property
245
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
246
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
247
+ r"""
248
+ Returns:
249
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
250
+ indexed by its weight name.
251
+ """
252
+ # set recursively
253
+ processors = {}
254
+
255
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
256
+ if hasattr(module, "get_processor"):
257
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
258
+
259
+ for sub_name, child in module.named_children():
260
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
261
+
262
+ return processors
263
+
264
+ for name, module in self.named_children():
265
+ fn_recursive_add_processors(name, module, processors)
266
+
267
+ return processors
268
+
269
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
270
+ def set_attn_processor(
271
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
272
+ ):
273
+ r"""
274
+ Sets the attention processor to use to compute attention.
275
+
276
+ Parameters:
277
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
278
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
279
+ for **all** `Attention` layers.
280
+
281
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
282
+ processor. This is strongly recommended when setting trainable attention processors.
283
+
284
+ """
285
+ count = len(self.attn_processors.keys())
286
+
287
+ if isinstance(processor, dict) and len(processor) != count:
288
+ raise ValueError(
289
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
290
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
291
+ )
292
+
293
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
294
+ if hasattr(module, "set_processor"):
295
+ if not isinstance(processor, dict):
296
+ module.set_processor(processor, _remove_lora=_remove_lora)
297
+ else:
298
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
299
+
300
+ for sub_name, child in module.named_children():
301
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
302
+
303
+ for name, module in self.named_children():
304
+ fn_recursive_attn_processor(name, module, processor)
305
+
306
+ def set_default_attn_processor(self):
307
+ """
308
+ Disables custom attention processors and sets the default attention implementation.
309
+ """
310
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
311
+ processor = AttnProcessor()
312
+ else:
313
+ raise ValueError(
314
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
315
+ )
316
+
317
+ self.set_attn_processor(processor, _remove_lora=True)
318
+
319
+ @apply_forward_hook
320
+ def encode(
321
+ self, x: torch.FloatTensor, return_dict: bool = True
322
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
323
+ """
324
+ Encode a batch of images into latents.
325
+
326
+ Args:
327
+ x (`torch.FloatTensor`): Input batch of images.
328
+ return_dict (`bool`, *optional*, defaults to `True`):
329
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
330
+
331
+ Returns:
332
+ The latent representations of the encoded images. If `return_dict` is True, a
333
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
334
+ """
335
+ h = self.encoder(x)
336
+ moments = self.quant_conv(h)
337
+ posterior = DiagonalGaussianDistribution(moments)
338
+
339
+ if not return_dict:
340
+ return (posterior,)
341
+
342
+ return AutoencoderKLOutput(latent_dist=posterior)
343
+
344
+ @apply_forward_hook
345
+ def decode(
346
+ self,
347
+ z: torch.FloatTensor,
348
+ num_frames: int,
349
+ return_dict: bool = True,
350
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
351
+ """
352
+ Decode a batch of images.
353
+
354
+ Args:
355
+ z (`torch.FloatTensor`): Input batch of latent vectors.
356
+ return_dict (`bool`, *optional*, defaults to `True`):
357
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
358
+
359
+ Returns:
360
+ [`~models.vae.DecoderOutput`] or `tuple`:
361
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
362
+ returned.
363
+
364
+ """
365
+ batch_size = z.shape[0] // num_frames
366
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
367
+ decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
368
+
369
+ if not return_dict:
370
+ return (decoded,)
371
+
372
+ return DecoderOutput(sample=decoded)
373
+
374
+ def forward(
375
+ self,
376
+ sample: torch.FloatTensor,
377
+ sample_posterior: bool = False,
378
+ return_dict: bool = True,
379
+ generator: Optional[torch.Generator] = None,
380
+ num_frames: int = 1,
381
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
382
+ r"""
383
+ Args:
384
+ sample (`torch.FloatTensor`): Input sample.
385
+ sample_posterior (`bool`, *optional*, defaults to `False`):
386
+ Whether to sample from the posterior.
387
+ return_dict (`bool`, *optional*, defaults to `True`):
388
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
389
+ """
390
+ x = sample
391
+ posterior = self.encode(x).latent_dist
392
+ if sample_posterior:
393
+ z = posterior.sample(generator=generator)
394
+ else:
395
+ z = posterior.mode()
396
+
397
+ dec = self.decode(z, num_frames=num_frames).sample
398
+
399
+ if not return_dict:
400
+ return (dec,)
401
+
402
+ return DecoderOutput(sample=dec)
@@ -18,10 +18,10 @@ from typing import Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
 
21
- from ..configuration_utils import ConfigMixin, register_to_config
22
- from ..utils import BaseOutput
23
- from ..utils.accelerate_utils import apply_forward_hook
24
- from .modeling_utils import ModelMixin
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...utils import BaseOutput
23
+ from ...utils.accelerate_utils import apply_forward_hook
24
+ from ..modeling_utils import ModelMixin
25
25
  from .vae import DecoderOutput, DecoderTiny, EncoderTiny
26
26
 
27
27
 
@@ -91,23 +91,24 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
91
91
  `force_upcast` can be set to `False` (see this fp16-friendly
92
92
  [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
93
93
  """
94
+
94
95
  _supports_gradient_checkpointing = True
95
96
 
96
97
  @register_to_config
97
98
  def __init__(
98
99
  self,
99
- in_channels=3,
100
- out_channels=3,
101
- encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
102
- decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
100
+ in_channels: int = 3,
101
+ out_channels: int = 3,
102
+ encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
103
+ decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
103
104
  act_fn: str = "relu",
104
105
  latent_channels: int = 4,
105
106
  upsampling_scaling_factor: int = 2,
106
- num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
107
- num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
107
+ num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
108
+ num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
108
109
  latent_magnitude: int = 3,
109
110
  latent_shift: float = 0.5,
110
- force_upcast: float = False,
111
+ force_upcast: bool = False,
111
112
  scaling_factor: float = 1.0,
112
113
  ):
113
114
  super().__init__()
@@ -147,33 +148,36 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
147
148
  self.tile_sample_min_size = 512
148
149
  self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
149
150
 
150
- def _set_gradient_checkpointing(self, module, value=False):
151
+ self.register_to_config(block_out_channels=decoder_block_out_channels)
152
+ self.register_to_config(force_upcast=False)
153
+
154
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
151
155
  if isinstance(module, (EncoderTiny, DecoderTiny)):
152
156
  module.gradient_checkpointing = value
153
157
 
154
- def scale_latents(self, x):
158
+ def scale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
155
159
  """raw latents -> [0, 1]"""
156
160
  return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
157
161
 
158
- def unscale_latents(self, x):
162
+ def unscale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
159
163
  """[0, 1] -> raw latents"""
160
164
  return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
161
165
 
162
- def enable_slicing(self):
166
+ def enable_slicing(self) -> None:
163
167
  r"""
164
168
  Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
165
169
  compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
166
170
  """
167
171
  self.use_slicing = True
168
172
 
169
- def disable_slicing(self):
173
+ def disable_slicing(self) -> None:
170
174
  r"""
171
175
  Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
172
176
  decoding in one step.
173
177
  """
174
178
  self.use_slicing = False
175
179
 
176
- def enable_tiling(self, use_tiling: bool = True):
180
+ def enable_tiling(self, use_tiling: bool = True) -> None:
177
181
  r"""
178
182
  Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
179
183
  compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
@@ -181,7 +185,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
181
185
  """
182
186
  self.use_tiling = use_tiling
183
187
 
184
- def disable_tiling(self):
188
+ def disable_tiling(self) -> None:
185
189
  r"""
186
190
  Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
187
191
  decoding in one step.
@@ -197,13 +201,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
197
201
 
198
202
  Args:
199
203
  x (`torch.FloatTensor`): Input batch of images.
200
- return_dict (`bool`, *optional*, defaults to `True`):
201
- Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
202
204
 
203
205
  Returns:
204
- [`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
205
- If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
206
- plain `tuple` is returned.
206
+ `torch.FloatTensor`: Encoded batch of images.
207
207
  """
208
208
  # scale of encoder output relative to input
209
209
  sf = self.spatial_scale_factor
@@ -249,13 +249,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
249
249
 
250
250
  Args:
251
251
  x (`torch.FloatTensor`): Input batch of images.
252
- return_dict (`bool`, *optional*, defaults to `True`):
253
- Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
254
252
 
255
253
  Returns:
256
- [`~models.vae.DecoderOutput`] or `tuple`:
257
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
258
- returned.
254
+ `torch.FloatTensor`: Encoded batch of images.
259
255
  """
260
256
  # scale of decoder output relative to input
261
257
  sf = self.spatial_scale_factor