diffsynth-engine 0.4.3.dev9__py3-none-any.whl → 0.4.3.dev10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +2 -1
- diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +29 -0
- diffsynth_engine/models/basic/attention.py +3 -3
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +41 -57
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +45 -28
- diffsynth_engine/pipelines/base.py +1 -1
- diffsynth_engine/pipelines/qwen_image.py +125 -13
- diffsynth_engine/pipelines/sd_image.py +3 -3
- diffsynth_engine/pipelines/sdxl_image.py +10 -6
- diffsynth_engine/tokenizers/__init__.py +4 -0
- diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +157 -0
- diffsynth_engine/tokenizers/qwen2_vl_processor.py +100 -0
- diffsynth_engine/utils/constants.py +6 -0
- diffsynth_engine/utils/image.py +213 -0
- diffsynth_engine/utils/offload.py +6 -5
- {diffsynth_engine-0.4.3.dev9.dist-info → diffsynth_engine-0.4.3.dev10.dist-info}/METADATA +2 -2
- {diffsynth_engine-0.4.3.dev9.dist-info → diffsynth_engine-0.4.3.dev10.dist-info}/RECORD +20 -17
- {diffsynth_engine-0.4.3.dev9.dist-info → diffsynth_engine-0.4.3.dev10.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.4.3.dev9.dist-info → diffsynth_engine-0.4.3.dev10.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.4.3.dev9.dist-info → diffsynth_engine-0.4.3.dev10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
{
|
|
2
|
+
"do_convert_rgb": true,
|
|
3
|
+
"do_normalize": true,
|
|
4
|
+
"do_rescale": true,
|
|
5
|
+
"do_resize": true,
|
|
6
|
+
"image_mean": [
|
|
7
|
+
0.48145466,
|
|
8
|
+
0.4578275,
|
|
9
|
+
0.40821073
|
|
10
|
+
],
|
|
11
|
+
"image_processor_type": "Qwen2VLImageProcessor",
|
|
12
|
+
"image_std": [
|
|
13
|
+
0.26862954,
|
|
14
|
+
0.26130258,
|
|
15
|
+
0.27577711
|
|
16
|
+
],
|
|
17
|
+
"max_pixels": 12845056,
|
|
18
|
+
"merge_size": 2,
|
|
19
|
+
"min_pixels": 3136,
|
|
20
|
+
"patch_size": 14,
|
|
21
|
+
"processor_class": "Qwen2_5_VLProcessor",
|
|
22
|
+
"resample": 3,
|
|
23
|
+
"rescale_factor": 0.00392156862745098,
|
|
24
|
+
"size": {
|
|
25
|
+
"longest_edge": 12845056,
|
|
26
|
+
"shortest_edge": 3136
|
|
27
|
+
},
|
|
28
|
+
"temporal_patch_size": 2
|
|
29
|
+
}
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
3
4
|
from einops import rearrange, repeat
|
|
4
5
|
from typing import Optional
|
|
5
6
|
|
|
6
|
-
import torch.nn.functional as F
|
|
7
7
|
from diffsynth_engine.utils import logging
|
|
8
8
|
from diffsynth_engine.utils.flag import (
|
|
9
9
|
FLASH_ATTN_3_AVAILABLE,
|
|
@@ -42,11 +42,11 @@ if XFORMERS_AVAILABLE:
|
|
|
42
42
|
|
|
43
43
|
if SDPA_AVAILABLE:
|
|
44
44
|
|
|
45
|
-
def sdpa_attn(q, k, v, attn_mask=None, scale=None):
|
|
45
|
+
def sdpa_attn(q, k, v, attn_mask=None, is_causal=False, scale=None):
|
|
46
46
|
q = q.transpose(1, 2)
|
|
47
47
|
k = k.transpose(1, 2)
|
|
48
48
|
v = v.transpose(1, 2)
|
|
49
|
-
out =
|
|
49
|
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=is_causal, scale=scale)
|
|
50
50
|
return out.transpose(1, 2)
|
|
51
51
|
|
|
52
52
|
|
|
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Tuple, Optional
|
|
|
7
7
|
|
|
8
8
|
from diffsynth_engine.models.base import PreTrainedModel
|
|
9
9
|
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
|
|
10
|
-
from diffsynth_engine.models.basic
|
|
10
|
+
from diffsynth_engine.models.basic import attention as attention_ops
|
|
11
11
|
from diffsynth_engine.models.utils import no_init_weights
|
|
12
12
|
from diffsynth_engine.utils.cache import Cache, DynamicCache
|
|
13
13
|
from diffsynth_engine.utils import logging
|
|
@@ -152,17 +152,15 @@ class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
|
|
152
152
|
self,
|
|
153
153
|
dim: int = 80,
|
|
154
154
|
theta: float = 10000.0,
|
|
155
|
-
device: str = "cuda:0",
|
|
156
|
-
dtype: torch.dtype = torch.bfloat16,
|
|
157
155
|
):
|
|
158
156
|
super().__init__()
|
|
159
|
-
with torch.device(
|
|
160
|
-
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2
|
|
161
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
157
|
+
with torch.device("cpu"):
|
|
158
|
+
self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
|
162
159
|
|
|
163
|
-
def forward(self, seqlen: int) -> torch.Tensor:
|
|
164
|
-
|
|
165
|
-
|
|
160
|
+
def forward(self, seqlen: int, device: str) -> torch.Tensor:
|
|
161
|
+
inv_freq = self.inv_freq.to(device=device)
|
|
162
|
+
seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
|
|
163
|
+
freqs = torch.outer(seq, inv_freq)
|
|
166
164
|
return freqs
|
|
167
165
|
|
|
168
166
|
|
|
@@ -222,7 +220,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|
|
222
220
|
q = rearrange(q, "s n d -> 1 s n d")
|
|
223
221
|
k = rearrange(k, "s n d -> 1 s n d")
|
|
224
222
|
v = rearrange(v, "s n d -> 1 s n d")
|
|
225
|
-
out = attention(q, k, v, attn_impl=self.attn_impl, attn_mask=attention_mask)
|
|
223
|
+
out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl, attn_mask=attention_mask)
|
|
226
224
|
out = rearrange(out, "1 s n d -> s (n d)")
|
|
227
225
|
out = self.proj(out)
|
|
228
226
|
return out
|
|
@@ -301,7 +299,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
|
301
299
|
dtype=dtype,
|
|
302
300
|
)
|
|
303
301
|
head_dim = config.hidden_size // config.num_heads
|
|
304
|
-
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2
|
|
302
|
+
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
|
305
303
|
self.blocks = nn.ModuleList(
|
|
306
304
|
[
|
|
307
305
|
Qwen2_5_VisionBlock(
|
|
@@ -348,7 +346,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
|
348
346
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
349
347
|
pos_ids = torch.cat(pos_ids, dim=0)
|
|
350
348
|
max_grid_size = grid_thw[:, 1:].max()
|
|
351
|
-
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
349
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device=grid_thw.device)
|
|
352
350
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
353
351
|
return rotary_pos_emb
|
|
354
352
|
|
|
@@ -488,7 +486,6 @@ class Qwen2_5_Attention(nn.Module):
|
|
|
488
486
|
hidden_size: int = 3584,
|
|
489
487
|
num_attention_heads: int = 28,
|
|
490
488
|
num_key_value_heads: int = 4,
|
|
491
|
-
# dropout: float = 0.0,
|
|
492
489
|
mrope_section: List[int] = [16, 24, 24],
|
|
493
490
|
attn_impl: Optional[str] = None,
|
|
494
491
|
device: str = "cuda:0",
|
|
@@ -501,7 +498,6 @@ class Qwen2_5_Attention(nn.Module):
|
|
|
501
498
|
self.head_dim = hidden_size // num_attention_heads
|
|
502
499
|
self.num_key_value_heads = num_key_value_heads
|
|
503
500
|
self.num_key_value_groups = num_attention_heads // num_key_value_heads
|
|
504
|
-
# self.dropout = dropout
|
|
505
501
|
self.mrope_section = mrope_section
|
|
506
502
|
self.attn_impl = attn_impl
|
|
507
503
|
|
|
@@ -521,8 +517,6 @@ class Qwen2_5_Attention(nn.Module):
|
|
|
521
517
|
self.num_attention_heads * self.head_dim, self.hidden_size, bias=False, device=device, dtype=dtype
|
|
522
518
|
)
|
|
523
519
|
|
|
524
|
-
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=self.head_dim, device=device, dtype=dtype)
|
|
525
|
-
|
|
526
520
|
def forward(
|
|
527
521
|
self,
|
|
528
522
|
hidden_states: torch.Tensor,
|
|
@@ -556,14 +550,18 @@ class Qwen2_5_Attention(nn.Module):
|
|
|
556
550
|
if attention_mask is not None: # no matter the length, we just slice it
|
|
557
551
|
causal_mask = attention_mask[:, :, :, : key_states.shape[1]]
|
|
558
552
|
|
|
559
|
-
# TODO:
|
|
560
|
-
|
|
561
|
-
query_states,
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
553
|
+
# TODO: use is_causal when attention mask is causal
|
|
554
|
+
if self.attn_impl == "sdpa":
|
|
555
|
+
out = attention_ops.sdpa_attn(query_states, key_states, value_states, is_causal=True)
|
|
556
|
+
else:
|
|
557
|
+
# TODO: attention_mask for flash attention 2
|
|
558
|
+
out = attention_ops.attention(
|
|
559
|
+
query_states,
|
|
560
|
+
key_states,
|
|
561
|
+
value_states,
|
|
562
|
+
attn_impl=self.attn_impl,
|
|
563
|
+
attn_mask=causal_mask,
|
|
564
|
+
)
|
|
567
565
|
out = rearrange(out, "b s n d -> b s (n d)")
|
|
568
566
|
out = self.o_proj(out)
|
|
569
567
|
return out, past_key_values
|
|
@@ -647,29 +645,29 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
|
|
|
647
645
|
|
|
648
646
|
|
|
649
647
|
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
|
650
|
-
def __init__(self, dim: int = 128
|
|
648
|
+
def __init__(self, dim: int = 128):
|
|
651
649
|
super().__init__()
|
|
652
|
-
with torch.device(
|
|
653
|
-
inv_freq = self.compute_rope(dim) # default rope without dynamic frequency
|
|
654
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
650
|
+
with torch.device("cpu"):
|
|
651
|
+
self.inv_freq = self.compute_rope(dim) # default rope without dynamic frequency
|
|
655
652
|
|
|
656
653
|
def compute_rope(self, dim: int, theta: float = 1000000.0):
|
|
657
654
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
|
658
655
|
return inv_freq
|
|
659
656
|
|
|
660
657
|
@torch.no_grad()
|
|
661
|
-
def forward(self,
|
|
658
|
+
def forward(self, position_ids: torch.LongTensor, device: str, dtype: torch.dtype):
|
|
662
659
|
# In contrast to other models, Qwen2_5_VL has different position ids for the grids
|
|
663
660
|
# So we expand the inv_freq to shape (3, ...)
|
|
664
|
-
|
|
661
|
+
inv_freq = self.inv_freq.to(device=device)
|
|
662
|
+
inv_freq_expanded = inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
|
665
663
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
666
664
|
|
|
667
|
-
freqs = (inv_freq_expanded
|
|
665
|
+
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(2, 3)
|
|
668
666
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
669
667
|
cos = emb.cos()
|
|
670
668
|
sin = emb.sin()
|
|
671
669
|
|
|
672
|
-
return cos.to(device=
|
|
670
|
+
return cos.to(device=device, dtype=dtype), sin.to(device=device, dtype=dtype)
|
|
673
671
|
|
|
674
672
|
|
|
675
673
|
class Qwen2_5_VLModel(nn.Module):
|
|
@@ -702,7 +700,7 @@ class Qwen2_5_VLModel(nn.Module):
|
|
|
702
700
|
)
|
|
703
701
|
self.norm = Qwen2_5_RMSNorm(config.hidden_size, config.rms_norm_eps, device=device, dtype=dtype)
|
|
704
702
|
head_dim = config.hidden_size // config.num_attention_heads
|
|
705
|
-
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=head_dim
|
|
703
|
+
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=head_dim)
|
|
706
704
|
|
|
707
705
|
def get_input_embeddings(self):
|
|
708
706
|
return self.embed_tokens
|
|
@@ -749,7 +747,7 @@ class Qwen2_5_VLModel(nn.Module):
|
|
|
749
747
|
hidden_states = inputs_embeds
|
|
750
748
|
|
|
751
749
|
# create position embeddings to be shared across the decoder layers
|
|
752
|
-
position_embeddings = self.rotary_emb(hidden_states,
|
|
750
|
+
position_embeddings = self.rotary_emb(position_ids, device=hidden_states.device, dtype=hidden_states.dtype)
|
|
753
751
|
|
|
754
752
|
# decoder layers
|
|
755
753
|
for decoder_layer in self.layers:
|
|
@@ -940,8 +938,7 @@ class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
|
|
|
940
938
|
with torch.device("meta"), no_init_weights():
|
|
941
939
|
model = cls(vision_config=vision_config, config=config, device=device, dtype=dtype)
|
|
942
940
|
model.load_state_dict(state_dict, assign=True)
|
|
943
|
-
|
|
944
|
-
param.data = param.data.to(device=device, dtype=dtype, non_blocking=True)
|
|
941
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
945
942
|
return model
|
|
946
943
|
|
|
947
944
|
def get_input_embeddings(self):
|
|
@@ -1202,27 +1199,14 @@ class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
|
|
|
1202
1199
|
if position_ids is None:
|
|
1203
1200
|
assert attention_mask is None or attention_mask.ndim == 2, "attention mask must be 2D"
|
|
1204
1201
|
# calculate RoPE index once per generation in the pre-fill stage only
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
self.rope_deltas = rope_deltas
|
|
1214
|
-
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
1215
|
-
else:
|
|
1216
|
-
batch_size, seq_length, _ = inputs_embeds.shape
|
|
1217
|
-
delta = (
|
|
1218
|
-
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
|
|
1219
|
-
)
|
|
1220
|
-
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
1221
|
-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
1222
|
-
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
1223
|
-
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
1224
|
-
position_ids = position_ids.add(delta)
|
|
1225
|
-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
1202
|
+
position_ids, rope_deltas = self.get_rope_index(
|
|
1203
|
+
input_ids,
|
|
1204
|
+
image_grid_thw,
|
|
1205
|
+
video_grid_thw,
|
|
1206
|
+
second_per_grid_ts,
|
|
1207
|
+
attention_mask,
|
|
1208
|
+
)
|
|
1209
|
+
self.rope_deltas = rope_deltas
|
|
1226
1210
|
|
|
1227
1211
|
hidden_states, present_key_values = self.model(
|
|
1228
1212
|
input_ids=None,
|
|
@@ -81,41 +81,47 @@ class QwenEmbedRope(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
def forward(self, video_fhw, txt_length, device):
|
|
83
83
|
"""
|
|
84
|
-
Args:
|
|
85
|
-
|
|
84
|
+
Args:
|
|
85
|
+
video_fhw (List[Tuple[int, int, int]]): A list of (frame, height, width) tuples for each video/image
|
|
86
|
+
txt_length (int): The maximum length of the text sequences
|
|
86
87
|
"""
|
|
87
88
|
if self.pos_freqs.device != device:
|
|
88
89
|
self.pos_freqs = self.pos_freqs.to(device)
|
|
89
90
|
self.neg_freqs = self.neg_freqs.to(device)
|
|
90
91
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
92
|
+
vid_freqs = []
|
|
93
|
+
max_vid_index = 0
|
|
94
|
+
for idx, fhw in enumerate(video_fhw):
|
|
95
|
+
frame, height, width = fhw
|
|
96
|
+
rope_key = f"{idx}_{height}_{width}"
|
|
97
|
+
|
|
98
|
+
if rope_key not in self.rope_cache:
|
|
99
|
+
seq_lens = frame * height * width
|
|
100
|
+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
101
|
+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
102
|
+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
|
103
|
+
if self.scale_rope:
|
|
104
|
+
freqs_height = torch.cat(
|
|
105
|
+
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
|
106
|
+
)
|
|
107
|
+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
108
|
+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
|
109
|
+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
110
|
+
|
|
111
|
+
else:
|
|
112
|
+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
113
|
+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
114
|
+
|
|
115
|
+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
|
116
|
+
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
|
117
|
+
vid_freqs.append(self.rope_cache[rope_key])
|
|
99
118
|
if self.scale_rope:
|
|
100
|
-
|
|
101
|
-
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
102
|
-
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
|
103
|
-
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
104
|
-
|
|
119
|
+
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
|
105
120
|
else:
|
|
106
|
-
|
|
107
|
-
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
108
|
-
|
|
109
|
-
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
|
110
|
-
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
|
111
|
-
vid_freqs = self.rope_cache[rope_key]
|
|
112
|
-
|
|
113
|
-
if self.scale_rope:
|
|
114
|
-
max_vid_index = max(height // 2, width // 2)
|
|
115
|
-
else:
|
|
116
|
-
max_vid_index = max(height, width)
|
|
121
|
+
max_vid_index = max(height, width, max_vid_index)
|
|
117
122
|
|
|
118
123
|
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + txt_length, ...]
|
|
124
|
+
vid_freqs = torch.cat(vid_freqs, dim=0)
|
|
119
125
|
|
|
120
126
|
return vid_freqs, txt_freqs
|
|
121
127
|
|
|
@@ -364,6 +370,7 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
364
370
|
def forward(
|
|
365
371
|
self,
|
|
366
372
|
image: torch.Tensor,
|
|
373
|
+
edit: torch.Tensor = None,
|
|
367
374
|
text: torch.Tensor = None,
|
|
368
375
|
timestep: torch.LongTensor = None,
|
|
369
376
|
txt_seq_lens: torch.LongTensor = None,
|
|
@@ -377,6 +384,7 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
377
384
|
cfg_parallel(
|
|
378
385
|
(
|
|
379
386
|
image,
|
|
387
|
+
edit,
|
|
380
388
|
text,
|
|
381
389
|
timestep,
|
|
382
390
|
txt_seq_lens,
|
|
@@ -385,11 +393,18 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
385
393
|
),
|
|
386
394
|
):
|
|
387
395
|
conditioning = self.time_text_embed(timestep, image.dtype)
|
|
388
|
-
video_fhw = (1, h // 2, w // 2) # frame, height, width
|
|
396
|
+
video_fhw = [(1, h // 2, w // 2)] # frame, height, width
|
|
389
397
|
max_length = txt_seq_lens.max().item()
|
|
398
|
+
image = self.patchify(image)
|
|
399
|
+
image_seq_len = image.shape[1]
|
|
400
|
+
if edit is not None:
|
|
401
|
+
edit = edit.to(dtype=image.dtype)
|
|
402
|
+
edit = self.patchify(edit)
|
|
403
|
+
image = torch.cat([image, edit], dim=1)
|
|
404
|
+
video_fhw += video_fhw
|
|
405
|
+
|
|
390
406
|
image_rotary_emb = self.pos_embed(video_fhw, max_length, image.device)
|
|
391
407
|
|
|
392
|
-
image = self.patchify(image)
|
|
393
408
|
image = self.img_in(image)
|
|
394
409
|
text = self.txt_in(self.txt_norm(text[:, :max_length]))
|
|
395
410
|
|
|
@@ -397,6 +412,8 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
397
412
|
text, image = block(image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb)
|
|
398
413
|
image = self.norm_out(image, conditioning)
|
|
399
414
|
image = self.proj_out(image)
|
|
415
|
+
if edit is not None:
|
|
416
|
+
image = image[:, :image_seq_len]
|
|
400
417
|
|
|
401
418
|
image = self.unpatchify(image, h, w)
|
|
402
419
|
|
|
@@ -164,7 +164,7 @@ class BasePipeline:
|
|
|
164
164
|
@staticmethod
|
|
165
165
|
def generate_noise(shape, seed=None, device="cpu", dtype=torch.float16):
|
|
166
166
|
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
|
167
|
-
noise = torch.randn(shape, generator=generator, device=device
|
|
167
|
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
|
168
168
|
return noise
|
|
169
169
|
|
|
170
170
|
def encode_image(
|