diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +7 -2
- diffusers/configuration_utils.py +4 -0
- diffusers/loaders.py +262 -12
- diffusers/models/attention.py +31 -12
- diffusers/models/attention_processor.py +189 -0
- diffusers/models/controlnet.py +9 -2
- diffusers/models/embeddings.py +66 -0
- diffusers/models/modeling_pytorch_flax_utils.py +6 -0
- diffusers/models/modeling_utils.py +5 -2
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/unet_2d_condition.py +45 -6
- diffusers/models/vae.py +3 -0
- diffusers/pipelines/__init__.py +8 -0
- diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
- diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
- diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
- diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
- diffusers/pipelines/pipeline_utils.py +54 -25
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
- diffusers/schedulers/scheduling_ddpm.py +63 -16
- diffusers/schedulers/scheduling_heun_discrete.py +51 -1
- diffusers/utils/__init__.py +4 -1
- diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/hub_utils.py +4 -1
- diffusers/utils/import_utils.py +41 -0
- diffusers/utils/pil_utils.py +24 -0
- diffusers/utils/testing_utils.py +10 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -149,6 +149,9 @@ class Attention(nn.Module):
|
|
149
149
|
is_lora = hasattr(self, "processor") and isinstance(
|
150
150
|
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
|
151
151
|
)
|
152
|
+
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
153
|
+
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
154
|
+
)
|
152
155
|
|
153
156
|
if use_memory_efficient_attention_xformers:
|
154
157
|
if self.added_kv_proj_dim is not None:
|
@@ -192,6 +195,17 @@ class Attention(nn.Module):
|
|
192
195
|
)
|
193
196
|
processor.load_state_dict(self.processor.state_dict())
|
194
197
|
processor.to(self.processor.to_q_lora.up.weight.device)
|
198
|
+
elif is_custom_diffusion:
|
199
|
+
processor = CustomDiffusionXFormersAttnProcessor(
|
200
|
+
train_kv=self.processor.train_kv,
|
201
|
+
train_q_out=self.processor.train_q_out,
|
202
|
+
hidden_size=self.processor.hidden_size,
|
203
|
+
cross_attention_dim=self.processor.cross_attention_dim,
|
204
|
+
attention_op=attention_op,
|
205
|
+
)
|
206
|
+
processor.load_state_dict(self.processor.state_dict())
|
207
|
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
208
|
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
195
209
|
else:
|
196
210
|
processor = XFormersAttnProcessor(attention_op=attention_op)
|
197
211
|
else:
|
@@ -203,6 +217,16 @@ class Attention(nn.Module):
|
|
203
217
|
)
|
204
218
|
processor.load_state_dict(self.processor.state_dict())
|
205
219
|
processor.to(self.processor.to_q_lora.up.weight.device)
|
220
|
+
elif is_custom_diffusion:
|
221
|
+
processor = CustomDiffusionAttnProcessor(
|
222
|
+
train_kv=self.processor.train_kv,
|
223
|
+
train_q_out=self.processor.train_q_out,
|
224
|
+
hidden_size=self.processor.hidden_size,
|
225
|
+
cross_attention_dim=self.processor.cross_attention_dim,
|
226
|
+
)
|
227
|
+
processor.load_state_dict(self.processor.state_dict())
|
228
|
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
229
|
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
206
230
|
else:
|
207
231
|
processor = AttnProcessor()
|
208
232
|
|
@@ -459,6 +483,84 @@ class LoRAAttnProcessor(nn.Module):
|
|
459
483
|
return hidden_states
|
460
484
|
|
461
485
|
|
486
|
+
class CustomDiffusionAttnProcessor(nn.Module):
|
487
|
+
def __init__(
|
488
|
+
self,
|
489
|
+
train_kv=True,
|
490
|
+
train_q_out=True,
|
491
|
+
hidden_size=None,
|
492
|
+
cross_attention_dim=None,
|
493
|
+
out_bias=True,
|
494
|
+
dropout=0.0,
|
495
|
+
):
|
496
|
+
super().__init__()
|
497
|
+
self.train_kv = train_kv
|
498
|
+
self.train_q_out = train_q_out
|
499
|
+
|
500
|
+
self.hidden_size = hidden_size
|
501
|
+
self.cross_attention_dim = cross_attention_dim
|
502
|
+
|
503
|
+
# `_custom_diffusion` id for easy serialization and loading.
|
504
|
+
if self.train_kv:
|
505
|
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
506
|
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
507
|
+
if self.train_q_out:
|
508
|
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
509
|
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
510
|
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
511
|
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
512
|
+
|
513
|
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
514
|
+
batch_size, sequence_length, _ = hidden_states.shape
|
515
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
516
|
+
if self.train_q_out:
|
517
|
+
query = self.to_q_custom_diffusion(hidden_states)
|
518
|
+
else:
|
519
|
+
query = attn.to_q(hidden_states)
|
520
|
+
|
521
|
+
if encoder_hidden_states is None:
|
522
|
+
crossattn = False
|
523
|
+
encoder_hidden_states = hidden_states
|
524
|
+
else:
|
525
|
+
crossattn = True
|
526
|
+
if attn.norm_cross:
|
527
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
528
|
+
|
529
|
+
if self.train_kv:
|
530
|
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
531
|
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
532
|
+
else:
|
533
|
+
key = attn.to_k(encoder_hidden_states)
|
534
|
+
value = attn.to_v(encoder_hidden_states)
|
535
|
+
|
536
|
+
if crossattn:
|
537
|
+
detach = torch.ones_like(key)
|
538
|
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
539
|
+
key = detach * key + (1 - detach) * key.detach()
|
540
|
+
value = detach * value + (1 - detach) * value.detach()
|
541
|
+
|
542
|
+
query = attn.head_to_batch_dim(query)
|
543
|
+
key = attn.head_to_batch_dim(key)
|
544
|
+
value = attn.head_to_batch_dim(value)
|
545
|
+
|
546
|
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
547
|
+
hidden_states = torch.bmm(attention_probs, value)
|
548
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
549
|
+
|
550
|
+
if self.train_q_out:
|
551
|
+
# linear proj
|
552
|
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
553
|
+
# dropout
|
554
|
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
555
|
+
else:
|
556
|
+
# linear proj
|
557
|
+
hidden_states = attn.to_out[0](hidden_states)
|
558
|
+
# dropout
|
559
|
+
hidden_states = attn.to_out[1](hidden_states)
|
560
|
+
|
561
|
+
return hidden_states
|
562
|
+
|
563
|
+
|
462
564
|
class AttnAddedKVProcessor:
|
463
565
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
464
566
|
residual = hidden_states
|
@@ -699,6 +801,91 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
|
699
801
|
return hidden_states
|
700
802
|
|
701
803
|
|
804
|
+
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
805
|
+
def __init__(
|
806
|
+
self,
|
807
|
+
train_kv=True,
|
808
|
+
train_q_out=False,
|
809
|
+
hidden_size=None,
|
810
|
+
cross_attention_dim=None,
|
811
|
+
out_bias=True,
|
812
|
+
dropout=0.0,
|
813
|
+
attention_op: Optional[Callable] = None,
|
814
|
+
):
|
815
|
+
super().__init__()
|
816
|
+
self.train_kv = train_kv
|
817
|
+
self.train_q_out = train_q_out
|
818
|
+
|
819
|
+
self.hidden_size = hidden_size
|
820
|
+
self.cross_attention_dim = cross_attention_dim
|
821
|
+
self.attention_op = attention_op
|
822
|
+
|
823
|
+
# `_custom_diffusion` id for easy serialization and loading.
|
824
|
+
if self.train_kv:
|
825
|
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
826
|
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
827
|
+
if self.train_q_out:
|
828
|
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
829
|
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
830
|
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
831
|
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
832
|
+
|
833
|
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
834
|
+
batch_size, sequence_length, _ = (
|
835
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
836
|
+
)
|
837
|
+
|
838
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
839
|
+
|
840
|
+
if self.train_q_out:
|
841
|
+
query = self.to_q_custom_diffusion(hidden_states)
|
842
|
+
else:
|
843
|
+
query = attn.to_q(hidden_states)
|
844
|
+
|
845
|
+
if encoder_hidden_states is None:
|
846
|
+
crossattn = False
|
847
|
+
encoder_hidden_states = hidden_states
|
848
|
+
else:
|
849
|
+
crossattn = True
|
850
|
+
if attn.norm_cross:
|
851
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
852
|
+
|
853
|
+
if self.train_kv:
|
854
|
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
855
|
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
856
|
+
else:
|
857
|
+
key = attn.to_k(encoder_hidden_states)
|
858
|
+
value = attn.to_v(encoder_hidden_states)
|
859
|
+
|
860
|
+
if crossattn:
|
861
|
+
detach = torch.ones_like(key)
|
862
|
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
863
|
+
key = detach * key + (1 - detach) * key.detach()
|
864
|
+
value = detach * value + (1 - detach) * value.detach()
|
865
|
+
|
866
|
+
query = attn.head_to_batch_dim(query).contiguous()
|
867
|
+
key = attn.head_to_batch_dim(key).contiguous()
|
868
|
+
value = attn.head_to_batch_dim(value).contiguous()
|
869
|
+
|
870
|
+
hidden_states = xformers.ops.memory_efficient_attention(
|
871
|
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
872
|
+
)
|
873
|
+
hidden_states = hidden_states.to(query.dtype)
|
874
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
875
|
+
|
876
|
+
if self.train_q_out:
|
877
|
+
# linear proj
|
878
|
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
879
|
+
# dropout
|
880
|
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
881
|
+
else:
|
882
|
+
# linear proj
|
883
|
+
hidden_states = attn.to_out[0](hidden_states)
|
884
|
+
# dropout
|
885
|
+
hidden_states = attn.to_out[1](hidden_states)
|
886
|
+
return hidden_states
|
887
|
+
|
888
|
+
|
702
889
|
class SlicedAttnProcessor:
|
703
890
|
def __init__(self, slice_size):
|
704
891
|
self.slice_size = slice_size
|
@@ -834,4 +1021,6 @@ AttentionProcessor = Union[
|
|
834
1021
|
AttnAddedKVProcessor2_0,
|
835
1022
|
LoRAAttnProcessor,
|
836
1023
|
LoRAXFormersAttnProcessor,
|
1024
|
+
CustomDiffusionAttnProcessor,
|
1025
|
+
CustomDiffusionXFormersAttnProcessor,
|
837
1026
|
]
|
diffusers/models/controlnet.py
CHANGED
@@ -457,6 +457,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
457
457
|
timestep_cond: Optional[torch.Tensor] = None,
|
458
458
|
attention_mask: Optional[torch.Tensor] = None,
|
459
459
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
460
|
+
guess_mode: bool = False,
|
460
461
|
return_dict: bool = True,
|
461
462
|
) -> Union[ControlNetOutput, Tuple]:
|
462
463
|
# check channel order
|
@@ -557,8 +558,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
557
558
|
mid_block_res_sample = self.controlnet_mid_block(sample)
|
558
559
|
|
559
560
|
# 6. scaling
|
560
|
-
|
561
|
-
|
561
|
+
if guess_mode:
|
562
|
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
|
563
|
+
scales *= conditioning_scale
|
564
|
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
565
|
+
mid_block_res_sample *= scales[-1] # last one
|
566
|
+
else:
|
567
|
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
568
|
+
mid_block_res_sample *= conditioning_scale
|
562
569
|
|
563
570
|
if self.config.global_pool_conditions:
|
564
571
|
down_block_res_samples = [
|
diffusers/models/embeddings.py
CHANGED
@@ -377,3 +377,69 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
|
|
377
377
|
conditioning = timesteps_emb + class_labels # (N, D)
|
378
378
|
|
379
379
|
return conditioning
|
380
|
+
|
381
|
+
|
382
|
+
class TextTimeEmbedding(nn.Module):
|
383
|
+
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
384
|
+
super().__init__()
|
385
|
+
self.norm1 = nn.LayerNorm(encoder_dim)
|
386
|
+
self.pool = AttentionPooling(num_heads, encoder_dim)
|
387
|
+
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
388
|
+
self.norm2 = nn.LayerNorm(time_embed_dim)
|
389
|
+
|
390
|
+
def forward(self, hidden_states):
|
391
|
+
hidden_states = self.norm1(hidden_states)
|
392
|
+
hidden_states = self.pool(hidden_states)
|
393
|
+
hidden_states = self.proj(hidden_states)
|
394
|
+
hidden_states = self.norm2(hidden_states)
|
395
|
+
return hidden_states
|
396
|
+
|
397
|
+
|
398
|
+
class AttentionPooling(nn.Module):
|
399
|
+
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
400
|
+
|
401
|
+
def __init__(self, num_heads, embed_dim, dtype=None):
|
402
|
+
super().__init__()
|
403
|
+
self.dtype = dtype
|
404
|
+
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
405
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
406
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
407
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
408
|
+
self.num_heads = num_heads
|
409
|
+
self.dim_per_head = embed_dim // self.num_heads
|
410
|
+
|
411
|
+
def forward(self, x):
|
412
|
+
bs, length, width = x.size()
|
413
|
+
|
414
|
+
def shape(x):
|
415
|
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
416
|
+
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
417
|
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
418
|
+
x = x.transpose(1, 2)
|
419
|
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
420
|
+
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
421
|
+
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
422
|
+
x = x.transpose(1, 2)
|
423
|
+
return x
|
424
|
+
|
425
|
+
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
426
|
+
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
427
|
+
|
428
|
+
# (bs*n_heads, class_token_length, dim_per_head)
|
429
|
+
q = shape(self.q_proj(class_token))
|
430
|
+
# (bs*n_heads, length+class_token_length, dim_per_head)
|
431
|
+
k = shape(self.k_proj(x))
|
432
|
+
v = shape(self.v_proj(x))
|
433
|
+
|
434
|
+
# (bs*n_heads, class_token_length, length+class_token_length):
|
435
|
+
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
436
|
+
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
437
|
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
438
|
+
|
439
|
+
# (bs*n_heads, dim_per_head, class_token_length)
|
440
|
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
441
|
+
|
442
|
+
# (bs, length+1, width)
|
443
|
+
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
444
|
+
|
445
|
+
return a[:, 0, :] # cls_token
|
@@ -110,6 +110,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
|
110
110
|
.replace("_1", ".1")
|
111
111
|
.replace("_2", ".2")
|
112
112
|
.replace("_3", ".3")
|
113
|
+
.replace("_4", ".4")
|
114
|
+
.replace("_5", ".5")
|
115
|
+
.replace("_6", ".6")
|
116
|
+
.replace("_7", ".7")
|
117
|
+
.replace("_8", ".8")
|
118
|
+
.replace("_9", ".9")
|
113
119
|
)
|
114
120
|
|
115
121
|
flax_key = ".".join(flax_key_tuple_array)
|
@@ -15,6 +15,7 @@
|
|
15
15
|
# limitations under the License.
|
16
16
|
|
17
17
|
import inspect
|
18
|
+
import itertools
|
18
19
|
import os
|
19
20
|
from functools import partial
|
20
21
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
@@ -60,7 +61,8 @@ if is_safetensors_available():
|
|
60
61
|
|
61
62
|
def get_parameter_device(parameter: torch.nn.Module):
|
62
63
|
try:
|
63
|
-
|
64
|
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
65
|
+
return next(parameters_and_buffers).device
|
64
66
|
except StopIteration:
|
65
67
|
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
66
68
|
|
@@ -75,7 +77,8 @@ def get_parameter_device(parameter: torch.nn.Module):
|
|
75
77
|
|
76
78
|
def get_parameter_dtype(parameter: torch.nn.Module):
|
77
79
|
try:
|
78
|
-
|
80
|
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
81
|
+
return next(parameters_and_buffers).dtype
|
79
82
|
except StopIteration:
|
80
83
|
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
81
84
|
|
@@ -225,7 +225,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
225
225
|
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
226
226
|
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
227
227
|
hidden_states
|
228
|
-
encoder_hidden_states ( `torch.
|
228
|
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
229
229
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
230
230
|
self-attention.
|
231
231
|
timestep ( `torch.long`, *optional*):
|
@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
|
23
23
|
from ..loaders import UNet2DConditionLoadersMixin
|
24
24
|
from ..utils import BaseOutput, logging
|
25
25
|
from .attention_processor import AttentionProcessor, AttnProcessor
|
26
|
-
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
26
|
+
from .embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
27
27
|
from .modeling_utils import ModelMixin
|
28
28
|
from .unet_2d_blocks import (
|
29
29
|
CrossAttnDownBlock2D,
|
@@ -97,11 +97,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
97
97
|
class_embed_type (`str`, *optional*, defaults to None):
|
98
98
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
99
99
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
100
|
+
addition_embed_type (`str`, *optional*, defaults to None):
|
101
|
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
102
|
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
100
103
|
num_class_embeds (`int`, *optional*, defaults to None):
|
101
104
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
102
105
|
class conditioning with `class_embed_type` equal to `None`.
|
103
106
|
time_embedding_type (`str`, *optional*, default to `positional`):
|
104
107
|
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
108
|
+
time_embedding_dim (`int`, *optional*, default to `None`):
|
109
|
+
An optional override for the dimension of the projected time embedding.
|
105
110
|
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
106
111
|
Optional activation function to use on the time embeddings only one time before they as passed to the rest
|
107
112
|
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
@@ -155,12 +160,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
155
160
|
dual_cross_attention: bool = False,
|
156
161
|
use_linear_projection: bool = False,
|
157
162
|
class_embed_type: Optional[str] = None,
|
163
|
+
addition_embed_type: Optional[str] = None,
|
158
164
|
num_class_embeds: Optional[int] = None,
|
159
165
|
upcast_attention: bool = False,
|
160
166
|
resnet_time_scale_shift: str = "default",
|
161
167
|
resnet_skip_time_act: bool = False,
|
162
168
|
resnet_out_scale_factor: int = 1.0,
|
163
169
|
time_embedding_type: str = "positional",
|
170
|
+
time_embedding_dim: Optional[int] = None,
|
164
171
|
time_embedding_act_fn: Optional[str] = None,
|
165
172
|
timestep_post_act: Optional[str] = None,
|
166
173
|
time_cond_proj_dim: Optional[int] = None,
|
@@ -170,6 +177,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
170
177
|
class_embeddings_concat: bool = False,
|
171
178
|
mid_block_only_cross_attention: Optional[bool] = None,
|
172
179
|
cross_attention_norm: Optional[str] = None,
|
180
|
+
addition_embed_type_num_heads=64,
|
173
181
|
):
|
174
182
|
super().__init__()
|
175
183
|
|
@@ -214,7 +222,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
214
222
|
|
215
223
|
# time
|
216
224
|
if time_embedding_type == "fourier":
|
217
|
-
time_embed_dim = block_out_channels[0] * 2
|
225
|
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
218
226
|
if time_embed_dim % 2 != 0:
|
219
227
|
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
220
228
|
self.time_proj = GaussianFourierProjection(
|
@@ -222,7 +230,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
222
230
|
)
|
223
231
|
timestep_input_dim = time_embed_dim
|
224
232
|
elif time_embedding_type == "positional":
|
225
|
-
time_embed_dim = block_out_channels[0] * 4
|
233
|
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
226
234
|
|
227
235
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
228
236
|
timestep_input_dim = block_out_channels[0]
|
@@ -248,7 +256,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
248
256
|
if class_embed_type is None and num_class_embeds is not None:
|
249
257
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
250
258
|
elif class_embed_type == "timestep":
|
251
|
-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
259
|
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
252
260
|
elif class_embed_type == "identity":
|
253
261
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
254
262
|
elif class_embed_type == "projection":
|
@@ -273,6 +281,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
273
281
|
else:
|
274
282
|
self.class_embedding = None
|
275
283
|
|
284
|
+
if addition_embed_type == "text":
|
285
|
+
if encoder_hid_dim is not None:
|
286
|
+
text_time_embedding_from_dim = encoder_hid_dim
|
287
|
+
else:
|
288
|
+
text_time_embedding_from_dim = cross_attention_dim
|
289
|
+
|
290
|
+
self.add_embedding = TextTimeEmbedding(
|
291
|
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
292
|
+
)
|
293
|
+
elif addition_embed_type is not None:
|
294
|
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
|
295
|
+
|
276
296
|
if time_embedding_act_fn is None:
|
277
297
|
self.time_embed_act = None
|
278
298
|
elif time_embedding_act_fn == "swish":
|
@@ -437,7 +457,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
437
457
|
self.conv_norm_out = nn.GroupNorm(
|
438
458
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
439
459
|
)
|
440
|
-
|
460
|
+
|
461
|
+
if act_fn == "swish":
|
462
|
+
self.conv_act = lambda x: F.silu(x)
|
463
|
+
elif act_fn == "mish":
|
464
|
+
self.conv_act = nn.Mish()
|
465
|
+
elif act_fn == "silu":
|
466
|
+
self.conv_act = nn.SiLU()
|
467
|
+
elif act_fn == "gelu":
|
468
|
+
self.conv_act = nn.GELU()
|
469
|
+
else:
|
470
|
+
raise ValueError(f"Unsupported activation function: {act_fn}")
|
471
|
+
|
441
472
|
else:
|
442
473
|
self.conv_norm_out = None
|
443
474
|
self.conv_act = None
|
@@ -648,7 +679,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
648
679
|
|
649
680
|
t_emb = self.time_proj(timesteps)
|
650
681
|
|
651
|
-
#
|
682
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
652
683
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
653
684
|
# there might be better ways to encapsulate this.
|
654
685
|
t_emb = t_emb.to(dtype=self.dtype)
|
@@ -662,6 +693,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
662
693
|
if self.config.class_embed_type == "timestep":
|
663
694
|
class_labels = self.time_proj(class_labels)
|
664
695
|
|
696
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
697
|
+
# there might be better ways to encapsulate this.
|
698
|
+
class_labels = class_labels.to(dtype=sample.dtype)
|
699
|
+
|
665
700
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
666
701
|
|
667
702
|
if self.config.class_embeddings_concat:
|
@@ -669,6 +704,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
669
704
|
else:
|
670
705
|
emb = emb + class_emb
|
671
706
|
|
707
|
+
if self.config.addition_embed_type == "text":
|
708
|
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
709
|
+
emb = emb + aug_emb
|
710
|
+
|
672
711
|
if self.time_embed_act is not None:
|
673
712
|
emb = self.time_embed_act(emb)
|
674
713
|
|
diffusers/models/vae.py
CHANGED
@@ -212,6 +212,7 @@ class Decoder(nn.Module):
|
|
212
212
|
sample = z
|
213
213
|
sample = self.conv_in(sample)
|
214
214
|
|
215
|
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
215
216
|
if self.training and self.gradient_checkpointing:
|
216
217
|
|
217
218
|
def create_custom_forward(module):
|
@@ -222,6 +223,7 @@ class Decoder(nn.Module):
|
|
222
223
|
|
223
224
|
# middle
|
224
225
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
226
|
+
sample = sample.to(upscale_dtype)
|
225
227
|
|
226
228
|
# up
|
227
229
|
for up_block in self.up_blocks:
|
@@ -229,6 +231,7 @@ class Decoder(nn.Module):
|
|
229
231
|
else:
|
230
232
|
# middle
|
231
233
|
sample = self.mid_block(sample)
|
234
|
+
sample = sample.to(upscale_dtype)
|
232
235
|
|
233
236
|
# up
|
234
237
|
for up_block in self.up_blocks:
|
diffusers/pipelines/__init__.py
CHANGED
@@ -44,6 +44,14 @@ except OptionalDependencyNotAvailable:
|
|
44
44
|
else:
|
45
45
|
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
46
46
|
from .audioldm import AudioLDMPipeline
|
47
|
+
from .deepfloyd_if import (
|
48
|
+
IFImg2ImgPipeline,
|
49
|
+
IFImg2ImgSuperResolutionPipeline,
|
50
|
+
IFInpaintingPipeline,
|
51
|
+
IFInpaintingSuperResolutionPipeline,
|
52
|
+
IFPipeline,
|
53
|
+
IFSuperResolutionPipeline,
|
54
|
+
)
|
47
55
|
from .latent_diffusion import LDMTextToImagePipeline
|
48
56
|
from .paint_by_example import PaintByExamplePipeline
|
49
57
|
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
@@ -56,7 +56,7 @@ class RobertaSeriesConfig(XLMRobertaConfig):
|
|
56
56
|
|
57
57
|
|
58
58
|
class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
59
|
-
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
59
|
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
|
60
60
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
61
61
|
base_model_prefix = "roberta"
|
62
62
|
config_class = RobertaSeriesConfig
|
@@ -65,6 +65,10 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
|
65
65
|
super().__init__(config)
|
66
66
|
self.roberta = XLMRobertaModel(config)
|
67
67
|
self.transformation = nn.Linear(config.hidden_size, config.project_dim)
|
68
|
+
self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
|
69
|
+
if self.has_pre_transformation:
|
70
|
+
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
71
|
+
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
68
72
|
self.post_init()
|
69
73
|
|
70
74
|
def forward(
|
@@ -95,15 +99,26 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
|
95
99
|
encoder_hidden_states=encoder_hidden_states,
|
96
100
|
encoder_attention_mask=encoder_attention_mask,
|
97
101
|
output_attentions=output_attentions,
|
98
|
-
output_hidden_states=output_hidden_states,
|
102
|
+
output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
|
99
103
|
return_dict=return_dict,
|
100
104
|
)
|
101
105
|
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
106
|
+
if self.has_pre_transformation:
|
107
|
+
sequence_output2 = outputs["hidden_states"][-2]
|
108
|
+
sequence_output2 = self.pre_LN(sequence_output2)
|
109
|
+
projection_state2 = self.transformation_pre(sequence_output2)
|
110
|
+
|
111
|
+
return TransformationModelOutput(
|
112
|
+
projection_state=projection_state2,
|
113
|
+
last_hidden_state=outputs.last_hidden_state,
|
114
|
+
hidden_states=outputs.hidden_states,
|
115
|
+
attentions=outputs.attentions,
|
116
|
+
)
|
117
|
+
else:
|
118
|
+
projection_state = self.transformation(outputs.last_hidden_state)
|
119
|
+
return TransformationModelOutput(
|
120
|
+
projection_state=projection_state,
|
121
|
+
last_hidden_state=outputs.last_hidden_state,
|
122
|
+
hidden_states=outputs.hidden_states,
|
123
|
+
attentions=outputs.attentions,
|
124
|
+
)
|
@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
|
57
57
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
58
58
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
59
59
|
|
60
|
+
In addition the pipeline inherits the following loading methods:
|
61
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
62
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
63
|
+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
64
|
+
|
65
|
+
as well as the following saving methods:
|
66
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
67
|
+
|
60
68
|
Args:
|
61
69
|
vae ([`AutoencoderKL`]):
|
62
70
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
|
96
96
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
97
97
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
98
98
|
|
99
|
+
In addition the pipeline inherits the following loading methods:
|
100
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
101
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
102
|
+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
103
|
+
|
104
|
+
as well as the following saving methods:
|
105
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
106
|
+
|
99
107
|
Args:
|
100
108
|
vae ([`AutoencoderKL`]):
|
101
109
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -293,7 +293,7 @@ class AudioLDMPipeline(DiffusionPipeline):
|
|
293
293
|
|
294
294
|
waveform = self.vocoder(mel_spectrogram)
|
295
295
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
296
|
-
waveform = waveform.cpu()
|
296
|
+
waveform = waveform.cpu().float()
|
297
297
|
return waveform
|
298
298
|
|
299
299
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|