diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -26,9 +26,11 @@ from .modeling_utils import ModelMixin
26
26
  @dataclass
27
27
  class TransformerTemporalModelOutput(BaseOutput):
28
28
  """
29
+ The output of [`TransformerTemporalModel`].
30
+
29
31
  Args:
30
- sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`)
31
- Hidden states conditioned on `encoder_hidden_states` input.
32
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
33
+ The hidden states output conditioned on `encoder_hidden_states` input.
32
34
  """
33
35
 
34
36
  sample: torch.FloatTensor
@@ -36,24 +38,23 @@ class TransformerTemporalModelOutput(BaseOutput):
36
38
 
37
39
  class TransformerTemporalModel(ModelMixin, ConfigMixin):
38
40
  """
39
- Transformer model for video-like data.
41
+ A Transformer model for video-like data.
40
42
 
41
43
  Parameters:
42
44
  num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
43
45
  attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
44
46
  in_channels (`int`, *optional*):
45
- Pass if the input is continuous. The number of channels in the input and output.
47
+ The number of channels in the input and output (specify if the input is **continuous**).
46
48
  num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
47
49
  dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48
- cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
49
- sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
50
- Note that this is fixed at training time as it is used for learning a number of position embeddings. See
51
- `ImagePositionalEmbeddings`.
52
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
50
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
51
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
52
+ This is fixed during training since it is used to learn a number of position embeddings.
53
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
53
54
  attention_bias (`bool`, *optional*):
54
- Configure if the TransformerBlocks' attention should contain a bias parameter.
55
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
55
56
  double_self_attention (`bool`, *optional*):
56
- Configure if each TransformerBlock should contain two self-attention layers
57
+ Configure if each `TransformerBlock` should contain two self-attention layers.
57
58
  """
58
59
 
59
60
  @register_to_config
@@ -114,25 +115,27 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
114
115
  return_dict: bool = True,
115
116
  ):
116
117
  """
118
+ The [`TransformerTemporal`] forward method.
119
+
117
120
  Args:
118
- hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
119
- When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
120
- hidden_states
121
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
122
+ Input hidden_states.
121
123
  encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
122
124
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
123
125
  self-attention.
124
126
  timestep ( `torch.long`, *optional*):
125
- Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
127
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
126
128
  class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
127
- Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
128
- conditioning.
129
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
130
+ `AdaLayerZeroNorm`.
129
131
  return_dict (`bool`, *optional*, defaults to `True`):
130
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
132
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
133
+ tuple.
131
134
 
132
135
  Returns:
133
- [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`:
134
- [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`.
135
- When returning a tuple, the first element is the sample tensor.
136
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
137
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
138
+ returned, otherwise a `tuple` where the first element is the sample tensor.
136
139
  """
137
140
  # 1. Input
138
141
  batch_frames, channel, height, width = hidden_states.shape
@@ -28,9 +28,11 @@ from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up
28
28
  @dataclass
29
29
  class UNet1DOutput(BaseOutput):
30
30
  """
31
+ The output of [`UNet1DModel`].
32
+
31
33
  Args:
32
34
  sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
33
- Hidden states output. Output of last layer of model.
35
+ The hidden states output from the last layer of the model.
34
36
  """
35
37
 
36
38
  sample: torch.FloatTensor
@@ -38,10 +40,10 @@ class UNet1DOutput(BaseOutput):
38
40
 
39
41
  class UNet1DModel(ModelMixin, ConfigMixin):
