diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -149,6 +149,9 @@ class Attention(nn.Module):
149
149
  is_lora = hasattr(self, "processor") and isinstance(
150
150
  self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
151
151
  )
152
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
153
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
154
+ )
152
155
 
153
156
  if use_memory_efficient_attention_xformers:
154
157
  if self.added_kv_proj_dim is not None:
@@ -192,6 +195,17 @@ class Attention(nn.Module):
192
195
  )
193
196
  processor.load_state_dict(self.processor.state_dict())
194
197
  processor.to(self.processor.to_q_lora.up.weight.device)
198
+ elif is_custom_diffusion:
199
+ processor = CustomDiffusionXFormersAttnProcessor(
200
+ train_kv=self.processor.train_kv,
201
+ train_q_out=self.processor.train_q_out,
202
+ hidden_size=self.processor.hidden_size,
203
+ cross_attention_dim=self.processor.cross_attention_dim,
204
+ attention_op=attention_op,
205
+ )
206
+ processor.load_state_dict(self.processor.state_dict())
207
+ if hasattr(self.processor, "to_k_custom_diffusion"):
208
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
195
209
  else:
196
210
  processor = XFormersAttnProcessor(attention_op=attention_op)
197
211
  else:
@@ -203,6 +217,16 @@ class Attention(nn.Module):
203
217
  )
204
218
  processor.load_state_dict(self.processor.state_dict())
205
219
  processor.to(self.processor.to_q_lora.up.weight.device)
220
+ elif is_custom_diffusion:
221
+ processor = CustomDiffusionAttnProcessor(
222
+ train_kv=self.processor.train_kv,
223
+ train_q_out=self.processor.train_q_out,
224
+ hidden_size=self.processor.hidden_size,
225
+ cross_attention_dim=self.processor.cross_attention_dim,
226
+ )
227
+ processor.load_state_dict(self.processor.state_dict())
228
+ if hasattr(self.processor, "to_k_custom_diffusion"):
229
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
206
230
  else:
207
231
  processor = AttnProcessor()
208
232
 
@@ -459,6 +483,84 @@ class LoRAAttnProcessor(nn.Module):
459
483
  return hidden_states
460
484
 
461
485
 
486
+ class CustomDiffusionAttnProcessor(nn.Module):
487
+ def __init__(
488
+ self,
489
+ train_kv=True,
490
+ train_q_out=True,
491
+ hidden_size=None,
492
+ cross_attention_dim=None,
493
+ out_bias=True,
494
+ dropout=0.0,
495
+ ):
496
+ super().__init__()
497
+ self.train_kv = train_kv
498
+ self.train_q_out = train_q_out
499
+
500
+ self.hidden_size = hidden_size
501
+ self.cross_attention_dim = cross_attention_dim
502
+
503
+ # `_custom_diffusion` id for easy serialization and loading.
504
+ if self.train_kv:
505
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
506
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
507
+ if self.train_q_out:
508
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
509
+ self.to_out_custom_diffusion = nn.ModuleList([])
510
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
511
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
512
+
513
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
514
+ batch_size, sequence_length, _ = hidden_states.shape
515
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
516
+ if self.train_q_out:
517
+ query = self.to_q_custom_diffusion(hidden_states)
518
+ else:
519
+ query = attn.to_q(hidden_states)
520
+
521
+ if encoder_hidden_states is None:
522
+ crossattn = False
523
+ encoder_hidden_states = hidden_states
524
+ else:
525
+ crossattn = True
526
+ if attn.norm_cross:
527
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
528
+
529
+ if self.train_kv:
530
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
531
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
532
+ else:
533
+ key = attn.to_k(encoder_hidden_states)
534
+ value = attn.to_v(encoder_hidden_states)
535
+
536
+ if crossattn:
537
+ detach = torch.ones_like(key)
538
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
539
+ key = detach * key + (1 - detach) * key.detach()
540
+ value = detach * value + (1 - detach) * value.detach()
541
+
542
+ query = attn.head_to_batch_dim(query)
543
+ key = attn.head_to_batch_dim(key)
544
+ value = attn.head_to_batch_dim(value)
545
+
546
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
547
+ hidden_states = torch.bmm(attention_probs, value)
548
+ hidden_states = attn.batch_to_head_dim(hidden_states)
549
+
550
+ if self.train_q_out:
551
+ # linear proj
552
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
553
+ # dropout
554
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
555
+ else:
556
+ # linear proj
557
+ hidden_states = attn.to_out[0](hidden_states)
558
+ # dropout
559
+ hidden_states = attn.to_out[1](hidden_states)
560
+
561
+ return hidden_states
562
+
563
+
462
564
  class AttnAddedKVProcessor:
