diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1149 @@
1
+ # Copyright 2024 The RhymesAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils.accelerate_utils import apply_forward_hook
24
+ from ..attention_processor import Attention, SpatialNorm
25
+ from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
26
+ from ..downsampling import Downsample2D
27
+ from ..modeling_outputs import AutoencoderKLOutput
28
+ from ..modeling_utils import ModelMixin
29
+ from ..resnet import ResnetBlock2D
30
+ from ..upsampling import Upsample2D
31
+
32
+
33
+ class AllegroTemporalConvLayer(nn.Module):
34
+ r"""
35
+ Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from:
36
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ in_dim: int,
42
+ out_dim: Optional[int] = None,
43
+ dropout: float = 0.0,
44
+ norm_num_groups: int = 32,
45
+ up_sample: bool = False,
46
+ down_sample: bool = False,
47
+ stride: int = 1,
48
+ ) -> None:
49
+ super().__init__()
50
+
51
+ out_dim = out_dim or in_dim
52
+ pad_h = pad_w = int((stride - 1) * 0.5)
53
+ pad_t = 0
54
+
55
+ self.down_sample = down_sample
56
+ self.up_sample = up_sample
57
+
58
+ if down_sample:
59
+ self.conv1 = nn.Sequential(
60
+ nn.GroupNorm(norm_num_groups, in_dim),
61
+ nn.SiLU(),
62
+ nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)),
63
+ )
64
+ elif up_sample:
65
+ self.conv1 = nn.Sequential(
66
+ nn.GroupNorm(norm_num_groups, in_dim),
67
+ nn.SiLU(),
68
+ nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)),
69
+ )
70
+ else:
71
+ self.conv1 = nn.Sequential(
72
+ nn.GroupNorm(norm_num_groups, in_dim),
73
+ nn.SiLU(),
74
+ nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
75
+ )
76
+ self.conv2 = nn.Sequential(
77
+ nn.GroupNorm(norm_num_groups, out_dim),
78
+ nn.SiLU(),
79
+ nn.Dropout(dropout),
80
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
81
+ )
82
+ self.conv3 = nn.Sequential(
83
+ nn.GroupNorm(norm_num_groups, out_dim),
84
+ nn.SiLU(),
85
+ nn.Dropout(dropout),
86
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
87
+ )
88
+ self.conv4 = nn.Sequential(
89
+ nn.GroupNorm(norm_num_groups, out_dim),
90
+ nn.SiLU(),
91
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
92
+ )
93
+
94
+ @staticmethod
95
+ def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor:
96
+ hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2)
97
+ hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2)
98
+ return hidden_states
99
+
100
+ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor:
101
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
102
+
103
+ if self.down_sample:
104
+ identity = hidden_states[:, :, ::2]
105
+ elif self.up_sample:
106
+ identity = hidden_states.repeat_interleave(2, dim=2)
107
+ else:
108
+ identity = hidden_states
109
+
110
+ if self.down_sample or self.up_sample:
111
+ hidden_states = self.conv1(hidden_states)
112
+ else:
113
+ hidden_states = self._pad_temporal_dim(hidden_states)
114
+ hidden_states = self.conv1(hidden_states)
115
+
116
+ if self.up_sample:
117
+ hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3)
118
+
119
+ hidden_states = self._pad_temporal_dim(hidden_states)
120
+ hidden_states = self.conv2(hidden_states)
121
+
122
+ hidden_states = self._pad_temporal_dim(hidden_states)
123
+ hidden_states = self.conv3(hidden_states)
124
+
125
+ hidden_states = self._pad_temporal_dim(hidden_states)
126
+ hidden_states = self.conv4(hidden_states)
127
+
128
+ hidden_states = identity + hidden_states
129
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
130
+
131
+ return hidden_states
132
+
133
+
134
+ class AllegroDownBlock3D(nn.Module):
135
+ def __init__(
136
+ self,
137
+ in_channels: int,
138
+ out_channels: int,
139
+ dropout: float = 0.0,
140
+ num_layers: int = 1,
141
+ resnet_eps: float = 1e-6,
142
+ resnet_time_scale_shift: str = "default",
143
+ resnet_act_fn: str = "swish",
144
+ resnet_groups: int = 32,
145
+ resnet_pre_norm: bool = True,
146
+ output_scale_factor: float = 1.0,
147
+ spatial_downsample: bool = True,
148
+ temporal_downsample: bool = False,
149
+ downsample_padding: int = 1,
150
+ ):
151
+ super().__init__()
152
+
153
+ resnets = []
154
+ temp_convs = []
155
+
156
+ for i in range(num_layers):
157
+ in_channels = in_channels if i == 0 else out_channels
158
+ resnets.append(
159
+ ResnetBlock2D(
160
+ in_channels=in_channels,
161
+ out_channels=out_channels,
162
+ temb_channels=None,
163
+ eps=resnet_eps,
164
+ groups=resnet_groups,
165
+ dropout=dropout,
166
+ time_embedding_norm=resnet_time_scale_shift,
167
+ non_linearity=resnet_act_fn,
168
+ output_scale_factor=output_scale_factor,
169
+ pre_norm=resnet_pre_norm,
170
+ )
171
+ )
172
+ temp_convs.append(
173
+ AllegroTemporalConvLayer(
174
+ out_channels,
175
+ out_channels,
176
+ dropout=0.1,
177
+ norm_num_groups=resnet_groups,
178
+ )
179
+ )
180
+
181
+ self.resnets = nn.ModuleList(resnets)
182
+ self.temp_convs = nn.ModuleList(temp_convs)
183
+
184
+ if temporal_downsample:
185
+ self.temp_convs_down = AllegroTemporalConvLayer(
186
+ out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3
187
+ )
188
+ self.add_temp_downsample = temporal_downsample
189
+
190
+ if spatial_downsample:
191
+ self.downsamplers = nn.ModuleList(
192
+ [
193
+ Downsample2D(
194
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
195
+ )
196
+ ]
197
+ )
198
+ else:
199
+ self.downsamplers = None
200
+
201
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202
+ batch_size = hidden_states.shape[0]
203
+
204
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
205
+
206
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
207
+ hidden_states = resnet(hidden_states, temb=None)
208
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
209
+
210
+ if self.add_temp_downsample:
211
+ hidden_states = self.temp_convs_down(hidden_states, batch_size=batch_size)
212
+
213
+ if self.downsamplers is not None:
214
+ for downsampler in self.downsamplers:
215
+ hidden_states = downsampler(hidden_states)
216
+
217
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
218
+ return hidden_states
219
+
220
+
221
+ class AllegroUpBlock3D(nn.Module):
222
+ def __init__(
223
+ self,
224
+ in_channels: int,
225
+ out_channels: int,
226
+ dropout: float = 0.0,
227
+ num_layers: int = 1,
228
+ resnet_eps: float = 1e-6,
229
+ resnet_time_scale_shift: str = "default", # default, spatial
230
+ resnet_act_fn: str = "swish",
231
+ resnet_groups: int = 32,
232
+ resnet_pre_norm: bool = True,
233
+ output_scale_factor: float = 1.0,
234
+ spatial_upsample: bool = True,
235
+ temporal_upsample: bool = False,
236
+ temb_channels: Optional[int] = None,
237
+ ):
238
+ super().__init__()
239
+
240
+ resnets = []
241
+ temp_convs = []
242
+
243
+ for i in range(num_layers):
244
+ input_channels = in_channels if i == 0 else out_channels
245
+
246
+ resnets.append(
247
+ ResnetBlock2D(
248
+ in_channels=input_channels,
249
+ out_channels=out_channels,
250
+ temb_channels=temb_channels,
251
+ eps=resnet_eps,
252
+ groups=resnet_groups,
253
+ dropout=dropout,
254
+ time_embedding_norm=resnet_time_scale_shift,
255
+ non_linearity=resnet_act_fn,
256
+ output_scale_factor=output_scale_factor,
257
+ pre_norm=resnet_pre_norm,
258
+ )
259
+ )
260
+ temp_convs.append(
261
+ AllegroTemporalConvLayer(
262
+ out_channels,
263
+ out_channels,
264
+ dropout=0.1,
265
+ norm_num_groups=resnet_groups,
266
+ )
267
+ )
268
+
269
+ self.resnets = nn.ModuleList(resnets)
270
+ self.temp_convs = nn.ModuleList(temp_convs)
271
+
272
+ self.add_temp_upsample = temporal_upsample
273
+ if temporal_upsample:
274
+ self.temp_conv_up = AllegroTemporalConvLayer(
275
+ out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3
276
+ )
277
+
278
+ if spatial_upsample:
279
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
280
+ else:
281
+ self.upsamplers = None
282
+
283
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
284
+ batch_size = hidden_states.shape[0]
285
+
286
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
287
+
288
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
289
+ hidden_states = resnet(hidden_states, temb=None)
290
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
291
+
292
+ if self.add_temp_upsample:
293
+ hidden_states = self.temp_conv_up(hidden_states, batch_size=batch_size)
294
+
295
+ if self.upsamplers is not None:
296
+ for upsampler in self.upsamplers:
297
+ hidden_states = upsampler(hidden_states)
298
+
299
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
300
+ return hidden_states
301
+
302
+
303
+ class AllegroMidBlock3DConv(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_channels: int,
307
+ temb_channels: int,
308
+ dropout: float = 0.0,
309
+ num_layers: int = 1,
310
+ resnet_eps: float = 1e-6,
311
+ resnet_time_scale_shift: str = "default", # default, spatial
312
+ resnet_act_fn: str = "swish",
313
+ resnet_groups: int = 32,
314
+ resnet_pre_norm: bool = True,
315
+ add_attention: bool = True,
316
+ attention_head_dim: int = 1,
317
+ output_scale_factor: float = 1.0,
318
+ ):
319
+ super().__init__()
320
+
321
+ # there is always at least one resnet
322
+ resnets = [
323
+ ResnetBlock2D(
324
+ in_channels=in_channels,
325
+ out_channels=in_channels,
326
+ temb_channels=temb_channels,
327
+ eps=resnet_eps,
328
+ groups=resnet_groups,
329
+ dropout=dropout,
330
+ time_embedding_norm=resnet_time_scale_shift,
331
+ non_linearity=resnet_act_fn,
332
+ output_scale_factor=output_scale_factor,
333
+ pre_norm=resnet_pre_norm,
334
+ )
335
+ ]
336
+ temp_convs = [
337
+ AllegroTemporalConvLayer(
338
+ in_channels,
339
+ in_channels,
340
+ dropout=0.1,
341
+ norm_num_groups=resnet_groups,
342
+ )
343
+ ]
344
+ attentions = []
345
+
346
+ if attention_head_dim is None:
347
+ attention_head_dim = in_channels
348
+
349
+ for _ in range(num_layers):
350
+ if add_attention:
351
+ attentions.append(
352
+ Attention(
353
+ in_channels,
354
+ heads=in_channels // attention_head_dim,
355
+ dim_head=attention_head_dim,
356
+ rescale_output_factor=output_scale_factor,
357
+ eps=resnet_eps,
358
+ norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
359
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
360
+ residual_connection=True,
361
+ bias=True,
362
+ upcast_softmax=True,
363
+ _from_deprecated_attn_block=True,
364
+ )
365
+ )
366
+ else:
367
+ attentions.append(None)
368
+
369
+ resnets.append(
370
+ ResnetBlock2D(
371
+ in_channels=in_channels,
372
+ out_channels=in_channels,
373
+ temb_channels=temb_channels,
374
+ eps=resnet_eps,
375
+ groups=resnet_groups,
376
+ dropout=dropout,
377
+ time_embedding_norm=resnet_time_scale_shift,
378
+ non_linearity=resnet_act_fn,
379
+ output_scale_factor=output_scale_factor,
380
+ pre_norm=resnet_pre_norm,
381
+ )
382
+ )
383
+
384
+ temp_convs.append(
385
+ AllegroTemporalConvLayer(
386
+ in_channels,
387
+ in_channels,
388
+ dropout=0.1,
389
+ norm_num_groups=resnet_groups,
390
+ )
391
+ )
392
+
393
+ self.resnets = nn.ModuleList(resnets)
394
+ self.temp_convs = nn.ModuleList(temp_convs)
395
+ self.attentions = nn.ModuleList(attentions)
396
+
397
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
398
+ batch_size = hidden_states.shape[0]
399
+
400
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
401
+ hidden_states = self.resnets[0](hidden_states, temb=None)
402
+
403
+ hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size)
404
+
405
+ for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]):
406
+ hidden_states = attn(hidden_states)
407
+ hidden_states = resnet(hidden_states, temb=None)
408
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
409
+
410
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
411
+ return hidden_states
412
+
413
+
414
+ class AllegroEncoder3D(nn.Module):
415
+ def __init__(
416
+ self,
417
+ in_channels: int = 3,
418
+ out_channels: int = 3,
419
+ down_block_types: Tuple[str, ...] = (
420
+ "AllegroDownBlock3D",
421
+ "AllegroDownBlock3D",
422
+ "AllegroDownBlock3D",
423
+ "AllegroDownBlock3D",
424
+ ),
425
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
426
+ temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False],
427
+ layers_per_block: int = 2,
428
+ norm_num_groups: int = 32,
429
+ act_fn: str = "silu",
430
+ double_z: bool = True,
431
+ ):
432
+ super().__init__()
433
+
434
+ self.conv_in = nn.Conv2d(
435
+ in_channels,
436
+ block_out_channels[0],
437
+ kernel_size=3,
438
+ stride=1,
439
+ padding=1,
440
+ )
441
+
442
+ self.temp_conv_in = nn.Conv3d(
443
+ in_channels=block_out_channels[0],
444
+ out_channels=block_out_channels[0],
445
+ kernel_size=(3, 1, 1),
446
+ padding=(1, 0, 0),
447
+ )
448
+
449
+ self.down_blocks = nn.ModuleList([])
450
+
451
+ # down
452
+ output_channel = block_out_channels[0]
453
+ for i, down_block_type in enumerate(down_block_types):
454
+ input_channel = output_channel
455
+ output_channel = block_out_channels[i]
456
+ is_final_block = i == len(block_out_channels) - 1
457
+
458
+ if down_block_type == "AllegroDownBlock3D":
459
+ down_block = AllegroDownBlock3D(
460
+ num_layers=layers_per_block,
461
+ in_channels=input_channel,
462
+ out_channels=output_channel,
463
+ spatial_downsample=not is_final_block,
464
+ temporal_downsample=temporal_downsample_blocks[i],
465
+ resnet_eps=1e-6,
466
+ downsample_padding=0,
467
+ resnet_act_fn=act_fn,
468
+ resnet_groups=norm_num_groups,
469
+ )
470
+ else:
471
+ raise ValueError("Invalid `down_block_type` encountered. Must be `AllegroDownBlock3D`")
472
+
473
+ self.down_blocks.append(down_block)
474
+
475
+ # mid
476
+ self.mid_block = AllegroMidBlock3DConv(
477
+ in_channels=block_out_channels[-1],
478
+ resnet_eps=1e-6,
479
+ resnet_act_fn=act_fn,
480
+ output_scale_factor=1,
481
+ resnet_time_scale_shift="default",
482
+ attention_head_dim=block_out_channels[-1],
483
+ resnet_groups=norm_num_groups,
484
+ temb_channels=None,
485
+ )
486
+
487
+ # out
488
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
489
+ self.conv_act = nn.SiLU()
490
+
491
+ conv_out_channels = 2 * out_channels if double_z else out_channels
492
+
493
+ self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0))
494
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
495
+
496
+ self.gradient_checkpointing = False
497
+
498
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
499
+ batch_size = sample.shape[0]
500
+
501
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
502
+ sample = self.conv_in(sample)
503
+
504
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
505
+ residual = sample
506
+ sample = self.temp_conv_in(sample)
507
+ sample = sample + residual
508
+
509
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
510
+
511
+ def create_custom_forward(module):
512
+ def custom_forward(*inputs):
513
+ return module(*inputs)
514
+
515
+ return custom_forward
516
+
517
+ # Down blocks
518
+ for down_block in self.down_blocks:
519
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
520
+
521
+ # Mid block
522
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
523
+ else:
524
+ # Down blocks
525
+ for down_block in self.down_blocks:
526
+ sample = down_block(sample)
527
+
528
+ # Mid block
529
+ sample = self.mid_block(sample)
530
+
531
+ # Post process
532
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
533
+ sample = self.conv_norm_out(sample)
534
+ sample = self.conv_act(sample)
535
+
536
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
537
+ residual = sample
538
+ sample = self.temp_conv_out(sample)
539
+ sample = sample + residual
540
+
541
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
542
+ sample = self.conv_out(sample)
543
+
544
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
545
+ return sample
546
+
547
+
548
+ class AllegroDecoder3D(nn.Module):
549
+ def __init__(
550
+ self,
551
+ in_channels: int = 4,
552
+ out_channels: int = 3,
553
+ up_block_types: Tuple[str, ...] = (
554
+ "AllegroUpBlock3D",
555
+ "AllegroUpBlock3D",
556
+ "AllegroUpBlock3D",
557
+ "AllegroUpBlock3D",
558
+ ),
559
+ temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False],
560
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
561
+ layers_per_block: int = 2,
562
+ norm_num_groups: int = 32,
563
+ act_fn: str = "silu",
564
+ norm_type: str = "group", # group, spatial
565
+ ):
566
+ super().__init__()
567
+
568
+ self.conv_in = nn.Conv2d(
569
+ in_channels,
570
+ block_out_channels[-1],
571
+ kernel_size=3,
572
+ stride=1,
573
+ padding=1,
574
+ )
575
+
576
+ self.temp_conv_in = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0))
577
+
578
+ self.mid_block = None
579
+ self.up_blocks = nn.ModuleList([])
580
+
581
+ temb_channels = in_channels if norm_type == "spatial" else None
582
+
583
+ # mid
584
+ self.mid_block = AllegroMidBlock3DConv(
585
+ in_channels=block_out_channels[-1],
586
+ resnet_eps=1e-6,
587
+ resnet_act_fn=act_fn,
588
+ output_scale_factor=1,
589
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
590
+ attention_head_dim=block_out_channels[-1],
591
+ resnet_groups=norm_num_groups,
592
+ temb_channels=temb_channels,
593
+ )
594
+
595
+ # up
596
+ reversed_block_out_channels = list(reversed(block_out_channels))
597
+ output_channel = reversed_block_out_channels[0]
598
+ for i, up_block_type in enumerate(up_block_types):
599
+ prev_output_channel = output_channel
600
+ output_channel = reversed_block_out_channels[i]
601
+
602
+ is_final_block = i == len(block_out_channels) - 1
603
+
604
+ if up_block_type == "AllegroUpBlock3D":
605
+ up_block = AllegroUpBlock3D(
606
+ num_layers=layers_per_block + 1,
607
+ in_channels=prev_output_channel,
608
+ out_channels=output_channel,
609
+ spatial_upsample=not is_final_block,
610
+ temporal_upsample=temporal_upsample_blocks[i],
611
+ resnet_eps=1e-6,
612
+ resnet_act_fn=act_fn,
613
+ resnet_groups=norm_num_groups,
614
+ temb_channels=temb_channels,
615
+ resnet_time_scale_shift=norm_type,
616
+ )
617
+ else:
618
+ raise ValueError("Invalid `UP_block_type` encountered. Must be `AllegroUpBlock3D`")
619
+
620
+ self.up_blocks.append(up_block)
621
+ prev_output_channel = output_channel
622
+
623
+ # out
624
+ if norm_type == "spatial":
625
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
626
+ else:
627
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
628
+
629
+ self.conv_act = nn.SiLU()
630
+
631
+ self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0))
632
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
633
+
634
+ self.gradient_checkpointing = False
635
+
636
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
637
+ batch_size = sample.shape[0]
638
+
639
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
640
+ sample = self.conv_in(sample)
641
+
642
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
643
+ residual = sample
644
+ sample = self.temp_conv_in(sample)
645
+ sample = sample + residual
646
+
647
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
648
+
649
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
650
+
651
+ def create_custom_forward(module):
652
+ def custom_forward(*inputs):
653
+ return module(*inputs)
654
+
655
+ return custom_forward
656
+
657
+ # Mid block
658
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
659
+
660
+ # Up blocks
661
+ for up_block in self.up_blocks:
662
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
663
+
664
+ else:
665
+ # Mid block
666
+ sample = self.mid_block(sample)
667
+ sample = sample.to(upscale_dtype)
668
+
669
+ # Up blocks
670
+ for up_block in self.up_blocks:
671
+ sample = up_block(sample)
672
+
673
+ # Post process
674
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
675
+ sample = self.conv_norm_out(sample)
676
+ sample = self.conv_act(sample)
677
+
678
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
679
+ residual = sample
680
+ sample = self.temp_conv_out(sample)
681
+ sample = sample + residual
682
+
683
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
684
+ sample = self.conv_out(sample)
685
+
686
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
687
+ return sample
688
+
689
+
690
+ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
691
+ r"""
692
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
693
+ [Allegro](https://github.com/rhymes-ai/Allegro).
694
+
695
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
696
+ for all models (such as downloading or saving).
697
+
698
+ Parameters:
699
+ in_channels (int, defaults to `3`):
700
+ Number of channels in the input image.
701
+ out_channels (int, defaults to `3`):
702
+ Number of channels in the output.
703
+ down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
704
+ Tuple of strings denoting which types of down blocks to use.
705
+ up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
706
+ Tuple of strings denoting which types of up blocks to use.
707
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
708
+ Tuple of integers denoting number of output channels in each block.
709
+ temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`):
710
+ Tuple of booleans denoting which blocks to enable temporal downsampling in.
711
+ latent_channels (`int`, defaults to `4`):
712
+ Number of channels in latents.
713
+ layers_per_block (`int`, defaults to `2`):
714
+ Number of resnet or attention or temporal convolution layers per down/up block.
715
+ act_fn (`str`, defaults to `"silu"`):
716
+ The activation function to use.
717
+ norm_num_groups (`int`, defaults to `32`):
718
+ Number of groups to use in normalization layers.
719
+ temporal_compression_ratio (`int`, defaults to `4`):
720
+ Ratio by which temporal dimension of samples are compressed.
721
+ sample_size (`int`, defaults to `320`):
722
+ Default latent size.
723
+ scaling_factor (`float`, defaults to `0.13235`):
724
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
725
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
726
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
727
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
728
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
729
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
730
+ force_upcast (`bool`, default to `True`):
731
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
732
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
733
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
734
+ """
735
+
736
+ _supports_gradient_checkpointing = True
737
+
738
+ @register_to_config
739
+ def __init__(
740
+ self,
741
+ in_channels: int = 3,
742
+ out_channels: int = 3,
743
+ down_block_types: Tuple[str, ...] = (
744
+ "AllegroDownBlock3D",
745
+ "AllegroDownBlock3D",
746
+ "AllegroDownBlock3D",
747
+ "AllegroDownBlock3D",
748
+ ),
749
+ up_block_types: Tuple[str, ...] = (
750
+ "AllegroUpBlock3D",
751
+ "AllegroUpBlock3D",
752
+ "AllegroUpBlock3D",
753
+ "AllegroUpBlock3D",
754
+ ),
755
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
756
+ temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False),
757
+ temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False),
758
+ latent_channels: int = 4,
759
+ layers_per_block: int = 2,
760
+ act_fn: str = "silu",
761
+ norm_num_groups: int = 32,
762
+ temporal_compression_ratio: float = 4,
763
+ sample_size: int = 320,
764
+ scaling_factor: float = 0.13,
765
+ force_upcast: bool = True,
766
+ ) -> None:
767
+ super().__init__()
768
+
769
+ self.encoder = AllegroEncoder3D(
770
+ in_channels=in_channels,
771
+ out_channels=latent_channels,
772
+ down_block_types=down_block_types,
773
+ temporal_downsample_blocks=temporal_downsample_blocks,
774
+ block_out_channels=block_out_channels,
775
+ layers_per_block=layers_per_block,
776
+ act_fn=act_fn,
777
+ norm_num_groups=norm_num_groups,
778
+ double_z=True,
779
+ )
780
+ self.decoder = AllegroDecoder3D(
781
+ in_channels=latent_channels,
782
+ out_channels=out_channels,
783
+ up_block_types=up_block_types,
784
+ temporal_upsample_blocks=temporal_upsample_blocks,
785
+ block_out_channels=block_out_channels,
786
+ layers_per_block=layers_per_block,
787
+ norm_num_groups=norm_num_groups,
788
+ act_fn=act_fn,
789
+ )
790
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
791
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
792
+
793
+ # TODO(aryan): For the 1.0.0 refactor, `temporal_compression_ratio` can be inferred directly and we don't need
794
+ # to use a specific parameter here or in other VAEs.
795
+
796
+ self.use_slicing = False
797
+ self.use_tiling = False
798
+
799
+ self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
800
+ self.tile_overlap_t = 8
801
+ self.tile_overlap_h = 120
802
+ self.tile_overlap_w = 80
803
+ sample_frames = 24
804
+
805
+ self.kernel = (sample_frames, sample_size, sample_size)
806
+ self.stride = (
807
+ sample_frames - self.tile_overlap_t,
808
+ sample_size - self.tile_overlap_h,
809
+ sample_size - self.tile_overlap_w,
810
+ )
811
+
812
+ def _set_gradient_checkpointing(self, module, value=False):
813
+ if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
814
+ module.gradient_checkpointing = value
815
+
816
+ def enable_tiling(self) -> None:
817
+ r"""
818
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
819
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
820
+ processing larger images.
821
+ """
822
+ self.use_tiling = True
823
+
824
+ def disable_tiling(self) -> None:
825
+ r"""
826
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
827
+ decoding in one step.
828
+ """
829
+ self.use_tiling = False
830
+
831
+ def enable_slicing(self) -> None:
832
+ r"""
833
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
834
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
835
+ """
836
+ self.use_slicing = True
837
+
838
+ def disable_slicing(self) -> None:
839
+ r"""
840
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
841
+ decoding in one step.
842
+ """
843
+ self.use_slicing = False
844
+
845
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
846
+ # TODO(aryan)
847
+ # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
848
+ if self.use_tiling:
849
+ return self.tiled_encode(x)
850
+
851
+ raise NotImplementedError("Encoding without tiling has not been implemented yet.")
852
+
853
+ @apply_forward_hook
854
+ def encode(
855
+ self, x: torch.Tensor, return_dict: bool = True
856
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
857
+ r"""
858
+ Encode a batch of videos into latents.
859
+
860
+ Args:
861
+ x (`torch.Tensor`):
862
+ Input batch of videos.
863
+ return_dict (`bool`, defaults to `True`):
864
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
865
+
866
+ Returns:
867
+ The latent representations of the encoded videos. If `return_dict` is True, a
868
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
869
+ """
870
+ if self.use_slicing and x.shape[0] > 1:
871
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
872
+ h = torch.cat(encoded_slices)
873
+ else:
874
+ h = self._encode(x)
875
+
876
+ posterior = DiagonalGaussianDistribution(h)
877
+
878
+ if not return_dict:
879
+ return (posterior,)
880
+ return AutoencoderKLOutput(latent_dist=posterior)
881
+
882
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
883
+ # TODO(aryan): refactor tiling implementation
884
+ # if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
885
+ if self.use_tiling:
886
+ return self.tiled_decode(z)
887
+
888
+ raise NotImplementedError("Decoding without tiling has not been implemented yet.")
889
+
890
+ @apply_forward_hook
891
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
892
+ """
893
+ Decode a batch of videos.
894
+
895
+ Args:
896
+ z (`torch.Tensor`):
897
+ Input batch of latent vectors.
898
+ return_dict (`bool`, defaults to `True`):
899
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
900
+
901
+ Returns:
902
+ [`~models.vae.DecoderOutput`] or `tuple`:
903
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
904
+ returned.
905
+ """
906
+ if self.use_slicing and z.shape[0] > 1:
907
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
908
+ decoded = torch.cat(decoded_slices)
909
+ else:
910
+ decoded = self._decode(z)
911
+
912
+ if not return_dict:
913
+ return (decoded,)
914
+ return DecoderOutput(sample=decoded)
915
+
916
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
917
+ local_batch_size = 1
918
+ rs = self.spatial_compression_ratio
919
+ rt = self.config.temporal_compression_ratio
920
+
921
+ batch_size, num_channels, num_frames, height, width = x.shape
922
+
923
+ output_num_frames = math.floor((num_frames - self.kernel[0]) / self.stride[0]) + 1
924
+ output_height = math.floor((height - self.kernel[1]) / self.stride[1]) + 1
925
+ output_width = math.floor((width - self.kernel[2]) / self.stride[2]) + 1
926
+
927
+ count = 0
928
+ output_latent = x.new_zeros(
929
+ (
930
+ output_num_frames * output_height * output_width,
931
+ 2 * self.config.latent_channels,
932
+ self.kernel[0] // rt,
933
+ self.kernel[1] // rs,
934
+ self.kernel[2] // rs,
935
+ )
936
+ )
937
+ vae_batch_input = x.new_zeros((local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]))
938
+
939
+ for i in range(output_num_frames):
940
+ for j in range(output_height):
941
+ for k in range(output_width):
942
+ n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
943
+ h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
944
+ w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
945
+
946
+ video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
947
+ vae_batch_input[count % local_batch_size] = video_cube
948
+
949
+ if (
950
+ count % local_batch_size == local_batch_size - 1
951
+ or count == output_num_frames * output_height * output_width - 1
952
+ ):
953
+ latent = self.encoder(vae_batch_input)
954
+
955
+ if (
956
+ count == output_num_frames * output_height * output_width - 1
957
+ and count % local_batch_size != local_batch_size - 1
958
+ ):
959
+ output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1]
960
+ else:
961
+ output_latent[count - local_batch_size + 1 : count + 1] = latent
962
+
963
+ vae_batch_input = x.new_zeros(
964
+ (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2])
965
+ )
966
+
967
+ count += 1
968
+
969
+ latent = x.new_zeros(
970
+ (batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs)
971
+ )
972
+ output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
973
+ output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
974
+ output_overlap = (
975
+ output_kernel[0] - output_stride[0],
976
+ output_kernel[1] - output_stride[1],
977
+ output_kernel[2] - output_stride[2],
978
+ )
979
+
980
+ for i in range(output_num_frames):
981
+ n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0]
982
+ for j in range(output_height):
983
+ h_start, h_end = j * output_stride[1], j * output_stride[1] + output_kernel[1]
984
+ for k in range(output_width):
985
+ w_start, w_end = k * output_stride[2], k * output_stride[2] + output_kernel[2]
986
+ latent_mean = _prepare_for_blend(
987
+ (i, output_num_frames, output_overlap[0]),
988
+ (j, output_height, output_overlap[1]),
989
+ (k, output_width, output_overlap[2]),
990
+ output_latent[i * output_height * output_width + j * output_width + k].unsqueeze(0),
991
+ )
992
+ latent[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean
993
+
994
+ latent = latent.permute(0, 2, 1, 3, 4).flatten(0, 1)
995
+ latent = self.quant_conv(latent)
996
+ latent = latent.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
997
+ return latent
998
+
999
+ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
1000
+ local_batch_size = 1
1001
+ rs = self.spatial_compression_ratio
1002
+ rt = self.config.temporal_compression_ratio
1003
+
1004
+ latent_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
1005
+ latent_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
1006
+
1007
+ batch_size, num_channels, num_frames, height, width = z.shape
1008
+
1009
+ ## post quant conv (a mapping)
1010
+ z = z.permute(0, 2, 1, 3, 4).flatten(0, 1)
1011
+ z = self.post_quant_conv(z)
1012
+ z = z.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
1013
+
1014
+ output_num_frames = math.floor((num_frames - latent_kernel[0]) / latent_stride[0]) + 1
1015
+ output_height = math.floor((height - latent_kernel[1]) / latent_stride[1]) + 1
1016
+ output_width = math.floor((width - latent_kernel[2]) / latent_stride[2]) + 1
1017
+
1018
+ count = 0
1019
+ decoded_videos = z.new_zeros(
1020
+ (
1021
+ output_num_frames * output_height * output_width,
1022
+ self.config.out_channels,
1023
+ self.kernel[0],
1024
+ self.kernel[1],
1025
+ self.kernel[2],
1026
+ )
1027
+ )
1028
+ vae_batch_input = z.new_zeros(
1029
+ (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2])
1030
+ )
1031
+
1032
+ for i in range(output_num_frames):
1033
+ for j in range(output_height):
1034
+ for k in range(output_width):
1035
+ n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0]
1036
+ h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1]
1037
+ w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2]
1038
+
1039
+ current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
1040
+ vae_batch_input[count % local_batch_size] = current_latent
1041
+
1042
+ if (
1043
+ count % local_batch_size == local_batch_size - 1
1044
+ or count == output_num_frames * output_height * output_width - 1
1045
+ ):
1046
+ current_video = self.decoder(vae_batch_input)
1047
+
1048
+ if (
1049
+ count == output_num_frames * output_height * output_width - 1
1050
+ and count % local_batch_size != local_batch_size - 1
1051
+ ):
1052
+ decoded_videos[count - count % local_batch_size :] = current_video[
1053
+ : count % local_batch_size + 1
1054
+ ]
1055
+ else:
1056
+ decoded_videos[count - local_batch_size + 1 : count + 1] = current_video
1057
+
1058
+ vae_batch_input = z.new_zeros(
1059
+ (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2])
1060
+ )
1061
+
1062
+ count += 1
1063
+
1064
+ video = z.new_zeros((batch_size, self.config.out_channels, num_frames * rt, height * rs, width * rs))
1065
+ video_overlap = (
1066
+ self.kernel[0] - self.stride[0],
1067
+ self.kernel[1] - self.stride[1],
1068
+ self.kernel[2] - self.stride[2],
1069
+ )
1070
+
1071
+ for i in range(output_num_frames):
1072
+ n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
1073
+ for j in range(output_height):
1074
+ h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
1075
+ for k in range(output_width):
1076
+ w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
1077
+ out_video_blend = _prepare_for_blend(
1078
+ (i, output_num_frames, video_overlap[0]),
1079
+ (j, output_height, video_overlap[1]),
1080
+ (k, output_width, video_overlap[2]),
1081
+ decoded_videos[i * output_height * output_width + j * output_width + k].unsqueeze(0),
1082
+ )
1083
+ video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend
1084
+
1085
+ video = video.permute(0, 2, 1, 3, 4).contiguous()
1086
+ return video
1087
+
1088
+ def forward(
1089
+ self,
1090
+ sample: torch.Tensor,
1091
+ sample_posterior: bool = False,
1092
+ return_dict: bool = True,
1093
+ generator: Optional[torch.Generator] = None,
1094
+ ) -> Union[DecoderOutput, torch.Tensor]:
1095
+ r"""
1096
+ Args:
1097
+ sample (`torch.Tensor`): Input sample.
1098
+ sample_posterior (`bool`, *optional*, defaults to `False`):
1099
+ Whether to sample from the posterior.
1100
+ return_dict (`bool`, *optional*, defaults to `True`):
1101
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1102
+ generator (`torch.Generator`, *optional*):
1103
+ PyTorch random number generator.
1104
+ """
1105
+ x = sample
1106
+ posterior = self.encode(x).latent_dist
1107
+ if sample_posterior:
1108
+ z = posterior.sample(generator=generator)
1109
+ else:
1110
+ z = posterior.mode()
1111
+ dec = self.decode(z).sample
1112
+
1113
+ if not return_dict:
1114
+ return (dec,)
1115
+
1116
+ return DecoderOutput(sample=dec)
1117
+
1118
+
1119
+ def _prepare_for_blend(n_param, h_param, w_param, x):
1120
+ # TODO(aryan): refactor
1121
+ n, n_max, overlap_n = n_param
1122
+ h, h_max, overlap_h = h_param
1123
+ w, w_max, overlap_w = w_param
1124
+ if overlap_n > 0:
1125
+ if n > 0: # the head overlap part decays from 0 to 1
1126
+ x[:, :, 0:overlap_n, :, :] = x[:, :, 0:overlap_n, :, :] * (
1127
+ torch.arange(0, overlap_n).float().to(x.device) / overlap_n
1128
+ ).reshape(overlap_n, 1, 1)
1129
+ if n < n_max - 1: # the tail overlap part decays from 1 to 0
1130
+ x[:, :, -overlap_n:, :, :] = x[:, :, -overlap_n:, :, :] * (
1131
+ 1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n
1132
+ ).reshape(overlap_n, 1, 1)
1133
+ if h > 0:
1134
+ x[:, :, :, 0:overlap_h, :] = x[:, :, :, 0:overlap_h, :] * (
1135
+ torch.arange(0, overlap_h).float().to(x.device) / overlap_h
1136
+ ).reshape(overlap_h, 1)
1137
+ if h < h_max - 1:
1138
+ x[:, :, :, -overlap_h:, :] = x[:, :, :, -overlap_h:, :] * (
1139
+ 1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h
1140
+ ).reshape(overlap_h, 1)
1141
+ if w > 0:
1142
+ x[:, :, :, :, 0:overlap_w] = x[:, :, :, :, 0:overlap_w] * (
1143
+ torch.arange(0, overlap_w).float().to(x.device) / overlap_w
1144
+ )
1145
+ if w < w_max - 1:
1146
+ x[:, :, :, :, -overlap_w:] = x[:, :, :, :, -overlap_w:] * (
1147
+ 1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w
1148
+ )
1149
+ return x