diffusers 0.23.1__py3-none-any.whl → 0.24.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (176) hide show
  1. diffusers/__init__.py +16 -2
  2. diffusers/configuration_utils.py +1 -0
  3. diffusers/dependency_versions_check.py +0 -1
  4. diffusers/dependency_versions_table.py +4 -5
  5. diffusers/image_processor.py +186 -14
  6. diffusers/loaders/__init__.py +82 -0
  7. diffusers/loaders/ip_adapter.py +157 -0
  8. diffusers/loaders/lora.py +1415 -0
  9. diffusers/loaders/lora_conversion_utils.py +284 -0
  10. diffusers/loaders/single_file.py +631 -0
  11. diffusers/loaders/textual_inversion.py +459 -0
  12. diffusers/loaders/unet.py +735 -0
  13. diffusers/loaders/utils.py +59 -0
  14. diffusers/models/__init__.py +12 -1
  15. diffusers/models/attention.py +165 -14
  16. diffusers/models/attention_flax.py +9 -1
  17. diffusers/models/attention_processor.py +286 -1
  18. diffusers/models/autoencoder_asym_kl.py +14 -9
  19. diffusers/models/autoencoder_kl.py +3 -18
  20. diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
  21. diffusers/models/autoencoder_tiny.py +20 -24
  22. diffusers/models/consistency_decoder_vae.py +37 -30
  23. diffusers/models/controlnet.py +59 -39
  24. diffusers/models/controlnet_flax.py +19 -18
  25. diffusers/models/embeddings_flax.py +2 -0
  26. diffusers/models/lora.py +131 -1
  27. diffusers/models/modeling_flax_utils.py +2 -1
  28. diffusers/models/modeling_outputs.py +17 -0
  29. diffusers/models/modeling_utils.py +27 -19
  30. diffusers/models/normalization.py +2 -2
  31. diffusers/models/resnet.py +390 -59
  32. diffusers/models/transformer_2d.py +20 -3
  33. diffusers/models/transformer_temporal.py +183 -1
  34. diffusers/models/unet_2d_blocks_flax.py +5 -0
  35. diffusers/models/unet_2d_condition.py +9 -0
  36. diffusers/models/unet_2d_condition_flax.py +13 -13
  37. diffusers/models/unet_3d_blocks.py +957 -173
  38. diffusers/models/unet_3d_condition.py +16 -8
  39. diffusers/models/unet_kandi3.py +589 -0
  40. diffusers/models/unet_motion_model.py +48 -33
  41. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  42. diffusers/models/vae.py +63 -13
  43. diffusers/models/vae_flax.py +7 -0
  44. diffusers/models/vq_model.py +3 -1
  45. diffusers/optimization.py +16 -9
  46. diffusers/pipelines/__init__.py +65 -12
  47. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
  48. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
  49. diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
  50. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  51. diffusers/pipelines/auto_pipeline.py +6 -0
  52. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  53. diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
  54. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
  55. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
  56. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
  57. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
  58. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
  59. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  60. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  61. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  62. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  63. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  64. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  65. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  66. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  67. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  68. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  69. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  70. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  71. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  72. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  73. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  74. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  75. diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
  76. diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
  77. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
  78. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
  79. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  80. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  81. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  82. diffusers/pipelines/pipeline_flax_utils.py +4 -2
  83. diffusers/pipelines/pipeline_utils.py +33 -13
  84. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
  85. diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
  86. diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
  87. diffusers/pipelines/stable_diffusion/__init__.py +64 -21
  88. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
  89. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
  90. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  91. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  92. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  93. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
  94. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
  95. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
  96. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  97. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
  98. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
  107. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
  108. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
  109. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
  110. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
  111. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
  112. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
  114. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
  115. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
  116. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
  117. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
  118. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  119. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
  120. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  121. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
  122. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  123. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
  124. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
  125. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
  126. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
  127. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
  128. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
  129. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
  130. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
  131. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  132. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  133. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  134. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  135. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
  136. diffusers/schedulers/__init__.py +2 -4
  137. diffusers/schedulers/deprecated/__init__.py +50 -0
  138. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  139. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  140. diffusers/schedulers/scheduling_ddim.py +1 -3
  141. diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
  142. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  143. diffusers/schedulers/scheduling_ddpm.py +1 -3
  144. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
  145. diffusers/schedulers/scheduling_deis_multistep.py +15 -5
  146. diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
  147. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
  148. diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
  149. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
  150. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
  151. diffusers/schedulers/scheduling_euler_discrete.py +40 -13
  152. diffusers/schedulers/scheduling_heun_discrete.py +15 -5
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
  155. diffusers/schedulers/scheduling_lcm.py +123 -29
  156. diffusers/schedulers/scheduling_lms_discrete.py +1 -3
  157. diffusers/schedulers/scheduling_pndm.py +1 -3
  158. diffusers/schedulers/scheduling_repaint.py +1 -3
  159. diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
  160. diffusers/utils/__init__.py +1 -0
  161. diffusers/utils/constants.py +8 -7
  162. diffusers/utils/dummy_pt_objects.py +45 -0
  163. diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
  164. diffusers/utils/dynamic_modules_utils.py +4 -4
  165. diffusers/utils/export_utils.py +8 -3
  166. diffusers/utils/logging.py +10 -10
  167. diffusers/utils/outputs.py +5 -5
  168. diffusers/utils/peft_utils.py +88 -44
  169. diffusers/utils/torch_utils.py +2 -2
  170. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/METADATA +38 -22
  171. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/RECORD +175 -157
  172. diffusers/loaders.py +0 -3336
  173. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/WHEEL +0 -0
  175. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +0 -0
  176. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,489 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..configuration_utils import ConfigMixin, register_to_config
