diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. diffsynth_engine/__init__.py +12 -0
  2. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
  3. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
  4. diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
  5. diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
  6. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  7. diffsynth_engine/configs/__init__.py +16 -1
  8. diffsynth_engine/configs/controlnet.py +13 -0
  9. diffsynth_engine/configs/pipeline.py +37 -11
  10. diffsynth_engine/models/base.py +1 -1
  11. diffsynth_engine/models/basic/attention.py +105 -43
  12. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  13. diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
  14. diffsynth_engine/models/flux/flux_controlnet.py +16 -30
  15. diffsynth_engine/models/flux/flux_dit.py +49 -62
  16. diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
  17. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  18. diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
  19. diffsynth_engine/models/flux/flux_vae.py +20 -2
  20. diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
  21. diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
  22. diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
  23. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  24. diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
  25. diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
  26. diffsynth_engine/models/sd/sd_unet.py +1 -1
  27. diffsynth_engine/models/sd3/sd3_dit.py +1 -1
  28. diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
  29. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
  30. diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
  31. diffsynth_engine/models/vae/vae.py +1 -1
  32. diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
  33. diffsynth_engine/models/wan/wan_dit.py +65 -28
  34. diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
  35. diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
  36. diffsynth_engine/models/wan/wan_vae.py +2 -2
  37. diffsynth_engine/pipelines/base.py +73 -7
  38. diffsynth_engine/pipelines/flux_image.py +139 -120
  39. diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
  40. diffsynth_engine/pipelines/qwen_image.py +272 -87
  41. diffsynth_engine/pipelines/sdxl_image.py +1 -1
  42. diffsynth_engine/pipelines/utils.py +52 -0
  43. diffsynth_engine/pipelines/wan_s2v.py +25 -14
  44. diffsynth_engine/pipelines/wan_video.py +43 -19
  45. diffsynth_engine/tokenizers/base.py +6 -0
  46. diffsynth_engine/tokenizers/qwen2.py +12 -4
  47. diffsynth_engine/utils/constants.py +13 -12
  48. diffsynth_engine/utils/download.py +4 -2
  49. diffsynth_engine/utils/env.py +2 -0
  50. diffsynth_engine/utils/flag.py +6 -0
  51. diffsynth_engine/utils/loader.py +25 -6
  52. diffsynth_engine/utils/parallel.py +62 -29
  53. diffsynth_engine/utils/video.py +3 -1
  54. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
  55. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
  56. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  57. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  58. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  59. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  60. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  61. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  62. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  63. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  64. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  65. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  66. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  67. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
  68. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
  69. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
@@ -12,19 +12,15 @@ from diffsynth_engine.utils.flag import (
12
12
  SDPA_AVAILABLE,
13
13
  SAGE_ATTN_AVAILABLE,
14
14
  SPARGE_ATTN_AVAILABLE,
15
+ VIDEO_SPARSE_ATTN_AVAILABLE,
15
16
  )
17
+ from diffsynth_engine.utils.platform import DTYPE_FP8
16
18
 
17
19
  FA3_MAX_HEADDIM = 256
18
20
 
19
21
  logger = logging.get_logger(__name__)
20
22
 
21
23
 
22
- def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
23
- padding_size = (alignment - x.shape[dim] % alignment) % alignment
24
- padded_x = F.pad(x, (0, padding_size), "constant", 0)
25
- return padded_x[..., : x.shape[dim]]
26
-
27
-
28
24
  if FLASH_ATTN_3_AVAILABLE:
29
25
  from flash_attn_interface import flash_attn_func as flash_attn3
30
26
  if FLASH_ATTN_2_AVAILABLE:
@@ -32,6 +28,11 @@ if FLASH_ATTN_2_AVAILABLE:
32
28
  if XFORMERS_AVAILABLE:
33
29
  from xformers.ops import memory_efficient_attention
34
30
 