463
565
  def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
464
566
  residual = hidden_states
@@ -699,6 +801,91 @@ class LoRAXFormersAttnProcessor(nn.Module):
699
801
  return hidden_states
700
802
 
701
803
 
804
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
805
+ def __init__(
806
+ self,
807
+ train_kv=True,
808
+ train_q_out=False,
809
+ hidden_size=None,
810
+ cross_attention_dim=None,
811
+ out_bias=True,
812
+ dropout=0.0,
813
+ attention_op: Optional[Callable] = None,
814
+ ):
815
+ super().__init__()
816
+ self.train_kv = train_kv
817
+ self.train_q_out = train_q_out
818
+
819
+ self.hidden_size = hidden_size
820
+ self.cross_attention_dim = cross_attention_dim
821
+ self.attention_op = attention_op
822
+
823
+ # `_custom_diffusion` id for easy serialization and loading.
824
+ if self.train_kv:
825
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
826
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
827
+ if self.train_q_out:
828
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
829
+ self.to_out_custom_diffusion = nn.ModuleList([])
830
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
831
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
832
+
833
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
834
+ batch_size, sequence_length, _ = (
835
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
836
+ )
837
+
838
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
839
+
840
+ if self.train_q_out:
841
+ query = self.to_q_custom_diffusion(hidden_states)
842
+ else:
843
+ query = attn.to_q(hidden_states)
844
+
845
+ if encoder_hidden_states is None:
846
+ crossattn = False
847
+ encoder_hidden_states = hidden_states
848
+ else:
849
+ crossattn = True
850
+ if attn.norm_cross:
851
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
852
+
853
+ if self.train_kv:
854
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
855
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
856
+ else:
857
+ key = attn.to_k(encoder_hidden_states)
858
+ value = attn.to_v(encoder_hidden_states)
859
+
860
+ if crossattn:
861
+ detach = torch.ones_like(key)
862
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
863
+ key = detach * key + (1 - detach) * key.detach()
864
+ value = detach * value + (1 - detach) * value.detach()
865
+
866
+ query = attn.head_to_batch_dim(query).contiguous()
867
+ key = attn.head_to_batch_dim(key).contiguous()
868
+ value = attn.head_to_batch_dim(value).contiguous()
869
+
870
+ hidden_states = xformers.ops.memory_efficient_attention(
871
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
872
+ )
873
+ hidden_states = hidden_states.to(query.dtype)
874
+ hidden_states = attn.batch_to_head_dim(hidden_states)
875
+
876
+ if self.train_q_out:
877
+ # linear proj
878
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
879
+ # dropout
880
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
881
+ else:
882
+ # linear proj
883
+ hidden_states = attn.to_out[0](hidden_states)
884
+ # dropout
885
+ hidden_states = attn.to_out[1](hidden_states)
886
+ return hidden_states
887
+
888
+
702
889
  class SlicedAttnProcessor:
703
890
  def __init__(self, slice_size):
704
891
  self.slice_size = slice_size
@@ -834,4 +1021,6 @@ AttentionProcessor = Union[
834
1021
  AttnAddedKVProcessor2_0,
835
1022
  LoRAAttnProcessor,
836
1023
  LoRAXFormersAttnProcessor,
1024
+ CustomDiffusionAttnProcessor,
1025
+ CustomDiffusionXFormersAttnProcessor,
837
1026
  ]
@@ -457,6 +457,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
457
457
  timestep_cond: Optional[torch.Tensor] = None,
458
458
  attention_mask: Optional[torch.Tensor] = None,
459
459
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
460
+ guess_mode: bool = False,
460
461
  return_dict: bool = True,
461
462
  ) -> Union[ControlNetOutput, Tuple]:
462
463
  # check channel order
@@ -557,8 +558,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
557
558
  mid_block_res_sample = self.controlnet_mid_block(sample)
558
559
 
559
560
  # 6. scaling
560
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
561
- mid_block_res_sample *= conditioning_scale
561
+ if guess_mode:
562
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
563
+ scales *= conditioning_scale
564
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
565
+ mid_block_res_sample *= scales[-1] # last one
566
+ else:
567
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
568
+ mid_block_res_sample *= conditioning_scale
562
569
 
563
570
  if self.config.global_pool_conditions:
564
571
  down_block_res_samples = [
@@ -377,3 +377,69 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
377
377
  conditioning = timesteps_emb + class_labels # (N, D)
378
378
 
379
379
  return conditioning
380
+
381
+
382
+ class TextTimeEmbedding(nn.Module):
383
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
384
+ super().__init__()
385
+ self.norm1 = nn.LayerNorm(encoder_dim)
386
+ self.pool = AttentionPooling(num_heads, encoder_dim)
387
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
388
+ self.norm2 = nn.LayerNorm(time_embed_dim)
389
+
390
+ def forward(self, hidden_states):
391
+ hidden_states = self.norm1(hidden_states)
392
+ hidden_states = self.pool(hidden_states)
393
+ hidden_states = self.proj(hidden_states)
394
+ hidden_states = self.norm2(hidden_states)
395
+ return hidden_states
396
+
397
+
398
+ class AttentionPooling(nn.Module):
399
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
400
+
401
+ def __init__(self, num_heads, embed_dim, dtype=None):
402
+ super().__init__()
403
+ self.dtype = dtype
404
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
405
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
406
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
407
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
408
+ self.num_heads = num_heads
409
+ self.dim_per_head = embed_dim // self.num_heads
410
+
411
+ def forward(self, x):
412
+ bs, length, width = x.size()
413
+
414
+ def shape(x):
415
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
416
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
417
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
418
+ x = x.transpose(1, 2)
419
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
420
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
421
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
422
+ x = x.transpose(1, 2)
423
+ return x
424
+
425
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
426
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
427
+
428
+ # (bs*n_heads, class_token_length, dim_per_head)
429
+ q = shape(self.q_proj(class_token))
430
+ # (bs*n_heads, length+class_token_length, dim_per_head)
431
+ k = shape(self.k_proj(x))
432
+ v = shape(self.v_proj(x))
433
+
434
+ # (bs*n_heads, class_token_length, length+class_token_length):
435
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
436
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
437
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
438
+
439
+ # (bs*n_heads, dim_per_head, class_token_length)
440
+ a = torch.einsum("bts,bcs->bct", weight, v)
441
+
442
+ # (bs, length+1, width)
443
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
444
+
445
+ return a[:, 0, :] # cls_token
@@ -110,6 +110,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
110
110
  .replace("_1", ".1")
111
111
  .replace("_2", ".2")
112
112
  .replace("_3", ".3")
113
+ .replace("_4", ".4")
114
+ .replace("_5", ".5")
115
+ .replace("_6", ".6")
116
+ .replace("_7", ".7")
117
+ .replace("_8", ".8")
118
+ .replace("_9", ".9")
113
119
  )
114
120
 
115
121
  flax_key = ".".join(flax_key_tuple_array)
@@ -15,6 +15,7 @@
15
15
  # limitations under the License.
16
16
 
17
17
  import inspect
18
+ import itertools
18
19
  import os
19
20
  from functools import partial
20
21
  from typing import Any, Callable, List, Optional, Tuple, Union
