diffsynth-engine 0.6.1.dev22__py3-none-any.whl → 0.6.1.dev23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/pipeline.py +33 -5
- diffsynth_engine/models/basic/attention.py +59 -20
- diffsynth_engine/models/basic/video_sparse_attention.py +235 -0
- diffsynth_engine/models/flux/flux_controlnet.py +7 -19
- diffsynth_engine/models/flux/flux_dit.py +22 -36
- diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +13 -15
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/wan/wan_dit.py +62 -22
- diffsynth_engine/pipelines/flux_image.py +11 -10
- diffsynth_engine/pipelines/qwen_image.py +3 -10
- diffsynth_engine/pipelines/wan_s2v.py +3 -8
- diffsynth_engine/pipelines/wan_video.py +11 -13
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/parallel.py +51 -6
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/RECORD +34 -32
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev22.dist-info → diffsynth_engine-0.6.1.dev23.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
{
|
|
2
|
+
"diffusers": {
|
|
3
|
+
"global_rename_dict": {
|
|
4
|
+
"patch_embedding": "patch_embedding",
|
|
5
|
+
"condition_embedder.text_embedder.linear_1": "text_embedding.0",
|
|
6
|
+
"condition_embedder.text_embedder.linear_2": "text_embedding.2",
|
|
7
|
+
"condition_embedder.time_embedder.linear_1": "time_embedding.0",
|
|
8
|
+
"condition_embedder.time_embedder.linear_2": "time_embedding.2",
|
|
9
|
+
"condition_embedder.time_proj": "time_projection.1",
|
|
10
|
+
"condition_embedder.image_embedder.norm1": "img_emb.proj.0",
|
|
11
|
+
"condition_embedder.image_embedder.ff.net.0.proj": "img_emb.proj.1",
|
|
12
|
+
"condition_embedder.image_embedder.ff.net.2": "img_emb.proj.3",
|
|
13
|
+
"condition_embedder.image_embedder.norm2": "img_emb.proj.4",
|
|
14
|
+
"condition_embedder.image_embedder.pos_embed": "img_emb.emb_pos",
|
|
15
|
+
"proj_out": "head.head",
|
|
16
|
+
"scale_shift_table": "head.modulation"
|
|
17
|
+
},
|
|
18
|
+
"rename_dict": {
|
|
19
|
+
"attn1.to_q": "self_attn.q",
|
|
20
|
+
"attn1.to_k": "self_attn.k",
|
|
21
|
+
"attn1.to_v": "self_attn.v",
|
|
22
|
+
"attn1.to_out.0": "self_attn.o",
|
|
23
|
+
"attn1.norm_q": "self_attn.norm_q",
|
|
24
|
+
"attn1.norm_k": "self_attn.norm_k",
|
|
25
|
+
"to_gate_compress": "self_attn.gate_compress",
|
|
26
|
+
"attn2.to_q": "cross_attn.q",
|
|
27
|
+
"attn2.to_k": "cross_attn.k",
|
|
28
|
+
"attn2.to_v": "cross_attn.v",
|
|
29
|
+
"attn2.to_out.0": "cross_attn.o",
|
|
30
|
+
"attn2.norm_q": "cross_attn.norm_q",
|
|
31
|
+
"attn2.norm_k": "cross_attn.norm_k",
|
|
32
|
+
"attn2.add_k_proj": "cross_attn.k_img",
|
|
33
|
+
"attn2.add_v_proj": "cross_attn.v_img",
|
|
34
|
+
"attn2.norm_added_k": "cross_attn.norm_k_img",
|
|
35
|
+
"norm2": "norm3",
|
|
36
|
+
"ffn.net.0.proj": "ffn.0",
|
|
37
|
+
"ffn.net.2": "ffn.2",
|
|
38
|
+
"scale_shift_table": "modulation"
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
}
|
|
@@ -5,6 +5,7 @@ from dataclasses import dataclass, field
|
|
|
5
5
|
from typing import List, Dict, Tuple, Optional
|
|
6
6
|
|
|
7
7
|
from diffsynth_engine.configs.controlnet import ControlType
|
|
8
|
+
from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@dataclass
|
|
@@ -30,16 +31,43 @@ class AttnImpl(Enum):
|
|
|
30
31
|
SDPA = "sdpa" # Scaled Dot Product Attention
|
|
31
32
|
SAGE = "sage" # Sage Attention
|
|
32
33
|
SPARGE = "sparge" # Sparge Attention
|
|
34
|
+
VSA = "vsa" # Video Sparse Attention
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class SpargeAttentionParams:
|
|
39
|
+
smooth_k: bool = True
|
|
40
|
+
cdfthreshd: float = 0.6
|
|
41
|
+
simthreshd1: float = 0.98
|
|
42
|
+
pvthreshd: float = 50.0
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class VideoSparseAttentionParams:
|
|
47
|
+
sparsity: float = 0.9
|
|
33
48
|
|
|
34
49
|
|
|
35
50
|
@dataclass
|
|
36
51
|
class AttentionConfig:
|
|
37
52
|
dit_attn_impl: AttnImpl = AttnImpl.AUTO
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
53
|
+
attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None
|
|
54
|
+
|
|
55
|
+
def get_attn_kwargs(self, latents: torch.Tensor, device: str) -> Dict:
|
|
56
|
+
attn_kwargs = {"attn_impl": self.dit_attn_impl.value}
|
|
57
|
+
if isinstance(self.attn_params, SpargeAttentionParams):
|
|
58
|
+
assert self.dit_attn_impl == AttnImpl.SPARGE
|
|
59
|
+
attn_kwargs.update(
|
|
60
|
+
{
|
|
61
|
+
"smooth_k": self.attn_params.smooth_k,
|
|
62
|
+
"simthreshd1": self.attn_params.simthreshd1,
|
|
63
|
+
"cdfthreshd": self.attn_params.cdfthreshd,
|
|
64
|
+
"pvthreshd": self.attn_params.pvthreshd,
|
|
65
|
+
}
|
|
66
|
+
)
|
|
67
|
+
elif isinstance(self.attn_params, VideoSparseAttentionParams):
|
|
68
|
+
assert self.dit_attn_impl == AttnImpl.VSA
|
|
69
|
+
attn_kwargs.update(get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.attn_params.sparsity, device=device))
|
|
70
|
+
return attn_kwargs
|
|
43
71
|
|
|
44
72
|
|
|
45
73
|
@dataclass
|
|
@@ -12,6 +12,7 @@ from diffsynth_engine.utils.flag import (
|
|
|
12
12
|
SDPA_AVAILABLE,
|
|
13
13
|
SAGE_ATTN_AVAILABLE,
|
|
14
14
|
SPARGE_ATTN_AVAILABLE,
|
|
15
|
+
VIDEO_SPARSE_ATTN_AVAILABLE,
|
|
15
16
|
)
|
|
16
17
|
from diffsynth_engine.utils.platform import DTYPE_FP8
|
|
17
18
|
|
|
@@ -20,12 +21,6 @@ FA3_MAX_HEADDIM = 256
|
|
|
20
21
|
logger = logging.get_logger(__name__)
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
|
|
24
|
-
padding_size = (alignment - x.shape[dim] % alignment) % alignment
|
|
25
|
-
padded_x = F.pad(x, (0, padding_size), "constant", 0)
|
|
26
|
-
return padded_x[..., : x.shape[dim]]
|
|
27
|
-
|
|
28
|
-
|
|
29
24
|
if FLASH_ATTN_3_AVAILABLE:
|
|
30
25
|
from flash_attn_interface import flash_attn_func as flash_attn3
|
|
31
26
|
if FLASH_ATTN_2_AVAILABLE:
|
|
@@ -33,6 +28,11 @@ if FLASH_ATTN_2_AVAILABLE:
|
|
|
33
28
|
if XFORMERS_AVAILABLE:
|
|
34
29
|
from xformers.ops import memory_efficient_attention
|
|
35
30
|
|
|
31
|
+
def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
|
|
32
|
+
padding_size = (alignment - x.shape[dim] % alignment) % alignment
|
|
33
|
+
padded_x = F.pad(x, (0, padding_size), "constant", 0)
|
|
34
|
+
return padded_x[..., : x.shape[dim]]
|
|
35
|
+
|
|
36
36
|
def xformers_attn(q, k, v, attn_mask=None, scale=None):
|
|
37
37
|
if attn_mask is not None:
|
|
38
38
|
if attn_mask.ndim == 2:
|
|
@@ -94,6 +94,13 @@ if SPARGE_ATTN_AVAILABLE:
|
|
|
94
94
|
return out.transpose(1, 2)
|
|
95
95
|
|
|
96
96
|
|
|
97
|
+
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
98
|
+
from diffsynth_engine.models.basic.video_sparse_attention import (
|
|
99
|
+
video_sparse_attn,
|
|
100
|
+
distributed_video_sparse_attn,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
97
104
|
def eager_attn(q, k, v, attn_mask=None, scale=None):
|
|
98
105
|
q = q.transpose(1, 2)
|
|
99
106
|
k = k.transpose(1, 2)
|
|
@@ -109,9 +116,10 @@ def eager_attn(q, k, v, attn_mask=None, scale=None):
|
|
|
109
116
|
|
|
110
117
|
|
|
111
118
|
def attention(
|
|
112
|
-
q,
|
|
113
|
-
k,
|
|
114
|
-
v,
|
|
119
|
+
q: torch.Tensor,
|
|
120
|
+
k: torch.Tensor,
|
|
121
|
+
v: torch.Tensor,
|
|
122
|
+
g: Optional[torch.Tensor] = None,
|
|
115
123
|
attn_impl: Optional[str] = "auto",
|
|
116
124
|
attn_mask: Optional[torch.Tensor] = None,
|
|
117
125
|
scale: Optional[float] = None,
|
|
@@ -133,6 +141,7 @@ def attention(
|
|
|
133
141
|
"sdpa",
|
|
134
142
|
"sage",
|
|
135
143
|
"sparge",
|
|
144
|
+
"vsa",
|
|
136
145
|
]
|
|
137
146
|
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
|
|
138
147
|
if attn_impl is None or attn_impl == "auto":
|
|
@@ -189,10 +198,24 @@ def attention(
|
|
|
189
198
|
v,
|
|
190
199
|
attn_mask=attn_mask,
|
|
191
200
|
scale=scale,
|
|
192
|
-
smooth_k=kwargs.get("
|
|
193
|
-
simthreshd1=kwargs.get("
|
|
194
|
-
cdfthreshd=kwargs.get("
|
|
195
|
-
pvthreshd=kwargs.get("
|
|
201
|
+
smooth_k=kwargs.get("smooth_k", True),
|
|
202
|
+
simthreshd1=kwargs.get("simthreshd1", 0.6),
|
|
203
|
+
cdfthreshd=kwargs.get("cdfthreshd", 0.98),
|
|
204
|
+
pvthreshd=kwargs.get("pvthreshd", 50),
|
|
205
|
+
)
|
|
206
|
+
if attn_impl == "vsa":
|
|
207
|
+
return video_sparse_attn(
|
|
208
|
+
q,
|
|
209
|
+
k,
|
|
210
|
+
v,
|
|
211
|
+
g,
|
|
212
|
+
sparsity=kwargs.get("sparsity"),
|
|
213
|
+
num_tiles=kwargs.get("num_tiles"),
|
|
214
|
+
total_seq_length=kwargs.get("total_seq_length"),
|
|
215
|
+
tile_partition_indices=kwargs.get("tile_partition_indices"),
|
|
216
|
+
reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
|
|
217
|
+
variable_block_sizes=kwargs.get("variable_block_sizes"),
|
|
218
|
+
non_pad_index=kwargs.get("non_pad_index"),
|
|
196
219
|
)
|
|
197
220
|
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
|
198
221
|
|
|
@@ -242,9 +265,10 @@ class Attention(nn.Module):
|
|
|
242
265
|
|
|
243
266
|
|
|
244
267
|
def long_context_attention(
|
|
245
|
-
q,
|
|
246
|
-
k,
|
|
247
|
-
v,
|
|
268
|
+
q: torch.Tensor,
|
|
269
|
+
k: torch.Tensor,
|
|
270
|
+
v: torch.Tensor,
|
|
271
|
+
g: Optional[torch.Tensor] = None,
|
|
248
272
|
attn_impl: Optional[str] = None,
|
|
249
273
|
attn_mask: Optional[torch.Tensor] = None,
|
|
250
274
|
scale: Optional[float] = None,
|
|
@@ -267,6 +291,7 @@ def long_context_attention(
|
|
|
267
291
|
"sdpa",
|
|
268
292
|
"sage",
|
|
269
293
|
"sparge",
|
|
294
|
+
"vsa",
|
|
270
295
|
]
|
|
271
296
|
assert attn_mask is None, "long context attention does not support attention mask"
|
|
272
297
|
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
|
|
@@ -307,11 +332,25 @@ def long_context_attention(
|
|
|
307
332
|
if attn_impl == "sparge":
|
|
308
333
|
attn_processor = SparseAttentionMeansim()
|
|
309
334
|
# default args from spas_sage2_attn_meansim_cuda
|
|
310
|
-
attn_processor.smooth_k = torch.tensor(kwargs.get("
|
|
311
|
-
attn_processor.simthreshd1 = torch.tensor(kwargs.get("
|
|
312
|
-
attn_processor.cdfthreshd = torch.tensor(kwargs.get("
|
|
313
|
-
attn_processor.pvthreshd = torch.tensor(kwargs.get("
|
|
335
|
+
attn_processor.smooth_k = torch.tensor(kwargs.get("smooth_k", True))
|
|
336
|
+
attn_processor.simthreshd1 = torch.tensor(kwargs.get("simthreshd1", 0.6))
|
|
337
|
+
attn_processor.cdfthreshd = torch.tensor(kwargs.get("cdfthreshd", 0.98))
|
|
338
|
+
attn_processor.pvthreshd = torch.tensor(kwargs.get("pvthreshd", 50))
|
|
314
339
|
return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
|
|
315
340
|
q, k, v, softmax_scale=scale
|
|
316
341
|
)
|
|
342
|
+
if attn_impl == "vsa":
|
|
343
|
+
return distributed_video_sparse_attn(
|
|
344
|
+
q,
|
|
345
|
+
k,
|
|
346
|
+
v,
|
|
347
|
+
g,
|
|
348
|
+
sparsity=kwargs.get("sparsity"),
|
|
349
|
+
num_tiles=kwargs.get("num_tiles"),
|
|
350
|
+
total_seq_length=kwargs.get("total_seq_length"),
|
|
351
|
+
tile_partition_indices=kwargs.get("tile_partition_indices"),
|
|
352
|
+
reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
|
|
353
|
+
variable_block_sizes=kwargs.get("variable_block_sizes"),
|
|
354
|
+
non_pad_index=kwargs.get("non_pad_index"),
|
|
355
|
+
)
|
|
317
356
|
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import math
|
|
3
|
+
import functools
|
|
4
|
+
|
|
5
|
+
from vsa import video_sparse_attn as vsa_core
|
|
6
|
+
from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
|
|
7
|
+
|
|
8
|
+
VSA_TILE_SIZE = (4, 4, 4)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@functools.lru_cache(maxsize=10)
|
|
12
|
+
def get_tile_partition_indices(
|
|
13
|
+
dit_seq_shape: tuple[int, int, int],
|
|
14
|
+
tile_size: tuple[int, int, int],
|
|
15
|
+
device: torch.device,
|
|
16
|
+
) -> torch.LongTensor:
|
|
17
|
+
T, H, W = dit_seq_shape
|
|
18
|
+
ts, hs, ws = tile_size
|
|
19
|
+
indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
|
|
20
|
+
ls = []
|
|
21
|
+
for t in range(math.ceil(T / ts)):
|
|
22
|
+
for h in range(math.ceil(H / hs)):
|
|
23
|
+
for w in range(math.ceil(W / ws)):
|
|
24
|
+
ls.append(
|
|
25
|
+
indices[
|
|
26
|
+
t * ts : min(t * ts + ts, T), h * hs : min(h * hs + hs, H), w * ws : min(w * ws + ws, W)
|
|
27
|
+
].flatten()
|
|
28
|
+
)
|
|
29
|
+
index = torch.cat(ls, dim=0)
|
|
30
|
+
return index
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@functools.lru_cache(maxsize=10)
|
|
34
|
+
def get_reverse_tile_partition_indices(
|
|
35
|
+
dit_seq_shape: tuple[int, int, int],
|
|
36
|
+
tile_size: tuple[int, int, int],
|
|
37
|
+
device: torch.device,
|
|
38
|
+
) -> torch.LongTensor:
|
|
39
|
+
return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@functools.lru_cache(maxsize=10)
|
|
43
|
+
def construct_variable_block_sizes(
|
|
44
|
+
dit_seq_shape: tuple[int, int, int],
|
|
45
|
+
num_tiles: tuple[int, int, int],
|
|
46
|
+
device: torch.device,
|
|
47
|
+
) -> torch.LongTensor:
|
|
48
|
+
"""
|
|
49
|
+
Compute the number of valid (non-padded) tokens inside every
|
|
50
|
+
(ts_t x ts_h x ts_w) tile after padding -- flattened in the order
|
|
51
|
+
(t-tile, h-tile, w-tile) that `rearrange` uses.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
torch.LongTensor # shape: [∏ full_window_size]
|
|
56
|
+
"""
|
|
57
|
+
# unpack
|
|
58
|
+
t, h, w = dit_seq_shape
|
|
59
|
+
ts_t, ts_h, ts_w = VSA_TILE_SIZE
|
|
60
|
+
n_t, n_h, n_w = num_tiles
|
|
61
|
+
|
|
62
|
+
def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
|
|
63
|
+
"""Vector with the size of each tile along one dimension."""
|
|
64
|
+
sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device)
|
|
65
|
+
# size of last (possibly partial) tile
|
|
66
|
+
remainder = dim_len - (n_tiles - 1) * tile
|
|
67
|
+
sizes[-1] = remainder if remainder > 0 else tile
|
|
68
|
+
return sizes
|
|
69
|
+
|
|
70
|
+
t_sizes = _sizes(t, ts_t, n_t) # [n_t]
|
|
71
|
+
h_sizes = _sizes(h, ts_h, n_h) # [n_h]
|
|
72
|
+
w_sizes = _sizes(w, ts_w, n_w) # [n_w]
|
|
73
|
+
|
|
74
|
+
# broadcast‑multiply to get voxels per tile, then flatten
|
|
75
|
+
block_sizes = (
|
|
76
|
+
t_sizes[:, None, None] # [n_t, 1, 1]
|
|
77
|
+
* h_sizes[None, :, None] # [1, n_h, 1]
|
|
78
|
+
* w_sizes[None, None, :] # [1, 1, n_w]
|
|
79
|
+
).reshape(-1) # [n_t * n_h * n_w]
|
|
80
|
+
|
|
81
|
+
return block_sizes
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@functools.lru_cache(maxsize=10)
|
|
85
|
+
def get_non_pad_index(
|
|
86
|
+
variable_block_sizes: torch.LongTensor,
|
|
87
|
+
max_block_size: int,
|
|
88
|
+
):
|
|
89
|
+
n_win = variable_block_sizes.shape[0]
|
|
90
|
+
device = variable_block_sizes.device
|
|
91
|
+
starts_pad = torch.arange(n_win, device=device) * max_block_size
|
|
92
|
+
index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
|
|
93
|
+
index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
|
|
94
|
+
return index_pad[index_mask]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_vsa_kwargs(
|
|
98
|
+
latent_shape: tuple[int, int, int],
|
|
99
|
+
patch_size: tuple[int, int, int],
|
|
100
|
+
sparsity: float,
|
|
101
|
+
device: torch.device,
|
|
102
|
+
):
|
|
103
|
+
dit_seq_shape = (
|
|
104
|
+
latent_shape[0] // patch_size[0],
|
|
105
|
+
latent_shape[1] // patch_size[1],
|
|
106
|
+
latent_shape[2] // patch_size[2],
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
num_tiles = (
|
|
110
|
+
math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
|
|
111
|
+
math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
|
|
112
|
+
math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]),
|
|
113
|
+
)
|
|
114
|
+
total_seq_length = math.prod(dit_seq_shape)
|
|
115
|
+
|
|
116
|
+
tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
|
|
117
|
+
reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
|
|
118
|
+
variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
|
|
119
|
+
non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
|
|
120
|
+
|
|
121
|
+
return {
|
|
122
|
+
"sparsity": sparsity,
|
|
123
|
+
"num_tiles": num_tiles,
|
|
124
|
+
"total_seq_length": total_seq_length,
|
|
125
|
+
"tile_partition_indices": tile_partition_indices,
|
|
126
|
+
"reverse_tile_partition_indices": reverse_tile_partition_indices,
|
|
127
|
+
"variable_block_sizes": variable_block_sizes,
|
|
128
|
+
"non_pad_index": non_pad_index,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def tile(
|
|
133
|
+
x: torch.Tensor,
|
|
134
|
+
num_tiles: tuple[int, int, int],
|
|
135
|
+
tile_partition_indices: torch.LongTensor,
|
|
136
|
+
non_pad_index: torch.LongTensor,
|
|
137
|
+
) -> torch.Tensor:
|
|
138
|
+
t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
|
|
139
|
+
h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
|
|
140
|
+
w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
|
|
141
|
+
|
|
142
|
+
x_padded = torch.zeros(
|
|
143
|
+
(x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
|
|
144
|
+
device=x.device,
|
|
145
|
+
dtype=x.dtype,
|
|
146
|
+
)
|
|
147
|
+
x_padded[:, non_pad_index] = x[:, tile_partition_indices]
|
|
148
|
+
return x_padded
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def untile(
|
|
152
|
+
x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor
|
|
153
|
+
) -> torch.Tensor:
|
|
154
|
+
x = x[:, non_pad_index][:, reverse_tile_partition_indices]
|
|
155
|
+
return x
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def video_sparse_attn(
|
|
159
|
+
q: torch.Tensor,
|
|
160
|
+
k: torch.Tensor,
|
|
161
|
+
v: torch.Tensor,
|
|
162
|
+
g: torch.Tensor,
|
|
163
|
+
sparsity: float,
|
|
164
|
+
num_tiles: tuple[int, int, int],
|
|
165
|
+
total_seq_length: int,
|
|
166
|
+
tile_partition_indices: torch.LongTensor,
|
|
167
|
+
reverse_tile_partition_indices: torch.LongTensor,
|
|
168
|
+
variable_block_sizes: torch.LongTensor,
|
|
169
|
+
non_pad_index: torch.LongTensor,
|
|
170
|
+
):
|
|
171
|
+
q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
|
|
172
|
+
k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
|
|
173
|
+
v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
|
|
174
|
+
g = tile(g, num_tiles, tile_partition_indices, non_pad_index)
|
|
175
|
+
|
|
176
|
+
q = q.transpose(1, 2).contiguous()
|
|
177
|
+
k = k.transpose(1, 2).contiguous()
|
|
178
|
+
v = v.transpose(1, 2).contiguous()
|
|
179
|
+
g = g.transpose(1, 2).contiguous()
|
|
180
|
+
|
|
181
|
+
topk = math.ceil((1 - sparsity) * (total_seq_length / math.prod(VSA_TILE_SIZE)))
|
|
182
|
+
out = vsa_core(
|
|
183
|
+
q,
|
|
184
|
+
k,
|
|
185
|
+
v,
|
|
186
|
+
variable_block_sizes=variable_block_sizes,
|
|
187
|
+
topk=topk,
|
|
188
|
+
block_size=VSA_TILE_SIZE,
|
|
189
|
+
compress_attn_weight=g,
|
|
190
|
+
).transpose(1, 2)
|
|
191
|
+
out = untile(out, reverse_tile_partition_indices, non_pad_index)
|
|
192
|
+
return out
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def distributed_video_sparse_attn(
|
|
196
|
+
q: torch.Tensor,
|
|
197
|
+
k: torch.Tensor,
|
|
198
|
+
v: torch.Tensor,
|
|
199
|
+
g: torch.Tensor,
|
|
200
|
+
sparsity: float,
|
|
201
|
+
num_tiles: tuple[int, int, int],
|
|
202
|
+
total_seq_length: int,
|
|
203
|
+
tile_partition_indices: torch.LongTensor,
|
|
204
|
+
reverse_tile_partition_indices: torch.LongTensor,
|
|
205
|
+
variable_block_sizes: torch.LongTensor,
|
|
206
|
+
non_pad_index: torch.LongTensor,
|
|
207
|
+
scatter_idx: int = 2,
|
|
208
|
+
gather_idx: int = 1,
|
|
209
|
+
):
|
|
210
|
+
from yunchang.comm.all_to_all import SeqAllToAll4D
|
|
211
|
+
|
|
212
|
+
assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
|
|
213
|
+
sp_ulysses_group = get_sp_ulysses_group()
|
|
214
|
+
|
|
215
|
+
q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
|
|
216
|
+
k = SeqAllToAll4D.apply(sp_ulysses_group, k, scatter_idx, gather_idx)
|
|
217
|
+
v = SeqAllToAll4D.apply(sp_ulysses_group, v, scatter_idx, gather_idx)
|
|
218
|
+
g = SeqAllToAll4D.apply(sp_ulysses_group, g, scatter_idx, gather_idx)
|
|
219
|
+
|
|
220
|
+
out = video_sparse_attn(
|
|
221
|
+
q,
|
|
222
|
+
k,
|
|
223
|
+
v,
|
|
224
|
+
g,
|
|
225
|
+
sparsity,
|
|
226
|
+
num_tiles,
|
|
227
|
+
total_seq_length,
|
|
228
|
+
tile_partition_indices,
|
|
229
|
+
reverse_tile_partition_indices,
|
|
230
|
+
variable_block_sizes,
|
|
231
|
+
non_pad_index,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
out = SeqAllToAll4D.apply(sp_ulysses_group, out, gather_idx, scatter_idx)
|
|
235
|
+
return out
|
|
@@ -86,7 +86,6 @@ class FluxControlNet(PreTrainedModel):
|
|
|
86
86
|
def __init__(
|
|
87
87
|
self,
|
|
88
88
|
condition_channels: int = 64,
|
|
89
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
90
89
|
device: str = "cuda:0",
|
|
91
90
|
dtype: torch.dtype = torch.bfloat16,
|
|
92
91
|
):
|
|
@@ -103,10 +102,7 @@ class FluxControlNet(PreTrainedModel):
|
|
|
103
102
|
self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
|
|
104
103
|
self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
|
|
105
104
|
self.blocks = nn.ModuleList(
|
|
106
|
-
[
|
|
107
|
-
FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
108
|
-
for _ in range(6)
|
|
109
|
-
]
|
|
105
|
+
[FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(6)]
|
|
110
106
|
)
|
|
111
107
|
# controlnet projection
|
|
112
108
|
self.blocks_proj = nn.ModuleList(
|
|
@@ -128,6 +124,7 @@ class FluxControlNet(PreTrainedModel):
|
|
|
128
124
|
image_ids: torch.Tensor,
|
|
129
125
|
text_ids: torch.Tensor,
|
|
130
126
|
guidance: torch.Tensor,
|
|
127
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
131
128
|
):
|
|
132
129
|
hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
|
|
133
130
|
condition = (
|
|
@@ -141,7 +138,9 @@ class FluxControlNet(PreTrainedModel):
|
|
|
141
138
|
# double block
|
|
142
139
|
double_block_outputs = []
|
|
143
140
|
for i, block in enumerate(self.blocks):
|
|
144
|
-
hidden_states, prompt_emb = block(
|
|
141
|
+
hidden_states, prompt_emb = block(
|
|
142
|
+
hidden_states, prompt_emb, condition, image_rotary_emb, attn_kwargs=attn_kwargs
|
|
143
|
+
)
|
|
145
144
|
double_block_outputs.append(self.blocks_proj[i](hidden_states))
|
|
146
145
|
|
|
147
146
|
# apply control scale
|
|
@@ -149,24 +148,13 @@ class FluxControlNet(PreTrainedModel):
|
|
|
149
148
|
return double_block_outputs, None
|
|
150
149
|
|
|
151
150
|
@classmethod
|
|
152
|
-
def from_state_dict(
|
|
153
|
-
cls,
|
|
154
|
-
state_dict: Dict[str, torch.Tensor],
|
|
155
|
-
device: str,
|
|
156
|
-
dtype: torch.dtype,
|
|
157
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
158
|
-
):
|
|
151
|
+
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
|
|
159
152
|
if "controlnet_x_embedder.weight" in state_dict:
|
|
160
153
|
condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
|
|
161
154
|
else:
|
|
162
155
|
condition_channels = 64
|
|
163
156
|
|
|
164
|
-
model = cls(
|
|
165
|
-
condition_channels=condition_channels,
|
|
166
|
-
attn_kwargs=attn_kwargs,
|
|
167
|
-
device="meta",
|
|
168
|
-
dtype=dtype,
|
|
169
|
-
)
|
|
157
|
+
model = cls(condition_channels=condition_channels, device="meta", dtype=dtype)
|
|
170
158
|
model.requires_grad_(False)
|
|
171
159
|
model.load_state_dict(state_dict, assign=True)
|
|
172
160
|
model.to(device=device, dtype=dtype, non_blocking=True)
|