diffsynth-engine 0.6.1.dev14__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. diffsynth_engine/__init__.py +6 -2
  2. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  3. diffsynth_engine/configs/__init__.py +10 -6
  4. diffsynth_engine/configs/pipeline.py +17 -10
  5. diffsynth_engine/models/base.py +1 -1
  6. diffsynth_engine/models/basic/attention.py +59 -20
  7. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  8. diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
  9. diffsynth_engine/models/flux/flux_controlnet.py +7 -19
  10. diffsynth_engine/models/flux/flux_dit.py +27 -38
  11. diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
  12. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  13. diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
  14. diffsynth_engine/models/qwen_image/qwen_image_dit.py +28 -34
  15. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  16. diffsynth_engine/models/wan/wan_audio_encoder.py +0 -1
  17. diffsynth_engine/models/wan/wan_dit.py +64 -27
  18. diffsynth_engine/pipelines/base.py +36 -4
  19. diffsynth_engine/pipelines/flux_image.py +19 -17
  20. diffsynth_engine/pipelines/qwen_image.py +45 -36
  21. diffsynth_engine/pipelines/sdxl_image.py +1 -1
  22. diffsynth_engine/pipelines/utils.py +52 -0
  23. diffsynth_engine/pipelines/wan_s2v.py +4 -9
  24. diffsynth_engine/pipelines/wan_video.py +43 -19
  25. diffsynth_engine/tokenizers/base.py +6 -0
  26. diffsynth_engine/tokenizers/qwen2.py +12 -4
  27. diffsynth_engine/utils/constants.py +13 -12
  28. diffsynth_engine/utils/flag.py +6 -0
  29. diffsynth_engine/utils/parallel.py +62 -29
  30. {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
  31. {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +45 -43
  32. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  33. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  34. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  35. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  36. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  37. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  38. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  39. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  40. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  41. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  42. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  43. {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
  44. {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
  45. {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
@@ -12,11 +12,13 @@ from .configs import (
12
12
  WanStateDicts,
13
13
  QwenImageStateDicts,
14
14
  AttnImpl,
15
+ SpargeAttentionParams,
16
+ VideoSparseAttentionParams,
17
+ LoraConfig,
15
18
  ControlNetParams,
16
19
  ControlType,
17
20
  QwenImageControlNetParams,
18
21
  QwenImageControlType,
19
- LoraConfig,
20
22
  )
21
23
  from .pipelines import (
22
24
  SDImagePipeline,
@@ -59,6 +61,9 @@ __all__ = [
59
61
  "WanStateDicts",
60
62
  "QwenImageStateDicts",
61
63
  "AttnImpl",
64
+ "SpargeAttentionParams",
65
+ "VideoSparseAttentionParams",
66
+ "LoraConfig",
62
67
  "ControlNetParams",
63
68
  "ControlType",
64
69
  "QwenImageControlNetParams",
@@ -79,7 +84,6 @@ __all__ = [
79
84
  "FluxIPAdapterRefTool",
80
85
  "FluxReplaceByControlTool",
81
86
  "FluxReduxRefTool",
82
- "LoraConfig",
83
87
  "fetch_model",
84
88
  "fetch_modelscope_model",
85
89
  "register_fetch_modelscope_model",
@@ -0,0 +1,41 @@
1
+ {
2
+ "diffusers": {
3
+ "global_rename_dict": {
4
+ "patch_embedding": "patch_embedding",
5
+ "condition_embedder.text_embedder.linear_1": "text_embedding.0",
6
+ "condition_embedder.text_embedder.linear_2": "text_embedding.2",
7
+ "condition_embedder.time_embedder.linear_1": "time_embedding.0",
8
+ "condition_embedder.time_embedder.linear_2": "time_embedding.2",
9
+ "condition_embedder.time_proj": "time_projection.1",
10
+ "condition_embedder.image_embedder.norm1": "img_emb.proj.0",
11
+ "condition_embedder.image_embedder.ff.net.0.proj": "img_emb.proj.1",
12
+ "condition_embedder.image_embedder.ff.net.2": "img_emb.proj.3",
13
+ "condition_embedder.image_embedder.norm2": "img_emb.proj.4",
14
+ "condition_embedder.image_embedder.pos_embed": "img_emb.emb_pos",
15
+ "proj_out": "head.head",
16
+ "scale_shift_table": "head.modulation"
17
+ },
18
+ "rename_dict": {
19
+ "attn1.to_q": "self_attn.q",
20
+ "attn1.to_k": "self_attn.k",
21
+ "attn1.to_v": "self_attn.v",
22
+ "attn1.to_out.0": "self_attn.o",
23
+ "attn1.norm_q": "self_attn.norm_q",
24
+ "attn1.norm_k": "self_attn.norm_k",
25
+ "to_gate_compress": "self_attn.gate_compress",
26
+ "attn2.to_q": "cross_attn.q",
27
+ "attn2.to_k": "cross_attn.k",
28
+ "attn2.to_v": "cross_attn.v",
29
+ "attn2.to_out.0": "cross_attn.o",
30
+ "attn2.norm_q": "cross_attn.norm_q",
31
+ "attn2.norm_k": "cross_attn.norm_k",
32
+ "attn2.add_k_proj": "cross_attn.k_img",
33
+ "attn2.add_v_proj": "cross_attn.v_img",
34
+ "attn2.norm_added_k": "cross_attn.norm_k_img",
35
+ "norm2": "norm3",
36
+ "ffn.net.0.proj": "ffn.0",
37
+ "ffn.net.2": "ffn.2",
38
+ "scale_shift_table": "modulation"
39
+ }
40
+ }
41
+ }
@@ -17,14 +17,16 @@ from .pipeline import (
17
17
  WanStateDicts,
18
18
  WanS2VStateDicts,
19
19
  QwenImageStateDicts,
20
- LoraConfig,
21
20
  AttnImpl,
21
+ SpargeAttentionParams,
22
+ VideoSparseAttentionParams,
23
+ LoraConfig,
22
24
  )
23
25
  from .controlnet import (
24
26
  ControlType,
25
27
  ControlNetParams,
26
- QwenImageControlNetParams,
27
28
  QwenImageControlType,
29
+ QwenImageControlNetParams,
28
30
  )
29
31
 
30
32
  __all__ = [
@@ -46,10 +48,12 @@ __all__ = [
46
48
  "WanStateDicts",
47
49
  "WanS2VStateDicts",
48
50
  "QwenImageStateDicts",
49
- "QwenImageControlType",
50
- "QwenImageControlNetParams",
51
+ "AttnImpl",
52
+ "SpargeAttentionParams",
53
+ "VideoSparseAttentionParams",
54
+ "LoraConfig",
51
55
  "ControlType",
52
56
  "ControlNetParams",
53
- "LoraConfig",
54
- "AttnImpl",
57
+ "QwenImageControlType",
58
+ "QwenImageControlNetParams",
55
59
  ]
@@ -30,16 +30,26 @@ class AttnImpl(Enum):
30
30
  SDPA = "sdpa" # Scaled Dot Product Attention
31
31
  SAGE = "sage" # Sage Attention
32
32
  SPARGE = "sparge" # Sparge Attention
33
+ VSA = "vsa" # Video Sparse Attention
34
+
35
+
36
+ @dataclass
37
+ class SpargeAttentionParams:
38
+ smooth_k: bool = True
39
+ cdfthreshd: float = 0.6
40
+ simthreshd1: float = 0.98
41
+ pvthreshd: float = 50.0
42
+
43
+
44
+ @dataclass
45
+ class VideoSparseAttentionParams:
46
+ sparsity: float = 0.9
33
47
 
34
48
 
35
49
  @dataclass
36
50
  class AttentionConfig:
37
51
  dit_attn_impl: AttnImpl = AttnImpl.AUTO
38
- # Sparge Attention
39
- sparge_smooth_k: bool = True
40
- sparge_cdfthreshd: float = 0.6
41
- sparge_simthreshd1: float = 0.98
42
- sparge_pvthreshd: float = 50.0
52
+ attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None
43
53
 
44
54
 
45
55
  @dataclass
@@ -234,14 +244,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
234
244
  encoder_dtype: torch.dtype = torch.bfloat16
235
245
  vae_dtype: torch.dtype = torch.float32
236
246
 
247
+ load_encoder: bool = True
248
+
237
249
  # override OptimizationConfig
238
250
  fbcache_relative_l1_threshold = 0.009
239
251
 
240
- # override BaseConfig
241
- vae_tiled: bool = True
242
- vae_tile_size: Tuple[int, int] = (34, 34)
243
- vae_tile_stride: Tuple[int, int] = (18, 16)
244
-
245
252
  @classmethod
246
253
  def basic_config(
247
254
  cls,
@@ -57,7 +57,7 @@ class PreTrainedModel(nn.Module):
57
57
  def get_tp_plan(self):
58
58
  raise NotImplementedError(f"{self.__class__.__name__} does not support TP")
59
59
 
60
- def get_fsdp_modules(self):
60
+ def get_fsdp_module_cls(self):
61
61
  raise NotImplementedError(f"{self.__class__.__name__} does not support FSDP")
62
62
 
63
63
 
@@ -12,6 +12,7 @@ from diffsynth_engine.utils.flag import (
12
12
  SDPA_AVAILABLE,
13
13
  SAGE_ATTN_AVAILABLE,
14
14
  SPARGE_ATTN_AVAILABLE,
15
+ VIDEO_SPARSE_ATTN_AVAILABLE,
15
16
  )
16
17
  from diffsynth_engine.utils.platform import DTYPE_FP8
17
18
 
@@ -20,12 +21,6 @@ FA3_MAX_HEADDIM = 256
20
21
  logger = logging.get_logger(__name__)
21
22
 
22
23
 
23
- def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
24
- padding_size = (alignment - x.shape[dim] % alignment) % alignment
25
- padded_x = F.pad(x, (0, padding_size), "constant", 0)
26
- return padded_x[..., : x.shape[dim]]
27
-
28
-
29
24
  if FLASH_ATTN_3_AVAILABLE:
30
25
  from flash_attn_interface import flash_attn_func as flash_attn3
31
26
  if FLASH_ATTN_2_AVAILABLE:
@@ -33,6 +28,11 @@ if FLASH_ATTN_2_AVAILABLE:
33
28
  if XFORMERS_AVAILABLE:
34
29
  from xformers.ops import memory_efficient_attention
35
30
 
31
+ def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
32
+ padding_size = (alignment - x.shape[dim] % alignment) % alignment
33
+ padded_x = F.pad(x, (0, padding_size), "constant", 0)
34
+ return padded_x[..., : x.shape[dim]]
35
+
36
36
  def xformers_attn(q, k, v, attn_mask=None, scale=None):
37
37
  if attn_mask is not None:
38
38
  if attn_mask.ndim == 2:
@@ -94,6 +94,13 @@ if SPARGE_ATTN_AVAILABLE:
94
94
  return out.transpose(1, 2)
95
95
 
96
96
 
97
+ if VIDEO_SPARSE_ATTN_AVAILABLE:
98
+ from diffsynth_engine.models.basic.video_sparse_attention import (
99
+ video_sparse_attn,
100
+ distributed_video_sparse_attn,
101
+ )
102
+
103
+
97
104
  def eager_attn(q, k, v, attn_mask=None, scale=None):
98
105
  q = q.transpose(1, 2)
99
106
  k = k.transpose(1, 2)
@@ -109,9 +116,10 @@ def eager_attn(q, k, v, attn_mask=None, scale=None):
109
116
 
110
117
 
111
118
  def attention(
112
- q,
113
- k,
114
- v,
119
+ q: torch.Tensor,
120
+ k: torch.Tensor,
121
+ v: torch.Tensor,
122
+ g: Optional[torch.Tensor] = None,
115
123
  attn_impl: Optional[str] = "auto",
116
124
  attn_mask: Optional[torch.Tensor] = None,
117
125
  scale: Optional[float] = None,
@@ -133,6 +141,7 @@ def attention(
133
141
  "sdpa",
134
142
  "sage",
135
143
  "sparge",
144
+ "vsa",
136
145
  ]
137
146
  flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
138
147
  if attn_impl is None or attn_impl == "auto":
@@ -189,10 +198,24 @@ def attention(
189
198
  v,
190
199
  attn_mask=attn_mask,
191
200
  scale=scale,
192
- smooth_k=kwargs.get("sparge_smooth_k", True),
193
- simthreshd1=kwargs.get("sparge_simthreshd1", 0.6),
194
- cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
195
- pvthreshd=kwargs.get("sparge_pvthreshd", 50),
201
+ smooth_k=kwargs.get("smooth_k", True),
202
+ simthreshd1=kwargs.get("simthreshd1", 0.6),
203
+ cdfthreshd=kwargs.get("cdfthreshd", 0.98),
204
+ pvthreshd=kwargs.get("pvthreshd", 50),
205
+ )
206
+ if attn_impl == "vsa":
207
+ return video_sparse_attn(
208
+ q,
209
+ k,
210
+ v,
211
+ g,
212
+ sparsity=kwargs.get("sparsity"),
213
+ num_tiles=kwargs.get("num_tiles"),
214
+ total_seq_length=kwargs.get("total_seq_length"),
215
+ tile_partition_indices=kwargs.get("tile_partition_indices"),
216
+ reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
217
+ variable_block_sizes=kwargs.get("variable_block_sizes"),
218
+ non_pad_index=kwargs.get("non_pad_index"),
196
219
  )
197
220
  raise ValueError(f"Invalid attention implementation: {attn_impl}")
198
221
 
@@ -242,9 +265,10 @@ class Attention(nn.Module):
242
265
 
243
266
 
244
267
  def long_context_attention(
245
- q,
246
- k,
247
- v,
268
+ q: torch.Tensor,
269
+ k: torch.Tensor,
270
+ v: torch.Tensor,
271
+ g: Optional[torch.Tensor] = None,
248
272
  attn_impl: Optional[str] = None,
249
273
  attn_mask: Optional[torch.Tensor] = None,
250
274
  scale: Optional[float] = None,
@@ -267,6 +291,7 @@ def long_context_attention(
267
291
  "sdpa",
268
292
  "sage",
269
293
  "sparge",
294
+ "vsa",
270
295
  ]
271
296
  assert attn_mask is None, "long context attention does not support attention mask"
272
297
  flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
@@ -307,11 +332,25 @@ def long_context_attention(
307
332
  if attn_impl == "sparge":
308
333
  attn_processor = SparseAttentionMeansim()
309
334
  # default args from spas_sage2_attn_meansim_cuda
310
- attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
311
- attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
312
- attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
313
- attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
335
+ attn_processor.smooth_k = torch.tensor(kwargs.get("smooth_k", True))
336
+ attn_processor.simthreshd1 = torch.tensor(kwargs.get("simthreshd1", 0.6))
337
+ attn_processor.cdfthreshd = torch.tensor(kwargs.get("cdfthreshd", 0.98))
338
+ attn_processor.pvthreshd = torch.tensor(kwargs.get("pvthreshd", 50))
314
339
  return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
315
340
  q, k, v, softmax_scale=scale
316
341
  )
342
+ if attn_impl == "vsa":
343
+ return distributed_video_sparse_attn(
344
+ q,
345
+ k,
346
+ v,
347
+ g,
348
+ sparsity=kwargs.get("sparsity"),
349
+ num_tiles=kwargs.get("num_tiles"),
350
+ total_seq_length=kwargs.get("total_seq_length"),
351
+ tile_partition_indices=kwargs.get("tile_partition_indices"),
352
+ reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
353
+ variable_block_sizes=kwargs.get("variable_block_sizes"),
354
+ non_pad_index=kwargs.get("non_pad_index"),
355
+ )
317
356
  raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
3
4
  import math
4
5
 
5
6
 
@@ -91,8 +92,8 @@ class NewGELUActivation(nn.Module):
91
92
  the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
92
93
  """
93
94
 
94
- def forward(self, input: "torch.Tensor") -> "torch.Tensor":
95
- return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
96
97
 
97
98
 
98
99
  class ApproximateGELU(nn.Module):
@@ -115,3 +116,36 @@ class ApproximateGELU(nn.Module):
115
116
  def forward(self, x: torch.Tensor) -> torch.Tensor:
116
117
  x = self.proj(x)
117
118
  return x * torch.sigmoid(1.702 * x)
119
+
120
+
121
+ class GELU(nn.Module):
122
+ r"""
123
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
124
+
125
+ Parameters:
126
+ dim_in (`int`): The number of channels in the input.
127
+ dim_out (`int`): The number of channels in the output.
128
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
129
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ dim_in: int,
135
+ dim_out: int,
136
+ approximate: str = "none",
137
+ bias: bool = True,
138
+ device: str = "cuda:0",
139
+ dtype: torch.dtype = torch.float16,
140
+ ):
141
+ super().__init__()
142
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias, device=device, dtype=dtype)
143
+ self.approximate = approximate
144
+
145
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
146
+ return F.gelu(gate, approximate=self.approximate)
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ x = self.proj(x)
150
+ x = self.gelu(x)
151
+ return x
@@ -0,0 +1,238 @@
1
+ import torch
2
+ import math
3
+ import functools
4
+
5
+ from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
6
+ from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
7
+
8
+ if VIDEO_SPARSE_ATTN_AVAILABLE:
9
+ from vsa import video_sparse_attn as vsa_core
10
+
11
+ VSA_TILE_SIZE = (4, 4, 4)
12
+
13
+
14
+ @functools.lru_cache(maxsize=10)
15
+ def get_tile_partition_indices(
16
+ dit_seq_shape: tuple[int, int, int],
17
+ tile_size: tuple[int, int, int],
18
+ device: torch.device,
19
+ ) -> torch.LongTensor:
20
+ T, H, W = dit_seq_shape
21
+ ts, hs, ws = tile_size
22
+ indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
23
+ ls = []
24
+ for t in range(math.ceil(T / ts)):
25
+ for h in range(math.ceil(H / hs)):
26
+ for w in range(math.ceil(W / ws)):
27
+ ls.append(
28
+ indices[
29
+ t * ts : min(t * ts + ts, T), h * hs : min(h * hs + hs, H), w * ws : min(w * ws + ws, W)
30
+ ].flatten()
31
+ )
32
+ index = torch.cat(ls, dim=0)
33
+ return index
34
+
35
+
36
+ @functools.lru_cache(maxsize=10)
37
+ def get_reverse_tile_partition_indices(
38
+ dit_seq_shape: tuple[int, int, int],
39
+ tile_size: tuple[int, int, int],
40
+ device: torch.device,
41
+ ) -> torch.LongTensor:
42
+ return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
43
+
44
+
45
+ @functools.lru_cache(maxsize=10)
46
+ def construct_variable_block_sizes(
47
+ dit_seq_shape: tuple[int, int, int],
48
+ num_tiles: tuple[int, int, int],
49
+ device: torch.device,
50
+ ) -> torch.LongTensor:
51
+ """
52
+ Compute the number of valid (non-padded) tokens inside every
53
+ (ts_t x ts_h x ts_w) tile after padding -- flattened in the order
54
+ (t-tile, h-tile, w-tile) that `rearrange` uses.
55
+
56
+ Returns
57
+ -------
58
+ torch.LongTensor # shape: [∏ full_window_size]
59
+ """
60
+ # unpack
61
+ t, h, w = dit_seq_shape
62
+ ts_t, ts_h, ts_w = VSA_TILE_SIZE
63
+ n_t, n_h, n_w = num_tiles
64
+
65
+ def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
66
+ """Vector with the size of each tile along one dimension."""
67
+ sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device)
68
+ # size of last (possibly partial) tile
69
+ remainder = dim_len - (n_tiles - 1) * tile
70
+ sizes[-1] = remainder if remainder > 0 else tile
71
+ return sizes
72
+
73
+ t_sizes = _sizes(t, ts_t, n_t) # [n_t]
74
+ h_sizes = _sizes(h, ts_h, n_h) # [n_h]
75
+ w_sizes = _sizes(w, ts_w, n_w) # [n_w]
76
+
77
+ # broadcast‑multiply to get voxels per tile, then flatten
78
+ block_sizes = (
79
+ t_sizes[:, None, None] # [n_t, 1, 1]
80
+ * h_sizes[None, :, None] # [1, n_h, 1]
81
+ * w_sizes[None, None, :] # [1, 1, n_w]
82
+ ).reshape(-1) # [n_t * n_h * n_w]
83
+
84
+ return block_sizes
85
+
86
+
87
+ @functools.lru_cache(maxsize=10)
88
+ def get_non_pad_index(
89
+ variable_block_sizes: torch.LongTensor,
90
+ max_block_size: int,
91
+ ):
92
+ n_win = variable_block_sizes.shape[0]
93
+ device = variable_block_sizes.device
94
+ starts_pad = torch.arange(n_win, device=device) * max_block_size
95
+ index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
96
+ index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
97
+ return index_pad[index_mask]
98
+
99
+
100
+ def get_vsa_kwargs(
101
+ latent_shape: tuple[int, int, int],
102
+ patch_size: tuple[int, int, int],
103
+ sparsity: float,
104
+ device: torch.device,
105
+ ):
106
+ dit_seq_shape = (
107
+ latent_shape[0] // patch_size[0],
108
+ latent_shape[1] // patch_size[1],
109
+ latent_shape[2] // patch_size[2],
110
+ )
111
+
112
+ num_tiles = (
113
+ math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
114
+ math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
115
+ math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]),
116
+ )
117
+ total_seq_length = math.prod(dit_seq_shape)
118
+
119
+ tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
120
+ reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
121
+ variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
122
+ non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
123
+
124
+ return {
125
+ "sparsity": sparsity,
126
+ "num_tiles": num_tiles,
127
+ "total_seq_length": total_seq_length,
128
+ "tile_partition_indices": tile_partition_indices,
129
+ "reverse_tile_partition_indices": reverse_tile_partition_indices,
130
+ "variable_block_sizes": variable_block_sizes,
131
+ "non_pad_index": non_pad_index,
132
+ }
133
+
134
+
135
+ def tile(
136
+ x: torch.Tensor,
137
+ num_tiles: tuple[int, int, int],
138
+ tile_partition_indices: torch.LongTensor,
139
+ non_pad_index: torch.LongTensor,
140
+ ) -> torch.Tensor:
141
+ t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
142
+ h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
143
+ w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
144
+
145
+ x_padded = torch.zeros(
146
+ (x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
147
+ device=x.device,
148
+ dtype=x.dtype,
149
+ )
150
+ x_padded[:, non_pad_index] = x[:, tile_partition_indices]
151
+ return x_padded
152
+
153
+
154
+ def untile(
155
+ x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor
156
+ ) -> torch.Tensor:
157
+ x = x[:, non_pad_index][:, reverse_tile_partition_indices]
158
+ return x
159
+
160
+
161
+ def video_sparse_attn(
162
+ q: torch.Tensor,
163
+ k: torch.Tensor,
164
+ v: torch.Tensor,
165
+ g: torch.Tensor,
166
+ sparsity: float,
167
+ num_tiles: tuple[int, int, int],
168
+ total_seq_length: int,
169
+ tile_partition_indices: torch.LongTensor,
170
+ reverse_tile_partition_indices: torch.LongTensor,
171
+ variable_block_sizes: torch.LongTensor,
172
+ non_pad_index: torch.LongTensor,
173
+ ):
174
+ q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
175
+ k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
176
+ v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
177
+ g = tile(g, num_tiles, tile_partition_indices, non_pad_index)
178
+
179
+ q = q.transpose(1, 2).contiguous()
180
+ k = k.transpose(1, 2).contiguous()
181
+ v = v.transpose(1, 2).contiguous()
182
+ g = g.transpose(1, 2).contiguous()
183
+
184
+ topk = math.ceil((1 - sparsity) * (total_seq_length / math.prod(VSA_TILE_SIZE)))
185
+ out = vsa_core(
186
+ q,
187
+ k,
188
+ v,
189
+ variable_block_sizes=variable_block_sizes,
190
+ topk=topk,
191
+ block_size=VSA_TILE_SIZE,
192
+ compress_attn_weight=g,
193
+ ).transpose(1, 2)
194
+ out = untile(out, reverse_tile_partition_indices, non_pad_index)
195
+ return out
196
+
197
+
198
+ def distributed_video_sparse_attn(
199
+ q: torch.Tensor,
200
+ k: torch.Tensor,
201
+ v: torch.Tensor,
202
+ g: torch.Tensor,
203
+ sparsity: float,
204
+ num_tiles: tuple[int, int, int],
205
+ total_seq_length: int,
206
+ tile_partition_indices: torch.LongTensor,
207
+ reverse_tile_partition_indices: torch.LongTensor,
208
+ variable_block_sizes: torch.LongTensor,
209
+ non_pad_index: torch.LongTensor,
210
+ scatter_idx: int = 2,
211
+ gather_idx: int = 1,
212
+ ):
213
+ from yunchang.comm.all_to_all import SeqAllToAll4D
214
+
215
+ assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
216
+ sp_ulysses_group = get_sp_ulysses_group()
217
+
218
+ q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
219
+ k = SeqAllToAll4D.apply(sp_ulysses_group, k, scatter_idx, gather_idx)
220
+ v = SeqAllToAll4D.apply(sp_ulysses_group, v, scatter_idx, gather_idx)
221
+ g = SeqAllToAll4D.apply(sp_ulysses_group, g, scatter_idx, gather_idx)
222
+
223
+ out = video_sparse_attn(
224
+ q,
225
+ k,
226
+ v,
227
+ g,
228
+ sparsity,
229
+ num_tiles,
230
+ total_seq_length,
231
+ tile_partition_indices,
232
+ reverse_tile_partition_indices,
233
+ variable_block_sizes,
234
+ non_pad_index,
235
+ )
236
+
237
+ out = SeqAllToAll4D.apply(sp_ulysses_group, out, gather_idx, scatter_idx)
238
+ return out
@@ -86,7 +86,6 @@ class FluxControlNet(PreTrainedModel):
86
86
  def __init__(
87
87
  self,
88
88
  condition_channels: int = 64,
89
- attn_kwargs: Optional[Dict[str, Any]] = None,
90
89
  device: str = "cuda:0",
91
90
  dtype: torch.dtype = torch.bfloat16,
92
91
  ):
@@ -103,10 +102,7 @@ class FluxControlNet(PreTrainedModel):
103
102
  self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
104
103
  self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
105
104
  self.blocks = nn.ModuleList(
106
- [
107
- FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
108
- for _ in range(6)
109
- ]
105
+ [FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(6)]
110
106
  )
111
107
  # controlnet projection
112
108
  self.blocks_proj = nn.ModuleList(
@@ -128,6 +124,7 @@ class FluxControlNet(PreTrainedModel):
128
124
  image_ids: torch.Tensor,
129
125
  text_ids: torch.Tensor,
130
126
  guidance: torch.Tensor,
127
+ attn_kwargs: Optional[Dict[str, Any]] = None,
131
128
  ):
132
129
  hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
133
130
  condition = (
@@ -141,7 +138,9 @@ class FluxControlNet(PreTrainedModel):
141
138
  # double block
142
139
  double_block_outputs = []
143
140
  for i, block in enumerate(self.blocks):
144
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, condition, image_rotary_emb)
141
+ hidden_states, prompt_emb = block(
142
+ hidden_states, prompt_emb, condition, image_rotary_emb, attn_kwargs=attn_kwargs
143
+ )
145
144
  double_block_outputs.append(self.blocks_proj[i](hidden_states))
146
145
 
147
146
  # apply control scale
@@ -149,24 +148,13 @@ class FluxControlNet(PreTrainedModel):
149
148
  return double_block_outputs, None
150
149
 
151
150
  @classmethod
152
- def from_state_dict(
153
- cls,
154
- state_dict: Dict[str, torch.Tensor],
155
- device: str,
156
- dtype: torch.dtype,
157
- attn_kwargs: Optional[Dict[str, Any]] = None,
158
- ):
151
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
159
152
  if "controlnet_x_embedder.weight" in state_dict:
160
153
  condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
161
154
  else:
162
155
  condition_channels = 64
163
156
 
164
- model = cls(
165
- condition_channels=condition_channels,
166
- attn_kwargs=attn_kwargs,
167
- device="meta",
168
- dtype=dtype,
169
- )
157
+ model = cls(condition_channels=condition_channels, device="meta", dtype=dtype)
170
158
  model.requires_grad_(False)
171
159
  model.load_state_dict(state_dict, assign=True)
172
160
  model.to(device=device, dtype=dtype, non_blocking=True)