diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -26,9 +26,11 @@ from .modeling_utils import ModelMixin
|
|
26
26
|
@dataclass
|
27
27
|
class TransformerTemporalModelOutput(BaseOutput):
|
28
28
|
"""
|
29
|
+
The output of [`TransformerTemporalModel`].
|
30
|
+
|
29
31
|
Args:
|
30
|
-
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`)
|
31
|
-
|
32
|
+
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
33
|
+
The hidden states output conditioned on `encoder_hidden_states` input.
|
32
34
|
"""
|
33
35
|
|
34
36
|
sample: torch.FloatTensor
|
@@ -36,24 +38,23 @@ class TransformerTemporalModelOutput(BaseOutput):
|
|
36
38
|
|
37
39
|
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
38
40
|
"""
|
39
|
-
Transformer model for video-like data.
|
41
|
+
A Transformer model for video-like data.
|
40
42
|
|
41
43
|
Parameters:
|
42
44
|
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
43
45
|
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
44
46
|
in_channels (`int`, *optional*):
|
45
|
-
|
47
|
+
The number of channels in the input and output (specify if the input is **continuous**).
|
46
48
|
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
47
49
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
48
|
-
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
49
|
-
sample_size (`int`, *optional*):
|
50
|
-
|
51
|
-
|
52
|
-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
50
|
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
51
|
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
52
|
+
This is fixed during training since it is used to learn a number of position embeddings.
|
53
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
53
54
|
attention_bias (`bool`, *optional*):
|
54
|
-
Configure if the
|
55
|
+
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
55
56
|
double_self_attention (`bool`, *optional*):
|
56
|
-
Configure if each TransformerBlock should contain two self-attention layers
|
57
|
+
Configure if each `TransformerBlock` should contain two self-attention layers.
|
57
58
|
"""
|
58
59
|
|
59
60
|
@register_to_config
|
@@ -114,25 +115,27 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
|
114
115
|
return_dict: bool = True,
|
115
116
|
):
|
116
117
|
"""
|
118
|
+
The [`TransformerTemporal`] forward method.
|
119
|
+
|
117
120
|
Args:
|
118
|
-
hidden_states (
|
119
|
-
|
120
|
-
hidden_states
|
121
|
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
122
|
+
Input hidden_states.
|
121
123
|
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
122
124
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
123
125
|
self-attention.
|
124
126
|
timestep ( `torch.long`, *optional*):
|
125
|
-
Optional timestep to be applied as an embedding in AdaLayerNorm
|
127
|
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
126
128
|
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
127
|
-
Optional class labels to be applied as an embedding in
|
128
|
-
|
129
|
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
130
|
+
`AdaLayerZeroNorm`.
|
129
131
|
return_dict (`bool`, *optional*, defaults to `True`):
|
130
|
-
Whether or not to return a [
|
132
|
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
133
|
+
tuple.
|
131
134
|
|
132
135
|
Returns:
|
133
|
-
[`~models.
|
134
|
-
|
135
|
-
|
136
|
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
137
|
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
138
|
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
136
139
|
"""
|
137
140
|
# 1. Input
|
138
141
|
batch_frames, channel, height, width = hidden_states.shape
|
diffusers/models/unet_1d.py
CHANGED
@@ -28,9 +28,11 @@ from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up
|
|
28
28
|
@dataclass
|
29
29
|
class UNet1DOutput(BaseOutput):
|
30
30
|
"""
|
31
|
+
The output of [`UNet1DModel`].
|
32
|
+
|
31
33
|
Args:
|
32
34
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
|
33
|
-
|
35
|
+
The hidden states output from the last layer of the model.
|
34
36
|
"""
|
35
37
|
|
36
38
|
sample: torch.FloatTensor
|
@@ -38,10 +40,10 @@ class UNet1DOutput(BaseOutput):
|
|
38
40
|
|
39
41
|
class UNet1DModel(ModelMixin, ConfigMixin):
|
40
42
|
r"""
|
41
|
-
|
43
|
+
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
42
44
|
|
43
|
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for
|
44
|
-
|
45
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
46
|
+
for all models (such as downloading or saving).
|
45
47
|
|
46
48
|
Parameters:
|
47
49
|
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
@@ -49,24 +51,24 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
|
49
51
|
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
50
52
|
extra_in_channels (`int`, *optional*, defaults to 0):
|
51
53
|
Number of additional channels to be added to the input of the first down block. Useful for cases where the
|
52
|
-
input data has more channels than what the model
|
54
|
+
input data has more channels than what the model was initially designed for.
|
53
55
|
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
54
|
-
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for
|
55
|
-
flip_sin_to_cos (`bool`, *optional*, defaults to :
|
56
|
-
|
57
|
-
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
58
|
-
|
59
|
-
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
60
|
-
|
61
|
-
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
62
|
-
|
63
|
-
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"):
|
64
|
-
out_block_type (`str`, *optional*, defaults to `None`):
|
65
|
-
act_fn (`str`, *optional*, defaults to None):
|
66
|
-
norm_num_groups (`int`, *optional*, defaults to 8):
|
67
|
-
layers_per_block (`int`, *optional*, defaults to 1):
|
68
|
-
downsample_each_block (`int`, *optional*, defaults to False:
|
69
|
-
|
56
|
+
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
|
57
|
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
58
|
+
Whether to flip sin to cos for Fourier time embedding.
|
59
|
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`):
|
60
|
+
Tuple of downsample block types.
|
61
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`):
|
62
|
+
Tuple of upsample block types.
|
63
|
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
|
64
|
+
Tuple of block output channels.
|
65
|
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
|
66
|
+
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
|
67
|
+
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
|
68
|
+
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
|
69
|
+
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
|
70
|
+
downsample_each_block (`int`, *optional*, defaults to `False`):
|
71
|
+
Experimental feature for using a UNet without upsampling.
|
70
72
|
"""
|
71
73
|
|
72
74
|
@register_to_config
|
@@ -197,15 +199,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
|
197
199
|
return_dict: bool = True,
|
198
200
|
) -> Union[UNet1DOutput, Tuple]:
|
199
201
|
r"""
|
202
|
+
The [`UNet1DModel`] forward method.
|
203
|
+
|
200
204
|
Args:
|
201
|
-
sample (`torch.FloatTensor`):
|
202
|
-
|
205
|
+
sample (`torch.FloatTensor`):
|
206
|
+
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
207
|
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
203
208
|
return_dict (`bool`, *optional*, defaults to `True`):
|
204
209
|
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
205
210
|
|
206
211
|
Returns:
|
207
|
-
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
|
208
|
-
|
212
|
+
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
|
213
|
+
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
|
214
|
+
returned where the first element is the sample tensor.
|
209
215
|
"""
|
210
216
|
|
211
217
|
# 1. time
|
diffusers/models/unet_2d.py
CHANGED
@@ -27,9 +27,11 @@ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
|
27
27
|
@dataclass
|
28
28
|
class UNet2DOutput(BaseOutput):
|
29
29
|
"""
|
30
|
+
The output of [`UNet2DModel`].
|
31
|
+
|
30
32
|
Args:
|
31
33
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
32
|
-
|
34
|
+
The hidden states output from the last layer of the model.
|
33
35
|
"""
|
34
36
|
|
35
37
|
sample: torch.FloatTensor
|
@@ -37,46 +39,49 @@ class UNet2DOutput(BaseOutput):
|
|
37
39
|
|
38
40
|
class UNet2DModel(ModelMixin, ConfigMixin):
|
39
41
|
r"""
|
40
|
-
|
42
|
+
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
41
43
|
|
42
|
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for
|
43
|
-
|
44
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
45
|
+
for all models (such as downloading or saving).
|
44
46
|
|
45
47
|
Parameters:
|
46
48
|
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
47
49
|
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
48
50
|
1)`.
|
49
|
-
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input
|
51
|
+
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
|
50
52
|
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
51
53
|
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
52
54
|
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
53
|
-
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for
|
54
|
-
flip_sin_to_cos (`bool`, *optional*, defaults to :
|
55
|
-
|
56
|
-
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
57
|
-
|
58
|
-
types.
|
55
|
+
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
|
56
|
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
57
|
+
Whether to flip sin to cos for Fourier time embedding.
|
58
|
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
59
|
+
Tuple of downsample block types.
|
59
60
|
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
60
|
-
|
61
|
-
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
62
|
-
|
63
|
-
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
64
|
-
|
61
|
+
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
62
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
63
|
+
Tuple of upsample block types.
|
64
|
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
65
|
+
Tuple of block output channels.
|
65
66
|
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
66
67
|
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
67
68
|
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
69
|
+
downsample_type (`str`, *optional*, defaults to `conv`):
|
70
|
+
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
71
|
+
upsample_type (`str`, *optional*, defaults to `conv`):
|
72
|
+
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
68
73
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
69
74
|
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
70
|
-
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for
|
71
|
-
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for
|
75
|
+
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
76
|
+
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
72
77
|
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
73
|
-
for
|
74
|
-
class_embed_type (`str`, *optional*, defaults to None):
|
78
|
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
79
|
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
75
80
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
76
81
|
`"timestep"`, or `"identity"`.
|
77
|
-
num_class_embeds (`int`, *optional*, defaults to None):
|
78
|
-
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim
|
79
|
-
|
82
|
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
83
|
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
|
84
|
+
conditioning with `class_embed_type` equal to `None`.
|
80
85
|
"""
|
81
86
|
|
82
87
|
@register_to_config
|
@@ -95,6 +100,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
95
100
|
layers_per_block: int = 2,
|
96
101
|
mid_block_scale_factor: float = 1,
|
97
102
|
downsample_padding: int = 1,
|
103
|
+
downsample_type: str = "conv",
|
104
|
+
upsample_type: str = "conv",
|
98
105
|
act_fn: str = "silu",
|
99
106
|
attention_head_dim: Optional[int] = 8,
|
100
107
|
norm_num_groups: int = 32,
|
@@ -164,9 +171,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
164
171
|
resnet_eps=norm_eps,
|
165
172
|
resnet_act_fn=act_fn,
|
166
173
|
resnet_groups=norm_num_groups,
|
167
|
-
|
174
|
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
168
175
|
downsample_padding=downsample_padding,
|
169
176
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
177
|
+
downsample_type=downsample_type,
|
170
178
|
)
|
171
179
|
self.down_blocks.append(down_block)
|
172
180
|
|
@@ -178,7 +186,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
178
186
|
resnet_act_fn=act_fn,
|
179
187
|
output_scale_factor=mid_block_scale_factor,
|
180
188
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
181
|
-
|
189
|
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
182
190
|
resnet_groups=norm_num_groups,
|
183
191
|
add_attention=add_attention,
|
184
192
|
)
|
@@ -204,8 +212,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
204
212
|
resnet_eps=norm_eps,
|
205
213
|
resnet_act_fn=act_fn,
|
206
214
|
resnet_groups=norm_num_groups,
|
207
|
-
|
215
|
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
208
216
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
217
|
+
upsample_type=upsample_type,
|
209
218
|
)
|
210
219
|
self.up_blocks.append(up_block)
|
211
220
|
prev_output_channel = output_channel
|
@@ -224,17 +233,21 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
224
233
|
return_dict: bool = True,
|
225
234
|
) -> Union[UNet2DOutput, Tuple]:
|
226
235
|
r"""
|
236
|
+
The [`UNet2DModel`] forward method.
|
237
|
+
|
227
238
|
Args:
|
228
|
-
sample (`torch.FloatTensor`):
|
229
|
-
|
239
|
+
sample (`torch.FloatTensor`):
|
240
|
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
241
|
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
230
242
|
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
231
243
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
232
244
|
return_dict (`bool`, *optional*, defaults to `True`):
|
233
245
|
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
234
246
|
|
235
247
|
Returns:
|
236
|
-
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
237
|
-
|
248
|
+
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
249
|
+
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
250
|
+
returned where the first element is the sample tensor.
|
238
251
|
"""
|
239
252
|
# 0. center input if necessary
|
240
253
|
if self.config.center_input_sample:
|