diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +7 -2
- diffusers/configuration_utils.py +4 -0
- diffusers/loaders.py +262 -12
- diffusers/models/attention.py +31 -12
- diffusers/models/attention_processor.py +189 -0
- diffusers/models/controlnet.py +9 -2
- diffusers/models/embeddings.py +66 -0
- diffusers/models/modeling_pytorch_flax_utils.py +6 -0
- diffusers/models/modeling_utils.py +5 -2
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/unet_2d_condition.py +45 -6
- diffusers/models/vae.py +3 -0
- diffusers/pipelines/__init__.py +8 -0
- diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
- diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
- diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
- diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
- diffusers/pipelines/pipeline_utils.py +54 -25
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
- diffusers/schedulers/scheduling_ddpm.py +63 -16
- diffusers/schedulers/scheduling_heun_discrete.py +51 -1
- diffusers/utils/__init__.py +4 -1
- diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/hub_utils.py +4 -1
- diffusers/utils/import_utils.py +41 -0
- diffusers/utils/pil_utils.py +24 -0
- diffusers/utils/testing_utils.py +10 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -149,6 +149,9 @@ class Attention(nn.Module):
|
|
149
149
|
is_lora = hasattr(self, "processor") and isinstance(
|
150
150
|
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
|
151
151
|
)
|
152
|
+
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
153
|
+
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
154
|
+
)
|
152
155
|
|
153
156
|
if use_memory_efficient_attention_xformers:
|
154
157
|
if self.added_kv_proj_dim is not None:
|
@@ -192,6 +195,17 @@ class Attention(nn.Module):
|
|
192
195
|
)
|
193
196
|
processor.load_state_dict(self.processor.state_dict())
|
194
197
|
processor.to(self.processor.to_q_lora.up.weight.device)
|
198
|
+
elif is_custom_diffusion:
|
199
|
+
processor = CustomDiffusionXFormersAttnProcessor(
|
200
|
+
train_kv=self.processor.train_kv,
|
201
|
+
train_q_out=self.processor.train_q_out,
|
202
|
+
hidden_size=self.processor.hidden_size,
|
203
|
+
cross_attention_dim=self.processor.cross_attention_dim,
|
204
|
+
attention_op=attention_op,
|
205
|
+
)
|
206
|
+
processor.load_state_dict(self.processor.state_dict())
|
207
|
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
208
|
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
195
209
|
else:
|
196
210
|
processor = XFormersAttnProcessor(attention_op=attention_op)
|
197
211
|
else:
|
@@ -203,6 +217,16 @@ class Attention(nn.Module):
|
|
203
217
|
)
|
204
218
|
processor.load_state_dict(self.processor.state_dict())
|
205
219
|
processor.to(self.processor.to_q_lora.up.weight.device)
|
220
|
+
elif is_custom_diffusion:
|
221
|
+
processor = CustomDiffusionAttnProcessor(
|
222
|
+
train_kv=self.processor.train_kv,
|
223
|
+
train_q_out=self.processor.train_q_out,
|
224
|
+
hidden_size=self.processor.hidden_size,
|
225
|
+
cross_attention_dim=self.processor.cross_attention_dim,
|
226
|
+
)
|
227
|
+
processor.load_state_dict(self.processor.state_dict())
|
228
|
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
229
|
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
206
230
|
else:
|
207
231
|
processor = AttnProcessor()
|
208
232
|
|
@@ -459,6 +483,84 @@ class LoRAAttnProcessor(nn.Module):
|
|
459
483
|
return hidden_states
|
460
484
|
|
461
485
|
|
486
|
+
class CustomDiffusionAttnProcessor(nn.Module):
|
487
|
+
def __init__(
|
488
|
+
self,
|
489
|
+
train_kv=True,
|
490
|
+
train_q_out=True,
|
491
|
+
hidden_size=None,
|
492
|
+
cross_attention_dim=None,
|
493
|
+
out_bias=True,
|
494
|
+
dropout=0.0,
|
495
|
+
):
|
496
|
+
super().__init__()
|
497
|
+
self.train_kv = train_kv
|
498
|
+
self.train_q_out = train_q_out
|
499
|
+
|
500
|
+
self.hidden_size = hidden_size
|
501
|
+
self.cross_attention_dim = cross_attention_dim
|
502
|
+
|
503
|
+
# `_custom_diffusion` id for easy serialization and loading.
|
504
|
+
if self.train_kv:
|
505
|
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
506
|
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
507
|
+
if self.train_q_out:
|
508
|
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
509
|
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
510
|
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
511
|
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
512
|
+
|
513
|
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
514
|
+
batch_size, sequence_length, _ = hidden_states.shape
|
515
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
516
|
+
if self.train_q_out:
|
517
|
+
query = self.to_q_custom_diffusion(hidden_states)
|
518
|
+
else:
|
519
|
+
query = attn.to_q(hidden_states)
|
520
|
+
|
521
|
+
if encoder_hidden_states is None:
|
522
|
+
crossattn = False
|
523
|
+
encoder_hidden_states = hidden_states
|
524
|
+
else:
|
525
|
+
crossattn = True
|
526
|
+
if attn.norm_cross:
|
527
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
528
|
+
|
529
|
+
if self.train_kv:
|
530
|
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
531
|
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
532
|
+
else:
|
533
|
+
key = attn.to_k(encoder_hidden_states)
|
534
|
+
value = attn.to_v(encoder_hidden_states)
|
535
|
+
|
536
|
+
if crossattn:
|
537
|
+
detach = torch.ones_like(key)
|
538
|
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
539
|
+
key = detach * key + (1 - detach) * key.detach()
|
540
|
+
value = detach * value + (1 - detach) * value.detach()
|
541
|
+
|
542
|
+
query = attn.head_to_batch_dim(query)
|
543
|
+
key = attn.head_to_batch_dim(key)
|
544
|
+
value = attn.head_to_batch_dim(value)
|
545
|
+
|
546
|
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
547
|
+
hidden_states = torch.bmm(attention_probs, value)
|
548
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
549
|
+
|
550
|
+
if self.train_q_out:
|
551
|
+
# linear proj
|
552
|
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
553
|
+
# dropout
|
554
|
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
555
|
+
else:
|
556
|
+
# linear proj
|
557
|
+
hidden_states = attn.to_out[0](hidden_states)
|
558
|
+
# dropout
|
559
|
+
hidden_states = attn.to_out[1](hidden_states)
|
560
|
+
|
561
|
+
return hidden_states
|
562
|
+
|
563
|
+
|
462
564
|
class AttnAddedKVProcessor:
|
463
565
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
464
566
|
residual = hidden_states
|
@@ -699,6 +801,91 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
|
699
801
|
return hidden_states
|
700
802
|
|
701
803
|
|
804
|
+
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
805
|
+
def __init__(
|
806
|
+
self,
|
807
|
+
train_kv=True,
|
808
|
+
train_q_out=False,
|
809
|
+
hidden_size=None,
|
810
|
+
cross_attention_dim=None,
|
811
|
+
out_bias=True,
|
812
|
+
dropout=0.0,
|
813
|
+
attention_op: Optional[Callable] = None,
|
814
|
+
):
|
815
|
+
super().__init__()
|
816
|
+
self.train_kv = train_kv
|
817
|
+
self.train_q_out = train_q_out
|
818
|
+
|
819
|
+
self.hidden_size = hidden_size
|
820
|
+
self.cross_attention_dim = cross_attention_dim
|
821
|
+
self.attention_op = attention_op
|
822
|
+
|
823
|
+
# `_custom_diffusion` id for easy serialization and loading.
|
824
|
+
if self.train_kv:
|
825
|
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
826
|
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
827
|
+
if self.train_q_out:
|
828
|
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
829
|
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
830
|
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
831
|
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
832
|
+
|
833
|
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
834
|
+
batch_size, sequence_length, _ = (
|
835
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
836
|
+
)
|
837
|
+
|
838
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
839
|
+
|
840
|
+
if self.train_q_out:
|
841
|
+
query = self.to_q_custom_diffusion(hidden_states)
|
842
|
+
else:
|
843
|
+
query = attn.to_q(hidden_states)
|
844
|
+
|
845
|
+
if encoder_hidden_states is None:
|
846
|
+
crossattn = False
|
847
|
+
encoder_hidden_states = hidden_states
|
848
|
+
else:
|
849
|
+
crossattn = True
|
850
|
+
if attn.norm_cross:
|
851
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
852
|
+
|
853
|
+
if self.train_kv:
|
854
|
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
855
|
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
856
|
+
else:
|
857
|
+
key = attn.to_k(encoder_hidden_states)
|
858
|
+
value = attn.to_v(encoder_hidden_states)
|
859
|
+
|
860
|
+
if crossattn:
|
861
|
+
detach = torch.ones_like(key)
|
862
|
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
863
|
+
key = detach * key + (1 - detach) * key.detach()
|
864
|
+
value = detach * value + (1 - detach) * value.detach()
|
865
|
+
|
866
|
+
query = attn.head_to_batch_dim(query).contiguous()
|
867
|
+
key = attn.head_to_batch_dim(key).contiguous()
|
868
|
+
value = attn.head_to_batch_dim(value).contiguous()
|
869
|
+
|
870
|
+
hidden_states = xformers.ops.memory_efficient_attention(
|
871
|
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
872
|
+
)
|
873
|
+
hidden_states = hidden_states.to(query.dtype)
|
874
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
875
|
+
|
876
|
+
if self.train_q_out:
|
877
|
+
# linear proj
|
878
|
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
879
|
+
# dropout
|
880
|
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
881
|
+
else:
|
882
|
+
# linear proj
|
883
|
+
hidden_states = attn.to_out[0](hidden_states)
|
884
|
+
# dropout
|
885
|
+
hidden_states = attn.to_out[1](hidden_states)
|
886
|
+
return hidden_states
|
887
|
+
|
888
|
+
|
702
889
|
class SlicedAttnProcessor:
|
703
890
|
def __init__(self, slice_size):
|
704
891
|
self.slice_size = slice_size
|
@@ -834,4 +1021,6 @@ AttentionProcessor = Union[
|
|
834
1021
|
AttnAddedKVProcessor2_0,
|
835
1022
|
LoRAAttnProcessor,
|
836
1023
|
LoRAXFormersAttnProcessor,
|
1024
|
+
CustomDiffusionAttnProcessor,
|
1025
|
+
CustomDiffusionXFormersAttnProcessor,
|
837
1026
|
]
|
diffusers/models/controlnet.py
CHANGED
@@ -457,6 +457,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
457
457
|
timestep_cond: Optional[torch.Tensor] = None,
|
458
458
|
attention_mask: Optional[torch.Tensor] = None,
|
459
459
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
460
|
+
guess_mode: bool = False,
|
460
461
|
return_dict: bool = True,
|
461
462
|
) -> Union[ControlNetOutput, Tuple]:
|
462
463
|
# check channel order
|
@@ -557,8 +558,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
557
558
|
mid_block_res_sample = self.controlnet_mid_block(sample)
|
558
559
|
|
559
560
|
# 6. scaling
|
560
|
-
|
561
|
-
|
561
|
+
if guess_mode:
|
562
|
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
|
563
|
+
scales *= conditioning_scale
|
564
|
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
565
|
+
mid_block_res_sample *= scales[-1] # last one
|
566
|
+
else:
|
567
|
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
568
|
+
mid_block_res_sample *= conditioning_scale
|
562
569
|
|
563
570
|
if self.config.global_pool_conditions:
|
564
571
|
down_block_res_samples = [
|
diffusers/models/embeddings.py
CHANGED
@@ -377,3 +377,69 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
|
|
377
377
|
conditioning = timesteps_emb + class_labels # (N, D)
|
378
378
|
|
379
379
|
return conditioning
|
380
|
+
|
381
|
+
|
382
|
+
class TextTimeEmbedding(nn.Module):
|
383
|
+
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
384
|
+
super().__init__()
|
385
|
+
self.norm1 = nn.LayerNorm(encoder_dim)
|
386
|
+
self.pool = AttentionPooling(num_heads, encoder_dim)
|
387
|
+
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
388
|
+
self.norm2 = nn.LayerNorm(time_embed_dim)
|
389
|
+
|
390
|
+
def forward(self, hidden_states):
|
391
|
+
hidden_states = self.norm1(hidden_states)
|
392
|
+
hidden_states = self.pool(hidden_states)
|
393
|
+
hidden_states = self.proj(hidden_states)
|
394
|
+
hidden_states = self.norm2(hidden_states)
|
395
|
+
return hidden_states
|
396
|
+
|
397
|
+
|
398
|
+
class AttentionPooling(nn.Module):
|
399
|
+
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
400
|
+
|
401
|
+
def __init__(self, num_heads, embed_dim, dtype=None):
|
402
|
+
super().__init__()
|
403
|
+
self.dtype = dtype
|
404
|
+
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
405
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
406
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
407
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
408
|
+
self.num_heads = num_heads
|
409
|
+
self.dim_per_head = embed_dim // self.num_heads
|
410
|
+
|
411
|
+
def forward(self, x):
|
412
|
+
bs, length, width = x.size()
|
413
|
+
|
414
|
+
def shape(x):
|
415
|
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
416
|
+
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
417
|
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
418
|
+
x = x.transpose(1, 2)
|
419
|
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
420
|
+
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
421
|
+
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
422
|
+
x = x.transpose(1, 2)
|
423
|
+
return x
|
424
|
+
|
425
|
+
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
426
|
+
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
427
|
+
|
428
|
+
# (bs*n_heads, class_token_length, dim_per_head)
|
429
|
+
q = shape(self.q_proj(class_token))
|
430
|
+
# (bs*n_heads, length+class_token_length, dim_per_head)
|
431
|
+
k = shape(self.k_proj(x))
|
432
|
+
v = shape(self.v_proj(x))
|
433
|
+
|
434
|
+
# (bs*n_heads, class_token_length, length+class_token_length):
|
435
|
+
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
436
|
+
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
437
|
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
438
|
+
|
439
|
+
# (bs*n_heads, dim_per_head, class_token_length)
|
440
|
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
441
|
+
|
442
|
+
# (bs, length+1, width)
|
443
|
+
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
444
|
+
|
445
|
+
return a[:, 0, :] # cls_token
|
@@ -110,6 +110,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
|
110
110
|
.replace("_1", ".1")
|
111
111
|
.replace("_2", ".2")
|
112
112
|
.replace("_3", ".3")
|
113
|
+
.replace("_4", ".4")
|
114
|
+
.replace("_5", ".5")
|
115
|
+
.replace("_6", ".6")
|
116
|
+
.replace("_7", ".7")
|
117
|
+
.replace("_8", ".8")
|
118
|
+
.replace("_9", ".9")
|
113
119
|
)
|
114
120
|
|
115
121
|
flax_key = ".".join(flax_key_tuple_array)
|
@@ -15,6 +15,7 @@
|
|
15
15
|
# limitations under the License.
|
16
16
|
|
17
17
|
import inspect
|
18
|
+
import itertools
|
18
19
|
import os
|
19
20
|
from functools import partial
|
20
21
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
@@ -60,7 +61,8 @@ if is_safetensors_available():
|
|
60
61
|
|
61
62
|
def get_parameter_device(parameter: torch.nn.Module):
|
62
63
|
try:
|
63
|
-
|
64
|
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
65
|
+
return next(parameters_and_buffers).device
|
64
66
|
except StopIteration:
|
65
67
|
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
66
68
|
|
@@ -75,7 +77,8 @@ def get_parameter_device(parameter: torch.nn.Module):
|
|
75
77
|
|
76
78
|
def get_parameter_dtype(parameter: torch.nn.Module):
|
77
79
|
try:
|
78
|
-
|
80
|
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
81
|
+
return next(parameters_and_buffers).dtype
|
79
82
|
except StopIteration:
|
80
83
|
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
81
84
|
|
@@ -225,7 +225,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
225
225
|
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
226
226
|
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
227
227
|
hidden_states
|
228
|
-
encoder_hidden_states ( `torch.
|
228
|
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
229
229
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
230
230
|
self-attention.
|
231
231
|
timestep ( `torch.long`, *optional*):
|
@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
|
23
23
|
from ..loaders import UNet2DConditionLoadersMixin
|
24
24
|
from ..utils import BaseOutput, logging
|
25
25
|
from .attention_processor import AttentionProcessor, AttnProcessor
|
26
|
-
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
26
|
+
from .embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
27
27
|
from .modeling_utils import ModelMixin
|
28
28
|
from .unet_2d_blocks import (
|
29
29
|
CrossAttnDownBlock2D,
|
@@ -97,11 +97,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
97
97
|
class_embed_type (`str`, *optional*, defaults to None):
|
98
98
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
99
99
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
100
|
+
addition_embed_type (`str`, *optional*, defaults to None):
|
101
|
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
102
|
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
100
103
|
num_class_embeds (`int`, *optional*, defaults to None):
|
101
104
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
102
105
|
class conditioning with `class_embed_type` equal to `None`.
|
103
106
|
time_embedding_type (`str`, *optional*, default to `positional`):
|
104
107
|
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
108
|
+
time_embedding_dim (`int`, *optional*, default to `None`):
|
109
|
+
An optional override for the dimension of the projected time embedding.
|
105
110
|
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
106
111
|
Optional activation function to use on the time embeddings only one time before they as passed to the rest
|
107
112
|
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
@@ -155,12 +160,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
155
160
|
dual_cross_attention: bool = False,
|
156
161
|
use_linear_projection: bool = False,
|
157
162
|
class_embed_type: Optional[str] = None,
|
163
|
+
addition_embed_type: Optional[str] = None,
|
158
164
|
num_class_embeds: Optional[int] = None,
|
159
165
|
upcast_attention: bool = False,
|
160
166
|
resnet_time_scale_shift: str = "default",
|
161
167
|
resnet_skip_time_act: bool = False,
|
162
168
|
resnet_out_scale_factor: int = 1.0,
|
163
169
|
time_embedding_type: str = "positional",
|
170
|
+
time_embedding_dim: Optional[int] = None,
|
164
171
|
time_embedding_act_fn: Optional[str] = None,
|
165
172
|
timestep_post_act: Optional[str] = None,
|
166
173
|
time_cond_proj_dim: Optional[int] = None,
|
@@ -170,6 +177,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
170
177
|
class_embeddings_concat: bool = False,
|
171
178
|
mid_block_only_cross_attention: Optional[bool] = None,
|
172
179
|
cross_attention_norm: Optional[str] = None,
|
180
|
+
addition_embed_type_num_heads=64,
|
173
181
|
):
|
174
182
|
super().__init__()
|
175
183
|
|
@@ -214,7 +222,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
214
222
|
|
215
223
|
# time
|
216
224
|
if time_embedding_type == "fourier":
|
217
|
-
time_embed_dim = block_out_channels[0] * 2
|
225
|
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
218
226
|
if time_embed_dim % 2 != 0:
|
219
227
|
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
220
228
|
self.time_proj = GaussianFourierProjection(
|
@@ -222,7 +230,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
222
230
|
)
|
223
231
|
timestep_input_dim = time_embed_dim
|
224
232
|
elif time_embedding_type == "positional":
|
225
|
-
time_embed_dim = block_out_channels[0] * 4
|
233
|
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
226
234
|
|
227
235
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
228
236
|
timestep_input_dim = block_out_channels[0]
|
@@ -248,7 +256,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
248
256
|
if class_embed_type is None and num_class_embeds is not None:
|
249
257
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
250
258
|
elif class_embed_type == "timestep":
|
251
|
-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
259
|
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
252
260
|
elif class_embed_type == "identity":
|
253
261
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
254
262
|
elif class_embed_type == "projection":
|
@@ -273,6 +281,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
273
281
|
else:
|
274
282
|
self.class_embedding = None
|
275
283
|
|
284
|
+
if addition_embed_type == "text":
|
285
|
+
if encoder_hid_dim is not None:
|
286
|
+
text_time_embedding_from_dim = encoder_hid_dim
|
287
|
+
else:
|
288
|
+
text_time_embedding_from_dim = cross_attention_dim
|
289
|
+
|
290
|
+
self.add_embedding = TextTimeEmbedding(
|
291
|
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
292
|
+
)
|
293
|
+
elif addition_embed_type is not None:
|
294
|
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
|
295
|
+
|
276
296
|
if time_embedding_act_fn is None:
|
277
297
|
self.time_embed_act = None
|
278
298
|
elif time_embedding_act_fn == "swish":
|
@@ -437,7 +457,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
437
457
|
self.conv_norm_out = nn.GroupNorm(
|
438
458
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
439
459
|
)
|
440
|
-
|
460
|
+
|
461
|
+
if act_fn == "swish":
|
462
|
+
self.conv_act = lambda x: F.silu(x)
|
463
|
+
elif act_fn == "mish":
|
464
|
+
self.conv_act = nn.Mish()
|
465
|
+
elif act_fn == "silu":
|
466
|
+
self.conv_act = nn.SiLU()
|
467
|
+
elif act_fn == "gelu":
|
468
|
+
self.conv_act = nn.GELU()
|
469
|
+
else:
|
470
|
+
raise ValueError(f"Unsupported activation function: {act_fn}")
|
471
|
+
|
441
472
|
else:
|
442
473
|
self.conv_norm_out = None
|
443
474
|
self.conv_act = None
|
@@ -648,7 +679,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
648
679
|
|
649
680
|
t_emb = self.time_proj(timesteps)
|
650
681
|
|
651
|
-
#
|
682
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
652
683
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
653
684
|
# there might be better ways to encapsulate this.
|
654
685
|
t_emb = t_emb.to(dtype=self.dtype)
|
@@ -662,6 +693,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
662
693
|
if self.config.class_embed_type == "timestep":
|
663
694
|
class_labels = self.time_proj(class_labels)
|
664
695
|
|
696
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
697
|
+
# there might be better ways to encapsulate this.
|
698
|
+
class_labels = class_labels.to(dtype=sample.dtype)
|
699
|
+
|
665
700
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
666
701
|
|
667
702
|
if self.config.class_embeddings_concat:
|
@@ -669,6 +704,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
669
704
|
else:
|
670
705
|
emb = emb + class_emb
|
671
706
|
|
707
|
+
if self.config.addition_embed_type == "text":
|
708
|
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
709
|
+
emb = emb + aug_emb
|
710
|
+
|
672
711
|
if self.time_embed_act is not None:
|
673
712
|
emb = self.time_embed_act(emb)
|
674
713
|
|
diffusers/models/vae.py
CHANGED
@@ -212,6 +212,7 @@ class Decoder(nn.Module):
|
|
212
212
|
sample = z
|
213
213
|
sample = self.conv_in(sample)
|
214
214
|
|
215
|
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
215
216
|
if self.training and self.gradient_checkpointing:
|
216
217
|
|
217
218
|
def create_custom_forward(module):
|
@@ -222,6 +223,7 @@ class Decoder(nn.Module):
|
|
222
223
|
|
223
224
|
# middle
|
224
225
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
226
|
+
sample = sample.to(upscale_dtype)
|
225
227
|
|
226
228
|
# up
|
227
229
|
for up_block in self.up_blocks:
|
@@ -229,6 +231,7 @@ class Decoder(nn.Module):
|
|
229
231
|
else:
|
230
232
|
# middle
|
231
233
|
sample = self.mid_block(sample)
|
234
|
+
sample = sample.to(upscale_dtype)
|
232
235
|
|
233
236
|
# up
|
234
237
|
for up_block in self.up_blocks:
|
diffusers/pipelines/__init__.py
CHANGED
@@ -44,6 +44,14 @@ except OptionalDependencyNotAvailable:
|
|
44
44
|
else:
|
45
45
|
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
46
46
|
from .audioldm import AudioLDMPipeline
|
47
|
+
from .deepfloyd_if import (
|
48
|
+
IFImg2ImgPipeline,
|
49
|
+
IFImg2ImgSuperResolutionPipeline,
|
50
|
+
IFInpaintingPipeline,
|
51
|
+
IFInpaintingSuperResolutionPipeline,
|
52
|
+
IFPipeline,
|
53
|
+
IFSuperResolutionPipeline,
|
54
|
+
)
|
47
55
|
from .latent_diffusion import LDMTextToImagePipeline
|
48
56
|
from .paint_by_example import PaintByExamplePipeline
|
49
57
|
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
@@ -56,7 +56,7 @@ class RobertaSeriesConfig(XLMRobertaConfig):
|
|
56
56
|
|
57
57
|
|
58
58
|
class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
59
|
-
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
59
|
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
|
60
60
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
61
61
|
base_model_prefix = "roberta"
|
62
62
|
config_class = RobertaSeriesConfig
|
@@ -65,6 +65,10 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
|
65
65
|
super().__init__(config)
|
66
66
|
self.roberta = XLMRobertaModel(config)
|
67
67
|
self.transformation = nn.Linear(config.hidden_size, config.project_dim)
|
68
|
+
self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
|
69
|
+
if self.has_pre_transformation:
|
70
|
+
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
71
|
+
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
68
72
|
self.post_init()
|
69
73
|
|
70
74
|
def forward(
|
@@ -95,15 +99,26 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
|
95
99
|
encoder_hidden_states=encoder_hidden_states,
|
96
100
|
encoder_attention_mask=encoder_attention_mask,
|
97
101
|
output_attentions=output_attentions,
|
98
|
-
output_hidden_states=output_hidden_states,
|
102
|
+
output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
|
99
103
|
return_dict=return_dict,
|
100
104
|
)
|
101
105
|
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
106
|
+
if self.has_pre_transformation:
|
107
|
+
sequence_output2 = outputs["hidden_states"][-2]
|
108
|
+
sequence_output2 = self.pre_LN(sequence_output2)
|
109
|
+
projection_state2 = self.transformation_pre(sequence_output2)
|
110
|
+
|
111
|
+
return TransformationModelOutput(
|
112
|
+
projection_state=projection_state2,
|
113
|
+
last_hidden_state=outputs.last_hidden_state,
|
114
|
+
hidden_states=outputs.hidden_states,
|
115
|
+
attentions=outputs.attentions,
|
116
|
+
)
|
117
|
+
else:
|
118
|
+
projection_state = self.transformation(outputs.last_hidden_state)
|
119
|
+
return TransformationModelOutput(
|
120
|
+
projection_state=projection_state,
|
121
|
+
last_hidden_state=outputs.last_hidden_state,
|
122
|
+
hidden_states=outputs.hidden_states,
|
123
|
+
attentions=outputs.attentions,
|
124
|
+
)
|
@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
|
57
57
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
58
58
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
59
59
|
|
60
|
+
In addition the pipeline inherits the following loading methods:
|
61
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
62
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
63
|
+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
64
|
+
|
65
|
+
as well as the following saving methods:
|
66
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
67
|
+
|
60
68
|
Args:
|
61
69
|
vae ([`AutoencoderKL`]):
|
62
70
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
|
96
96
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
97
97
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
98
98
|
|
99
|
+
In addition the pipeline inherits the following loading methods:
|
100
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
101
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
102
|
+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
103
|
+
|
104
|
+
as well as the following saving methods:
|
105
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
106
|
+
|
99
107
|
Args:
|
100
108
|
vae ([`AutoencoderKL`]):
|
101
109
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -293,7 +293,7 @@ class AudioLDMPipeline(DiffusionPipeline):
|
|
293
293
|
|
294
294
|
waveform = self.vocoder(mel_spectrogram)
|
295
295
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
296
|
-
waveform = waveform.cpu()
|
296
|
+
waveform = waveform.cpu().float()
|
297
297
|
return waveform
|
298
298
|
|
299
299
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|