@@ -60,7 +61,8 @@ if is_safetensors_available():
60
61
 
61
62
  def get_parameter_device(parameter: torch.nn.Module):
62
63
  try:
63
- return next(parameter.parameters()).device
64
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
65
+ return next(parameters_and_buffers).device
64
66
  except StopIteration:
65
67
  # For torch.nn.DataParallel compatibility in PyTorch 1.5
66
68
 
@@ -75,7 +77,8 @@ def get_parameter_device(parameter: torch.nn.Module):
75
77
 
76
78
  def get_parameter_dtype(parameter: torch.nn.Module):
77
79
  try:
78
- return next(parameter.parameters()).dtype
80
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
81
+ return next(parameters_and_buffers).dtype
79
82
  except StopIteration:
80
83
  # For torch.nn.DataParallel compatibility in PyTorch 1.5
81
84
 
@@ -225,7 +225,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
225
225
  hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
226
226
  When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
227
227
  hidden_states
228
- encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
228
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
229
229
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
230
230
  self-attention.
231
231
  timestep ( `torch.long`, *optional*):
@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
23
23
  from ..loaders import UNet2DConditionLoadersMixin
24
24
  from ..utils import BaseOutput, logging
25
25
  from .attention_processor import AttentionProcessor, AttnProcessor
26
- from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
26
+ from .embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
27
27
  from .modeling_utils import ModelMixin
