sglang 0.4.4.post4__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. sglang/lang/chat_template.py +24 -0
  2. sglang/srt/configs/model_config.py +4 -0
  3. sglang/srt/conversation.py +29 -4
  4. sglang/srt/layers/attention/flashattention_backend.py +286 -9
  5. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  6. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  7. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  8. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -3
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  15. sglang/srt/layers/quantization/__init__.py +1 -0
  16. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  17. sglang/srt/layers/quantization/fp8.py +3 -1
  18. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  19. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  20. sglang/srt/layers/radix_attention.py +2 -0
  21. sglang/srt/layers/rotary_embedding.py +63 -0
  22. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  23. sglang/srt/model_executor/model_runner.py +1 -0
  24. sglang/srt/models/llama.py +12 -4
  25. sglang/srt/models/llama4.py +420 -0
  26. sglang/srt/models/mllama4.py +154 -0
  27. sglang/version.py +1 -1
  28. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/METADATA +1 -1
  29. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/RECORD +32 -22
  30. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  31. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  32. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -294,6 +294,30 @@ register_chat_template(
294
294
  )
295
295
  )
296
296
 
297
+ # Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
298
+ register_chat_template(
299
+ ChatTemplate(
300
+ name="llama-4",
301
+ default_system_prompt=None,
302
+ role_prefix_and_suffix={
303
+ "system": (
304
+ "<|header_start|>system<|header_end|>\n\n",
305
+ "<|eot|>",
306
+ ),
307
+ "user": (
308
+ "<|header_start|>user<|header_end|>\n\n",
309
+ "<|eot|>",
310
+ ),
311
+ "assistant": (
312
+ "<|header_start|>assistant<|header_end|>\n\n",
313
+ "<|eot|>",
314
+ ),
315
+ },
316
+ stop_str=("<|eot|>",),
317
+ image_token="<|image|>",
318
+ )
319
+ )
320
+
297
321
  # Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
