sglang 0.4.4.post4__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/layers/attention/flashattention_backend.py +286 -9
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -0
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/model_executor/model_runner.py +1 -0
- sglang/srt/models/llama.py +12 -4
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/mllama4.py +154 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/METADATA +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/RECORD +32 -22
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
sglang/lang/chat_template.py
CHANGED
@@ -294,6 +294,30 @@ register_chat_template(
|
|
294
294
|
)
|
295
295
|
)
|
296
296
|
|
297
|
+
# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
298
|
+
register_chat_template(
|
299
|
+
ChatTemplate(
|
300
|
+
name="llama-4",
|
301
|
+
default_system_prompt=None,
|
302
|
+
role_prefix_and_suffix={
|
303
|
+
"system": (
|
304
|
+
"<|header_start|>system<|header_end|>\n\n",
|
305
|
+
"<|eot|>",
|
306
|
+
),
|
307
|
+
"user": (
|
308
|
+
"<|header_start|>user<|header_end|>\n\n",
|
309
|
+
"<|eot|>",
|
310
|
+
),
|
311
|
+
"assistant": (
|
312
|
+
"<|header_start|>assistant<|header_end|>\n\n",
|
313
|
+
"<|eot|>",
|
314
|
+
),
|
315
|
+
},
|
316
|
+
stop_str=("<|eot|>",),
|
317
|
+
image_token="<|image|>",
|
318
|
+
)
|
319
|
+
)
|
320
|
+
|
297
321
|
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
298
322
|
register_chat_template(
|
299
323
|
ChatTemplate(
|
@@ -65,6 +65,9 @@ class ModelConfig:
|
|
65
65
|
**kwargs,
|
66
66
|
)
|
67
67
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
68
|
+
self.attention_chunk_size = getattr(
|
69
|
+
self.hf_text_config, "attention_chunk_size", None
|
70
|
+
)
|
68
71
|
|
69
72
|
# Check model type
|
70
73
|
self.is_generation = is_generation_model(
|
@@ -467,6 +470,7 @@ multimodal_model_archs = [
|
|
467
470
|
"Gemma3ForConditionalGeneration",
|
468
471
|
"Grok1VForCausalLM",
|
469
472
|
"Grok1AForCausalLM",
|
473
|
+
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
|
470
474
|
"LlavaLlamaForCausalLM",
|
471
475
|
"LlavaMistralForCausalLM",
|
472
476
|
"LlavaQwenForCausalLM",
|
sglang/srt/conversation.py
CHANGED
@@ -33,6 +33,7 @@ class SeparatorStyle(IntEnum):
|
|
33
33
|
ADD_NEW_LINE_SINGLE = auto()
|
34
34
|
LLAMA2 = auto()
|
35
35
|
LLAMA3 = auto()
|
36
|
+
LLAMA4 = auto()
|
36
37
|
CHATGLM = auto()
|
37
38
|
CHATML = auto()
|
38
39
|
CHATINTERN = auto()
|
@@ -156,19 +157,30 @@ class Conversation:
|
|
156
157
|
else:
|
157
158
|
ret += role + ":"
|
158
159
|
return ret
|
160
|
+
elif self.sep_style == SeparatorStyle.LLAMA4:
|
161
|
+
# begin_of_text is added by default
|
162
|
+
if self.system_message:
|
163
|
+
ret = system_prompt
|
164
|
+
else:
|
165
|
+
ret = ""
|
166
|
+
for i, (role, message) in enumerate(self.messages):
|
167
|
+
if message:
|
168
|
+
ret += f"<|header_start|>{role}<|header_end|>\n\n"
|
169
|
+
ret += f"{message.strip()}<|eot|>"
|
170
|
+
else:
|
171
|
+
ret += f"<|header_start|>{role}<|header_end|>\n\n"
|
172
|
+
return ret
|
159
173
|
elif self.sep_style == SeparatorStyle.LLAMA3:
|
160
|
-
ret = "<|begin_of_text|>"
|
161
174
|
if self.system_message:
|
162
|
-
ret
|
175
|
+
ret = system_prompt
|
163
176
|
else:
|
164
|
-
ret
|
177
|
+
ret = ""
|
165
178
|
for i, (role, message) in enumerate(self.messages):
|
166
179
|
if message:
|
167
180
|
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
168
181
|
ret += f"{message.strip()}<|eot_id|>"
|
169
182
|
else:
|
170
183
|
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
171
|
-
# print(ret)
|
172
184
|
return ret
|
173
185
|
elif self.sep_style == SeparatorStyle.LLAMA2:
|
174
186
|
seps = [self.sep, self.sep2]
|
@@ -561,6 +573,19 @@ register_conv_template(
|
|
561
573
|
)
|
562
574
|
)
|
563
575
|
|
576
|
+
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
577
|
+
register_conv_template(
|
578
|
+
Conversation(
|
579
|
+
name="llama-4",
|
580
|
+
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
|
581
|
+
roles=("user", "assistant"),
|
582
|
+
sep_style=SeparatorStyle.LLAMA4,
|
583
|
+
sep="",
|
584
|
+
stop_str=["<|end_of_text|>", "<|eot|>", "<|eom|>"],
|
585
|
+
image_token="<|image|>",
|
586
|
+
)
|
587
|
+
)
|
588
|
+
|
564
589
|
register_conv_template(
|
565
590
|
Conversation(
|
566
591
|
name="chatml",
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import numpy as np
|
4
|
+
|
3
5
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
4
6
|
|
5
7
|
"""
|
@@ -45,6 +47,206 @@ class FlashAttentionMetadata:
|
|
45
47
|
# Sequence lengths for the forward batch
|
46
48
|
cache_seqlens_int32: torch.Tensor = None
|
47
49
|
|
50
|
+
@dataclass
|
51
|
+
class LocalAttentionMetadata:
|
52
|
+
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
|
53
|
+
local_seqused_k: torch.Tensor = None # sequence lengths for local attention
|
54
|
+
local_block_table: torch.Tensor = None # block table for local attention
|
55
|
+
local_max_query_len: int = 0 # max query length for local attention
|
56
|
+
local_max_seq_len: int = 0 # max sequence length for local attention
|
57
|
+
|
58
|
+
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
59
|
+
|
60
|
+
|
61
|
+
# Copied from:
|
62
|
+
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
|
63
|
+
#
|
64
|
+
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
65
|
+
# local attention blocks, where each block is passed to the attention kernel
|
66
|
+
# as an independent local ("virtual") batch item.
|
67
|
+
#
|
68
|
+
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
69
|
+
# q_seqlens = [4, 10, 5]
|
70
|
+
# kv_seqlens = [6, 17, 9]
|
71
|
+
# Then normally for regular attention we would compute with an attention mask
|
72
|
+
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
73
|
+
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
74
|
+
# k_toks > 0 1 2 3 4 5
|
75
|
+
# q_toks v _____________
|
76
|
+
# 0 | 1 1 1
|
77
|
+
# 1 | 1 1 1 1
|
78
|
+
# 2 | 1 1 1 1 1
|
79
|
+
# 3 | 1 1 1 1 1 1
|
80
|
+
#
|
81
|
+
# for local attention (with attn_chunk_size = 4) we would compute with an
|
82
|
+
# attention mask like:
|
83
|
+
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
84
|
+
# k_toks > 0 1 2 3 4 5
|
85
|
+
# q_toks v _____________
|
86
|
+
# 0 | 1 1 1
|
87
|
+
# 1 | 1 1 1 1
|
88
|
+
# 2 | 1
|
89
|
+
# 3 | 1 1
|
90
|
+
#
|
91
|
+
# We can simulate this mask using standard flash-attention by breaking the
|
92
|
+
# sequences into local ("virtual") batches, where each local batch item is a
|
93
|
+
# local attention block, so in this case batch idx 0 would be broken up into:
|
94
|
+
#
|
95
|
+
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
96
|
+
# k_toks > 0 1 2 3
|
97
|
+
# q_toks v _____________
|
98
|
+
# 0 | 1 1 1
|
99
|
+
# 1 | 1 1 1 1
|
100
|
+
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
101
|
+
# k_toks > 4 5
|
102
|
+
# q_toks v _____________
|
103
|
+
# 2 | 1
|
104
|
+
# 3 | 1 1
|
105
|
+
#
|
106
|
+
# e.g. if we have:
|
107
|
+
# attn_chunk_size = 4
|
108
|
+
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
109
|
+
# Then this function would return:
|
110
|
+
# __b0__ ______b1______ __b2__ < orig batch indices
|
111
|
+
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
112
|
+
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
113
|
+
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
114
|
+
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
115
|
+
def make_local_attention_virtual_batches(
|
116
|
+
attn_chunk_size: int,
|
117
|
+
query_start_loc_np: np.ndarray,
|
118
|
+
seq_lens_np: np.ndarray,
|
119
|
+
block_table: torch.Tensor,
|
120
|
+
page_size: int = 0,
|
121
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
122
|
+
"""
|
123
|
+
Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
124
|
+
local attention blocks, where each block is passed to the attention kernel
|
125
|
+
as an independent local ("virtual") batch item.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
attn_chunk_size: Size of local attention chunks
|
129
|
+
query_start_loc_np: Cumulative sum of query lengths (numpy array)
|
130
|
+
seq_lens_np: Sequence lengths (numpy array)
|
131
|
+
block_table: Block table for KV cache
|
132
|
+
page_size: Size of each page in the KV cache
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
seqlens_q_local: Query sequence lengths for local attention
|
136
|
+
cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
|
137
|
+
seqlens_k_local: Key sequence lengths for local attention
|
138
|
+
block_table_local: Block table for local attention
|
139
|
+
"""
|
140
|
+
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
141
|
+
actual_batch_size = seq_lens_np.shape[0]
|
142
|
+
|
143
|
+
# Handle if we are starting in the middle of a local attention block,
|
144
|
+
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
145
|
+
# the number of tokens that are not in the first local attention block and
|
146
|
+
# then we can simply use a cdiv for the rest.
|
147
|
+
# For example if we have:
|
148
|
+
# attn_chunk_size = 4
|
149
|
+
# q_seqlens = [4, 10, 5]
|
150
|
+
# k_seqlens = [6, 17, 9]
|
151
|
+
# Then we would get:
|
152
|
+
# new_tokens_in_first_block = [2, 1, 4]
|
153
|
+
# local_blocks = [2, 4, 2]
|
154
|
+
q_tokens_in_first_block = np.minimum(
|
155
|
+
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
|
156
|
+
).astype(np.int32)
|
157
|
+
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
158
|
+
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
|
159
|
+
|
160
|
+
# Once we know the number of local blocks we can compute the request spans
|
161
|
+
# for each batch idx, we can figure out the number of "virtual" requests we
|
162
|
+
# have to make,
|
163
|
+
# For the above example we would get:
|
164
|
+
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
165
|
+
#
|
166
|
+
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
167
|
+
# (TODO: max a utility to share this code with _prepare_inputs)
|
168
|
+
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
169
|
+
cu_num_blocks = np.cumsum(local_blocks)
|
170
|
+
virtual_batches = cu_num_blocks[-1]
|
171
|
+
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
172
|
+
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
173
|
+
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
174
|
+
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
175
|
+
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
176
|
+
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
177
|
+
# Then we can compute the seqlens_q_local, handling the fact that the
|
178
|
+
# first and last blocks could be partial
|
179
|
+
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
180
|
+
# set the first block since this may be a partial block
|
181
|
+
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
182
|
+
# set the remaining blocks
|
183
|
+
seqlens_q_local[arange > 0] = np.minimum(
|
184
|
+
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
|
185
|
+
)[arange > 0]
|
186
|
+
|
187
|
+
# convert from q_seqlens to cu_seqlens_q
|
188
|
+
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
|
189
|
+
|
190
|
+
# compute the seqlens_k_local,
|
191
|
+
# basically a full local attention block for all but the last block in each
|
192
|
+
# batch
|
193
|
+
# For our example this will be:
|
194
|
+
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
195
|
+
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
|
196
|
+
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
197
|
+
|
198
|
+
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
|
199
|
+
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
|
200
|
+
)
|
201
|
+
# For the example the local attention blocks start at:
|
202
|
+
# _b0_ _____b1_____ _b2_
|
203
|
+
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
204
|
+
block_starts = k_seqstarts_absolute // page_size
|
205
|
+
|
206
|
+
assert attn_chunk_size % page_size == 0, (
|
207
|
+
f"attn_chunk_size {attn_chunk_size} is not "
|
208
|
+
f"divisible by page_size {page_size}"
|
209
|
+
)
|
210
|
+
pages_per_local_batch = attn_chunk_size // page_size
|
211
|
+
|
212
|
+
# Create a block_table for the local attention blocks
|
213
|
+
# For out example if we have a block-table like (assuming page_size=2):
|
214
|
+
# block_table = [
|
215
|
+
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
216
|
+
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
217
|
+
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
218
|
+
# ]
|
219
|
+
# Then for the local batches we would want a block-table like
|
220
|
+
# block_table_local = [
|
221
|
+
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
222
|
+
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
223
|
+
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
224
|
+
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
225
|
+
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
226
|
+
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
227
|
+
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
228
|
+
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
229
|
+
# ]
|
230
|
+
block_indices = np.broadcast_to(
|
231
|
+
np.arange(pages_per_local_batch, dtype=np.int32),
|
232
|
+
(virtual_batches, pages_per_local_batch),
|
233
|
+
) + np.expand_dims(block_starts, axis=1)
|
234
|
+
block_indices = block_indices.flatten()
|
235
|
+
batch_indices = np.repeat(
|
236
|
+
np.arange(actual_batch_size, dtype=np.int32),
|
237
|
+
local_blocks * pages_per_local_batch,
|
238
|
+
)
|
239
|
+
block_table_local = block_table[batch_indices, block_indices].view(
|
240
|
+
virtual_batches, -1
|
241
|
+
)
|
242
|
+
|
243
|
+
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
|
244
|
+
|
245
|
+
|
246
|
+
def cdiv(a: int, b: int) -> int:
|
247
|
+
"""Ceiling division."""
|
248
|
+
return -(a // -b)
|
249
|
+
|
48
250
|
|
49
251
|
class FlashAttentionBackend(AttentionBackend):
|
50
252
|
"""FlashAttention backend implementation.
|
@@ -100,6 +302,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
100
302
|
self.step_id = step_id
|
101
303
|
self.speculative_num_steps = speculative_num_steps
|
102
304
|
|
305
|
+
# Local attention settings
|
306
|
+
self.attention_chunk_size = (
|
307
|
+
model_runner.attention_chunk_size
|
308
|
+
if hasattr(model_runner, "attention_chunk_size")
|
309
|
+
else None
|
310
|
+
)
|
311
|
+
|
103
312
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
104
313
|
"""Initialize forward metadata to cache repetitive calculations."""
|
105
314
|
metadata = FlashAttentionMetadata()
|
@@ -189,6 +398,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
189
398
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
190
399
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
191
400
|
]
|
401
|
+
|
192
402
|
# Precompute cumulative sequence lengths
|
193
403
|
if (
|
194
404
|
any(forward_batch.extend_prefix_lens_cpu)
|
@@ -203,6 +413,51 @@ class FlashAttentionBackend(AttentionBackend):
|
|
203
413
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
204
414
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
205
415
|
|
416
|
+
# Setup local attention if enabled
|
417
|
+
if (
|
418
|
+
self.attention_chunk_size is not None
|
419
|
+
and forward_batch.forward_mode == ForwardMode.EXTEND
|
420
|
+
):
|
421
|
+
# Convert tensors to numpy for local attention processing
|
422
|
+
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
423
|
+
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
|
424
|
+
|
425
|
+
# Adjust attention_chunk_size based on the actual sequence length
|
426
|
+
# to avoid index out of bounds errors
|
427
|
+
max_seq_len = seq_lens_np.max()
|
428
|
+
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
|
429
|
+
# Make sure effective_chunk_size is divisible by page_size
|
430
|
+
effective_chunk_size = (
|
431
|
+
effective_chunk_size // self.page_size
|
432
|
+
) * self.page_size
|
433
|
+
if effective_chunk_size < self.page_size:
|
434
|
+
effective_chunk_size = self.page_size
|
435
|
+
|
436
|
+
# Create local attention metadata
|
437
|
+
(
|
438
|
+
seqlens_q_local_np,
|
439
|
+
cu_seqlens_q_local_np,
|
440
|
+
seqlens_k_local_np,
|
441
|
+
block_table_local,
|
442
|
+
) = make_local_attention_virtual_batches(
|
443
|
+
effective_chunk_size,
|
444
|
+
cu_seqlens_q_np,
|
445
|
+
seq_lens_np,
|
446
|
+
metadata.page_table,
|
447
|
+
self.page_size,
|
448
|
+
)
|
449
|
+
|
450
|
+
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
451
|
+
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
|
452
|
+
device
|
453
|
+
),
|
454
|
+
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
455
|
+
local_block_table=block_table_local,
|
456
|
+
local_max_query_len=seqlens_q_local_np.max(),
|
457
|
+
local_max_seq_len=seqlens_k_local_np.max(),
|
458
|
+
)
|
459
|
+
metadata.local_attn_metadata = local_metadata
|
460
|
+
|
206
461
|
# Precompute strided indices
|
207
462
|
if self.page_size > 1:
|
208
463
|
self.strided_indices = torch.arange(
|
@@ -211,6 +466,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
211
466
|
metadata.page_table = (
|
212
467
|
metadata.page_table[:, self.strided_indices] // self.page_size
|
213
468
|
)
|
469
|
+
|
214
470
|
self.forward_metadata = metadata
|
215
471
|
|
216
472
|
def forward_extend(
|
@@ -254,7 +510,28 @@ class FlashAttentionBackend(AttentionBackend):
|
|
254
510
|
else (-1, -1)
|
255
511
|
)
|
256
512
|
|
257
|
-
|
513
|
+
# Check if we should use local attention
|
514
|
+
use_local_attn = (
|
515
|
+
self.attention_chunk_size is not None
|
516
|
+
and metadata.local_attn_metadata is not None
|
517
|
+
and (hasattr(layer, "use_irope") and layer.use_irope)
|
518
|
+
)
|
519
|
+
|
520
|
+
# Get the appropriate page table based on whether we're using local attention
|
521
|
+
if use_local_attn:
|
522
|
+
local_metadata = metadata.local_attn_metadata
|
523
|
+
page_table = local_metadata.local_block_table
|
524
|
+
cu_seqlens_q = local_metadata.local_query_start_loc
|
525
|
+
cache_seqlens = local_metadata.local_seqused_k
|
526
|
+
max_seqlen_q = local_metadata.local_max_query_len
|
527
|
+
max_seqlen_k = local_metadata.local_max_seq_len
|
528
|
+
else:
|
529
|
+
page_table = metadata.page_table
|
530
|
+
cu_seqlens_q = metadata.cu_seqlens_q
|
531
|
+
cache_seqlens = metadata.cache_seqlens_int32
|
532
|
+
max_seqlen_q = metadata.max_seq_len_q
|
533
|
+
max_seqlen_k = metadata.max_seq_len_k
|
534
|
+
cu_seqlens_k = metadata.cu_seqlens_k
|
258
535
|
|
259
536
|
# Use Flash Attention for prefill
|
260
537
|
if not self.use_mla:
|
@@ -272,10 +549,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
272
549
|
k_cache=key_cache,
|
273
550
|
v_cache=value_cache,
|
274
551
|
page_table=page_table,
|
275
|
-
cache_seqlens=
|
276
|
-
cu_seqlens_q=
|
277
|
-
cu_seqlens_k_new=
|
278
|
-
max_seqlen_q=
|
552
|
+
cache_seqlens=cache_seqlens,
|
553
|
+
cu_seqlens_q=cu_seqlens_q,
|
554
|
+
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
555
|
+
max_seqlen_q=max_seqlen_q,
|
279
556
|
softmax_scale=layer.scaling,
|
280
557
|
causal=True,
|
281
558
|
window_size=window_size,
|
@@ -307,10 +584,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
307
584
|
v_cache=c_kv_cache,
|
308
585
|
qv=q_nope,
|
309
586
|
page_table=page_table,
|
310
|
-
cache_seqlens=
|
311
|
-
cu_seqlens_q=
|
312
|
-
cu_seqlens_k_new=
|
313
|
-
max_seqlen_q=
|
587
|
+
cache_seqlens=cache_seqlens,
|
588
|
+
cu_seqlens_q=cu_seqlens_q,
|
589
|
+
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
590
|
+
max_seqlen_q=max_seqlen_q,
|
314
591
|
softmax_scale=layer.scaling,
|
315
592
|
causal=True,
|
316
593
|
softcap=layer.logit_cap,
|
@@ -23,9 +23,14 @@ def fused_moe_forward_native(
|
|
23
23
|
custom_routing_function: Optional[Callable] = None,
|
24
24
|
correction_bias: Optional[torch.Tensor] = None,
|
25
25
|
activation: str = "silu",
|
26
|
+
apply_router_weight_on_input: bool = False,
|
26
27
|
inplace: bool = True,
|
27
28
|
no_combine: bool = False,
|
28
29
|
) -> torch.Tensor:
|
30
|
+
|
31
|
+
if apply_router_weight_on_input:
|
32
|
+
raise NotImplementedError
|
33
|
+
|
29
34
|
topk_weights, topk_ids = select_experts(
|
30
35
|
hidden_states=x,
|
31
36
|
router_logits=router_logits,
|
sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
ADDED
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 5
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 4
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 2
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 32,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 2
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 32,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 8,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 32,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 5
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 256,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 8,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 16,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 256,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 32,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 32,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|