8
+ from ..loaders import UNet2DConditionLoadersMixin
9
+ from ..utils import BaseOutput, logging
10
+ from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
11
+ from .embeddings import TimestepEmbedding, Timesteps
12
+ from .modeling_utils import ModelMixin
13
+ from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
14
+
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ @dataclass
20
+ class UNetSpatioTemporalConditionOutput(BaseOutput):
21
+ """
22
+ The output of [`UNetSpatioTemporalConditionModel`].
23
+
24
+ Args:
25
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
26
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
27
+ """
28
+
29
+ sample: torch.FloatTensor = None
30
+
31
+
32
+ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
33
+ r"""
34
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
35
+ shaped output.
36
+
37
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
38
+ for all models (such as downloading or saving).
39
+
40
+ Parameters:
41
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
42
+ Height and width of input/output sample.
43
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
44
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
45
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
46
+ The tuple of downsample blocks to use.
47
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
48
+ The tuple of upsample blocks to use.
49
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
50
+ The tuple of output channels for each block.
51
+ addition_time_embed_dim: (`int`, defaults to 256):
52
+ Dimension to to encode the additional time ids.
53
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
54
+ The dimension of the projection of encoded `added_time_ids`.
55
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
56
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
57
+ The dimension of the cross attention features.
58
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
59
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
60
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
61
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
62
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
63
+ The number of attention heads.
64
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
65
+ """
66
+
67
+ _supports_gradient_checkpointing = True
68
+
69
+ @register_to_config
70
+ def __init__(
71
+ self,
72
+ sample_size: Optional[int] = None,
73
+ in_channels: int = 8,
74
+ out_channels: int = 4,
75
+ down_block_types: Tuple[str] = (
76
+ "CrossAttnDownBlockSpatioTemporal",
77
+ "CrossAttnDownBlockSpatioTemporal",
78
+ "CrossAttnDownBlockSpatioTemporal",
79
+ "DownBlockSpatioTemporal",
80
+ ),
81
+ up_block_types: Tuple[str] = (
82
+ "UpBlockSpatioTemporal",
83
+ "CrossAttnUpBlockSpatioTemporal",
84
+ "CrossAttnUpBlockSpatioTemporal",
85
+ "CrossAttnUpBlockSpatioTemporal",
86
+ ),
87
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
88
+ addition_time_embed_dim: int = 256,
89
+ projection_class_embeddings_input_dim: int = 768,
90
+ layers_per_block: Union[int, Tuple[int]] = 2,
91
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
92
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
93
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
94
+ num_frames: int = 25,
95
+ ):
96
+ super().__init__()
97
+
98
+ self.sample_size = sample_size
99
+
100
+ # Check inputs
101
+ if len(down_block_types) != len(up_block_types):
102
+ raise ValueError(
103
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
104
+ )
105
+
106
+ if len(block_out_channels) != len(down_block_types):
107
+ raise ValueError(
108
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
109
+ )
110
+
111
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
112
+ raise ValueError(
113
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
114
+ )
115
+
116
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
117
+ raise ValueError(
118
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
119
+ )
120
+
121
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
122
+ raise ValueError(
123
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
124
+ )
125
+
126
+ # input
127
+ self.conv_in = nn.Conv2d(
128
+ in_channels,
129
+ block_out_channels[0],
130
+ kernel_size=3,
131
+ padding=1,
132
+ )
133
+
134
+ # time
135
+ time_embed_dim = block_out_channels[0] * 4
136
+
137
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
138
+ timestep_input_dim = block_out_channels[0]
139
+
140
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
141
+
142
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
143
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
144
+
145
+ self.down_blocks = nn.ModuleList([])
146
+ self.up_blocks = nn.ModuleList([])
147
+
148
+ if isinstance(num_attention_heads, int):
149
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
150
+
151
+ if isinstance(cross_attention_dim, int):
152
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
153
+
154
+ if isinstance(layers_per_block, int):
155
+ layers_per_block = [layers_per_block] * len(down_block_types)
156
+
157
+ if isinstance(transformer_layers_per_block, int):
158
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
159
+
160
+ blocks_time_embed_dim = time_embed_dim
161
+
162
+ # down
163
+ output_channel = block_out_channels[0]
164
+ for i, down_block_type in enumerate(down_block_types):
165
+ input_channel = output_channel
166
+ output_channel = block_out_channels[i]
167
+ is_final_block = i == len(block_out_channels) - 1
168
+
169
+ down_block = get_down_block(
170
+ down_block_type,
171
+ num_layers=layers_per_block[i],
172
+ transformer_layers_per_block=transformer_layers_per_block[i],
173
+ in_channels=input_channel,
174
+ out_channels=output_channel,
175
+ temb_channels=blocks_time_embed_dim,
176
+ add_downsample=not is_final_block,
177
+ resnet_eps=1e-5,
178
+ cross_attention_dim=cross_attention_dim[i],
179
+ num_attention_heads=num_attention_heads[i],
180
+ resnet_act_fn="silu",
181
+ )
182
+ self.down_blocks.append(down_block)
183
+
184
+ # mid
185
+ self.mid_block = UNetMidBlockSpatioTemporal(
186
+ block_out_channels[-1],
187
+ temb_channels=blocks_time_embed_dim,
188
+ transformer_layers_per_block=transformer_layers_per_block[-1],
189
+ cross_attention_dim=cross_attention_dim[-1],
190
+ num_attention_heads=num_attention_heads[-1],
191
+ )
192
+
193
+ # count how many layers upsample the images
194
+ self.num_upsamplers = 0
195
+
196
+ # up
197
+ reversed_block_out_channels = list(reversed(block_out_channels))
198
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
199
+ reversed_layers_per_block = list(reversed(layers_per_block))
200
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
201
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
202
+
203
+ output_channel = reversed_block_out_channels[0]
204
+ for i, up_block_type in enumerate(up_block_types):
205
+ is_final_block = i == len(block_out_channels) - 1
206
+
207
+ prev_output_channel = output_channel
208
+ output_channel = reversed_block_out_channels[i]
209
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
210
+
211
+ # add upsample block for all BUT final layer
212
+ if not is_final_block:
213
+ add_upsample = True
214
+ self.num_upsamplers += 1
215
+ else:
216
+ add_upsample = False
217
+
218
+ up_block = get_up_block(
219
+ up_block_type,
220
+ num_layers=reversed_layers_per_block[i] + 1,
221
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
222
+ in_channels=input_channel,
223
+ out_channels=output_channel,
224
+ prev_output_channel=prev_output_channel,
225
+ temb_channels=blocks_time_embed_dim,
226
+ add_upsample=add_upsample,
227
+ resnet_eps=1e-5,
228
+ resolution_idx=i,
229
+ cross_attention_dim=reversed_cross_attention_dim[i],
230
+ num_attention_heads=reversed_num_attention_heads[i],
231
+ resnet_act_fn="silu",
232
+ )
233
+ self.up_blocks.append(up_block)
234
+ prev_output_channel = output_channel
235
+
236
+ # out
237
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
238
+ self.conv_act = nn.SiLU()
239
+
240
+ self.conv_out = nn.Conv2d(
241
+ block_out_channels[0],
242
+ out_channels,
243
+ kernel_size=3,
244
+ padding=1,
245
+ )
246
+
247
+ @property
248
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
249
+ r"""
250
+ Returns:
251
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
252
+ indexed by its weight name.
253
+ """
254
+ # set recursively
255
+ processors = {}
256
+
257
+ def fn_recursive_add_processors(
258
+ name: str,
259
+ module: torch.nn.Module,
260
+ processors: Dict[str, AttentionProcessor],
261
+ ):
262
+ if hasattr(module, "get_processor"):
263
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
264
+
265
+ for sub_name, child in module.named_children():
266
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
267
+
268
+ return processors
269
+
270
+ for name, module in self.named_children():
271
+ fn_recursive_add_processors(name, module, processors)
272
+
273
+ return processors
274
+
275
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
276
+ r"""
277
+ Sets the attention processor to use to compute attention.
278
+
279
+ Parameters:
280
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
281
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
282
+ for **all** `Attention` layers.
283
+
284
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
285
+ processor. This is strongly recommended when setting trainable attention processors.
286
+
287
+ """
288
+ count = len(self.attn_processors.keys())
289
+
290
+ if isinstance(processor, dict) and len(processor) != count:
291
+ raise ValueError(
292
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
293
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
294
+ )
295
+
296
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
297
+ if hasattr(module, "set_processor"):
298
+ if not isinstance(processor, dict):
299
+ module.set_processor(processor)
300
+ else:
301
+ module.set_processor(processor.pop(f"{name}.processor"))
302
+
303
+ for sub_name, child in module.named_children():
304
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
305
+
306
+ for name, module in self.named_children():
307
+ fn_recursive_attn_processor(name, module, processor)
308
+
309
+ def set_default_attn_processor(self):
310
+ """
311
+ Disables custom attention processors and sets the default attention implementation.
312
+ """
313
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
314
+ processor = AttnProcessor()
315
+ else:
316
+ raise ValueError(
317
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
318
+ )
319
+
320
+ self.set_attn_processor(processor)
321
+
322
+ def _set_gradient_checkpointing(self, module, value=False):
323
+ if hasattr(module, "gradient_checkpointing"):
324
+ module.gradient_checkpointing = value
325
+
326
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
327
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
328
+ """
329
+ Sets the attention processor to use [feed forward
330
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
331
+
332
+ Parameters:
333
+ chunk_size (`int`, *optional*):
334
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
335
+ over each tensor of dim=`dim`.
336
+ dim (`int`, *optional*, defaults to `0`):
337
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
338
+ or dim=1 (sequence length).
339
+ """
340
+ if dim not in [0, 1]:
341
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
342
+
343
+ # By default chunk size is 1
344
+ chunk_size = chunk_size or 1
345
+
346
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
347
+ if hasattr(module, "set_chunk_feed_forward"):
348
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
349
+
350
+ for child in module.children():
351
+ fn_recursive_feed_forward(child, chunk_size, dim)
352
+
353
+ for module in self.children():
354
+ fn_recursive_feed_forward(module, chunk_size, dim)
355
+
356
+ def forward(
357
+ self,
358
+ sample: torch.FloatTensor,
359
+ timestep: Union[torch.Tensor, float, int],
360
+ encoder_hidden_states: torch.Tensor,
361
+ added_time_ids: torch.Tensor,
362
+ return_dict: bool = True,
363
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
364
+ r"""
365
+ The [`UNetSpatioTemporalConditionModel`] forward method.
366
+
367
+ Args:
368
+ sample (`torch.FloatTensor`):
369
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
370
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
371
+ encoder_hidden_states (`torch.FloatTensor`):
372
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
373
+ added_time_ids: (`torch.FloatTensor`):
374
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
375
+ embeddings and added to the time embeddings.
376
+ return_dict (`bool`, *optional*, defaults to `True`):
377
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
378
+ tuple.
379
+ Returns:
380
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
381
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
382
+ a `tuple` is returned where the first element is the sample tensor.
383
+ """
384
+ # 1. time
385
+ timesteps = timestep
386
+ if not torch.is_tensor(timesteps):
387
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
388
+ # This would be a good case for the `match` statement (Python 3.10+)
389
+ is_mps = sample.device.type == "mps"
390
+ if isinstance(timestep, float):
391
+ dtype = torch.float32 if is_mps else torch.float64
392
+ else:
393
+ dtype = torch.int32 if is_mps else torch.int64
394
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
395
+ elif len(timesteps.shape) == 0:
396
+ timesteps = timesteps[None].to(sample.device)
397
+
398
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
399
+ batch_size, num_frames = sample.shape[:2]
400
+ timesteps = timesteps.expand(batch_size)
401
+
402
+ t_emb = self.time_proj(timesteps)
403
+
404
+ # `Timesteps` does not contain any weights and will always return f32 tensors
405
+ # but time_embedding might actually be running in fp16. so we need to cast here.
406
+ # there might be better ways to encapsulate this.
407
+ t_emb = t_emb.to(dtype=sample.dtype)
408
+
409
+ emb = self.time_embedding(t_emb)
410
+
411
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
412
+ time_embeds = time_embeds.reshape((batch_size, -1))
413
+ time_embeds = time_embeds.to(emb.dtype)
414
+ aug_emb = self.add_embedding(time_embeds)
415
+ emb = emb + aug_emb
416
+
417
+ # Flatten the batch and frames dimensions
418
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
419
+ sample = sample.flatten(0, 1)
420
+ # Repeat the embeddings num_video_frames times
421
+ # emb: [batch, channels] -> [batch * frames, channels]
422
+ emb = emb.repeat_interleave(num_frames, dim=0)
423
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
424
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
425
+
426
+ # 2. pre-process
427
+ sample = self.conv_in(sample)
428
+
429
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
430
+
431
+ down_block_res_samples = (sample,)
432
+ for downsample_block in self.down_blocks:
433
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
434
+ sample, res_samples = downsample_block(
435
+ hidden_states=sample,
436
+ temb=emb,
437
+ encoder_hidden_states=encoder_hidden_states,
438
+ image_only_indicator=image_only_indicator,
439
+ )
440
+ else:
441
+ sample, res_samples = downsample_block(
442
+ hidden_states=sample,
443
+ temb=emb,
444
+ image_only_indicator=image_only_indicator,
445
+ )
446
+
447
+ down_block_res_samples += res_samples
448
+
449
+ # 4. mid
450
+ sample = self.mid_block(
451
+ hidden_states=sample,
452
+ temb=emb,
453
+ encoder_hidden_states=encoder_hidden_states,
454
+ image_only_indicator=image_only_indicator,
455
+ )
456
+
457
+ # 5. up
458
+ for i, upsample_block in enumerate(self.up_blocks):
459
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
460
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
461
+
462
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
463
+ sample = upsample_block(
464
+ hidden_states=sample,
465
+ temb=emb,
466
+ res_hidden_states_tuple=res_samples,
467
+ encoder_hidden_states=encoder_hidden_states,
468
+ image_only_indicator=image_only_indicator,
469
+ )
470
+ else:
471
+ sample = upsample_block(
472
+ hidden_states=sample,
473
+ temb=emb,
474
+ res_hidden_states_tuple=res_samples,
475
+ image_only_indicator=image_only_indicator,
476
+ )
477
+
478
+ # 6. post-process
479
+ sample = self.conv_norm_out(sample)
480
+ sample = self.conv_act(sample)
481
+ sample = self.conv_out(sample)
482
+
483
+ # 7. Reshape back to original shape
484
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
485
+
486
+ if not return_dict:
487
+ return (sample,)
488
+
489
+ return UNetSpatioTemporalConditionOutput(sample=sample)
diffusers/models/vae.py CHANGED
@@ -22,7 +22,12 @@ from ..utils import BaseOutput, is_torch_version
22
22
  from ..utils.torch_utils import randn_tensor
23
23
  from .activations import get_activation
24
24
  from .attention_processor import SpatialNorm
25
- from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
25
+ from .unet_2d_blocks import (
26
+ AutoencoderTinyBlock,
27
+ UNetMidBlock2D,
28
+ get_down_block,
29
+ get_up_block,
30
+ )
26
31
 
27
32
 
28
33
  @dataclass
@@ -274,7 +279,9 @@ class Decoder(nn.Module):
274
279
  self.gradient_checkpointing = False
275
280
 
276
281
  def forward(
277
- self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None
282
+ self,
283
+ sample: torch.FloatTensor,
284
+ latent_embeds: Optional[torch.FloatTensor] = None,
278
285
  ) -> torch.FloatTensor:
279
286
  r"""The forward method of the `Decoder` class."""
280
287
 
@@ -292,14 +299,20 @@ class Decoder(nn.Module):
292
299
  if is_torch_version(">=", "1.11.0"):
293
300
  # middle
294
301
  sample = torch.utils.checkpoint.checkpoint(
295
- create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
302
+ create_custom_forward(self.mid_block),
303
+ sample,
304
+ latent_embeds,
305
+ use_reentrant=False,
296
306
  )
297
307
  sample = sample.to(upscale_dtype)
298
308
 
299
309
  # up
300
310
  for up_block in self.up_blocks:
301
311
  sample = torch.utils.checkpoint.checkpoint(
302
- create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
312
+ create_custom_forward(up_block),
313
+ sample,
314
+ latent_embeds,
315
+ use_reentrant=False,
303
316
  )
304
317
  else:
305
318
  # middle
@@ -540,7 +553,10 @@ class MaskConditionDecoder(nn.Module):
540
553
  if is_torch_version(">=", "1.11.0"):
541
554
  # middle
542
555
  sample = torch.utils.checkpoint.checkpoint(
543
- create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
556
+ create_custom_forward(self.mid_block),
557
+ sample,
558
+ latent_embeds,
559
+ use_reentrant=False,
544
560
  )
545
561
  sample = sample.to(upscale_dtype)
546
562
 
@@ -548,7 +564,10 @@ class MaskConditionDecoder(nn.Module):
548
564
  if image is not None and mask is not None:
549
565
  masked_image = (1 - mask) * image
550
566
  im_x = torch.utils.checkpoint.checkpoint(
551
- create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False
567
+ create_custom_forward(self.condition_encoder),
568
+ masked_image,
569
+ mask,
570
+ use_reentrant=False,
552
571
  )
553
572
 
554
573
  # up
@@ -558,7 +577,10 @@ class MaskConditionDecoder(nn.Module):
558
577
  mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
559
578
  sample = sample * mask_ + sample_ * (1 - mask_)
560
579
  sample = torch.utils.checkpoint.checkpoint(
561
- create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
580
+ create_custom_forward(up_block),
581
+ sample,
582
+ latent_embeds,
583
+ use_reentrant=False,
562
584
  )
563
585
  if image is not None and mask is not None:
564
586
  sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
@@ -573,7 +595,9 @@ class MaskConditionDecoder(nn.Module):
573
595
  if image is not None and mask is not None:
574
596
  masked_image = (1 - mask) * image
575
597
  im_x = torch.utils.checkpoint.checkpoint(
576
- create_custom_forward(self.condition_encoder), masked_image, mask
598
+ create_custom_forward(self.condition_encoder),
599
+ masked_image,
600
+ mask,
577
601
  )
578
602
 
579
603
  # up
@@ -754,7 +778,10 @@ class DiagonalGaussianDistribution(object):
754
778
  def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
755
779
  # make sure sample is on the same device as the parameters and has same dtype
756
780
  sample = randn_tensor(
757
- self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
781
+ self.mean.shape,
782
+ generator=generator,
783
+ device=self.parameters.device,
784
+ dtype=self.parameters.dtype,
758
785
  )
759
786
  x = self.mean + self.std * sample
760
787
  return x
@@ -764,7 +791,10 @@ class DiagonalGaussianDistribution(object):
764
791
  return torch.Tensor([0.0])
765
792
  else:
766
793
  if other is None:
767
- return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
794
+ return 0.5 * torch.sum(
795
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
796
+ dim=[1, 2, 3],
797
+ )
768
798
  else:
769
799
  return 0.5 * torch.sum(
770
800
  torch.pow(self.mean - other.mean, 2) / other.var
@@ -779,7 +809,10 @@ class DiagonalGaussianDistribution(object):
779
809
  if self.deterministic:
780
810
  return torch.Tensor([0.0])
781
811
  logtwopi = np.log(2.0 * np.pi)
782
- return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
812
+ return 0.5 * torch.sum(
813
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
814
+ dim=dims,
815
+ )
783
816
 
784
817
  def mode(self) -> torch.Tensor:
785
818
  return self.mean
@@ -820,7 +853,16 @@ class EncoderTiny(nn.Module):
820
853
  if i == 0:
821
854
  layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
822
855
  else:
823
- layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
856
+ layers.append(
857
+ nn.Conv2d(
858
+ num_channels,
859
+ num_channels,
860
+ kernel_size=3,
861
+ padding=1,
862
+ stride=2,
863
+ bias=False,
864
+ )
865
+ )
824
866
 
825
867
  for _ in range(num_block):
826
868
  layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
@@ -899,7 +941,15 @@ class DecoderTiny(nn.Module):
899
941
  layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
900
942
 
901
943
  conv_out_channel = num_channels if not is_final_block else out_channels
902
- layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
944
+ layers.append(
945
+ nn.Conv2d(
946
+ num_channels,
947
+ conv_out_channel,
948
+ kernel_size=3,
949
+ padding=1,
950
+ bias=is_final_block,
951
+ )
952
+ )
903
953
 
904
954
  self.layers = nn.Sequential(*layers)
905
955
  self.gradient_checkpointing = False