298
322
  register_chat_template(
299
323
  ChatTemplate(
@@ -65,6 +65,9 @@ class ModelConfig:
65
65
  **kwargs,
66
66
  )
67
67
  self.hf_text_config = get_hf_text_config(self.hf_config)
68
+ self.attention_chunk_size = getattr(
69
+ self.hf_text_config, "attention_chunk_size", None
70
+ )
68
71
 
69
72
  # Check model type
70
73
  self.is_generation = is_generation_model(
@@ -467,6 +470,7 @@ multimodal_model_archs = [
467
470
  "Gemma3ForConditionalGeneration",
468
471
  "Grok1VForCausalLM",
469
472
  "Grok1AForCausalLM",
473
+ # TODO: add multimodal support for "Llama4ForConditionalGeneration",
470
474
  "LlavaLlamaForCausalLM",
471
475
  "LlavaMistralForCausalLM",
472
476
  "LlavaQwenForCausalLM",
@@ -33,6 +33,7 @@ class SeparatorStyle(IntEnum):
33
33
  ADD_NEW_LINE_SINGLE = auto()
34
34
  LLAMA2 = auto()
35
35
  LLAMA3 = auto()
36
+ LLAMA4 = auto()
36
37
  CHATGLM = auto()
37
38
  CHATML = auto()
38
39
  CHATINTERN = auto()
@@ -156,19 +157,30 @@ class Conversation:
156
157
  else:
157
158
  ret += role + ":"
158
159
  return ret
160
+ elif self.sep_style == SeparatorStyle.LLAMA4:
161
+ # begin_of_text is added by default
162
+ if self.system_message:
163
+ ret = system_prompt
164
+ else:
165
+ ret = ""
166
+ for i, (role, message) in enumerate(self.messages):
167
+ if message:
168
+ ret += f"<|header_start|>{role}<|header_end|>\n\n"
169
+ ret += f"{message.strip()}<|eot|>"
170
+ else:
171
+ ret += f"<|header_start|>{role}<|header_end|>\n\n"
172
+ return ret
159
173
  elif self.sep_style == SeparatorStyle.LLAMA3:
160
- ret = "<|begin_of_text|>"
161
174
  if self.system_message:
162
- ret += system_prompt
175
+ ret = system_prompt
163
176
  else:
164
- ret += ""
177
+ ret = ""
165
178
  for i, (role, message) in enumerate(self.messages):
166
179
  if message:
167
180
  ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
168
181
  ret += f"{message.strip()}<|eot_id|>"
169
182
  else:
170
183
  ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
171
- # print(ret)
172
184
  return ret
173
185
  elif self.sep_style == SeparatorStyle.LLAMA2:
174
186
  seps = [self.sep, self.sep2]
@@ -561,6 +573,19 @@ register_conv_template(
561
573
  )
562
574
  )
563
575
 
576
+ # reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
577
+ register_conv_template(
578
+ Conversation(
579
+ name="llama-4",
580
+ system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
581
+ roles=("user", "assistant"),
582
+ sep_style=SeparatorStyle.LLAMA4,
583
+ sep="",
584
+ stop_str=["<|end_of_text|>", "<|eot|>", "<|eom|>"],
585
+ image_token="<|image|>",
586
+ )
587
+ )
588
+
564
589
  register_conv_template(
565
590
  Conversation(
566
591
  name="chatml",
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import numpy as np
4
+
3
5
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
4
6
 
5
7
  """
@@ -45,6 +47,206 @@ class FlashAttentionMetadata:
45
47
  # Sequence lengths for the forward batch
46
48
  cache_seqlens_int32: torch.Tensor = None
47
49
 
50
+ @dataclass
51
+ class LocalAttentionMetadata:
52
+ local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
53
+ local_seqused_k: torch.Tensor = None # sequence lengths for local attention
54
+ local_block_table: torch.Tensor = None # block table for local attention
55
+ local_max_query_len: int = 0 # max query length for local attention
56
+ local_max_seq_len: int = 0 # max sequence length for local attention
57
+
58
+ local_attn_metadata: Optional[LocalAttentionMetadata] = None
59
+
60
+
61
+ # Copied from:
62
+ # https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
63
+ #
64
+ # Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
65
+ # local attention blocks, where each block is passed to the attention kernel
66
+ # as an independent local ("virtual") batch item.
67
+ #
68
+ # For example, if are performing a chunked prefill a batch of 3 sequences:
69
+ # q_seqlens = [4, 10, 5]
70
+ # kv_seqlens = [6, 17, 9]
71
+ # Then normally for regular attention we would compute with an attention mask
72
+ # for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
73
+ # batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
74
+ # k_toks > 0 1 2 3 4 5
75
+ # q_toks v _____________
76
+ # 0 | 1 1 1
77
+ # 1 | 1 1 1 1
78
+ # 2 | 1 1 1 1 1
79
+ # 3 | 1 1 1 1 1 1
80
+ #
81
+ # for local attention (with attn_chunk_size = 4) we would compute with an
82
+ # attention mask like:
83
+ # batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
84
+ # k_toks > 0 1 2 3 4 5
85
+ # q_toks v _____________
86
+ # 0 | 1 1 1
87
+ # 1 | 1 1 1 1
88
+ # 2 | 1
89
+ # 3 | 1 1
90
+ #
91
+ # We can simulate this mask using standard flash-attention by breaking the
92
+ # sequences into local ("virtual") batches, where each local batch item is a
93
+ # local attention block, so in this case batch idx 0 would be broken up into:
94
+ #
95
+ # local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
96
+ # k_toks > 0 1 2 3
97
+ # q_toks v _____________
98
+ # 0 | 1 1 1
99
+ # 1 | 1 1 1 1
100
+ # local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
101
+ # k_toks > 4 5
102
+ # q_toks v _____________
103
+ # 2 | 1
104
+ # 3 | 1 1
105
+ #
106
+ # e.g. if we have:
107
+ # attn_chunk_size = 4
108
+ # query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
109
+ # Then this function would return:
110
+ # __b0__ ______b1______ __b2__ < orig batch indices
111
+ # q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
112
+ # cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
113
+ # seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
114
+ # block_table_local : shape[local_virtual_batches, pages_per_local_batch]
115
+ def make_local_attention_virtual_batches(
116
+ attn_chunk_size: int,
117
+ query_start_loc_np: np.ndarray,
118
+ seq_lens_np: np.ndarray,
119
+ block_table: torch.Tensor,
120
+ page_size: int = 0,
121
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
122
+ """
123
+ Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
124
+ local attention blocks, where each block is passed to the attention kernel
125
+ as an independent local ("virtual") batch item.
126
+
127
+ Args:
128
+ attn_chunk_size: Size of local attention chunks
129
+ query_start_loc_np: Cumulative sum of query lengths (numpy array)
130
+ seq_lens_np: Sequence lengths (numpy array)
131
+ block_table: Block table for KV cache
132
+ page_size: Size of each page in the KV cache
133
+
134
+ Returns:
135
+ seqlens_q_local: Query sequence lengths for local attention
136
+ cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
137
+ seqlens_k_local: Key sequence lengths for local attention
138
+ block_table_local: Block table for local attention
139
+ """
140
+ q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
141
+ actual_batch_size = seq_lens_np.shape[0]
142
+
143
+ # Handle if we are starting in the middle of a local attention block,
144
+ # we assume q_seqlens > 0 (for all elements), for each batch idx we compute
145
+ # the number of tokens that are not in the first local attention block and
146
+ # then we can simply use a cdiv for the rest.
147
+ # For example if we have:
148
+ # attn_chunk_size = 4
149
+ # q_seqlens = [4, 10, 5]
150
+ # k_seqlens = [6, 17, 9]
151
+ # Then we would get:
152
+ # new_tokens_in_first_block = [2, 1, 4]
153
+ # local_blocks = [2, 4, 2]
154
+ q_tokens_in_first_block = np.minimum(
155
+ attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
156
+ ).astype(np.int32)
157
+ tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
158
+ local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
159
+
160
+ # Once we know the number of local blocks we can compute the request spans
161
+ # for each batch idx, we can figure out the number of "virtual" requests we
162
+ # have to make,
163
+ # For the above example we would get:
164
+ # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
165
+ #
166
+ # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
167
+ # (TODO: max a utility to share this code with _prepare_inputs)
168
+ # arange step 1. [2, 4, 2] -> [2, 6, 8]
169
+ cu_num_blocks = np.cumsum(local_blocks)
170
+ virtual_batches = cu_num_blocks[-1]
171
+ # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
172
+ block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
173
+ # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
174
+ arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
175
+ # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
176
+ rarange = np.repeat(local_blocks, local_blocks) - arange - 1
177
+ # Then we can compute the seqlens_q_local, handling the fact that the
178
+ # first and last blocks could be partial
179
+ seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
180
+ # set the first block since this may be a partial block
181
+ seqlens_q_local[arange == 0] = q_tokens_in_first_block
182
+ # set the remaining blocks
183
+ seqlens_q_local[arange > 0] = np.minimum(
184
+ seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
185
+ )[arange > 0]
186
+
187
+ # convert from q_seqlens to cu_seqlens_q
188
+ cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
189
+
190
+ # compute the seqlens_k_local,
191
+ # basically a full local attention block for all but the last block in each
192
+ # batch
193
+ # For our example this will be:
194
+ # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
195
+ seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
196
+ seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
197
+
198
+ k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
199
+ rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
200
+ )
201
+ # For the example the local attention blocks start at:
202
+ # _b0_ _____b1_____ _b2_
203
+ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
204
+ block_starts = k_seqstarts_absolute // page_size
205
+
206
+ assert attn_chunk_size % page_size == 0, (
207
+ f"attn_chunk_size {attn_chunk_size} is not "
208
+ f"divisible by page_size {page_size}"
209
+ )
210
+ pages_per_local_batch = attn_chunk_size // page_size
211
+
212
+ # Create a block_table for the local attention blocks
213
+ # For out example if we have a block-table like (assuming page_size=2):
214
+ # block_table = [
215
+ # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
216
+ # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
217
+ # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
218
+ # ]
219
+ # Then for the local batches we would want a block-table like
220
+ # block_table_local = [
221
+ # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
222
+ # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
223
+ # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
224
+ # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
225
+ # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
226
+ # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
227
+ # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
228
+ # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
229
+ # ]
230
+ block_indices = np.broadcast_to(
231
+ np.arange(pages_per_local_batch, dtype=np.int32),
232
+ (virtual_batches, pages_per_local_batch),
233
+ ) + np.expand_dims(block_starts, axis=1)
234
+ block_indices = block_indices.flatten()
235
+ batch_indices = np.repeat(
236
+ np.arange(actual_batch_size, dtype=np.int32),
237
+ local_blocks * pages_per_local_batch,
238
+ )
239
+ block_table_local = block_table[batch_indices, block_indices].view(
240
+ virtual_batches, -1
241
+ )
242
+
243
+ return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
244
+
245
+
246
+ def cdiv(a: int, b: int) -> int:
247
+ """Ceiling division."""
248
+ return -(a // -b)
249
+
48
250
 
49
251
  class FlashAttentionBackend(AttentionBackend):
50
252
  """FlashAttention backend implementation.
@@ -100,6 +302,13 @@ class FlashAttentionBackend(AttentionBackend):
100
302
  self.step_id = step_id
101
303
  self.speculative_num_steps = speculative_num_steps
102
304
 
305
+ # Local attention settings
306
+ self.attention_chunk_size = (
307
+ model_runner.attention_chunk_size
308
+ if hasattr(model_runner, "attention_chunk_size")
309
+ else None
310
+ )
311
+
103
312
  def init_forward_metadata(self, forward_batch: ForwardBatch):
104
313
  """Initialize forward metadata to cache repetitive calculations."""
105
314
  metadata = FlashAttentionMetadata()
@@ -189,6 +398,7 @@ class FlashAttentionBackend(AttentionBackend):
189
398
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
190
399
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
191
400
  ]
401
+
192
402
  # Precompute cumulative sequence lengths
193
403
  if (
194
404
  any(forward_batch.extend_prefix_lens_cpu)
@@ -203,6 +413,51 @@ class FlashAttentionBackend(AttentionBackend):
203
413
  metadata.cu_seqlens_q = metadata.cu_seqlens_k
204
414
  metadata.max_seq_len_q = metadata.max_seq_len_k
205
415
 
416
+ # Setup local attention if enabled
417
+ if (
418
+ self.attention_chunk_size is not None
419
+ and forward_batch.forward_mode == ForwardMode.EXTEND
420
+ ):
421
+ # Convert tensors to numpy for local attention processing
422
+ cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
423
+ seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
424
+
425
+ # Adjust attention_chunk_size based on the actual sequence length
426
+ # to avoid index out of bounds errors
427
+ max_seq_len = seq_lens_np.max()
428
+ effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
429
+ # Make sure effective_chunk_size is divisible by page_size
430
+ effective_chunk_size = (
431
+ effective_chunk_size // self.page_size
432
+ ) * self.page_size
433
+ if effective_chunk_size < self.page_size:
434
+ effective_chunk_size = self.page_size
435
+
436
+ # Create local attention metadata
437
+ (
438
+ seqlens_q_local_np,
439
+ cu_seqlens_q_local_np,
440
+ seqlens_k_local_np,
441
+ block_table_local,
442
+ ) = make_local_attention_virtual_batches(
443
+ effective_chunk_size,
444
+ cu_seqlens_q_np,
445
+ seq_lens_np,
446
+ metadata.page_table,
447
+ self.page_size,
448
+ )
449
+
450
+ local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
451
+ local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
452
+ device
453
+ ),
454
+ local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
455
+ local_block_table=block_table_local,
456
+ local_max_query_len=seqlens_q_local_np.max(),
457
+ local_max_seq_len=seqlens_k_local_np.max(),
458
+ )
459
+ metadata.local_attn_metadata = local_metadata
460
+
206
461
  # Precompute strided indices
207
462
  if self.page_size > 1:
208
463
  self.strided_indices = torch.arange(
@@ -211,6 +466,7 @@ class FlashAttentionBackend(AttentionBackend):
211
466
  metadata.page_table = (
212
467
  metadata.page_table[:, self.strided_indices] // self.page_size
213
468
  )
469
+
214
470
  self.forward_metadata = metadata
215
471
 
216
472
  def forward_extend(
@@ -254,7 +510,28 @@ class FlashAttentionBackend(AttentionBackend):
254
510
  else (-1, -1)
255
511
  )
256
512
 
257
- page_table = metadata.page_table
513
+ # Check if we should use local attention
514
+ use_local_attn = (
515
+ self.attention_chunk_size is not None
516
+ and metadata.local_attn_metadata is not None
517
+ and (hasattr(layer, "use_irope") and layer.use_irope)
518
+ )
519
+
520
+ # Get the appropriate page table based on whether we're using local attention
521
+ if use_local_attn:
522
+ local_metadata = metadata.local_attn_metadata
523
+ page_table = local_metadata.local_block_table
524
+ cu_seqlens_q = local_metadata.local_query_start_loc
525
+ cache_seqlens = local_metadata.local_seqused_k
526
+ max_seqlen_q = local_metadata.local_max_query_len
527
+ max_seqlen_k = local_metadata.local_max_seq_len
528
+ else:
529
+ page_table = metadata.page_table
530
+ cu_seqlens_q = metadata.cu_seqlens_q
531
+ cache_seqlens = metadata.cache_seqlens_int32
532
+ max_seqlen_q = metadata.max_seq_len_q
533
+ max_seqlen_k = metadata.max_seq_len_k
534
+ cu_seqlens_k = metadata.cu_seqlens_k
258
535
 
259
536
  # Use Flash Attention for prefill
260
537
  if not self.use_mla:
@@ -272,10 +549,10 @@ class FlashAttentionBackend(AttentionBackend):
272
549
  k_cache=key_cache,
273
550
  v_cache=value_cache,
274
551
  page_table=page_table,
275
- cache_seqlens=metadata.cache_seqlens_int32,
276
- cu_seqlens_q=metadata.cu_seqlens_q,
277
- cu_seqlens_k_new=metadata.cu_seqlens_k,
278
- max_seqlen_q=metadata.max_seq_len_q,
552
+ cache_seqlens=cache_seqlens,
553
+ cu_seqlens_q=cu_seqlens_q,
554
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
555
+ max_seqlen_q=max_seqlen_q,
279
556
  softmax_scale=layer.scaling,
280
557
  causal=True,
281
558
  window_size=window_size,
@@ -307,10 +584,10 @@ class FlashAttentionBackend(AttentionBackend):
307
584
  v_cache=c_kv_cache,
308
585
  qv=q_nope,
309
586
  page_table=page_table,
310
- cache_seqlens=metadata.cache_seqlens_int32,
311
- cu_seqlens_q=metadata.cu_seqlens_q,
312
- cu_seqlens_k_new=metadata.cu_seqlens_k,
313
- max_seqlen_q=metadata.max_seq_len_q,
587
+ cache_seqlens=cache_seqlens,
588
+ cu_seqlens_q=cu_seqlens_q,
589
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
590
+ max_seqlen_q=max_seqlen_q,
314
591
  softmax_scale=layer.scaling,
315
592
  causal=True,
316
593
  softcap=layer.logit_cap,
@@ -23,9 +23,14 @@ def fused_moe_forward_native(
23
23
  custom_routing_function: Optional[Callable] = None,
24
24
  correction_bias: Optional[torch.Tensor] = None,
25
25
  activation: str = "silu",
26
+ apply_router_weight_on_input: bool = False,
26
27
  inplace: bool = True,
27
28
  no_combine: bool = False,
28
29
  ) -> torch.Tensor:
30
+
31
+ if apply_router_weight_on_input:
32
+ raise NotImplementedError
33
+
29
34
  topk_weights, topk_ids = select_experts(
30
35
  hidden_states=x,
31
36
  router_logits=router_logits,
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 32,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 32,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 8,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 32,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 5
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 8,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 16,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 256,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 32,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }