diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1070 @@
1
+ # Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # We gratefully acknowledge the Wan Team for their outstanding contributions.
16
+ # QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
17
+ # For more information about the Wan VAE, please refer to:
18
+ # - GitHub: https://github.com/Wan-Video/Wan2.1
19
+ # - arXiv: https://arxiv.org/abs/2503.20314
20
+
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+
28
+ from ...configuration_utils import ConfigMixin, register_to_config
29
+ from ...loaders import FromOriginalModelMixin
30
+ from ...utils import logging
31
+ from ...utils.accelerate_utils import apply_forward_hook
32
+ from ..activations import get_activation
33
+ from ..modeling_outputs import AutoencoderKLOutput
34
+ from ..modeling_utils import ModelMixin
35
+ from .vae import DecoderOutput, DiagonalGaussianDistribution
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ CACHE_T = 2
41
+
42
+
43
+ class QwenImageCausalConv3d(nn.Conv3d):
44
+ r"""
45
+ A custom 3D causal convolution layer with feature caching support.
46
+
47
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
48
+ caching for efficient inference.
49
+
50
+ Args:
51
+ in_channels (int): Number of channels in the input image
52
+ out_channels (int): Number of channels produced by the convolution
53
+ kernel_size (int or tuple): Size of the convolving kernel
54
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
55
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ in_channels: int,
61
+ out_channels: int,
62
+ kernel_size: Union[int, Tuple[int, int, int]],
63
+ stride: Union[int, Tuple[int, int, int]] = 1,
64
+ padding: Union[int, Tuple[int, int, int]] = 0,
65
+ ) -> None:
66
+ super().__init__(
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ kernel_size=kernel_size,
70
+ stride=stride,
71
+ padding=padding,
72
+ )
73
+
74
+ # Set up causal padding
75
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
76
+ self.padding = (0, 0, 0)
77
+
78
+ def forward(self, x, cache_x=None):
79
+ padding = list(self._padding)
80
+ if cache_x is not None and self._padding[4] > 0:
81
+ cache_x = cache_x.to(x.device)
82
+ x = torch.cat([cache_x, x], dim=2)
83
+ padding[4] -= cache_x.shape[2]
84
+ x = F.pad(x, padding)
85
+ return super().forward(x)
86
+
87
+
88
+ class QwenImageRMS_norm(nn.Module):
89
+ r"""
90
+ A custom RMS normalization layer.
91
+
92
+ Args:
93
+ dim (int): The number of dimensions to normalize over.
94
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
95
+ Default is True.
96
+ images (bool, optional): Whether the input represents image data. Default is True.
97
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
98
+ """
99
+
100
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
101
+ super().__init__()
102
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
103
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
104
+
105
+ self.channel_first = channel_first
106
+ self.scale = dim**0.5
107
+ self.gamma = nn.Parameter(torch.ones(shape))
108
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
109
+
110
+ def forward(self, x):
111
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
112
+
113
+
114
+ class QwenImageUpsample(nn.Upsample):
115
+ r"""
116
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
117
+
118
+ Args:
119
+ x (torch.Tensor): Input tensor to be upsampled.
120
+
121
+ Returns:
122
+ torch.Tensor: Upsampled tensor with the same data type as the input.
123
+ """
124
+
125
+ def forward(self, x):
126
+ return super().forward(x.float()).type_as(x)
127
+
128
+
129
+ class QwenImageResample(nn.Module):
130
+ r"""
131
+ A custom resampling module for 2D and 3D data.
132
+
133
+ Args:
134
+ dim (int): The number of input/output channels.
135
+ mode (str): The resampling mode. Must be one of:
136
+ - 'none': No resampling (identity operation).
137
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
138
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
139
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
140
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
141
+ """
142
+
143
+ def __init__(self, dim: int, mode: str) -> None:
144
+ super().__init__()
145
+ self.dim = dim
146
+ self.mode = mode
147
+
148
+ # layers
149
+ if mode == "upsample2d":
150
+ self.resample = nn.Sequential(
151
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
152
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
153
+ )
154
+ elif mode == "upsample3d":
155
+ self.resample = nn.Sequential(
156
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
157
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
158
+ )
159
+ self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
160
+
161
+ elif mode == "downsample2d":
162
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
163
+ elif mode == "downsample3d":
164
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
165
+ self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
166
+
167
+ else:
168
+ self.resample = nn.Identity()
169
+
170
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
171
+ b, c, t, h, w = x.size()
172
+ if self.mode == "upsample3d":
173
+ if feat_cache is not None:
174
+ idx = feat_idx[0]
175
+ if feat_cache[idx] is None:
176
+ feat_cache[idx] = "Rep"
177
+ feat_idx[0] += 1
178
+ else:
179
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
180
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
181
+ # cache last frame of last two chunk
182
+ cache_x = torch.cat(
183
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
184
+ )
185
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
186
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
187
+ if feat_cache[idx] == "Rep":
188
+ x = self.time_conv(x)
189
+ else:
190
+ x = self.time_conv(x, feat_cache[idx])
191
+ feat_cache[idx] = cache_x
192
+ feat_idx[0] += 1
193
+
194
+ x = x.reshape(b, 2, c, t, h, w)
195
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
196
+ x = x.reshape(b, c, t * 2, h, w)
197
+ t = x.shape[2]
198
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
199
+ x = self.resample(x)
200
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
201
+
202
+ if self.mode == "downsample3d":
203
+ if feat_cache is not None:
204
+ idx = feat_idx[0]
205
+ if feat_cache[idx] is None:
206
+ feat_cache[idx] = x.clone()
207
+ feat_idx[0] += 1
208
+ else:
209
+ cache_x = x[:, :, -1:, :, :].clone()
210
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
211
+ feat_cache[idx] = cache_x
212
+ feat_idx[0] += 1
213
+ return x
214
+
215
+
216
+ class QwenImageResidualBlock(nn.Module):
217
+ r"""
218
+ A custom residual block module.
219
+
220
+ Args:
221
+ in_dim (int): Number of input channels.
222
+ out_dim (int): Number of output channels.
223
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
224
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ in_dim: int,
230
+ out_dim: int,
231
+ dropout: float = 0.0,
232
+ non_linearity: str = "silu",
233
+ ) -> None:
234
+ super().__init__()
235
+ self.in_dim = in_dim
236
+ self.out_dim = out_dim
237
+ self.nonlinearity = get_activation(non_linearity)
238
+
239
+ # layers
240
+ self.norm1 = QwenImageRMS_norm(in_dim, images=False)
241
+ self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
242
+ self.norm2 = QwenImageRMS_norm(out_dim, images=False)
243
+ self.dropout = nn.Dropout(dropout)
244
+ self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
245
+ self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
246
+
247
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
248
+ # Apply shortcut connection
249
+ h = self.conv_shortcut(x)
250
+
251
+ # First normalization and activation
252
+ x = self.norm1(x)
253
+ x = self.nonlinearity(x)
254
+
255
+ if feat_cache is not None:
256
+ idx = feat_idx[0]
257
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
258
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
259
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
260
+
261
+ x = self.conv1(x, feat_cache[idx])
262
+ feat_cache[idx] = cache_x
263
+ feat_idx[0] += 1
264
+ else:
265
+ x = self.conv1(x)
266
+
267
+ # Second normalization and activation
268
+ x = self.norm2(x)
269
+ x = self.nonlinearity(x)
270
+
271
+ # Dropout
272
+ x = self.dropout(x)
273
+
274
+ if feat_cache is not None:
275
+ idx = feat_idx[0]
276
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
277
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
278
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
279
+
280
+ x = self.conv2(x, feat_cache[idx])
281
+ feat_cache[idx] = cache_x
282
+ feat_idx[0] += 1
283
+ else:
284
+ x = self.conv2(x)
285
+
286
+ # Add residual connection
287
+ return x + h
288
+
289
+
290
+ class QwenImageAttentionBlock(nn.Module):
291
+ r"""
292
+ Causal self-attention with a single head.
293
+
294
+ Args:
295
+ dim (int): The number of channels in the input tensor.
296
+ """
297
+
298
+ def __init__(self, dim):
299
+ super().__init__()
300
+ self.dim = dim
301
+
302
+ # layers
303
+ self.norm = QwenImageRMS_norm(dim)
304
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
305
+ self.proj = nn.Conv2d(dim, dim, 1)
306
+
307
+ def forward(self, x):
308
+ identity = x
309
+ batch_size, channels, time, height, width = x.size()
310
+
311
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
312
+ x = self.norm(x)
313
+
314
+ # compute query, key, value
315
+ qkv = self.to_qkv(x)
316
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
317
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
318
+ q, k, v = qkv.chunk(3, dim=-1)
319
+
320
+ # apply attention
321
+ x = F.scaled_dot_product_attention(q, k, v)
322
+
323
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
324
+
325
+ # output projection
326
+ x = self.proj(x)
327
+
328
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
329
+ x = x.view(batch_size, time, channels, height, width)
330
+ x = x.permute(0, 2, 1, 3, 4)
331
+
332
+ return x + identity
333
+
334
+
335
+ class QwenImageMidBlock(nn.Module):
336
+ """
337
+ Middle block for QwenImageVAE encoder and decoder.
338
+
339
+ Args:
340
+ dim (int): Number of input/output channels.
341
+ dropout (float): Dropout rate.
342
+ non_linearity (str): Type of non-linearity to use.
343
+ """
344
+
345
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
346
+ super().__init__()
347
+ self.dim = dim
348
+
349
+ # Create the components
350
+ resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
351
+ attentions = []
352
+ for _ in range(num_layers):
353
+ attentions.append(QwenImageAttentionBlock(dim))
354
+ resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
355
+ self.attentions = nn.ModuleList(attentions)
356
+ self.resnets = nn.ModuleList(resnets)
357
+
358
+ self.gradient_checkpointing = False
359
+
360
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
361
+ # First residual block
362
+ x = self.resnets[0](x, feat_cache, feat_idx)
363
+
364
+ # Process through attention and residual blocks
365
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
366
+ if attn is not None:
367
+ x = attn(x)
368
+
369
+ x = resnet(x, feat_cache, feat_idx)
370
+
371
+ return x
372
+
373
+
374
+ class QwenImageEncoder3d(nn.Module):
375
+ r"""
376
+ A 3D encoder module.
377
+
378
+ Args:
379
+ dim (int): The base number of channels in the first layer.
380
+ z_dim (int): The dimensionality of the latent space.
381
+ dim_mult (list of int): Multipliers for the number of channels in each block.
382
+ num_res_blocks (int): Number of residual blocks in each block.
383
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
384
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
385
+ dropout (float): Dropout rate for the dropout layers.
386
+ non_linearity (str): Type of non-linearity to use.
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ dim=128,
392
+ z_dim=4,
393
+ dim_mult=[1, 2, 4, 4],
394
+ num_res_blocks=2,
395
+ attn_scales=[],
396
+ temperal_downsample=[True, True, False],
397
+ dropout=0.0,
398
+ non_linearity: str = "silu",
399
+ ):
400
+ super().__init__()
401
+ self.dim = dim
402
+ self.z_dim = z_dim
403
+ self.dim_mult = dim_mult
404
+ self.num_res_blocks = num_res_blocks
405
+ self.attn_scales = attn_scales
406
+ self.temperal_downsample = temperal_downsample
407
+ self.nonlinearity = get_activation(non_linearity)
408
+
409
+ # dimensions
410
+ dims = [dim * u for u in [1] + dim_mult]
411
+ scale = 1.0
412
+
413
+ # init block
414
+ self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
415
+
416
+ # downsample blocks
417
+ self.down_blocks = nn.ModuleList([])
418
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
419
+ # residual (+attention) blocks
420
+ for _ in range(num_res_blocks):
421
+ self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
422
+ if scale in attn_scales:
423
+ self.down_blocks.append(QwenImageAttentionBlock(out_dim))
424
+ in_dim = out_dim
425
+
426
+ # downsample block
427
+ if i != len(dim_mult) - 1:
428
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
429
+ self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
430
+ scale /= 2.0
431
+
432
+ # middle blocks
433
+ self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
434
+
435
+ # output blocks
436
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
437
+ self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
438
+
439
+ self.gradient_checkpointing = False
440
+
441
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
442
+ if feat_cache is not None:
443
+ idx = feat_idx[0]
444
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
445
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
446
+ # cache last frame of last two chunk
447
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
448
+ x = self.conv_in(x, feat_cache[idx])
449
+ feat_cache[idx] = cache_x
450
+ feat_idx[0] += 1
451
+ else:
452
+ x = self.conv_in(x)
453
+
454
+ ## downsamples
455
+ for layer in self.down_blocks:
456
+ if feat_cache is not None:
457
+ x = layer(x, feat_cache, feat_idx)
458
+ else:
459
+ x = layer(x)
460
+
461
+ ## middle
462
+ x = self.mid_block(x, feat_cache, feat_idx)
463
+
464
+ ## head
465
+ x = self.norm_out(x)
466
+ x = self.nonlinearity(x)
467
+ if feat_cache is not None:
468
+ idx = feat_idx[0]
469
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
470
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
471
+ # cache last frame of last two chunk
472
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
473
+ x = self.conv_out(x, feat_cache[idx])
474
+ feat_cache[idx] = cache_x
475
+ feat_idx[0] += 1
476
+ else:
477
+ x = self.conv_out(x)
478
+ return x
479
+
480
+
481
+ class QwenImageUpBlock(nn.Module):
482
+ """
483
+ A block that handles upsampling for the QwenImageVAE decoder.
484
+
485
+ Args:
486
+ in_dim (int): Input dimension
487
+ out_dim (int): Output dimension
488
+ num_res_blocks (int): Number of residual blocks
489
+ dropout (float): Dropout rate
490
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
491
+ non_linearity (str): Type of non-linearity to use
492
+ """
493
+
494
+ def __init__(
495
+ self,
496
+ in_dim: int,
497
+ out_dim: int,
498
+ num_res_blocks: int,
499
+ dropout: float = 0.0,
500
+ upsample_mode: Optional[str] = None,
501
+ non_linearity: str = "silu",
502
+ ):
503
+ super().__init__()
504
+ self.in_dim = in_dim
505
+ self.out_dim = out_dim
506
+
507
+ # Create layers list
508
+ resnets = []
509
+ # Add residual blocks and attention if needed
510
+ current_dim = in_dim
511
+ for _ in range(num_res_blocks + 1):
512
+ resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
513
+ current_dim = out_dim
514
+
515
+ self.resnets = nn.ModuleList(resnets)
516
+
517
+ # Add upsampling layer if needed
518
+ self.upsamplers = None
519
+ if upsample_mode is not None:
520
+ self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
521
+
522
+ self.gradient_checkpointing = False
523
+
524
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
525
+ """
526
+ Forward pass through the upsampling block.
527
+
528
+ Args:
529
+ x (torch.Tensor): Input tensor
530
+ feat_cache (list, optional): Feature cache for causal convolutions
531
+ feat_idx (list, optional): Feature index for cache management
532
+
533
+ Returns:
534
+ torch.Tensor: Output tensor
535
+ """
536
+ for resnet in self.resnets:
537
+ if feat_cache is not None:
538
+ x = resnet(x, feat_cache, feat_idx)
539
+ else:
540
+ x = resnet(x)
541
+
542
+ if self.upsamplers is not None:
543
+ if feat_cache is not None:
544
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
545
+ else:
546
+ x = self.upsamplers[0](x)
547
+ return x
548
+
549
+
550
+ class QwenImageDecoder3d(nn.Module):
551
+ r"""
552
+ A 3D decoder module.
553
+
554
+ Args:
555
+ dim (int): The base number of channels in the first layer.
556
+ z_dim (int): The dimensionality of the latent space.
557
+ dim_mult (list of int): Multipliers for the number of channels in each block.
558
+ num_res_blocks (int): Number of residual blocks in each block.
559
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
560
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
561
+ dropout (float): Dropout rate for the dropout layers.
562
+ non_linearity (str): Type of non-linearity to use.
563
+ """
564
+
565
+ def __init__(
566
+ self,
567
+ dim=128,
568
+ z_dim=4,
569
+ dim_mult=[1, 2, 4, 4],
570
+ num_res_blocks=2,
571
+ attn_scales=[],
572
+ temperal_upsample=[False, True, True],
573
+ dropout=0.0,
574
+ non_linearity: str = "silu",
575
+ ):
576
+ super().__init__()
577
+ self.dim = dim
578
+ self.z_dim = z_dim
579
+ self.dim_mult = dim_mult
580
+ self.num_res_blocks = num_res_blocks
581
+ self.attn_scales = attn_scales
582
+ self.temperal_upsample = temperal_upsample
583
+
584
+ self.nonlinearity = get_activation(non_linearity)
585
+
586
+ # dimensions
587
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
588
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
589
+
590
+ # init block
591
+ self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
592
+
593
+ # middle blocks
594
+ self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
595
+
596
+ # upsample blocks
597
+ self.up_blocks = nn.ModuleList([])
598
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
599
+ # residual (+attention) blocks
600
+ if i > 0:
601
+ in_dim = in_dim // 2
602
+
603
+ # Determine if we need upsampling
604
+ upsample_mode = None
605
+ if i != len(dim_mult) - 1:
606
+ upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
607
+
608
+ # Create and add the upsampling block
609
+ up_block = QwenImageUpBlock(
610
+ in_dim=in_dim,
611
+ out_dim=out_dim,
612
+ num_res_blocks=num_res_blocks,
613
+ dropout=dropout,
614
+ upsample_mode=upsample_mode,
615
+ non_linearity=non_linearity,
616
+ )
617
+ self.up_blocks.append(up_block)
618
+
619
+ # Update scale for next iteration
620
+ if upsample_mode is not None:
621
+ scale *= 2.0
622
+
623
+ # output blocks
624
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
625
+ self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
626
+
627
+ self.gradient_checkpointing = False
628
+
629
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
630
+ ## conv1
631
+ if feat_cache is not None:
632
+ idx = feat_idx[0]
633
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
634
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
635
+ # cache last frame of last two chunk
636
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
637
+ x = self.conv_in(x, feat_cache[idx])
638
+ feat_cache[idx] = cache_x
639
+ feat_idx[0] += 1
640
+ else:
641
+ x = self.conv_in(x)
642
+
643
+ ## middle
644
+ x = self.mid_block(x, feat_cache, feat_idx)
645
+
646
+ ## upsamples
647
+ for up_block in self.up_blocks:
648
+ x = up_block(x, feat_cache, feat_idx)
649
+
650
+ ## head
651
+ x = self.norm_out(x)
652
+ x = self.nonlinearity(x)
653
+ if feat_cache is not None:
654
+ idx = feat_idx[0]
655
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
656
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
657
+ # cache last frame of last two chunk
658
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
659
+ x = self.conv_out(x, feat_cache[idx])
660
+ feat_cache[idx] = cache_x
661
+ feat_idx[0] += 1
662
+ else:
663
+ x = self.conv_out(x)
664
+ return x
665
+
666
+
667
+ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
668
+ r"""
669
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
670
+
671
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
672
+ for all models (such as downloading or saving).
673
+ """
674
+
675
+ _supports_gradient_checkpointing = False
676
+
677
+ # fmt: off
678
+ @register_to_config
679
+ def __init__(
680
+ self,
681
+ base_dim: int = 96,
682
+ z_dim: int = 16,
683
+ dim_mult: Tuple[int] = [1, 2, 4, 4],
684
+ num_res_blocks: int = 2,
685
+ attn_scales: List[float] = [],
686
+ temperal_downsample: List[bool] = [False, True, True],
687
+ dropout: float = 0.0,
688
+ latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
689
+ latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
690
+ ) -> None:
691
+ # fmt: on
692
+ super().__init__()
693
+
694
+ self.z_dim = z_dim
695
+ self.temperal_downsample = temperal_downsample
696
+ self.temperal_upsample = temperal_downsample[::-1]
697
+
698
+ self.encoder = QwenImageEncoder3d(
699
+ base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
700
+ )
701
+ self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
702
+ self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
703
+
704
+ self.decoder = QwenImageDecoder3d(
705
+ base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
706
+ )
707
+
708
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
709
+
710
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
711
+ # to perform decoding of a single video latent at a time.
712
+ self.use_slicing = False
713
+
714
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
715
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
716
+ # intermediate tiles together, the memory requirement can be lowered.
717
+ self.use_tiling = False
718
+
719
+ # The minimal tile height and width for spatial tiling to be used
720
+ self.tile_sample_min_height = 256
721
+ self.tile_sample_min_width = 256
722
+
723
+ # The minimal distance between two spatial tiles
724
+ self.tile_sample_stride_height = 192
725
+ self.tile_sample_stride_width = 192
726
+
727
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
728
+ self._cached_conv_counts = {
729
+ "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules())
730
+ if self.decoder is not None
731
+ else 0,
732
+ "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules())
733
+ if self.encoder is not None
734
+ else 0,
735
+ }
736
+
737
+ def enable_tiling(
738
+ self,
739
+ tile_sample_min_height: Optional[int] = None,
740
+ tile_sample_min_width: Optional[int] = None,
741
+ tile_sample_stride_height: Optional[float] = None,
742
+ tile_sample_stride_width: Optional[float] = None,
743
+ ) -> None:
744
+ r"""
745
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
746
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
747
+ processing larger images.
748
+
749
+ Args:
750
+ tile_sample_min_height (`int`, *optional*):
751
+ The minimum height required for a sample to be separated into tiles across the height dimension.
752
+ tile_sample_min_width (`int`, *optional*):
753
+ The minimum width required for a sample to be separated into tiles across the width dimension.
754
+ tile_sample_stride_height (`int`, *optional*):
755
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
756
+ no tiling artifacts produced across the height dimension.
757
+ tile_sample_stride_width (`int`, *optional*):
758
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
759
+ artifacts produced across the width dimension.
760
+ """
761
+ self.use_tiling = True
762
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
763
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
764
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
765
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
766
+
767
+ def disable_tiling(self) -> None:
768
+ r"""
769
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
770
+ decoding in one step.
771
+ """
772
+ self.use_tiling = False
773
+
774
+ def enable_slicing(self) -> None:
775
+ r"""
776
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
777
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
778
+ """
779
+ self.use_slicing = True
780
+
781
+ def disable_slicing(self) -> None:
782
+ r"""
783
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
784
+ decoding in one step.
785
+ """
786
+ self.use_slicing = False
787
+
788
+ def clear_cache(self):
789
+ def _count_conv3d(model):
790
+ count = 0
791
+ for m in model.modules():
792
+ if isinstance(m, QwenImageCausalConv3d):
793
+ count += 1
794
+ return count
795
+
796
+ self._conv_num = _count_conv3d(self.decoder)
797
+ self._conv_idx = [0]
798
+ self._feat_map = [None] * self._conv_num
799
+ # cache encode
800
+ self._enc_conv_num = _count_conv3d(self.encoder)
801
+ self._enc_conv_idx = [0]
802
+ self._enc_feat_map = [None] * self._enc_conv_num
803
+
804
+ def _encode(self, x: torch.Tensor):
805
+ _, _, num_frame, height, width = x.shape
806
+
807
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
808
+ return self.tiled_encode(x)
809
+
810
+ self.clear_cache()
811
+ iter_ = 1 + (num_frame - 1) // 4
812
+ for i in range(iter_):
813
+ self._enc_conv_idx = [0]
814
+ if i == 0:
815
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
816
+ else:
817
+ out_ = self.encoder(
818
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
819
+ feat_cache=self._enc_feat_map,
820
+ feat_idx=self._enc_conv_idx,
821
+ )
822
+ out = torch.cat([out, out_], 2)
823
+
824
+ enc = self.quant_conv(out)
825
+ self.clear_cache()
826
+ return enc
827
+
828
+ @apply_forward_hook
829
+ def encode(
830
+ self, x: torch.Tensor, return_dict: bool = True
831
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
832
+ r"""
833
+ Encode a batch of images into latents.
834
+
835
+ Args:
836
+ x (`torch.Tensor`): Input batch of images.
837
+ return_dict (`bool`, *optional*, defaults to `True`):
838
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
839
+
840
+ Returns:
841
+ The latent representations of the encoded videos. If `return_dict` is True, a
842
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
843
+ """
844
+ if self.use_slicing and x.shape[0] > 1:
845
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
846
+ h = torch.cat(encoded_slices)
847
+ else:
848
+ h = self._encode(x)
849
+ posterior = DiagonalGaussianDistribution(h)
850
+
851
+ if not return_dict:
852
+ return (posterior,)
853
+ return AutoencoderKLOutput(latent_dist=posterior)
854
+
855
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
856
+ _, _, num_frame, height, width = z.shape
857
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
858
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
859
+
860
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
861
+ return self.tiled_decode(z, return_dict=return_dict)
862
+
863
+ self.clear_cache()
864
+ x = self.post_quant_conv(z)
865
+ for i in range(num_frame):
866
+ self._conv_idx = [0]
867
+ if i == 0:
868
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
869
+ else:
870
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
871
+ out = torch.cat([out, out_], 2)
872
+
873
+ out = torch.clamp(out, min=-1.0, max=1.0)
874
+ self.clear_cache()
875
+ if not return_dict:
876
+ return (out,)
877
+
878
+ return DecoderOutput(sample=out)
879
+
880
+ @apply_forward_hook
881
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
882
+ r"""
883
+ Decode a batch of images.
884
+
885
+ Args:
886
+ z (`torch.Tensor`): Input batch of latent vectors.
887
+ return_dict (`bool`, *optional*, defaults to `True`):
888
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
889
+
890
+ Returns:
891
+ [`~models.vae.DecoderOutput`] or `tuple`:
892
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
893
+ returned.
894
+ """
895
+ if self.use_slicing and z.shape[0] > 1:
896
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
897
+ decoded = torch.cat(decoded_slices)
898
+ else:
899
+ decoded = self._decode(z).sample
900
+
901
+ if not return_dict:
902
+ return (decoded,)
903
+ return DecoderOutput(sample=decoded)
904
+
905
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
906
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
907
+ for y in range(blend_extent):
908
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
909
+ y / blend_extent
910
+ )
911
+ return b
912
+
913
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
914
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
915
+ for x in range(blend_extent):
916
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
917
+ x / blend_extent
918
+ )
919
+ return b
920
+
921
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
922
+ r"""Encode a batch of images using a tiled encoder.
923
+
924
+ Args:
925
+ x (`torch.Tensor`): Input batch of videos.
926
+
927
+ Returns:
928
+ `torch.Tensor`:
929
+ The latent representation of the encoded videos.
930
+ """
931
+ _, _, num_frames, height, width = x.shape
932
+ latent_height = height // self.spatial_compression_ratio
933
+ latent_width = width // self.spatial_compression_ratio
934
+
935
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
936
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
937
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
938
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
939
+
940
+ blend_height = tile_latent_min_height - tile_latent_stride_height
941
+ blend_width = tile_latent_min_width - tile_latent_stride_width
942
+
943
+ # Split x into overlapping tiles and encode them separately.
944
+ # The tiles have an overlap to avoid seams between tiles.
945
+ rows = []
946
+ for i in range(0, height, self.tile_sample_stride_height):
947
+ row = []
948
+ for j in range(0, width, self.tile_sample_stride_width):
949
+ self.clear_cache()
950
+ time = []
951
+ frame_range = 1 + (num_frames - 1) // 4
952
+ for k in range(frame_range):
953
+ self._enc_conv_idx = [0]
954
+ if k == 0:
955
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
956
+ else:
957
+ tile = x[
958
+ :,
959
+ :,
960
+ 1 + 4 * (k - 1) : 1 + 4 * k,
961
+ i : i + self.tile_sample_min_height,
962
+ j : j + self.tile_sample_min_width,
963
+ ]
964
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
965
+ tile = self.quant_conv(tile)
966
+ time.append(tile)
967
+ row.append(torch.cat(time, dim=2))
968
+ rows.append(row)
969
+ self.clear_cache()
970
+
971
+ result_rows = []
972
+ for i, row in enumerate(rows):
973
+ result_row = []
974
+ for j, tile in enumerate(row):
975
+ # blend the above tile and the left tile
976
+ # to the current tile and add the current tile to the result row
977
+ if i > 0:
978
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
979
+ if j > 0:
980
+ tile = self.blend_h(row[j - 1], tile, blend_width)
981
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
982
+ result_rows.append(torch.cat(result_row, dim=-1))
983
+
984
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
985
+ return enc
986
+
987
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
988
+ r"""
989
+ Decode a batch of images using a tiled decoder.
990
+
991
+ Args:
992
+ z (`torch.Tensor`): Input batch of latent vectors.
993
+ return_dict (`bool`, *optional*, defaults to `True`):
994
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
995
+
996
+ Returns:
997
+ [`~models.vae.DecoderOutput`] or `tuple`:
998
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
999
+ returned.
1000
+ """
1001
+ _, _, num_frames, height, width = z.shape
1002
+ sample_height = height * self.spatial_compression_ratio
1003
+ sample_width = width * self.spatial_compression_ratio
1004
+
1005
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1006
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1007
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1008
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1009
+
1010
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1011
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1012
+
1013
+ # Split z into overlapping tiles and decode them separately.
1014
+ # The tiles have an overlap to avoid seams between tiles.
1015
+ rows = []
1016
+ for i in range(0, height, tile_latent_stride_height):
1017
+ row = []
1018
+ for j in range(0, width, tile_latent_stride_width):
1019
+ self.clear_cache()
1020
+ time = []
1021
+ for k in range(num_frames):
1022
+ self._conv_idx = [0]
1023
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1024
+ tile = self.post_quant_conv(tile)
1025
+ decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1026
+ time.append(decoded)
1027
+ row.append(torch.cat(time, dim=2))
1028
+ rows.append(row)
1029
+ self.clear_cache()
1030
+
1031
+ result_rows = []
1032
+ for i, row in enumerate(rows):
1033
+ result_row = []
1034
+ for j, tile in enumerate(row):
1035
+ # blend the above tile and the left tile
1036
+ # to the current tile and add the current tile to the result row
1037
+ if i > 0:
1038
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1039
+ if j > 0:
1040
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1041
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1042
+ result_rows.append(torch.cat(result_row, dim=-1))
1043
+
1044
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1045
+
1046
+ if not return_dict:
1047
+ return (dec,)
1048
+ return DecoderOutput(sample=dec)
1049
+
1050
+ def forward(
1051
+ self,
1052
+ sample: torch.Tensor,
1053
+ sample_posterior: bool = False,
1054
+ return_dict: bool = True,
1055
+ generator: Optional[torch.Generator] = None,
1056
+ ) -> Union[DecoderOutput, torch.Tensor]:
1057
+ """
1058
+ Args:
1059
+ sample (`torch.Tensor`): Input sample.
1060
+ return_dict (`bool`, *optional*, defaults to `True`):
1061
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1062
+ """
1063
+ x = sample
1064
+ posterior = self.encode(x).latent_dist
1065
+ if sample_posterior:
1066
+ z = posterior.sample(generator=generator)
1067
+ else:
1068
+ z = posterior.mode()
1069
+ dec = self.decode(z, return_dict=return_dict)
1070
+ return dec