40
42
  r"""
41
- UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
43
+ A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
42
44
 
43
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
44
- implements for all the model (such as downloading or saving, etc.)
45
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46
+ for all models (such as downloading or saving).
45
47
 
46
48
  Parameters:
47
49
  sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
@@ -49,24 +51,24 @@ class UNet1DModel(ModelMixin, ConfigMixin):
49
51
  out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
50
52
  extra_in_channels (`int`, *optional*, defaults to 0):
51
53
  Number of additional channels to be added to the input of the first down block. Useful for cases where the
52
- input data has more channels than what the model is initially designed for.
54
+ input data has more channels than what the model was initially designed for.
53
55
  time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
54
- freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding.
55
- flip_sin_to_cos (`bool`, *optional*, defaults to :
56
- obj:`False`): Whether to flip sin to cos for fourier time embedding.
57
- down_block_types (`Tuple[str]`, *optional*, defaults to :
58
- obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types.
59
- up_block_types (`Tuple[str]`, *optional*, defaults to :
60
- obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
61
- block_out_channels (`Tuple[int]`, *optional*, defaults to :
62
- obj:`(32, 32, 64)`): Tuple of block output channels.
63
- mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
64
- out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
65
- act_fn (`str`, *optional*, defaults to None): optional activation function in UNet blocks.
66
- norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
67
- layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block.
68
- downsample_each_block (`int`, *optional*, defaults to False:
69
- experimental feature for using a UNet without upsampling.
56
+ freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
57
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
58
+ Whether to flip sin to cos for Fourier time embedding.
59
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`):
60
+ Tuple of downsample block types.
61
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`):
62
+ Tuple of upsample block types.
63
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
64
+ Tuple of block output channels.
65
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
66
+ out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
67
+ act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
68
+ norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
69
+ layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
70
+ downsample_each_block (`int`, *optional*, defaults to `False`):
71
+ Experimental feature for using a UNet without upsampling.
70
72
  """
71
73
 
72
74
  @register_to_config
@@ -197,15 +199,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
197
199
  return_dict: bool = True,
198
200
  ) -> Union[UNet1DOutput, Tuple]:
199
201
  r"""
202
+ The [`UNet1DModel`] forward method.
203
+
200
204
  Args:
201
- sample (`torch.FloatTensor`): `(batch_size, num_channels, sample_size)` noisy inputs tensor
202
- timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
205
+ sample (`torch.FloatTensor`):
206
+ The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
207
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
203
208
  return_dict (`bool`, *optional*, defaults to `True`):
204
209
  Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
205
210
 
206
211
  Returns:
207
- [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
208
- otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
212
+ [`~models.unet_1d.UNet1DOutput`] or `tuple`:
213
+ If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
214
+ returned where the first element is the sample tensor.
209
215
  """
210
216
 
211
217
  # 1. time
@@ -27,9 +27,11 @@ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
27
27
  @dataclass
28
28
  class UNet2DOutput(BaseOutput):
29
29
  """
30
+ The output of [`UNet2DModel`].
31
+
30
32
  Args:
31
33
  sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
32
- Hidden states output. Output of last layer of model.
34
+ The hidden states output from the last layer of the model.
33
35
  """
34
36
 
35
37
  sample: torch.FloatTensor
@@ -37,46 +39,49 @@ class UNet2DOutput(BaseOutput):
37
39
 
38
40
  class UNet2DModel(ModelMixin, ConfigMixin):
39
41
  r"""
40
- UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
42
+ A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
41
43
 
42
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
43
- implements for all the model (such as downloading or saving, etc.)
44
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
+ for all models (such as downloading or saving).
44
46
 
45
47
  Parameters:
46
48
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
47
49
  Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
48
50
  1)`.
49
- in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
51
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
50
52
  out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
51
53
  center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
52
54
  time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
53
- freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
54
- flip_sin_to_cos (`bool`, *optional*, defaults to :
55
- obj:`True`): Whether to flip sin to cos for fourier time embedding.
56
- down_block_types (`Tuple[str]`, *optional*, defaults to :
57
- obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
58
- types.
55
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
56
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
57
+ Whether to flip sin to cos for Fourier time embedding.
58
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
59
+ Tuple of downsample block types.
59
60
  mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
60
- The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
61
- up_block_types (`Tuple[str]`, *optional*, defaults to :
62
- obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
63
- block_out_channels (`Tuple[int]`, *optional*, defaults to :
64
- obj:`(224, 448, 672, 896)`): Tuple of block output channels.
61
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
62
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
63
+ Tuple of upsample block types.
64
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
65
+ Tuple of block output channels.
65
66
  layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
66
67
  mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
67
68
  downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
69
+ downsample_type (`str`, *optional*, defaults to `conv`):
70
+ The downsample type for downsampling layers. Choose between "conv" and "resnet"
71
+ upsample_type (`str`, *optional*, defaults to `conv`):
72
+ The upsample type for upsampling layers. Choose between "conv" and "resnet"
68
73
  act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
69
74
  attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
70
- norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
71
- norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
75
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
76
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
72
77
  resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
73
- for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
74
- class_embed_type (`str`, *optional*, defaults to None):
78
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
79
+ class_embed_type (`str`, *optional*, defaults to `None`):
75
80
  The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
76
81
  `"timestep"`, or `"identity"`.
77
- num_class_embeds (`int`, *optional*, defaults to None):
78
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
79
- class conditioning with `class_embed_type` equal to `None`.
82
+ num_class_embeds (`int`, *optional*, defaults to `None`):
83
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
84
+ conditioning with `class_embed_type` equal to `None`.
80
85
  """
81
86
 
82
87
  @register_to_config
@@ -95,6 +100,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
95
100
  layers_per_block: int = 2,
96
101
  mid_block_scale_factor: float = 1,
97
102
  downsample_padding: int = 1,
103
+ downsample_type: str = "conv",
104
+ upsample_type: str = "conv",
98
105
  act_fn: str = "silu",
99
106
  attention_head_dim: Optional[int] = 8,
100
107
  norm_num_groups: int = 32,
@@ -164,9 +171,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
164
171
  resnet_eps=norm_eps,
165
172
  resnet_act_fn=act_fn,
166
173
  resnet_groups=norm_num_groups,
167
- attn_num_head_channels=attention_head_dim,
174
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
168
175
  downsample_padding=downsample_padding,
169
176
  resnet_time_scale_shift=resnet_time_scale_shift,
177
+ downsample_type=downsample_type,
170
178
  )
171
179
  self.down_blocks.append(down_block)
172
180
 
@@ -178,7 +186,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
178
186
  resnet_act_fn=act_fn,
179
187
  output_scale_factor=mid_block_scale_factor,
180
188
  resnet_time_scale_shift=resnet_time_scale_shift,
181
- attn_num_head_channels=attention_head_dim,
189
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
182
190
  resnet_groups=norm_num_groups,
183
191
  add_attention=add_attention,
184
192
  )
@@ -204,8 +212,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
204
212
  resnet_eps=norm_eps,
205
213
  resnet_act_fn=act_fn,
206
214
  resnet_groups=norm_num_groups,
207
- attn_num_head_channels=attention_head_dim,
215
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
208
216
  resnet_time_scale_shift=resnet_time_scale_shift,
217
+ upsample_type=upsample_type,
209
218
  )
210
219
  self.up_blocks.append(up_block)
211
220
  prev_output_channel = output_channel
@@ -224,17 +233,21 @@ class UNet2DModel(ModelMixin, ConfigMixin):
224
233
  return_dict: bool = True,
225
234
  ) -> Union[UNet2DOutput, Tuple]:
226
235
  r"""
236
+ The [`UNet2DModel`] forward method.
237
+
227
238
  Args:
228
- sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
229
- timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
239
+ sample (`torch.FloatTensor`):
240
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
241
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
230
242
  class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
231
243
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
232
244
  return_dict (`bool`, *optional*, defaults to `True`):
233
245
  Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
234
246
 
235
247
  Returns:
236
- [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
237
- otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
248
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`:
249
+ If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
250
+ returned where the first element is the sample tensor.
238
251
  """
239
252
  # 0. center input if necessary
240
253
  if self.config.center_input_sample: