diffusers 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +16 -2
- diffusers/configuration_utils.py +1 -0
- diffusers/dependency_versions_check.py +1 -14
- diffusers/dependency_versions_table.py +5 -4
- diffusers/image_processor.py +186 -14
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +157 -0
- diffusers/loaders/lora.py +1415 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +631 -0
- diffusers/loaders/textual_inversion.py +459 -0
- diffusers/loaders/unet.py +735 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +12 -1
- diffusers/models/attention.py +165 -14
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +286 -1
- diffusers/models/autoencoder_asym_kl.py +14 -9
- diffusers/models/autoencoder_kl.py +3 -18
- diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/autoencoder_tiny.py +20 -24
- diffusers/models/consistency_decoder_vae.py +37 -30
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +2 -1
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +27 -19
- diffusers/models/normalization.py +2 -2
- diffusers/models/resnet.py +390 -59
- diffusers/models/transformer_2d.py +20 -3
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +9 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandi3.py +589 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/vae.py +63 -13
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +3 -1
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +65 -12
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
- diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +6 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
- diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +4 -2
- diffusers/pipelines/pipeline_utils.py +33 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
- diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
- diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/__init__.py +64 -21
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
- diffusers/schedulers/__init__.py +2 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +1 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
- diffusers/schedulers/scheduling_deis_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
- diffusers/schedulers/scheduling_euler_discrete.py +40 -13
- diffusers/schedulers/scheduling_heun_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +1 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
- diffusers/utils/__init__.py +1 -0
- diffusers/utils/constants.py +11 -6
- diffusers/utils/dummy_pt_objects.py +45 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
- diffusers/utils/dynamic_modules_utils.py +4 -4
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/logging.py +10 -10
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/torch_utils.py +2 -2
- diffusers/utils/versions.py +117 -0
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/METADATA +83 -64
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/RECORD +176 -157
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/WHEEL +1 -1
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +1 -0
- diffusers/loaders.py +0 -3336
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,489 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Dict, Optional, Tuple, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
|
7
|
+
from ..configuration_utils import ConfigMixin, register_to_config
|
8
|
+
from ..loaders import UNet2DConditionLoadersMixin
|
9
|
+
from ..utils import BaseOutput, logging
|
10
|
+
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
11
|
+
from .embeddings import TimestepEmbedding, Timesteps
|
12
|
+
from .modeling_utils import ModelMixin
|
13
|
+
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
|
14
|
+
|
15
|
+
|
16
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class UNetSpatioTemporalConditionOutput(BaseOutput):
|
21
|
+
"""
|
22
|
+
The output of [`UNetSpatioTemporalConditionModel`].
|
23
|
+
|
24
|
+
Args:
|
25
|
+
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
26
|
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
27
|
+
"""
|
28
|
+
|
29
|
+
sample: torch.FloatTensor = None
|
30
|
+
|
31
|
+
|
32
|
+
class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
33
|
+
r"""
|
34
|
+
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
|
35
|
+
shaped output.
|
36
|
+
|
37
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
38
|
+
for all models (such as downloading or saving).
|
39
|
+
|
40
|
+
Parameters:
|
41
|
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
42
|
+
Height and width of input/output sample.
|
43
|
+
in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
|
44
|
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
45
|
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
|
46
|
+
The tuple of downsample blocks to use.
|
47
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
|
48
|
+
The tuple of upsample blocks to use.
|
49
|
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
50
|
+
The tuple of output channels for each block.
|
51
|
+
addition_time_embed_dim: (`int`, defaults to 256):
|
52
|
+
Dimension to to encode the additional time ids.
|
53
|
+
projection_class_embeddings_input_dim (`int`, defaults to 768):
|
54
|
+
The dimension of the projection of encoded `added_time_ids`.
|
55
|
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
56
|
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
57
|
+
The dimension of the cross attention features.
|
58
|
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
59
|
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
60
|
+
[`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
|
61
|
+
[`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
|
62
|
+
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
|
63
|
+
The number of attention heads.
|
64
|
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
65
|
+
"""
|
66
|
+
|
67
|
+
_supports_gradient_checkpointing = True
|
68
|
+
|
69
|
+
@register_to_config
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
sample_size: Optional[int] = None,
|
73
|
+
in_channels: int = 8,
|
74
|
+
out_channels: int = 4,
|
75
|
+
down_block_types: Tuple[str] = (
|
76
|
+
"CrossAttnDownBlockSpatioTemporal",
|
77
|
+
"CrossAttnDownBlockSpatioTemporal",
|
78
|
+
"CrossAttnDownBlockSpatioTemporal",
|
79
|
+
"DownBlockSpatioTemporal",
|
80
|
+
),
|
81
|
+
up_block_types: Tuple[str] = (
|
82
|
+
"UpBlockSpatioTemporal",
|
83
|
+
"CrossAttnUpBlockSpatioTemporal",
|
84
|
+
"CrossAttnUpBlockSpatioTemporal",
|
85
|
+
"CrossAttnUpBlockSpatioTemporal",
|
86
|
+
),
|
87
|
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
88
|
+
addition_time_embed_dim: int = 256,
|
89
|
+
projection_class_embeddings_input_dim: int = 768,
|
90
|
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
91
|
+
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
92
|
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
93
|
+
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
|
94
|
+
num_frames: int = 25,
|
95
|
+
):
|
96
|
+
super().__init__()
|
97
|
+
|
98
|
+
self.sample_size = sample_size
|
99
|
+
|
100
|
+
# Check inputs
|
101
|
+
if len(down_block_types) != len(up_block_types):
|
102
|
+
raise ValueError(
|
103
|
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
104
|
+
)
|
105
|
+
|
106
|
+
if len(block_out_channels) != len(down_block_types):
|
107
|
+
raise ValueError(
|
108
|
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
109
|
+
)
|
110
|
+
|
111
|
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
112
|
+
raise ValueError(
|
113
|
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
114
|
+
)
|
115
|
+
|
116
|
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
117
|
+
raise ValueError(
|
118
|
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
119
|
+
)
|
120
|
+
|
121
|
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
122
|
+
raise ValueError(
|
123
|
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
124
|
+
)
|
125
|
+
|
126
|
+
# input
|
127
|
+
self.conv_in = nn.Conv2d(
|
128
|
+
in_channels,
|
129
|
+
block_out_channels[0],
|
130
|
+
kernel_size=3,
|
131
|
+
padding=1,
|
132
|
+
)
|
133
|
+
|
134
|
+
# time
|
135
|
+
time_embed_dim = block_out_channels[0] * 4
|
136
|
+
|
137
|
+
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
|
138
|
+
timestep_input_dim = block_out_channels[0]
|
139
|
+
|
140
|
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
141
|
+
|
142
|
+
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
|
143
|
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
144
|
+
|
145
|
+
self.down_blocks = nn.ModuleList([])
|
146
|
+
self.up_blocks = nn.ModuleList([])
|
147
|
+
|
148
|
+
if isinstance(num_attention_heads, int):
|
149
|
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
150
|
+
|
151
|
+
if isinstance(cross_attention_dim, int):
|
152
|
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
153
|
+
|
154
|
+
if isinstance(layers_per_block, int):
|
155
|
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
156
|
+
|
157
|
+
if isinstance(transformer_layers_per_block, int):
|
158
|
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
159
|
+
|
160
|
+
blocks_time_embed_dim = time_embed_dim
|
161
|
+
|
162
|
+
# down
|
163
|
+
output_channel = block_out_channels[0]
|
164
|
+
for i, down_block_type in enumerate(down_block_types):
|
165
|
+
input_channel = output_channel
|
166
|
+
output_channel = block_out_channels[i]
|
167
|
+
is_final_block = i == len(block_out_channels) - 1
|
168
|
+
|
169
|
+
down_block = get_down_block(
|
170
|
+
down_block_type,
|
171
|
+
num_layers=layers_per_block[i],
|
172
|
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
173
|
+
in_channels=input_channel,
|
174
|
+
out_channels=output_channel,
|
175
|
+
temb_channels=blocks_time_embed_dim,
|
176
|
+
add_downsample=not is_final_block,
|
177
|
+
resnet_eps=1e-5,
|
178
|
+
cross_attention_dim=cross_attention_dim[i],
|
179
|
+
num_attention_heads=num_attention_heads[i],
|
180
|
+
resnet_act_fn="silu",
|
181
|
+
)
|
182
|
+
self.down_blocks.append(down_block)
|
183
|
+
|
184
|
+
# mid
|
185
|
+
self.mid_block = UNetMidBlockSpatioTemporal(
|
186
|
+
block_out_channels[-1],
|
187
|
+
temb_channels=blocks_time_embed_dim,
|
188
|
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
189
|
+
cross_attention_dim=cross_attention_dim[-1],
|
190
|
+
num_attention_heads=num_attention_heads[-1],
|
191
|
+
)
|
192
|
+
|
193
|
+
# count how many layers upsample the images
|
194
|
+
self.num_upsamplers = 0
|
195
|
+
|
196
|
+
# up
|
197
|
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
198
|
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
199
|
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
200
|
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
201
|
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
202
|
+
|
203
|
+
output_channel = reversed_block_out_channels[0]
|
204
|
+
for i, up_block_type in enumerate(up_block_types):
|
205
|
+
is_final_block = i == len(block_out_channels) - 1
|
206
|
+
|
207
|
+
prev_output_channel = output_channel
|
208
|
+
output_channel = reversed_block_out_channels[i]
|
209
|
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
210
|
+
|
211
|
+
# add upsample block for all BUT final layer
|
212
|
+
if not is_final_block:
|
213
|
+
add_upsample = True
|
214
|
+
self.num_upsamplers += 1
|
215
|
+
else:
|
216
|
+
add_upsample = False
|
217
|
+
|
218
|
+
up_block = get_up_block(
|
219
|
+
up_block_type,
|
220
|
+
num_layers=reversed_layers_per_block[i] + 1,
|
221
|
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
222
|
+
in_channels=input_channel,
|
223
|
+
out_channels=output_channel,
|
224
|
+
prev_output_channel=prev_output_channel,
|
225
|
+
temb_channels=blocks_time_embed_dim,
|
226
|
+
add_upsample=add_upsample,
|
227
|
+
resnet_eps=1e-5,
|
228
|
+
resolution_idx=i,
|
229
|
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
230
|
+
num_attention_heads=reversed_num_attention_heads[i],
|
231
|
+
resnet_act_fn="silu",
|
232
|
+
)
|
233
|
+
self.up_blocks.append(up_block)
|
234
|
+
prev_output_channel = output_channel
|
235
|
+
|
236
|
+
# out
|
237
|
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
|
238
|
+
self.conv_act = nn.SiLU()
|
239
|
+
|
240
|
+
self.conv_out = nn.Conv2d(
|
241
|
+
block_out_channels[0],
|
242
|
+
out_channels,
|
243
|
+
kernel_size=3,
|
244
|
+
padding=1,
|
245
|
+
)
|
246
|
+
|
247
|
+
@property
|
248
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
249
|
+
r"""
|
250
|
+
Returns:
|
251
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
252
|
+
indexed by its weight name.
|
253
|
+
"""
|
254
|
+
# set recursively
|
255
|
+
processors = {}
|
256
|
+
|
257
|
+
def fn_recursive_add_processors(
|
258
|
+
name: str,
|
259
|
+
module: torch.nn.Module,
|
260
|
+
processors: Dict[str, AttentionProcessor],
|
261
|
+
):
|
262
|
+
if hasattr(module, "get_processor"):
|
263
|
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
264
|
+
|
265
|
+
for sub_name, child in module.named_children():
|
266
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
267
|
+
|
268
|
+
return processors
|
269
|
+
|
270
|
+
for name, module in self.named_children():
|
271
|
+
fn_recursive_add_processors(name, module, processors)
|
272
|
+
|
273
|
+
return processors
|
274
|
+
|
275
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
276
|
+
r"""
|
277
|
+
Sets the attention processor to use to compute attention.
|
278
|
+
|
279
|
+
Parameters:
|
280
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
281
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
282
|
+
for **all** `Attention` layers.
|
283
|
+
|
284
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
285
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
286
|
+
|
287
|
+
"""
|
288
|
+
count = len(self.attn_processors.keys())
|
289
|
+
|
290
|
+
if isinstance(processor, dict) and len(processor) != count:
|
291
|
+
raise ValueError(
|
292
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
293
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
294
|
+
)
|
295
|
+
|
296
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
297
|
+
if hasattr(module, "set_processor"):
|
298
|
+
if not isinstance(processor, dict):
|
299
|
+
module.set_processor(processor)
|
300
|
+
else:
|
301
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
302
|
+
|
303
|
+
for sub_name, child in module.named_children():
|
304
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
305
|
+
|
306
|
+
for name, module in self.named_children():
|
307
|
+
fn_recursive_attn_processor(name, module, processor)
|
308
|
+
|
309
|
+
def set_default_attn_processor(self):
|
310
|
+
"""
|
311
|
+
Disables custom attention processors and sets the default attention implementation.
|
312
|
+
"""
|
313
|
+
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
314
|
+
processor = AttnProcessor()
|
315
|
+
else:
|
316
|
+
raise ValueError(
|
317
|
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
318
|
+
)
|
319
|
+
|
320
|
+
self.set_attn_processor(processor)
|
321
|
+
|
322
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
323
|
+
if hasattr(module, "gradient_checkpointing"):
|
324
|
+
module.gradient_checkpointing = value
|
325
|
+
|
326
|
+
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
327
|
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
328
|
+
"""
|
329
|
+
Sets the attention processor to use [feed forward
|
330
|
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
331
|
+
|
332
|
+
Parameters:
|
333
|
+
chunk_size (`int`, *optional*):
|
334
|
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
335
|
+
over each tensor of dim=`dim`.
|
336
|
+
dim (`int`, *optional*, defaults to `0`):
|
337
|
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
338
|
+
or dim=1 (sequence length).
|
339
|
+
"""
|
340
|
+
if dim not in [0, 1]:
|
341
|
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
342
|
+
|
343
|
+
# By default chunk size is 1
|
344
|
+
chunk_size = chunk_size or 1
|
345
|
+
|
346
|
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
347
|
+
if hasattr(module, "set_chunk_feed_forward"):
|
348
|
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
349
|
+
|
350
|
+
for child in module.children():
|
351
|
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
352
|
+
|
353
|
+
for module in self.children():
|
354
|
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
355
|
+
|
356
|
+
def forward(
|
357
|
+
self,
|
358
|
+
sample: torch.FloatTensor,
|
359
|
+
timestep: Union[torch.Tensor, float, int],
|
360
|
+
encoder_hidden_states: torch.Tensor,
|
361
|
+
added_time_ids: torch.Tensor,
|
362
|
+
return_dict: bool = True,
|
363
|
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
364
|
+
r"""
|
365
|
+
The [`UNetSpatioTemporalConditionModel`] forward method.
|
366
|
+
|
367
|
+
Args:
|
368
|
+
sample (`torch.FloatTensor`):
|
369
|
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
|
370
|
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
371
|
+
encoder_hidden_states (`torch.FloatTensor`):
|
372
|
+
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
|
373
|
+
added_time_ids: (`torch.FloatTensor`):
|
374
|
+
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
|
375
|
+
embeddings and added to the time embeddings.
|
376
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
377
|
+
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
|
378
|
+
tuple.
|
379
|
+
Returns:
|
380
|
+
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
|
381
|
+
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
|
382
|
+
a `tuple` is returned where the first element is the sample tensor.
|
383
|
+
"""
|
384
|
+
# 1. time
|
385
|
+
timesteps = timestep
|
386
|
+
if not torch.is_tensor(timesteps):
|
387
|
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
388
|
+
# This would be a good case for the `match` statement (Python 3.10+)
|
389
|
+
is_mps = sample.device.type == "mps"
|
390
|
+
if isinstance(timestep, float):
|
391
|
+
dtype = torch.float32 if is_mps else torch.float64
|
392
|
+
else:
|
393
|
+
dtype = torch.int32 if is_mps else torch.int64
|
394
|
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
395
|
+
elif len(timesteps.shape) == 0:
|
396
|
+
timesteps = timesteps[None].to(sample.device)
|
397
|
+
|
398
|
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
399
|
+
batch_size, num_frames = sample.shape[:2]
|
400
|
+
timesteps = timesteps.expand(batch_size)
|
401
|
+
|
402
|
+
t_emb = self.time_proj(timesteps)
|
403
|
+
|
404
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
405
|
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
406
|
+
# there might be better ways to encapsulate this.
|
407
|
+
t_emb = t_emb.to(dtype=sample.dtype)
|
408
|
+
|
409
|
+
emb = self.time_embedding(t_emb)
|
410
|
+
|
411
|
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
412
|
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
413
|
+
time_embeds = time_embeds.to(emb.dtype)
|
414
|
+
aug_emb = self.add_embedding(time_embeds)
|
415
|
+
emb = emb + aug_emb
|
416
|
+
|
417
|
+
# Flatten the batch and frames dimensions
|
418
|
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
419
|
+
sample = sample.flatten(0, 1)
|
420
|
+
# Repeat the embeddings num_video_frames times
|
421
|
+
# emb: [batch, channels] -> [batch * frames, channels]
|
422
|
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
423
|
+
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
424
|
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
|
425
|
+
|
426
|
+
# 2. pre-process
|
427
|
+
sample = self.conv_in(sample)
|
428
|
+
|
429
|
+
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
|
430
|
+
|
431
|
+
down_block_res_samples = (sample,)
|
432
|
+
for downsample_block in self.down_blocks:
|
433
|
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
434
|
+
sample, res_samples = downsample_block(
|
435
|
+
hidden_states=sample,
|
436
|
+
temb=emb,
|
437
|
+
encoder_hidden_states=encoder_hidden_states,
|
438
|
+
image_only_indicator=image_only_indicator,
|
439
|
+
)
|
440
|
+
else:
|
441
|
+
sample, res_samples = downsample_block(
|
442
|
+
hidden_states=sample,
|
443
|
+
temb=emb,
|
444
|
+
image_only_indicator=image_only_indicator,
|
445
|
+
)
|
446
|
+
|
447
|
+
down_block_res_samples += res_samples
|
448
|
+
|
449
|
+
# 4. mid
|
450
|
+
sample = self.mid_block(
|
451
|
+
hidden_states=sample,
|
452
|
+
temb=emb,
|
453
|
+
encoder_hidden_states=encoder_hidden_states,
|
454
|
+
image_only_indicator=image_only_indicator,
|
455
|
+
)
|
456
|
+
|
457
|
+
# 5. up
|
458
|
+
for i, upsample_block in enumerate(self.up_blocks):
|
459
|
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
460
|
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
461
|
+
|
462
|
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
463
|
+
sample = upsample_block(
|
464
|
+
hidden_states=sample,
|
465
|
+
temb=emb,
|
466
|
+
res_hidden_states_tuple=res_samples,
|
467
|
+
encoder_hidden_states=encoder_hidden_states,
|
468
|
+
image_only_indicator=image_only_indicator,
|
469
|
+
)
|
470
|
+
else:
|
471
|
+
sample = upsample_block(
|
472
|
+
hidden_states=sample,
|
473
|
+
temb=emb,
|
474
|
+
res_hidden_states_tuple=res_samples,
|
475
|
+
image_only_indicator=image_only_indicator,
|
476
|
+
)
|
477
|
+
|
478
|
+
# 6. post-process
|
479
|
+
sample = self.conv_norm_out(sample)
|
480
|
+
sample = self.conv_act(sample)
|
481
|
+
sample = self.conv_out(sample)
|
482
|
+
|
483
|
+
# 7. Reshape back to original shape
|
484
|
+
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
485
|
+
|
486
|
+
if not return_dict:
|
487
|
+
return (sample,)
|
488
|
+
|
489
|
+
return UNetSpatioTemporalConditionOutput(sample=sample)
|
diffusers/models/vae.py
CHANGED
@@ -22,7 +22,12 @@ from ..utils import BaseOutput, is_torch_version
|
|
22
22
|
from ..utils.torch_utils import randn_tensor
|
23
23
|
from .activations import get_activation
|
24
24
|
from .attention_processor import SpatialNorm
|
25
|
-
from .unet_2d_blocks import
|
25
|
+
from .unet_2d_blocks import (
|
26
|
+
AutoencoderTinyBlock,
|
27
|
+
UNetMidBlock2D,
|
28
|
+
get_down_block,
|
29
|
+
get_up_block,
|
30
|
+
)
|
26
31
|
|
27
32
|
|
28
33
|
@dataclass
|
@@ -274,7 +279,9 @@ class Decoder(nn.Module):
|
|
274
279
|
self.gradient_checkpointing = False
|
275
280
|
|
276
281
|
def forward(
|
277
|
-
self,
|
282
|
+
self,
|
283
|
+
sample: torch.FloatTensor,
|
284
|
+
latent_embeds: Optional[torch.FloatTensor] = None,
|
278
285
|
) -> torch.FloatTensor:
|
279
286
|
r"""The forward method of the `Decoder` class."""
|
280
287
|
|
@@ -292,14 +299,20 @@ class Decoder(nn.Module):
|
|
292
299
|
if is_torch_version(">=", "1.11.0"):
|
293
300
|
# middle
|
294
301
|
sample = torch.utils.checkpoint.checkpoint(
|
295
|
-
create_custom_forward(self.mid_block),
|
302
|
+
create_custom_forward(self.mid_block),
|
303
|
+
sample,
|
304
|
+
latent_embeds,
|
305
|
+
use_reentrant=False,
|
296
306
|
)
|
297
307
|
sample = sample.to(upscale_dtype)
|
298
308
|
|
299
309
|
# up
|
300
310
|
for up_block in self.up_blocks:
|
301
311
|
sample = torch.utils.checkpoint.checkpoint(
|
302
|
-
create_custom_forward(up_block),
|
312
|
+
create_custom_forward(up_block),
|
313
|
+
sample,
|
314
|
+
latent_embeds,
|
315
|
+
use_reentrant=False,
|
303
316
|
)
|
304
317
|
else:
|
305
318
|
# middle
|
@@ -540,7 +553,10 @@ class MaskConditionDecoder(nn.Module):
|
|
540
553
|
if is_torch_version(">=", "1.11.0"):
|
541
554
|
# middle
|
542
555
|
sample = torch.utils.checkpoint.checkpoint(
|
543
|
-
create_custom_forward(self.mid_block),
|
556
|
+
create_custom_forward(self.mid_block),
|
557
|
+
sample,
|
558
|
+
latent_embeds,
|
559
|
+
use_reentrant=False,
|
544
560
|
)
|
545
561
|
sample = sample.to(upscale_dtype)
|
546
562
|
|
@@ -548,7 +564,10 @@ class MaskConditionDecoder(nn.Module):
|
|
548
564
|
if image is not None and mask is not None:
|
549
565
|
masked_image = (1 - mask) * image
|
550
566
|
im_x = torch.utils.checkpoint.checkpoint(
|
551
|
-
create_custom_forward(self.condition_encoder),
|
567
|
+
create_custom_forward(self.condition_encoder),
|
568
|
+
masked_image,
|
569
|
+
mask,
|
570
|
+
use_reentrant=False,
|
552
571
|
)
|
553
572
|
|
554
573
|
# up
|
@@ -558,7 +577,10 @@ class MaskConditionDecoder(nn.Module):
|
|
558
577
|
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
559
578
|
sample = sample * mask_ + sample_ * (1 - mask_)
|
560
579
|
sample = torch.utils.checkpoint.checkpoint(
|
561
|
-
create_custom_forward(up_block),
|
580
|
+
create_custom_forward(up_block),
|
581
|
+
sample,
|
582
|
+
latent_embeds,
|
583
|
+
use_reentrant=False,
|
562
584
|
)
|
563
585
|
if image is not None and mask is not None:
|
564
586
|
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
@@ -573,7 +595,9 @@ class MaskConditionDecoder(nn.Module):
|
|
573
595
|
if image is not None and mask is not None:
|
574
596
|
masked_image = (1 - mask) * image
|
575
597
|
im_x = torch.utils.checkpoint.checkpoint(
|
576
|
-
create_custom_forward(self.condition_encoder),
|
598
|
+
create_custom_forward(self.condition_encoder),
|
599
|
+
masked_image,
|
600
|
+
mask,
|
577
601
|
)
|
578
602
|
|
579
603
|
# up
|
@@ -754,7 +778,10 @@ class DiagonalGaussianDistribution(object):
|
|
754
778
|
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
755
779
|
# make sure sample is on the same device as the parameters and has same dtype
|
756
780
|
sample = randn_tensor(
|
757
|
-
self.mean.shape,
|
781
|
+
self.mean.shape,
|
782
|
+
generator=generator,
|
783
|
+
device=self.parameters.device,
|
784
|
+
dtype=self.parameters.dtype,
|
758
785
|
)
|
759
786
|
x = self.mean + self.std * sample
|
760
787
|
return x
|
@@ -764,7 +791,10 @@ class DiagonalGaussianDistribution(object):
|
|
764
791
|
return torch.Tensor([0.0])
|
765
792
|
else:
|
766
793
|
if other is None:
|
767
|
-
return 0.5 * torch.sum(
|
794
|
+
return 0.5 * torch.sum(
|
795
|
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
796
|
+
dim=[1, 2, 3],
|
797
|
+
)
|
768
798
|
else:
|
769
799
|
return 0.5 * torch.sum(
|
770
800
|
torch.pow(self.mean - other.mean, 2) / other.var
|
@@ -779,7 +809,10 @@ class DiagonalGaussianDistribution(object):
|
|
779
809
|
if self.deterministic:
|
780
810
|
return torch.Tensor([0.0])
|
781
811
|
logtwopi = np.log(2.0 * np.pi)
|
782
|
-
return 0.5 * torch.sum(
|
812
|
+
return 0.5 * torch.sum(
|
813
|
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
814
|
+
dim=dims,
|
815
|
+
)
|
783
816
|
|
784
817
|
def mode(self) -> torch.Tensor:
|
785
818
|
return self.mean
|
@@ -820,7 +853,16 @@ class EncoderTiny(nn.Module):
|
|
820
853
|
if i == 0:
|
821
854
|
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
|
822
855
|
else:
|
823
|
-
layers.append(
|
856
|
+
layers.append(
|
857
|
+
nn.Conv2d(
|
858
|
+
num_channels,
|
859
|
+
num_channels,
|
860
|
+
kernel_size=3,
|
861
|
+
padding=1,
|
862
|
+
stride=2,
|
863
|
+
bias=False,
|
864
|
+
)
|
865
|
+
)
|
824
866
|
|
825
867
|
for _ in range(num_block):
|
826
868
|
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
@@ -899,7 +941,15 @@ class DecoderTiny(nn.Module):
|
|
899
941
|
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
|
900
942
|
|
901
943
|
conv_out_channel = num_channels if not is_final_block else out_channels
|
902
|
-
layers.append(
|
944
|
+
layers.append(
|
945
|
+
nn.Conv2d(
|
946
|
+
num_channels,
|
947
|
+
conv_out_channel,
|
948
|
+
kernel_size=3,
|
949
|
+
padding=1,
|
950
|
+
bias=is_final_block,
|
951
|
+
)
|
952
|
+
)
|
903
953
|
|
904
954
|
self.layers = nn.Sequential(*layers)
|
905
955
|
self.gradient_checkpointing = False
|