31
+ def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
32
+ padding_size = (alignment - x.shape[dim] % alignment) % alignment
33
+ padded_x = F.pad(x, (0, padding_size), "constant", 0)
34
+ return padded_x[..., : x.shape[dim]]
35
+
35
36
  def xformers_attn(q, k, v, attn_mask=None, scale=None):
36
37
  if attn_mask is not None:
37
38
  if attn_mask.ndim == 2:
@@ -93,6 +94,13 @@ if SPARGE_ATTN_AVAILABLE:
93
94
  return out.transpose(1, 2)
94
95
 
95
96
 
97
+ if VIDEO_SPARSE_ATTN_AVAILABLE:
98
+ from diffsynth_engine.models.basic.video_sparse_attention import (
99
+ video_sparse_attn,
100
+ distributed_video_sparse_attn,
101
+ )
102
+
103
+
96
104
  def eager_attn(q, k, v, attn_mask=None, scale=None):
97
105
  q = q.transpose(1, 2)
98
106
  k = k.transpose(1, 2)
@@ -108,9 +116,10 @@ def eager_attn(q, k, v, attn_mask=None, scale=None):
108
116
 
109
117
 
110
118
  def attention(
111
- q,
112
- k,
113
- v,
119
+ q: torch.Tensor,
120
+ k: torch.Tensor,
121
+ v: torch.Tensor,
122
+ g: Optional[torch.Tensor] = None,
114
123
  attn_impl: Optional[str] = "auto",
115
124
  attn_mask: Optional[torch.Tensor] = None,
116
125
  scale: Optional[float] = None,
@@ -125,12 +134,14 @@ def attention(
125
134
  None,
126
135
  "auto",
127
136
  "eager",
128
- "flash_attn_2",
129
- "flash_attn_3",
137
+ "fa2",
138
+ "fa3",
139
+ "fa3_fp8",
130
140
  "xformers",
131
141
  "sdpa",
132
- "sage_attn",
133
- "sparge_attn",
142
+ "sage",
143
+ "sparge",
144
+ "vsa",
134
145
  ]
135
146
  flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
136
147
  if attn_impl is None or attn_impl == "auto":
@@ -139,9 +150,13 @@ def attention(
139
150
  return flash_attn3(q, k, v, softmax_scale=scale)
140
151
  else:
141
152
  if not flash_attn3_compatible:
142
- logger.warning(f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation")
153
+ logger.warning(
154
+ f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
155
+ )
143
156
  else:
144
- logger.debug("flash_attn_3 does not support attention mask, will use fallback attention implementation")
157
+ logger.debug(
158
+ "flash_attn_3 does not support attention mask, will use fallback attention implementation"
159
+ )
145
160
  if XFORMERS_AVAILABLE:
146
161
  return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
147
162
  if SDPA_AVAILABLE:
@@ -152,33 +167,55 @@ def attention(
152
167
  else:
153
168
  if attn_impl == "eager":
154
169
  return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
155
- if attn_impl == "flash_attn_3":
170
+ if attn_impl == "fa3" or attn_impl == "fa3_fp8":
156
171
  if not flash_attn3_compatible:
157
172
  raise RuntimeError(
158
173
  f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
159
174
  )
160
175
  if attn_mask is not None:
161
176
  raise RuntimeError("flash_attn_3 does not support attention mask")
162
- return flash_attn3(q, k, v, softmax_scale=scale)
163
- if attn_impl == "flash_attn_2":
177
+ if attn_impl == "fa3":
178
+ return flash_attn3(q, k, v, softmax_scale=scale)
179
+ else:
180
+ origin_dtype = q.dtype
181
+ q = q.to(dtype=DTYPE_FP8)
182
+ k = k.to(dtype=DTYPE_FP8)
183
+ v = v.to(dtype=DTYPE_FP8)
184
+ out = flash_attn3(q, k, v, softmax_scale=scale)
185
+ return out.to(dtype=origin_dtype)
186
+ if attn_impl == "fa2":
164
187
  return flash_attn2(q, k, v, softmax_scale=scale)
165
188
  if attn_impl == "xformers":
166
189
  return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
167
190
  if attn_impl == "sdpa":
168
191
  return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
169
- if attn_impl == "sage_attn":
192
+ if attn_impl == "sage":
170
193
  return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
171
- if attn_impl == "sparge_attn":
194
+ if attn_impl == "sparge":
172
195
  return sparge_attn(
173
196
  q,
174
197
  k,
175
198
  v,
176
199
  attn_mask=attn_mask,
177
200
  scale=scale,
178
- smooth_k=kwargs.get("sparge_smooth_k", True),
179
- simthreshd1=kwargs.get("sparge_simthreshd1", 0.6),
180
- cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
181
- pvthreshd=kwargs.get("sparge_pvthreshd", 50),
201
+ smooth_k=kwargs.get("smooth_k", True),
202
+ simthreshd1=kwargs.get("simthreshd1", 0.6),
203
+ cdfthreshd=kwargs.get("cdfthreshd", 0.98),
204
+ pvthreshd=kwargs.get("pvthreshd", 50),
205
+ )
206
+ if attn_impl == "vsa":
207
+ return video_sparse_attn(
208
+ q,
209
+ k,
210
+ v,
211
+ g,
212
+ sparsity=kwargs.get("sparsity"),
213
+ num_tiles=kwargs.get("num_tiles"),
214
+ total_seq_length=kwargs.get("total_seq_length"),
215
+ tile_partition_indices=kwargs.get("tile_partition_indices"),
216
+ reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
217
+ variable_block_sizes=kwargs.get("variable_block_sizes"),
218
+ non_pad_index=kwargs.get("non_pad_index"),
182
219
  )
183
220
  raise ValueError(f"Invalid attention implementation: {attn_impl}")
184
221
 
@@ -228,9 +265,10 @@ class Attention(nn.Module):
228
265
 
229
266
 
230
267
  def long_context_attention(
231
- q,
232
- k,
233
- v,
268
+ q: torch.Tensor,
269
+ k: torch.Tensor,
270
+ v: torch.Tensor,
271
+ g: Optional[torch.Tensor] = None,
234
272
  attn_impl: Optional[str] = None,
235
273
  attn_mask: Optional[torch.Tensor] = None,
236
274
  scale: Optional[float] = None,
@@ -247,12 +285,15 @@ def long_context_attention(
247
285
  assert attn_impl in [
248
286
  None,
249
287
  "auto",
250
- "flash_attn_2",
251
- "flash_attn_3",
288
+ "fa2",
289
+ "fa3",
290
+ "fa3_fp8",
252
291
  "sdpa",
253
- "sage_attn",
254
- "sparge_attn",
292
+ "sage",
293
+ "sparge",
294
+ "vsa",
255
295
  ]
296
+ assert attn_mask is None, "long context attention does not support attention mask"
256
297
  flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
257
298
  if attn_impl is None or attn_impl == "auto":
258
299
  if FLASH_ATTN_3_AVAILABLE:
@@ -268,27 +309,48 @@ def long_context_attention(
268
309
  return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
269
310
  raise ValueError("No available long context attention implementation")
270
311
  else:
271
- if attn_impl == "flash_attn_3":
272
- if flash_attn3_compatible:
273
- return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
274
- else:
312
+ if attn_impl == "fa3" or attn_impl == "fa3_fp8":
313
+ if not flash_attn3_compatible:
275
314
  raise RuntimeError(
276
315
  f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
277
316
  )
278
- if attn_impl == "flash_attn_2":
317
+ if attn_impl == "fa3":
318
+ return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
319
+
320
+ origin_dtype = q.dtype
321
+ q = q.to(dtype=DTYPE_FP8)
322
+ k = k.to(dtype=DTYPE_FP8)
323
+ v = v.to(dtype=DTYPE_FP8)
324
+ out = LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
325
+ return out.to(dtype=origin_dtype)
326
+ if attn_impl == "fa2":
279
327
  return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
280
328
  if attn_impl == "sdpa":
281
329
  return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
282
- if attn_impl == "sage_attn":
283
- return LongContextAttention(attn_type=AttnType.SAGE_FP8)(q, k, v, softmax_scale=scale)
284
- if attn_impl == "sparge_attn":
330
+ if attn_impl == "sage":
331
+ return LongContextAttention(attn_type=AttnType.SAGE_AUTO)(q, k, v, softmax_scale=scale)
332
+ if attn_impl == "sparge":
285
333
  attn_processor = SparseAttentionMeansim()
286
334
  # default args from spas_sage2_attn_meansim_cuda
287
- attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
288
- attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
289
- attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
290
- attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
335
+ attn_processor.smooth_k = torch.tensor(kwargs.get("smooth_k", True))
336
+ attn_processor.simthreshd1 = torch.tensor(kwargs.get("simthreshd1", 0.6))
337
+ attn_processor.cdfthreshd = torch.tensor(kwargs.get("cdfthreshd", 0.98))
338
+ attn_processor.pvthreshd = torch.tensor(kwargs.get("pvthreshd", 50))
291
339
  return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
292
340
  q, k, v, softmax_scale=scale
293
341
  )
342
+ if attn_impl == "vsa":
343
+ return distributed_video_sparse_attn(
344
+ q,
345
+ k,
346
+ v,
347
+ g,
348
+ sparsity=kwargs.get("sparsity"),
349
+ num_tiles=kwargs.get("num_tiles"),
350
+ total_seq_length=kwargs.get("total_seq_length"),
351
+ tile_partition_indices=kwargs.get("tile_partition_indices"),
352
+ reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
353
+ variable_block_sizes=kwargs.get("variable_block_sizes"),
354
+ non_pad_index=kwargs.get("non_pad_index"),
355
+ )
294
356
  raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
3
4
  import math
4
5
 
5
6
 
@@ -91,8 +92,8 @@ class NewGELUActivation(nn.Module):
91
92
  the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
92
93
  """
93
94
 
94
- def forward(self, input: "torch.Tensor") -> "torch.Tensor":
95
- return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
96
97
 
97
98
 
98
99
  class ApproximateGELU(nn.Module):
@@ -115,3 +116,36 @@ class ApproximateGELU(nn.Module):
115
116
  def forward(self, x: torch.Tensor) -> torch.Tensor:
116
117
  x = self.proj(x)
117
118
  return x * torch.sigmoid(1.702 * x)
119
+
120
+
121
+ class GELU(nn.Module):
122
+ r"""
123
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
124
+
125
+ Parameters:
126
+ dim_in (`int`): The number of channels in the input.
127
+ dim_out (`int`): The number of channels in the output.
128
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
129
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ dim_in: int,
135
+ dim_out: int,
136
+ approximate: str = "none",
137
+ bias: bool = True,
138
+ device: str = "cuda:0",
139
+ dtype: torch.dtype = torch.float16,
140
+ ):
141
+ super().__init__()
142
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias, device=device, dtype=dtype)
143
+ self.approximate = approximate
144
+
145
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
146
+ return F.gelu(gate, approximate=self.approximate)
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ x = self.proj(x)
150
+ x = self.gelu(x)
151
+ return x
@@ -0,0 +1,238 @@
1
+ import torch
2
+ import math
3
+ import functools
4
+
5
+ from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
6
+ from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
7
+
8
+ if VIDEO_SPARSE_ATTN_AVAILABLE:
9
+ from vsa import video_sparse_attn as vsa_core
10
+
11
+ VSA_TILE_SIZE = (4, 4, 4)
12
+
13
+
14
+ @functools.lru_cache(maxsize=10)
15
+ def get_tile_partition_indices(
16
+ dit_seq_shape: tuple[int, int, int],
17
+ tile_size: tuple[int, int, int],
18
+ device: torch.device,
19
+ ) -> torch.LongTensor:
20
+ T, H, W = dit_seq_shape
21
+ ts, hs, ws = tile_size
22
+ indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
23
+ ls = []
24
+ for t in range(math.ceil(T / ts)):
25
+ for h in range(math.ceil(H / hs)):
26
+ for w in range(math.ceil(W / ws)):
27
+ ls.append(
28
+ indices[
29
+ t * ts : min(t * ts + ts, T), h * hs : min(h * hs + hs, H), w * ws : min(w * ws + ws, W)
30
+ ].flatten()
31
+ )
32
+ index = torch.cat(ls, dim=0)
33
+ return index
34
+
35
+
36
+ @functools.lru_cache(maxsize=10)
37
+ def get_reverse_tile_partition_indices(
38
+ dit_seq_shape: tuple[int, int, int],
39
+ tile_size: tuple[int, int, int],
40
+ device: torch.device,
41
+ ) -> torch.LongTensor:
42
+ return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
43
+
44
+
45
+ @functools.lru_cache(maxsize=10)
46
+ def construct_variable_block_sizes(
47
+ dit_seq_shape: tuple[int, int, int],
48
+ num_tiles: tuple[int, int, int],
49
+ device: torch.device,
50
+ ) -> torch.LongTensor:
51
+ """
52
+ Compute the number of valid (non-padded) tokens inside every
53
+ (ts_t x ts_h x ts_w) tile after padding -- flattened in the order
54
+ (t-tile, h-tile, w-tile) that `rearrange` uses.
55
+
56
+ Returns
57
+ -------
58
+ torch.LongTensor # shape: [∏ full_window_size]
59
+ """
60
+ # unpack
61
+ t, h, w = dit_seq_shape
62
+ ts_t, ts_h, ts_w = VSA_TILE_SIZE
63
+ n_t, n_h, n_w = num_tiles
64
+
65
+ def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
66
+ """Vector with the size of each tile along one dimension."""
67
+ sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device)
68
+ # size of last (possibly partial) tile
69
+ remainder = dim_len - (n_tiles - 1) * tile
70
+ sizes[-1] = remainder if remainder > 0 else tile
71
+ return sizes
72
+
73
+ t_sizes = _sizes(t, ts_t, n_t) # [n_t]
74
+ h_sizes = _sizes(h, ts_h, n_h) # [n_h]
75
+ w_sizes = _sizes(w, ts_w, n_w) # [n_w]
76
+
77
+ # broadcast‑multiply to get voxels per tile, then flatten
78
+ block_sizes = (
79
+ t_sizes[:, None, None] # [n_t, 1, 1]
80
+ * h_sizes[None, :, None] # [1, n_h, 1]
81
+ * w_sizes[None, None, :] # [1, 1, n_w]
82
+ ).reshape(-1) # [n_t * n_h * n_w]
83
+
84
+ return block_sizes
85
+
86
+
87
+ @functools.lru_cache(maxsize=10)
88
+ def get_non_pad_index(
89
+ variable_block_sizes: torch.LongTensor,
90
+ max_block_size: int,
91
+ ):
92
+ n_win = variable_block_sizes.shape[0]
93
+ device = variable_block_sizes.device
94
+ starts_pad = torch.arange(n_win, device=device) * max_block_size
95
+ index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
96
+ index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
97
+ return index_pad[index_mask]
98
+
99
+
100
+ def get_vsa_kwargs(
101
+ latent_shape: tuple[int, int, int],
102
+ patch_size: tuple[int, int, int],
103
+ sparsity: float,
104
+ device: torch.device,
105
+ ):
106
+ dit_seq_shape = (
107
+ latent_shape[0] // patch_size[0],
108
+ latent_shape[1] // patch_size[1],
109
+ latent_shape[2] // patch_size[2],
110
+ )
111
+
112
+ num_tiles = (
113
+ math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
114
+ math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
115
+ math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]),
116
+ )
117
+ total_seq_length = math.prod(dit_seq_shape)
118
+
119
+ tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
120
+ reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
121
+ variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
122
+ non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
123
+
124
+ return {
125
+ "sparsity": sparsity,
126
+ "num_tiles": num_tiles,
127
+ "total_seq_length": total_seq_length,
128
+ "tile_partition_indices": tile_partition_indices,
129
+ "reverse_tile_partition_indices": reverse_tile_partition_indices,
130
+ "variable_block_sizes": variable_block_sizes,
131
+ "non_pad_index": non_pad_index,
132
+ }
133
+
134
+
135
+ def tile(
136
+ x: torch.Tensor,
137
+ num_tiles: tuple[int, int, int],
138
+ tile_partition_indices: torch.LongTensor,
139
+ non_pad_index: torch.LongTensor,
140
+ ) -> torch.Tensor:
141
+ t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
142
+ h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
143
+ w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
144
+
145
+ x_padded = torch.zeros(
146
+ (x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
147
+ device=x.device,
148
+ dtype=x.dtype,
149
+ )
150
+ x_padded[:, non_pad_index] = x[:, tile_partition_indices]
151
+ return x_padded
152
+
153
+
154
+ def untile(
155
+ x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor
156
+ ) -> torch.Tensor:
157
+ x = x[:, non_pad_index][:, reverse_tile_partition_indices]
158
+ return x
159
+
160
+
161
+ def video_sparse_attn(
162
+ q: torch.Tensor,
163
+ k: torch.Tensor,
164
+ v: torch.Tensor,
165
+ g: torch.Tensor,
166
+ sparsity: float,
167
+ num_tiles: tuple[int, int, int],
168
+ total_seq_length: int,
169
+ tile_partition_indices: torch.LongTensor,
170
+ reverse_tile_partition_indices: torch.LongTensor,
171
+ variable_block_sizes: torch.LongTensor,
172
+ non_pad_index: torch.LongTensor,
173
+ ):
174
+ q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
175
+ k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
176
+ v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
177
+ g = tile(g, num_tiles, tile_partition_indices, non_pad_index)
178
+
179
+ q = q.transpose(1, 2).contiguous()
180
+ k = k.transpose(1, 2).contiguous()
181
+ v = v.transpose(1, 2).contiguous()
182
+ g = g.transpose(1, 2).contiguous()
183
+
184
+ topk = math.ceil((1 - sparsity) * (total_seq_length / math.prod(VSA_TILE_SIZE)))
185
+ out = vsa_core(
186
+ q,
187
+ k,
188
+ v,
189
+ variable_block_sizes=variable_block_sizes,
190
+ topk=topk,
191
+ block_size=VSA_TILE_SIZE,
192
+ compress_attn_weight=g,
193
+ ).transpose(1, 2)
194
+ out = untile(out, reverse_tile_partition_indices, non_pad_index)
195
+ return out
196
+
197
+
198
+ def distributed_video_sparse_attn(
199
+ q: torch.Tensor,
200
+ k: torch.Tensor,
201
+ v: torch.Tensor,
202
+ g: torch.Tensor,
203
+ sparsity: float,
204
+ num_tiles: tuple[int, int, int],
205
+ total_seq_length: int,
206
+ tile_partition_indices: torch.LongTensor,
207
+ reverse_tile_partition_indices: torch.LongTensor,
208
+ variable_block_sizes: torch.LongTensor,
209
+ non_pad_index: torch.LongTensor,
210
+ scatter_idx: int = 2,
211
+ gather_idx: int = 1,
212
+ ):
213
+ from yunchang.comm.all_to_all import SeqAllToAll4D
214
+
215
+ assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
216
+ sp_ulysses_group = get_sp_ulysses_group()
217
+
218
+ q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
219
+ k = SeqAllToAll4D.apply(sp_ulysses_group, k, scatter_idx, gather_idx)
220
+ v = SeqAllToAll4D.apply(sp_ulysses_group, v, scatter_idx, gather_idx)
221
+ g = SeqAllToAll4D.apply(sp_ulysses_group, g, scatter_idx, gather_idx)
222
+
223
+ out = video_sparse_attn(
224
+ q,
225
+ k,
226
+ v,
227
+ g,
228
+ sparsity,
229
+ num_tiles,
230
+ total_seq_length,
231
+ tile_partition_indices,
232
+ reverse_tile_partition_indices,
233
+ variable_block_sizes,
234
+ non_pad_index,
235
+ )
236
+
237
+ out = SeqAllToAll4D.apply(sp_ulysses_group, out, gather_idx, scatter_idx)
238
+ return out
@@ -86,7 +86,6 @@ class FluxControlNet(PreTrainedModel):
86
86
  def __init__(
87
87
  self,
88
88
  condition_channels: int = 64,
89
- attn_kwargs: Optional[Dict[str, Any]] = None,
90
89
  device: str = "cuda:0",
91
90
  dtype: torch.dtype = torch.bfloat16,
92
91
  ):
@@ -103,10 +102,7 @@ class FluxControlNet(PreTrainedModel):
103
102
  self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
104
103
  self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
105
104
  self.blocks = nn.ModuleList(
106
- [
107
- FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
108
- for _ in range(6)
109
- ]
105
+ [FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(6)]
110
106
  )
111
107
  # controlnet projection
112
108
  self.blocks_proj = nn.ModuleList(
@@ -119,18 +115,17 @@ class FluxControlNet(PreTrainedModel):
119
115
 
120
116
  def forward(
121
117
  self,
122
- hidden_states,
123
- control_condition,
124
- control_scale,
125
- timestep,
126
- prompt_emb,
127
- pooled_prompt_emb,
128
- guidance,
129
- image_ids,
130
- text_ids,
118
+ hidden_states: torch.Tensor,
119
+ control_condition: torch.Tensor,
120
+ control_scale: float,
121
+ timestep: torch.Tensor,
122
+ prompt_emb: torch.Tensor,
123
+ pooled_prompt_emb: torch.Tensor,
124
+ image_ids: torch.Tensor,
125
+ text_ids: torch.Tensor,
126
+ guidance: torch.Tensor,
127
+ attn_kwargs: Optional[Dict[str, Any]] = None,
131
128
  ):
132
- hidden_states = self.patchify(hidden_states)
133
- control_condition = self.patchify(control_condition)
134
129
  hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
135
130
  condition = (
136
131
  self.time_embedder(timestep, hidden_states.dtype)
@@ -143,7 +138,9 @@ class FluxControlNet(PreTrainedModel):
143
138
  # double block
144
139
  double_block_outputs = []
145
140
  for i, block in enumerate(self.blocks):
146
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, condition, image_rotary_emb)
141
+ hidden_states, prompt_emb = block(
142
+ hidden_states, prompt_emb, condition, image_rotary_emb, attn_kwargs=attn_kwargs
143
+ )
147
144
  double_block_outputs.append(self.blocks_proj[i](hidden_states))
148
145
 
149
146
  # apply control scale
@@ -151,24 +148,13 @@ class FluxControlNet(PreTrainedModel):
151
148
  return double_block_outputs, None
152
149
 
153
150
  @classmethod
154
- def from_state_dict(
155
- cls,
156
- state_dict: Dict[str, torch.Tensor],
157
- device: str,
158
- dtype: torch.dtype,
159
- attn_kwargs: Optional[Dict[str, Any]] = None,
160
- ):
151
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
161
152
  if "controlnet_x_embedder.weight" in state_dict:
162
153
  condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
163
154
  else:
164
155
  condition_channels = 64
165
156
 
166
- model = cls(
167
- condition_channels=condition_channels,
168
- attn_kwargs=attn_kwargs,
169
- device="meta",
170
- dtype=dtype,
171
- )
157
+ model = cls(condition_channels=condition_channels, device="meta", dtype=dtype)
172
158
  model.requires_grad_(False)
173
159
  model.load_state_dict(state_dict, assign=True)
174
160
  model.to(device=device, dtype=dtype, non_blocking=True)