diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import Optional, Union
2
+ from typing import Dict, Optional, Union
3
3
 
4
4
  import torch
5
5
  import torch.nn.functional as F
@@ -8,6 +8,7 @@ from torch import nn
8
8
  from ..configuration_utils import ConfigMixin, register_to_config
9
9
  from ..utils import BaseOutput
10
10
  from .attention import BasicTransformerBlock
11
+ from .attention_processor import AttentionProcessor, AttnProcessor
11
12
  from .embeddings import TimestepEmbedding, Timesteps
12
13
  from .modeling_utils import ModelMixin
13
14
 
@@ -15,6 +16,8 @@ from .modeling_utils import ModelMixin
15
16
  @dataclass
16
17
  class PriorTransformerOutput(BaseOutput):
17
18
  """
19
+ The output of [`PriorTransformer`].
20
+
18
21
  Args:
19
22
  predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
20
23
  The predicted CLIP image embedding conditioned on the CLIP text embedding input.
@@ -25,27 +28,39 @@ class PriorTransformerOutput(BaseOutput):
25
28
 
26
29
  class PriorTransformer(ModelMixin, ConfigMixin):
27
30
  """
28
- The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
29
- transformer predicts the image embeddings through a denoising diffusion process.
30
-
31
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
32
- implements for all the models (such as downloading or saving, etc.)
33
-
34
- For more details, see the original paper: https://arxiv.org/abs/2204.06125
31
+ A Prior Transformer model.
35
32
 
36
33
  Parameters:
37
34
  num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
38
35
  attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
39
36
  num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
40
- embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
41
- image embeddings and text embeddings are both the same dimension.
42
- num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
43
- length of the prompt after it has been tokenized.
37
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
38
+ num_embeddings (`int`, *optional*, defaults to 77):
39
+ The number of embeddings of the model input `hidden_states`
44
40
  additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
45
- projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
41
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
46
42
  additional_embeddings`.
47
43
  dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48
-
44
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
45
+ The activation function to use to create timestep embeddings.
46
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
47
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
48
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
49
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
50
+ needed.
51
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
52
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
53
+ `encoder_hidden_states` is `None`.
54
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
55
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
56
+ product between the text embedding and image embedding as proposed in the unclip paper
57
+ https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
58
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
59
+ If None, will be set to `num_attention_heads * attention_head_dim`
60
+ embedding_proj_dim (`int`, *optional*, default to None):
61
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
62
+ clip_embed_dim (`int`, *optional*, default to None):
63
+ The dimension of the output. If None, will be set to `embedding_dim`.
49
64
  """
50
65
 
51
66
  @register_to_config
@@ -58,6 +73,14 @@ class PriorTransformer(ModelMixin, ConfigMixin):
58
73
  num_embeddings=77,
59
74
  additional_embeddings=4,
60
75
  dropout: float = 0.0,
76
+ time_embed_act_fn: str = "silu",
77
+ norm_in_type: Optional[str] = None, # layer
78
+ embedding_proj_norm_type: Optional[str] = None, # layer
79
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
80
+ added_emb_type: Optional[str] = "prd", # prd
81
+ time_embed_dim: Optional[int] = None,
82
+ embedding_proj_dim: Optional[int] = None,
83
+ clip_embed_dim: Optional[int] = None,
61
84
  ):
62
85
  super().__init__()
63
86
  self.num_attention_heads = num_attention_heads
@@ -65,17 +88,41 @@ class PriorTransformer(ModelMixin, ConfigMixin):
65
88
  inner_dim = num_attention_heads * attention_head_dim
66
89
  self.additional_embeddings = additional_embeddings
67
90
 
91
+ time_embed_dim = time_embed_dim or inner_dim
92
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
93
+ clip_embed_dim = clip_embed_dim or embedding_dim
94
+
68
95
  self.time_proj = Timesteps(inner_dim, True, 0)
69
- self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
96
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
70
97
 
71
98
  self.proj_in = nn.Linear(embedding_dim, inner_dim)
72
99
 
73
- self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
74
- self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
100
+ if embedding_proj_norm_type is None:
101
+ self.embedding_proj_norm = None
102
+ elif embedding_proj_norm_type == "layer":
103
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
104
+ else:
105
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
106
+
107
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
108
+
109
+ if encoder_hid_proj_type is None:
110
+ self.encoder_hidden_states_proj = None
111
+ elif encoder_hid_proj_type == "linear":
112
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
113
+ else:
114
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
75
115
 
76
116
  self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
77
117
 
78
- self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
118
+ if added_emb_type == "prd":
119
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
120
+ elif added_emb_type is None:
121
+ self.prd_embedding = None
122
+ else:
123
+ raise ValueError(
124
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
125
+ )
79
126
 
80
127
  self.transformer_blocks = nn.ModuleList(
81
128
  [
@@ -91,8 +138,16 @@ class PriorTransformer(ModelMixin, ConfigMixin):
91
138
  ]
92
139
  )
93
140
 
141
+ if norm_in_type == "layer":
142
+ self.norm_in = nn.LayerNorm(inner_dim)
143
+ elif norm_in_type is None:
144
+ self.norm_in = None
145
+ else:
146
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
147
+
94
148
  self.norm_out = nn.LayerNorm(inner_dim)
95
- self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
149
+
150
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
96
151
 
97
152
  causal_attention_mask = torch.full(
98
153
  [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
@@ -101,23 +156,92 @@ class PriorTransformer(ModelMixin, ConfigMixin):
101
156
  causal_attention_mask = causal_attention_mask[None, ...]
102
157
  self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
103
158
 
104
- self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
105
- self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
159
+ self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
160
+ self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
161
+
162
+ @property
163
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
164
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
165
+ r"""
166
+ Returns:
167
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
168
+ indexed by its weight name.
169
+ """
170
+ # set recursively
171
+ processors = {}
172
+
173
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
174
+ if hasattr(module, "set_processor"):
175
+ processors[f"{name}.processor"] = module.processor
176
+
177
+ for sub_name, child in module.named_children():
178
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
179
+
180
+ return processors
181
+
182
+ for name, module in self.named_children():
183
+ fn_recursive_add_processors(name, module, processors)
184
+
185
+ return processors
186
+
187
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
188
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
189
+ r"""
190
+ Sets the attention processor to use to compute attention.
191
+
192
+ Parameters:
193
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
194
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
195
+ for **all** `Attention` layers.
196
+
197
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
198
+ processor. This is strongly recommended when setting trainable attention processors.
199
+
200
+ """
201
+ count = len(self.attn_processors.keys())
202
+
203
+ if isinstance(processor, dict) and len(processor) != count:
204
+ raise ValueError(
205
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
206
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
207
+ )
208
+
209
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
210
+ if hasattr(module, "set_processor"):
211
+ if not isinstance(processor, dict):
212
+ module.set_processor(processor)
213
+ else:
214
+ module.set_processor(processor.pop(f"{name}.processor"))
215
+
216
+ for sub_name, child in module.named_children():
217
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
218
+
219
+ for name, module in self.named_children():
220
+ fn_recursive_attn_processor(name, module, processor)
221
+
222
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
223
+ def set_default_attn_processor(self):
224
+ """
225
+ Disables custom attention processors and sets the default attention implementation.
226
+ """
227
+ self.set_attn_processor(AttnProcessor())
106
228
 
107
229
  def forward(
108
230
  self,
109
231
  hidden_states,
110
232
  timestep: Union[torch.Tensor, float, int],
111
233
  proj_embedding: torch.FloatTensor,
112
- encoder_hidden_states: torch.FloatTensor,
234
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
113
235
  attention_mask: Optional[torch.BoolTensor] = None,
114
236
  return_dict: bool = True,
115
237
  ):
116
238
  """
239
+ The [`PriorTransformer`] forward method.
240
+
117
241
  Args:
118
242
  hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
119
- x_t, the currently predicted image embeddings.
120
- timestep (`torch.long`):
243
+ The currently predicted image embeddings.
244
+ timestep (`torch.LongTensor`):
121
245
  Current denoising step.
122
246
  proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
123
247
  Projected embedding vector the denoising process is conditioned on.
@@ -126,13 +250,13 @@ class PriorTransformer(ModelMixin, ConfigMixin):
126
250
  attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
127
251
  Text mask for the text embeddings.
128
252
  return_dict (`bool`, *optional*, defaults to `True`):
129
- Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
253
+ Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
130
254
  tuple.
131
255
 
132
256
  Returns:
133
257
  [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
134
- [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
135
- returning a tuple, the first element is the sample tensor.
258
+ If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
259
+ tuple is returned where the first element is the sample tensor.
136
260
  """
137
261
  batch_size = hidden_states.shape[0]
138
262
 
@@ -152,23 +276,61 @@ class PriorTransformer(ModelMixin, ConfigMixin):
152
276
  timesteps_projected = timesteps_projected.to(dtype=self.dtype)
153
277
  time_embeddings = self.time_embedding(timesteps_projected)
154
278
 
279
+ if self.embedding_proj_norm is not None:
280
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
281
+
155
282
  proj_embeddings = self.embedding_proj(proj_embedding)
156
- encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
283
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
284
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
285
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
286
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
287
+
157
288
  hidden_states = self.proj_in(hidden_states)
158
- prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
289
+
159
290
  positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
160
291
 
292
+ additional_embeds = []
293
+ additional_embeddings_len = 0
294
+
295
+ if encoder_hidden_states is not None:
296
+ additional_embeds.append(encoder_hidden_states)
297
+ additional_embeddings_len += encoder_hidden_states.shape[1]
298
+
299
+ if len(proj_embeddings.shape) == 2:
300
+ proj_embeddings = proj_embeddings[:, None, :]
301
+
302
+ if len(hidden_states.shape) == 2:
303
+ hidden_states = hidden_states[:, None, :]
304
+
305
+ additional_embeds = additional_embeds + [
306
+ proj_embeddings,
307
+ time_embeddings[:, None, :],
308
+ hidden_states,
309
+ ]
310
+
311
+ if self.prd_embedding is not None:
312
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
313
+ additional_embeds.append(prd_embedding)
314
+
161
315
  hidden_states = torch.cat(
162
- [
163
- encoder_hidden_states,
164
- proj_embeddings[:, None, :],
165
- time_embeddings[:, None, :],
166
- hidden_states[:, None, :],
167
- prd_embedding,
168
- ],
316
+ additional_embeds,
169
317
  dim=1,
170
318
  )
171
319
 
320
+ # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
321
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
322
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
323
+ positional_embeddings = F.pad(
324
+ positional_embeddings,
325
+ (
326
+ 0,
327
+ 0,
328
+ additional_embeddings_len,
329
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
330
+ ),
331
+ value=0.0,
332
+ )
333
+
172
334
  hidden_states = hidden_states + positional_embeddings
173
335
 
174
336
  if attention_mask is not None:
@@ -177,11 +339,19 @@ class PriorTransformer(ModelMixin, ConfigMixin):
177
339
  attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
178
340
  attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
179
341
 
342
+ if self.norm_in is not None:
343
+ hidden_states = self.norm_in(hidden_states)
344
+
180
345
  for block in self.transformer_blocks:
181
346
  hidden_states = block(hidden_states, attention_mask=attention_mask)
182
347
 
183
348
  hidden_states = self.norm_out(hidden_states)
184
- hidden_states = hidden_states[:, -1]
349
+
350
+ if self.prd_embedding is not None:
351
+ hidden_states = hidden_states[:, -1]
352
+ else:
353
+ hidden_states = hidden_states[:, additional_embeddings_len:]
354
+
185
355
  predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
186
356
 
187
357
  if not return_dict:
@@ -95,9 +95,9 @@ class Downsample1D(nn.Module):
95
95
  assert self.channels == self.out_channels
96
96
  self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
97
97
 
98
- def forward(self, x):
99
- assert x.shape[1] == self.channels
100
- return self.conv(x)
98
+ def forward(self, inputs):
99
+ assert inputs.shape[1] == self.channels
100
+ return self.conv(inputs)
101
101
 
102
102
 
103
103
  class Upsample2D(nn.Module):
@@ -431,13 +431,13 @@ class KDownsample2D(nn.Module):
431
431
  self.pad = kernel_1d.shape[1] // 2 - 1
432
432
  self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
433
433
 
434
- def forward(self, x):
435
- x = F.pad(x, (self.pad,) * 4, self.pad_mode)
436
- weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
437
- indices = torch.arange(x.shape[1], device=x.device)
438
- kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
434
+ def forward(self, inputs):
435
+ inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
436
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
437
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
438
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
439
439
  weight[indices, indices] = kernel
440
- return F.conv2d(x, weight, stride=2)
440
+ return F.conv2d(inputs, weight, stride=2)
441
441
 
442
442
 
443
443
  class KUpsample2D(nn.Module):
@@ -448,13 +448,13 @@ class KUpsample2D(nn.Module):
448
448
  self.pad = kernel_1d.shape[1] // 2 - 1
449
449
  self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
450
450
 
451
- def forward(self, x):
452
- x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
453
- weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
454
- indices = torch.arange(x.shape[1], device=x.device)
455
- kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
451
+ def forward(self, inputs):
452
+ inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
453
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
454
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
455
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
456
456
  weight[indices, indices] = kernel
457
- return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
457
+ return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
458
458
 
459
459
 
460
460
  class ResnetBlock2D(nn.Module):
@@ -664,13 +664,13 @@ class Conv1dBlock(nn.Module):
664
664
  self.group_norm = nn.GroupNorm(n_groups, out_channels)
665
665
  self.mish = nn.Mish()
666
666
 
667
- def forward(self, x):
668
- x = self.conv1d(x)
669
- x = rearrange_dims(x)
670
- x = self.group_norm(x)
671
- x = rearrange_dims(x)
672
- x = self.mish(x)
673
- return x
667
+ def forward(self, inputs):
668
+ intermediate_repr = self.conv1d(inputs)
669
+ intermediate_repr = rearrange_dims(intermediate_repr)
670
+ intermediate_repr = self.group_norm(intermediate_repr)
671
+ intermediate_repr = rearrange_dims(intermediate_repr)
672
+ output = self.mish(intermediate_repr)
673
+ return output
674
674
 
675
675
 
676
676
  # unet_rl.py
@@ -687,10 +687,10 @@ class ResidualTemporalBlock1D(nn.Module):
687
687
  nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
688
688
  )
689
689
 
690
- def forward(self, x, t):
690
+ def forward(self, inputs, t):
691
691
  """
692
692
  Args:
693
- x : [ batch_size x inp_channels x horizon ]
693
+ inputs : [ batch_size x inp_channels x horizon ]
694
694
  t : [ batch_size x embed_dim ]
695
695
 
696
696
  returns:
@@ -698,9 +698,9 @@ class ResidualTemporalBlock1D(nn.Module):
698
698
  """
699
699
  t = self.time_emb_act(t)
700
700
  t = self.time_emb(t)
701
- out = self.conv_in(x) + rearrange_dims(t)
701
+ out = self.conv_in(inputs) + rearrange_dims(t)
702
702
  out = self.conv_out(out)
703
- return out + self.residual_conv(x)
703
+ return out + self.residual_conv(inputs)
704
704
 
705
705
 
706
706
  def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
@@ -29,10 +29,12 @@ from .modeling_utils import ModelMixin
29
29
  @dataclass
30
30
  class Transformer2DModelOutput(BaseOutput):
31
31
  """
32
+ The output of [`Transformer2DModel`].
33
+
32
34
  Args:
33
35
  sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
34
- Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
35
- for the unnoised latent pixels.
36
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
37
+ distributions for the unnoised latent pixels.
36
38
  """
37
39
 
38
40
  sample: torch.FloatTensor
@@ -40,40 +42,30 @@ class Transformer2DModelOutput(BaseOutput):
40
42
 
41
43
  class Transformer2DModel(ModelMixin, ConfigMixin):
42
44
  """
43
- Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
44
- embeddings) inputs.
45
-
46
- When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
47
- transformer action. Finally, reshape to image.
48
-
49
- When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
50
- embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
51
- classes of unnoised image.
52
-
53
- Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
54
- image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
45
+ A 2D Transformer model for image-like data.
55
46
 
56
47
  Parameters:
57
48
  num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
58
49
  attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
59
50
  in_channels (`int`, *optional*):
60
- Pass if the input is continuous. The number of channels in the input and output.
51
+ The number of channels in the input and output (specify if the input is **continuous**).
61
52
  num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
62
53
  dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
63
- cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
64
- sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
65
- Note that this is fixed at training time as it is used for learning a number of position embeddings. See
66
- `ImagePositionalEmbeddings`.
54
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
55
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
56
+ This is fixed during training since it is used to learn a number of position embeddings.
67
57
  num_vector_embeds (`int`, *optional*):
68
- Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
58
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
69
59
  Includes the class for the masked latent pixel.
70
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
71
- num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
72
- The number of diffusion steps used during training. Note that this is fixed at training time as it is used
73
- to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
74
- up to but not more than steps than `num_embeds_ada_norm`.
60
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
61
+ num_embeds_ada_norm ( `int`, *optional*):
62
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
63
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
64
+ added to the hidden states.
65
+
66
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
75
67
  attention_bias (`bool`, *optional*):
76
- Configure if the TransformerBlocks' attention should contain a bias parameter.
68
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
77
69
  """
78
70
 
79
71
  @register_to_config
@@ -223,31 +215,34 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
223
215
  return_dict: bool = True,
224
216
  ):
225
217
  """
218
+ The [`Transformer2DModel`] forward method.
219
+
226
220
  Args:
227
- hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
228
- When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
229
- hidden_states
221
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
222
+ Input `hidden_states`.
230
223
  encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
231
224
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
232
225
  self-attention.
233
226
  timestep ( `torch.LongTensor`, *optional*):
234
- Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
227
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
235
228
  class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
236
- Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
237
- conditioning.
238
- encoder_attention_mask ( `torch.Tensor`, *optional* ).
239
- Cross-attention mask, applied to encoder_hidden_states. Two formats supported:
240
- Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0
241
- = keep, -10000 = discard.
242
- If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format
229
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
230
+ `AdaLayerZeroNorm`.
231
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
232
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
233
+
234
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
235
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
236
+
237
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
243
238
  above. This bias will be added to the cross-attention scores.
244
239
  return_dict (`bool`, *optional*, defaults to `True`):
245
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
240
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
241
+ tuple.
246
242
 
247
243
  Returns:
248
- [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
249
- [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
250
- returning a tuple, the first element is the sample tensor.
244
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
245
+ `tuple` where the first element is the sample tensor.
251
246
  """
252
247
  # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
253
248
  # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.