28
28
  from .unet_2d_blocks import (
29
29
  CrossAttnDownBlock2D,
@@ -97,11 +97,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
97
97
  class_embed_type (`str`, *optional*, defaults to None):
98
98
  The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
99
99
  `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
100
+ addition_embed_type (`str`, *optional*, defaults to None):
101
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
102
+ "text". "text" will use the `TextTimeEmbedding` layer.
100
103
  num_class_embeds (`int`, *optional*, defaults to None):
101
104
  Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
102
105
  class conditioning with `class_embed_type` equal to `None`.
103
106
  time_embedding_type (`str`, *optional*, default to `positional`):
104
107
  The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
108
+ time_embedding_dim (`int`, *optional*, default to `None`):
109
+ An optional override for the dimension of the projected time embedding.
105
110
  time_embedding_act_fn (`str`, *optional*, default to `None`):
106
111
  Optional activation function to use on the time embeddings only one time before they as passed to the rest
107
112
  of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
@@ -155,12 +160,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
155
160
  dual_cross_attention: bool = False,
156
161
  use_linear_projection: bool = False,
157
162
  class_embed_type: Optional[str] = None,
163
+ addition_embed_type: Optional[str] = None,
158
164
  num_class_embeds: Optional[int] = None,
159
165
  upcast_attention: bool = False,
160
166
  resnet_time_scale_shift: str = "default",
161
167
  resnet_skip_time_act: bool = False,
162
168
  resnet_out_scale_factor: int = 1.0,
163
169
  time_embedding_type: str = "positional",
170
+ time_embedding_dim: Optional[int] = None,
164
171
  time_embedding_act_fn: Optional[str] = None,
165
172
  timestep_post_act: Optional[str] = None,
166
173
  time_cond_proj_dim: Optional[int] = None,
@@ -170,6 +177,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
170
177
  class_embeddings_concat: bool = False,
171
178
  mid_block_only_cross_attention: Optional[bool] = None,
172
179
  cross_attention_norm: Optional[str] = None,
180
+ addition_embed_type_num_heads=64,
173
181
  ):
174
182
  super().__init__()
175
183
 
@@ -214,7 +222,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
214
222
 
215
223
  # time
216
224
  if time_embedding_type == "fourier":
217
- time_embed_dim = block_out_channels[0] * 2
225
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
218
226
  if time_embed_dim % 2 != 0:
219
227
  raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
220
228
  self.time_proj = GaussianFourierProjection(
@@ -222,7 +230,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
222
230
  )
223
231
  timestep_input_dim = time_embed_dim
224
232
  elif time_embedding_type == "positional":
225
- time_embed_dim = block_out_channels[0] * 4
233
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
226
234
 
227
235
  self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
228
236
  timestep_input_dim = block_out_channels[0]
@@ -248,7 +256,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
248
256
  if class_embed_type is None and num_class_embeds is not None:
249
257
  self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
250
258
  elif class_embed_type == "timestep":
251
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
259
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
252
260
  elif class_embed_type == "identity":
253
261
  self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
254
262
  elif class_embed_type == "projection":
@@ -273,6 +281,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
273
281
  else:
274
282
  self.class_embedding = None
275
283
 
284
+ if addition_embed_type == "text":
285
+ if encoder_hid_dim is not None:
286
+ text_time_embedding_from_dim = encoder_hid_dim
287
+ else:
288
+ text_time_embedding_from_dim = cross_attention_dim
289
+
290
+ self.add_embedding = TextTimeEmbedding(
291
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
292
+ )
293
+ elif addition_embed_type is not None:
294
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
295
+
276
296
  if time_embedding_act_fn is None:
277
297
  self.time_embed_act = None
278
298
  elif time_embedding_act_fn == "swish":
@@ -437,7 +457,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
437
457
  self.conv_norm_out = nn.GroupNorm(
438
458
  num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
439
459
  )
440
- self.conv_act = nn.SiLU()
460
+
461
+ if act_fn == "swish":
462
+ self.conv_act = lambda x: F.silu(x)
463
+ elif act_fn == "mish":
464
+ self.conv_act = nn.Mish()
465
+ elif act_fn == "silu":
466
+ self.conv_act = nn.SiLU()
467
+ elif act_fn == "gelu":
468
+ self.conv_act = nn.GELU()
469
+ else:
470
+ raise ValueError(f"Unsupported activation function: {act_fn}")
471
+
441
472
  else:
442
473
  self.conv_norm_out = None
443
474
  self.conv_act = None
@@ -648,7 +679,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
648
679
 
649
680
  t_emb = self.time_proj(timesteps)
650
681
 
651
- # timesteps does not contain any weights and will always return f32 tensors
682
+ # `Timesteps` does not contain any weights and will always return f32 tensors
652
683
  # but time_embedding might actually be running in fp16. so we need to cast here.
653
684
  # there might be better ways to encapsulate this.
654
685
  t_emb = t_emb.to(dtype=self.dtype)
@@ -662,6 +693,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
662
693
  if self.config.class_embed_type == "timestep":
663
694
  class_labels = self.time_proj(class_labels)
664
695
 
696
+ # `Timesteps` does not contain any weights and will always return f32 tensors
697
+ # there might be better ways to encapsulate this.
698
+ class_labels = class_labels.to(dtype=sample.dtype)
699
+
665
700
  class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
666
701
 
667
702
  if self.config.class_embeddings_concat:
@@ -669,6 +704,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
669
704
  else:
670
705
  emb = emb + class_emb
671
706
 
707
+ if self.config.addition_embed_type == "text":
708
+ aug_emb = self.add_embedding(encoder_hidden_states)
709
+ emb = emb + aug_emb
710
+
672
711
  if self.time_embed_act is not None:
673
712
  emb = self.time_embed_act(emb)
674
713
 
diffusers/models/vae.py CHANGED
@@ -212,6 +212,7 @@ class Decoder(nn.Module):
212
212
  sample = z
213
213
  sample = self.conv_in(sample)
214
214
 
215
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
215
216
  if self.training and self.gradient_checkpointing:
216
217
 
217
218
  def create_custom_forward(module):
@@ -222,6 +223,7 @@ class Decoder(nn.Module):
222
223
 
223
224
  # middle
224
225
  sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
226
+ sample = sample.to(upscale_dtype)
225
227
 
226
228
  # up
227
229
  for up_block in self.up_blocks:
@@ -229,6 +231,7 @@ class Decoder(nn.Module):
229
231
  else:
230
232
  # middle
231
233
  sample = self.mid_block(sample)
234
+ sample = sample.to(upscale_dtype)
232
235
 
233
236
  # up
234
237
  for up_block in self.up_blocks:
@@ -44,6 +44,14 @@ except OptionalDependencyNotAvailable:
44
44
  else:
45
45
  from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
46
46
  from .audioldm import AudioLDMPipeline
47
+ from .deepfloyd_if import (
48
+ IFImg2ImgPipeline,
49
+ IFImg2ImgSuperResolutionPipeline,
50
+ IFInpaintingPipeline,
51
+ IFInpaintingSuperResolutionPipeline,
52
+ IFPipeline,
53
+ IFSuperResolutionPipeline,
54
+ )
47
55
  from .latent_diffusion import LDMTextToImagePipeline
48
56
  from .paint_by_example import PaintByExamplePipeline
49
57
  from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
@@ -56,7 +56,7 @@ class RobertaSeriesConfig(XLMRobertaConfig):
56
56
 
57
57
 
58
58
  class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
59
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
59
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
60
60
  _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
61
61
  base_model_prefix = "roberta"
62
62
  config_class = RobertaSeriesConfig
@@ -65,6 +65,10 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
65
65
  super().__init__(config)
66
66
  self.roberta = XLMRobertaModel(config)
67
67
  self.transformation = nn.Linear(config.hidden_size, config.project_dim)
68
+ self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
69
+ if self.has_pre_transformation:
70
+ self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
71
+ self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
68
72
  self.post_init()
69
73
 
70
74
  def forward(
@@ -95,15 +99,26 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
95
99
  encoder_hidden_states=encoder_hidden_states,
96
100
  encoder_attention_mask=encoder_attention_mask,
97
101
  output_attentions=output_attentions,
98
- output_hidden_states=output_hidden_states,
102
+ output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
99
103
  return_dict=return_dict,
100
104
  )
101
105
 
102
- projection_state = self.transformation(outputs.last_hidden_state)
103
-
104
- return TransformationModelOutput(
105
- projection_state=projection_state,
106
- last_hidden_state=outputs.last_hidden_state,
107
- hidden_states=outputs.hidden_states,
108
- attentions=outputs.attentions,
109
- )
106
+ if self.has_pre_transformation:
107
+ sequence_output2 = outputs["hidden_states"][-2]
108
+ sequence_output2 = self.pre_LN(sequence_output2)
109
+ projection_state2 = self.transformation_pre(sequence_output2)
110
+
111
+ return TransformationModelOutput(
112
+ projection_state=projection_state2,
113
+ last_hidden_state=outputs.last_hidden_state,
114
+ hidden_states=outputs.hidden_states,
115
+ attentions=outputs.attentions,
116
+ )
117
+ else:
118
+ projection_state = self.transformation(outputs.last_hidden_state)
119
+ return TransformationModelOutput(
120
+ projection_state=projection_state,
121
+ last_hidden_state=outputs.last_hidden_state,
122
+ hidden_states=outputs.hidden_states,
123
+ attentions=outputs.attentions,
124
+ )
@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
57
57
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
58
58
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
59
59
 
60
+ In addition the pipeline inherits the following loading methods:
61
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
62
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
63
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
64
+
65
+ as well as the following saving methods:
66
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
67
+
60
68
  Args:
61
69
  vae ([`AutoencoderKL`]):
62
70
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
96
96
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
97
97
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
98
98
 
99
+ In addition the pipeline inherits the following loading methods:
100
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
101
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
102
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
103
+
104
+ as well as the following saving methods:
105
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
106
+
99
107
  Args:
100
108
  vae ([`AutoencoderKL`]):
101
109
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -293,7 +293,7 @@ class AudioLDMPipeline(DiffusionPipeline):
293
293
 
294
294
  waveform = self.vocoder(mel_spectrogram)
295
295
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
296
- waveform = waveform.cpu()
296
+ waveform = waveform.cpu().float()
297
297
  return waveform
298
298
 
299
299
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs