diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -149,6 +149,9 @@ class Attention(nn.Module):
149
149
  is_lora = hasattr(self, "processor") and isinstance(
150
150
  self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
151
151
  )
152
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
153
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
154
+ )
152
155
 
153
156
  if use_memory_efficient_attention_xformers:
154
157
  if self.added_kv_proj_dim is not None:
@@ -192,6 +195,17 @@ class Attention(nn.Module):
192
195
  )
193
196
  processor.load_state_dict(self.processor.state_dict())
194
197
  processor.to(self.processor.to_q_lora.up.weight.device)
198
+ elif is_custom_diffusion:
199
+ processor = CustomDiffusionXFormersAttnProcessor(
200
+ train_kv=self.processor.train_kv,
201
+ train_q_out=self.processor.train_q_out,
202
+ hidden_size=self.processor.hidden_size,
203
+ cross_attention_dim=self.processor.cross_attention_dim,
204
+ attention_op=attention_op,
205
+ )
206
+ processor.load_state_dict(self.processor.state_dict())
207
+ if hasattr(self.processor, "to_k_custom_diffusion"):
208
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
195
209
  else:
196
210
  processor = XFormersAttnProcessor(attention_op=attention_op)
197
211
  else:
@@ -203,6 +217,16 @@ class Attention(nn.Module):
203
217
  )
204
218
  processor.load_state_dict(self.processor.state_dict())
205
219
  processor.to(self.processor.to_q_lora.up.weight.device)
220
+ elif is_custom_diffusion:
221
+ processor = CustomDiffusionAttnProcessor(
222
+ train_kv=self.processor.train_kv,
223
+ train_q_out=self.processor.train_q_out,
224
+ hidden_size=self.processor.hidden_size,
225
+ cross_attention_dim=self.processor.cross_attention_dim,
226
+ )
227
+ processor.load_state_dict(self.processor.state_dict())
228
+ if hasattr(self.processor, "to_k_custom_diffusion"):
229
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
206
230
  else:
207
231
  processor = AttnProcessor()
208
232
 
@@ -459,6 +483,84 @@ class LoRAAttnProcessor(nn.Module):
459
483
  return hidden_states
460
484
 
461
485
 
486
+ class CustomDiffusionAttnProcessor(nn.Module):
487
+ def __init__(
488
+ self,
489
+ train_kv=True,
490
+ train_q_out=True,
491
+ hidden_size=None,
492
+ cross_attention_dim=None,
493
+ out_bias=True,
494
+ dropout=0.0,
495
+ ):
496
+ super().__init__()
497
+ self.train_kv = train_kv
498
+ self.train_q_out = train_q_out
499
+
500
+ self.hidden_size = hidden_size
501
+ self.cross_attention_dim = cross_attention_dim
502
+
503
+ # `_custom_diffusion` id for easy serialization and loading.
504
+ if self.train_kv:
505
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
506
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
507
+ if self.train_q_out:
508
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
509
+ self.to_out_custom_diffusion = nn.ModuleList([])
510
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
511
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
512
+
513
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
514
+ batch_size, sequence_length, _ = hidden_states.shape
515
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
516
+ if self.train_q_out:
517
+ query = self.to_q_custom_diffusion(hidden_states)
518
+ else:
519
+ query = attn.to_q(hidden_states)
520
+
521
+ if encoder_hidden_states is None:
522
+ crossattn = False
523
+ encoder_hidden_states = hidden_states
524
+ else:
525
+ crossattn = True
526
+ if attn.norm_cross:
527
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
528
+
529
+ if self.train_kv:
530
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
531
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
532
+ else:
533
+ key = attn.to_k(encoder_hidden_states)
534
+ value = attn.to_v(encoder_hidden_states)
535
+
536
+ if crossattn:
537
+ detach = torch.ones_like(key)
538
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
539
+ key = detach * key + (1 - detach) * key.detach()
540
+ value = detach * value + (1 - detach) * value.detach()
541
+
542
+ query = attn.head_to_batch_dim(query)
543
+ key = attn.head_to_batch_dim(key)
544
+ value = attn.head_to_batch_dim(value)
545
+
546
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
547
+ hidden_states = torch.bmm(attention_probs, value)
548
+ hidden_states = attn.batch_to_head_dim(hidden_states)
549
+
550
+ if self.train_q_out:
551
+ # linear proj
552
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
553
+ # dropout
554
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
555
+ else:
556
+ # linear proj
557
+ hidden_states = attn.to_out[0](hidden_states)
558
+ # dropout
559
+ hidden_states = attn.to_out[1](hidden_states)
560
+
561
+ return hidden_states
562
+
563
+
462
564
  class AttnAddedKVProcessor:
463
565
  def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
464
566
  residual = hidden_states
@@ -699,6 +801,91 @@ class LoRAXFormersAttnProcessor(nn.Module):
699
801
  return hidden_states
700
802
 
701
803
 
804
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
805
+ def __init__(
806
+ self,
807
+ train_kv=True,
808
+ train_q_out=False,
809
+ hidden_size=None,
810
+ cross_attention_dim=None,
811
+ out_bias=True,
812
+ dropout=0.0,
813
+ attention_op: Optional[Callable] = None,
814
+ ):
815
+ super().__init__()
816
+ self.train_kv = train_kv
817
+ self.train_q_out = train_q_out
818
+
819
+ self.hidden_size = hidden_size
820
+ self.cross_attention_dim = cross_attention_dim
821
+ self.attention_op = attention_op
822
+
823
+ # `_custom_diffusion` id for easy serialization and loading.
824
+ if self.train_kv:
825
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
826
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
827
+ if self.train_q_out:
828
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
829
+ self.to_out_custom_diffusion = nn.ModuleList([])
830
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
831
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
832
+
833
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
834
+ batch_size, sequence_length, _ = (
835
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
836
+ )
837
+
838
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
839
+
840
+ if self.train_q_out:
841
+ query = self.to_q_custom_diffusion(hidden_states)
842
+ else:
843
+ query = attn.to_q(hidden_states)
844
+
845
+ if encoder_hidden_states is None:
846
+ crossattn = False
847
+ encoder_hidden_states = hidden_states
848
+ else:
849
+ crossattn = True
850
+ if attn.norm_cross:
851
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
852
+
853
+ if self.train_kv:
854
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
855
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
856
+ else:
857
+ key = attn.to_k(encoder_hidden_states)
858
+ value = attn.to_v(encoder_hidden_states)
859
+
860
+ if crossattn:
861
+ detach = torch.ones_like(key)
862
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
863
+ key = detach * key + (1 - detach) * key.detach()
864
+ value = detach * value + (1 - detach) * value.detach()
865
+
866
+ query = attn.head_to_batch_dim(query).contiguous()
867
+ key = attn.head_to_batch_dim(key).contiguous()
868
+ value = attn.head_to_batch_dim(value).contiguous()
869
+
870
+ hidden_states = xformers.ops.memory_efficient_attention(
871
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
872
+ )
873
+ hidden_states = hidden_states.to(query.dtype)
874
+ hidden_states = attn.batch_to_head_dim(hidden_states)
875
+
876
+ if self.train_q_out:
877
+ # linear proj
878
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
879
+ # dropout
880
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
881
+ else:
882
+ # linear proj
883
+ hidden_states = attn.to_out[0](hidden_states)
884
+ # dropout
885
+ hidden_states = attn.to_out[1](hidden_states)
886
+ return hidden_states
887
+
888
+
702
889
  class SlicedAttnProcessor:
703
890
  def __init__(self, slice_size):
704
891
  self.slice_size = slice_size
@@ -834,4 +1021,6 @@ AttentionProcessor = Union[
834
1021
  AttnAddedKVProcessor2_0,
835
1022
  LoRAAttnProcessor,
836
1023
  LoRAXFormersAttnProcessor,
1024
+ CustomDiffusionAttnProcessor,
1025
+ CustomDiffusionXFormersAttnProcessor,
837
1026
  ]
@@ -457,6 +457,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
457
457
  timestep_cond: Optional[torch.Tensor] = None,
458
458
  attention_mask: Optional[torch.Tensor] = None,
459
459
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
460
+ guess_mode: bool = False,
460
461
  return_dict: bool = True,
461
462
  ) -> Union[ControlNetOutput, Tuple]:
462
463
  # check channel order
@@ -557,8 +558,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
557
558
  mid_block_res_sample = self.controlnet_mid_block(sample)
558
559
 
559
560
  # 6. scaling
560
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
561
- mid_block_res_sample *= conditioning_scale
561
+ if guess_mode:
562
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
563
+ scales *= conditioning_scale
564
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
565
+ mid_block_res_sample *= scales[-1] # last one
566
+ else:
567
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
568
+ mid_block_res_sample *= conditioning_scale
562
569
 
563
570
  if self.config.global_pool_conditions:
564
571
  down_block_res_samples = [
@@ -377,3 +377,69 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
377
377
  conditioning = timesteps_emb + class_labels # (N, D)
378
378
 
379
379
  return conditioning
380
+
381
+
382
+ class TextTimeEmbedding(nn.Module):
383
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
384
+ super().__init__()
385
+ self.norm1 = nn.LayerNorm(encoder_dim)
386
+ self.pool = AttentionPooling(num_heads, encoder_dim)
387
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
388
+ self.norm2 = nn.LayerNorm(time_embed_dim)
389
+
390
+ def forward(self, hidden_states):
391
+ hidden_states = self.norm1(hidden_states)
392
+ hidden_states = self.pool(hidden_states)
393
+ hidden_states = self.proj(hidden_states)
394
+ hidden_states = self.norm2(hidden_states)
395
+ return hidden_states
396
+
397
+
398
+ class AttentionPooling(nn.Module):
399
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
400
+
401
+ def __init__(self, num_heads, embed_dim, dtype=None):
402
+ super().__init__()
403
+ self.dtype = dtype
404
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
405
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
406
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
407
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
408
+ self.num_heads = num_heads
409
+ self.dim_per_head = embed_dim // self.num_heads
410
+
411
+ def forward(self, x):
412
+ bs, length, width = x.size()
413
+
414
+ def shape(x):
415
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
416
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
417
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
418
+ x = x.transpose(1, 2)
419
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
420
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
421
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
422
+ x = x.transpose(1, 2)
423
+ return x
424
+
425
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
426
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
427
+
428
+ # (bs*n_heads, class_token_length, dim_per_head)
429
+ q = shape(self.q_proj(class_token))
430
+ # (bs*n_heads, length+class_token_length, dim_per_head)
431
+ k = shape(self.k_proj(x))
432
+ v = shape(self.v_proj(x))
433
+
434
+ # (bs*n_heads, class_token_length, length+class_token_length):
435
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
436
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
437
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
438
+
439
+ # (bs*n_heads, dim_per_head, class_token_length)
440
+ a = torch.einsum("bts,bcs->bct", weight, v)
441
+
442
+ # (bs, length+1, width)
443
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
444
+
445
+ return a[:, 0, :] # cls_token
@@ -110,6 +110,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
110
110
  .replace("_1", ".1")
111
111
  .replace("_2", ".2")
112
112
  .replace("_3", ".3")
113
+ .replace("_4", ".4")
114
+ .replace("_5", ".5")
115
+ .replace("_6", ".6")
116
+ .replace("_7", ".7")
117
+ .replace("_8", ".8")
118
+ .replace("_9", ".9")
113
119
  )
114
120
 
115
121
  flax_key = ".".join(flax_key_tuple_array)
@@ -15,6 +15,7 @@
15
15
  # limitations under the License.
16
16
 
17
17
  import inspect
18
+ import itertools
18
19
  import os
19
20
  from functools import partial
20
21
  from typing import Any, Callable, List, Optional, Tuple, Union
@@ -60,7 +61,8 @@ if is_safetensors_available():
60
61
 
61
62
  def get_parameter_device(parameter: torch.nn.Module):
62
63
  try:
63
- return next(parameter.parameters()).device
64
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
65
+ return next(parameters_and_buffers).device
64
66
  except StopIteration:
65
67
  # For torch.nn.DataParallel compatibility in PyTorch 1.5
66
68
 
@@ -75,7 +77,8 @@ def get_parameter_device(parameter: torch.nn.Module):
75
77
 
76
78
  def get_parameter_dtype(parameter: torch.nn.Module):
77
79
  try:
78
- return next(parameter.parameters()).dtype
80
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
81
+ return next(parameters_and_buffers).dtype
79
82
  except StopIteration:
80
83
  # For torch.nn.DataParallel compatibility in PyTorch 1.5
81
84
 
@@ -225,7 +225,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
225
225
  hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
226
226
  When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
227
227
  hidden_states
228
- encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
228
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
229
229
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
230
230
  self-attention.
231
231
  timestep ( `torch.long`, *optional*):
@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
23
23
  from ..loaders import UNet2DConditionLoadersMixin
24
24
  from ..utils import BaseOutput, logging
25
25
  from .attention_processor import AttentionProcessor, AttnProcessor
26
- from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
26
+ from .embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
27
27
  from .modeling_utils import ModelMixin
28
28
  from .unet_2d_blocks import (
29
29
  CrossAttnDownBlock2D,
@@ -97,11 +97,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
97
97
  class_embed_type (`str`, *optional*, defaults to None):
98
98
  The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
99
99
  `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
100
+ addition_embed_type (`str`, *optional*, defaults to None):
101
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
102
+ "text". "text" will use the `TextTimeEmbedding` layer.
100
103
  num_class_embeds (`int`, *optional*, defaults to None):
101
104
  Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
102
105
  class conditioning with `class_embed_type` equal to `None`.
103
106
  time_embedding_type (`str`, *optional*, default to `positional`):
104
107
  The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
108
+ time_embedding_dim (`int`, *optional*, default to `None`):
109
+ An optional override for the dimension of the projected time embedding.
105
110
  time_embedding_act_fn (`str`, *optional*, default to `None`):
106
111
  Optional activation function to use on the time embeddings only one time before they as passed to the rest
107
112
  of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
@@ -155,12 +160,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
155
160
  dual_cross_attention: bool = False,
156
161
  use_linear_projection: bool = False,
157
162
  class_embed_type: Optional[str] = None,
163
+ addition_embed_type: Optional[str] = None,
158
164
  num_class_embeds: Optional[int] = None,
159
165
  upcast_attention: bool = False,
160
166
  resnet_time_scale_shift: str = "default",
161
167
  resnet_skip_time_act: bool = False,
162
168
  resnet_out_scale_factor: int = 1.0,
163
169
  time_embedding_type: str = "positional",
170
+ time_embedding_dim: Optional[int] = None,
164
171
  time_embedding_act_fn: Optional[str] = None,
165
172
  timestep_post_act: Optional[str] = None,
166
173
  time_cond_proj_dim: Optional[int] = None,
@@ -170,6 +177,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
170
177
  class_embeddings_concat: bool = False,
171
178
  mid_block_only_cross_attention: Optional[bool] = None,
172
179
  cross_attention_norm: Optional[str] = None,
180
+ addition_embed_type_num_heads=64,
173
181
  ):
174
182
  super().__init__()
175
183
 
@@ -214,7 +222,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
214
222
 
215
223
  # time
216
224
  if time_embedding_type == "fourier":
217
- time_embed_dim = block_out_channels[0] * 2
225
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
218
226
  if time_embed_dim % 2 != 0:
219
227
  raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
220
228
  self.time_proj = GaussianFourierProjection(
@@ -222,7 +230,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
222
230
  )
223
231
  timestep_input_dim = time_embed_dim
224
232
  elif time_embedding_type == "positional":
225
- time_embed_dim = block_out_channels[0] * 4
233
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
226
234
 
227
235
  self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
228
236
  timestep_input_dim = block_out_channels[0]
@@ -248,7 +256,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
248
256
  if class_embed_type is None and num_class_embeds is not None:
249
257
  self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
250
258
  elif class_embed_type == "timestep":
251
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
259
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
252
260
  elif class_embed_type == "identity":
253
261
  self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
254
262
  elif class_embed_type == "projection":
@@ -273,6 +281,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
273
281
  else:
274
282
  self.class_embedding = None
275
283
 
284
+ if addition_embed_type == "text":
285
+ if encoder_hid_dim is not None:
286
+ text_time_embedding_from_dim = encoder_hid_dim
287
+ else:
288
+ text_time_embedding_from_dim = cross_attention_dim
289
+
290
+ self.add_embedding = TextTimeEmbedding(
291
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
292
+ )
293
+ elif addition_embed_type is not None:
294
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
295
+
276
296
  if time_embedding_act_fn is None:
277
297
  self.time_embed_act = None
278
298
  elif time_embedding_act_fn == "swish":
@@ -437,7 +457,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
437
457
  self.conv_norm_out = nn.GroupNorm(
438
458
  num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
439
459
  )
440
- self.conv_act = nn.SiLU()
460
+
461
+ if act_fn == "swish":
462
+ self.conv_act = lambda x: F.silu(x)
463
+ elif act_fn == "mish":
464
+ self.conv_act = nn.Mish()
465
+ elif act_fn == "silu":
466
+ self.conv_act = nn.SiLU()
467
+ elif act_fn == "gelu":
468
+ self.conv_act = nn.GELU()
469
+ else:
470
+ raise ValueError(f"Unsupported activation function: {act_fn}")
471
+
441
472
  else:
442
473
  self.conv_norm_out = None
443
474
  self.conv_act = None
@@ -648,7 +679,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
648
679
 
649
680
  t_emb = self.time_proj(timesteps)
650
681
 
651
- # timesteps does not contain any weights and will always return f32 tensors
682
+ # `Timesteps` does not contain any weights and will always return f32 tensors
652
683
  # but time_embedding might actually be running in fp16. so we need to cast here.
653
684
  # there might be better ways to encapsulate this.
654
685
  t_emb = t_emb.to(dtype=self.dtype)
@@ -662,6 +693,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
662
693
  if self.config.class_embed_type == "timestep":
663
694
  class_labels = self.time_proj(class_labels)
664
695
 
696
+ # `Timesteps` does not contain any weights and will always return f32 tensors
697
+ # there might be better ways to encapsulate this.
698
+ class_labels = class_labels.to(dtype=sample.dtype)
699
+
665
700
  class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
666
701
 
667
702
  if self.config.class_embeddings_concat:
@@ -669,6 +704,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
669
704
  else:
670
705
  emb = emb + class_emb
671
706
 
707
+ if self.config.addition_embed_type == "text":
708
+ aug_emb = self.add_embedding(encoder_hidden_states)
709
+ emb = emb + aug_emb
710
+
672
711
  if self.time_embed_act is not None:
673
712
  emb = self.time_embed_act(emb)
674
713
 
diffusers/models/vae.py CHANGED
@@ -212,6 +212,7 @@ class Decoder(nn.Module):
212
212
  sample = z
213
213
  sample = self.conv_in(sample)
214
214
 
215
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
215
216
  if self.training and self.gradient_checkpointing:
216
217
 
217
218
  def create_custom_forward(module):
@@ -222,6 +223,7 @@ class Decoder(nn.Module):
222
223
 
223
224
  # middle
224
225
  sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
226
+ sample = sample.to(upscale_dtype)
225
227
 
226
228
  # up
227
229
  for up_block in self.up_blocks:
@@ -229,6 +231,7 @@ class Decoder(nn.Module):
229
231
  else:
230
232
  # middle
231
233
  sample = self.mid_block(sample)
234
+ sample = sample.to(upscale_dtype)
232
235
 
233
236
  # up
234
237
  for up_block in self.up_blocks:
@@ -44,6 +44,14 @@ except OptionalDependencyNotAvailable:
44
44
  else:
45
45
  from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
46
46
  from .audioldm import AudioLDMPipeline
47
+ from .deepfloyd_if import (
48
+ IFImg2ImgPipeline,
49
+ IFImg2ImgSuperResolutionPipeline,
50
+ IFInpaintingPipeline,
51
+ IFInpaintingSuperResolutionPipeline,
52
+ IFPipeline,
53
+ IFSuperResolutionPipeline,
54
+ )
47
55
  from .latent_diffusion import LDMTextToImagePipeline
48
56
  from .paint_by_example import PaintByExamplePipeline
49
57
  from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
@@ -56,7 +56,7 @@ class RobertaSeriesConfig(XLMRobertaConfig):
56
56
 
57
57
 
58
58
  class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
59
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
59
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
60
60
  _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
61
61
  base_model_prefix = "roberta"
62
62
  config_class = RobertaSeriesConfig
@@ -65,6 +65,10 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
65
65
  super().__init__(config)
66
66
  self.roberta = XLMRobertaModel(config)
67
67
  self.transformation = nn.Linear(config.hidden_size, config.project_dim)
68
+ self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
69
+ if self.has_pre_transformation:
70
+ self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
71
+ self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
68
72
  self.post_init()
69
73
 
70
74
  def forward(
@@ -95,15 +99,26 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
95
99
  encoder_hidden_states=encoder_hidden_states,
96
100
  encoder_attention_mask=encoder_attention_mask,
97
101
  output_attentions=output_attentions,
98
- output_hidden_states=output_hidden_states,
102
+ output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
99
103
  return_dict=return_dict,
100
104
  )
101
105
 
102
- projection_state = self.transformation(outputs.last_hidden_state)
103
-
104
- return TransformationModelOutput(
105
- projection_state=projection_state,
106
- last_hidden_state=outputs.last_hidden_state,
107
- hidden_states=outputs.hidden_states,
108
- attentions=outputs.attentions,
109
- )
106
+ if self.has_pre_transformation:
107
+ sequence_output2 = outputs["hidden_states"][-2]
108
+ sequence_output2 = self.pre_LN(sequence_output2)
109
+ projection_state2 = self.transformation_pre(sequence_output2)
110
+
111
+ return TransformationModelOutput(
112
+ projection_state=projection_state2,
113
+ last_hidden_state=outputs.last_hidden_state,
114
+ hidden_states=outputs.hidden_states,
115
+ attentions=outputs.attentions,
116
+ )
117
+ else:
118
+ projection_state = self.transformation(outputs.last_hidden_state)
119
+ return TransformationModelOutput(
120
+ projection_state=projection_state,
121
+ last_hidden_state=outputs.last_hidden_state,
122
+ hidden_states=outputs.hidden_states,
123
+ attentions=outputs.attentions,
124
+ )
@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
57
57
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
58
58
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
59
59
 
60
+ In addition the pipeline inherits the following loading methods:
61
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
62
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
63
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
64
+
65
+ as well as the following saving methods:
66
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
67
+
60
68
  Args:
61
69
  vae ([`AutoencoderKL`]):
62
70
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
96
96
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
97
97
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
98
98
 
99
+ In addition the pipeline inherits the following loading methods:
100
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
101
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
102
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
103
+
104
+ as well as the following saving methods:
105
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
106
+
99
107
  Args:
100
108
  vae ([`AutoencoderKL`]):
101
109
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -293,7 +293,7 @@ class AudioLDMPipeline(DiffusionPipeline):
293
293
 
294
294
  waveform = self.vocoder(mel_spectrogram)
295
295
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
296
- waveform = waveform.cpu()
296
+ waveform = waveform.cpu().float()
297
297
  return waveform
298
298
 
299
299
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs