sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ def _fwd_kernel(
|
|
51
51
|
kv_indices,
|
52
52
|
mask_ptr,
|
53
53
|
mask_indptr,
|
54
|
+
sink_ptr,
|
54
55
|
sm_scale,
|
55
56
|
kv_group_num,
|
56
57
|
stride_qbs,
|
@@ -78,6 +79,7 @@ def _fwd_kernel(
|
|
78
79
|
IS_CAUSAL: tl.constexpr,
|
79
80
|
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
80
81
|
STORE_TRANSPOSE: tl.constexpr,
|
82
|
+
HAS_SINK: tl.constexpr,
|
81
83
|
):
|
82
84
|
cur_seq = tl.program_id(0)
|
83
85
|
cur_head = tl.program_id(1)
|
@@ -132,38 +134,6 @@ def _fwd_kernel(
|
|
132
134
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
133
135
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
134
136
|
|
135
|
-
offs_kv_loc = tl.load(
|
136
|
-
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
137
|
-
)
|
138
|
-
|
139
|
-
# load k in transposed way
|
140
|
-
offs_buf_k = (
|
141
|
-
offs_kv_loc[None, :] * stride_buf_kbs
|
142
|
-
+ cur_kv_head * stride_buf_kh
|
143
|
-
+ offs_d[:, None]
|
144
|
-
)
|
145
|
-
k = tl.load(
|
146
|
-
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
147
|
-
)
|
148
|
-
|
149
|
-
qk = tl.dot(q.to(k.dtype), k)
|
150
|
-
if BLOCK_DPE > 0:
|
151
|
-
offs_kpe = (
|
152
|
-
offs_kv_loc[None, :] * stride_buf_kbs
|
153
|
-
+ cur_kv_head * stride_buf_kh
|
154
|
-
+ offs_dpe[:, None]
|
155
|
-
)
|
156
|
-
kpe = tl.load(
|
157
|
-
K_Buffer + offs_kpe,
|
158
|
-
mask=mask_n[None, :],
|
159
|
-
other=0.0,
|
160
|
-
)
|
161
|
-
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
162
|
-
qk *= sm_scale
|
163
|
-
|
164
|
-
if logit_cap > 0:
|
165
|
-
qk = logit_cap * tanh(qk / logit_cap)
|
166
|
-
|
167
137
|
final_mask = mask_m[:, None] & mask_n[None, :]
|
168
138
|
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
169
139
|
custom_mask = tl.load(
|
@@ -178,29 +148,77 @@ def _fwd_kernel(
|
|
178
148
|
final_mask &= custom_mask
|
179
149
|
if SLIDING_WINDOW_SIZE > 0:
|
180
150
|
# Add mask where q_id <= kv_id + sliding_window_size
|
181
|
-
|
182
|
-
|
183
|
-
|
151
|
+
# q_id = prefix_len + cur_m, kv_id = cur_n
|
152
|
+
window_mask = (
|
153
|
+
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
|
154
|
+
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
|
184
155
|
final_mask &= window_mask
|
185
|
-
qk = tl.where(final_mask, qk, float("-inf"))
|
186
156
|
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
deno = deno * re_scale + tl.sum(p, 1)
|
157
|
+
SKIP_TILE = False
|
158
|
+
if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0:
|
159
|
+
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
|
191
160
|
|
192
|
-
|
193
|
-
offs_kv_loc
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
199
|
-
)
|
200
|
-
p = p.to(v.dtype)
|
201
|
-
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
161
|
+
if not SKIP_TILE:
|
162
|
+
offs_kv_loc = tl.load(
|
163
|
+
kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
|
164
|
+
mask=mask_n,
|
165
|
+
other=0,
|
166
|
+
)
|
202
167
|
|
203
|
-
|
168
|
+
# load k in transposed way
|
169
|
+
offs_buf_k = (
|
170
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
171
|
+
+ cur_kv_head * stride_buf_kh
|
172
|
+
+ offs_d[:, None]
|
173
|
+
)
|
174
|
+
k = tl.load(
|
175
|
+
K_Buffer + offs_buf_k,
|
176
|
+
mask=(mask_n[None, :]) & (mask_d[:, None]),
|
177
|
+
other=0.0,
|
178
|
+
)
|
179
|
+
|
180
|
+
qk = tl.dot(q.to(k.dtype), k)
|
181
|
+
if BLOCK_DPE > 0:
|
182
|
+
offs_kpe = (
|
183
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
184
|
+
+ cur_kv_head * stride_buf_kh
|
185
|
+
+ offs_dpe[:, None]
|
186
|
+
)
|
187
|
+
kpe = tl.load(
|
188
|
+
K_Buffer + offs_kpe,
|
189
|
+
mask=mask_n[None, :],
|
190
|
+
other=0.0,
|
191
|
+
)
|
192
|
+
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
193
|
+
qk *= sm_scale
|
194
|
+
|
195
|
+
if logit_cap > 0:
|
196
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
197
|
+
|
198
|
+
qk = tl.where(final_mask, qk, float("-inf"))
|
199
|
+
|
200
|
+
row_max = tl.max(qk, 1)
|
201
|
+
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
202
|
+
n_e_max = tl.maximum(row_max_fixed, e_max)
|
203
|
+
|
204
|
+
re_scale = tl.exp(e_max - n_e_max)
|
205
|
+
p = tl.exp(qk - n_e_max[:, None])
|
206
|
+
deno = deno * re_scale + tl.sum(p, 1)
|
207
|
+
|
208
|
+
offs_buf_v = (
|
209
|
+
offs_kv_loc[:, None] * stride_buf_vbs
|
210
|
+
+ cur_kv_head * stride_buf_vh
|
211
|
+
+ offs_dv[None, :]
|
212
|
+
)
|
213
|
+
v = tl.load(
|
214
|
+
V_Buffer + offs_buf_v,
|
215
|
+
mask=mask_n[:, None] & mask_dv[None, :],
|
216
|
+
other=0.0,
|
217
|
+
)
|
218
|
+
p = p.to(v.dtype)
|
219
|
+
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
220
|
+
|
221
|
+
e_max = n_e_max
|
204
222
|
|
205
223
|
# stage 2: compute the triangle part
|
206
224
|
|
@@ -213,35 +231,7 @@ def _fwd_kernel(
|
|
213
231
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
214
232
|
mask_n = (start_n + offs_n) < cur_block_m_end
|
215
233
|
|
216
|
-
|
217
|
-
offs_k = (
|
218
|
-
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
219
|
-
+ cur_kv_head * stride_kh
|
220
|
-
+ offs_d[:, None]
|
221
|
-
)
|
222
|
-
k = tl.load(
|
223
|
-
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
224
|
-
)
|
225
|
-
|
226
|
-
qk = tl.dot(q, k, out_dtype=tl.float32)
|
227
|
-
if BLOCK_DPE > 0:
|
228
|
-
offs_kpe = (
|
229
|
-
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
230
|
-
+ cur_kv_head * stride_kh
|
231
|
-
+ offs_dpe[:, None]
|
232
|
-
)
|
233
|
-
kpe = tl.load(
|
234
|
-
K_Extend + offs_kpe,
|
235
|
-
mask=mask_n[None, :],
|
236
|
-
other=0.0,
|
237
|
-
)
|
238
|
-
qk += tl.dot(qpe, kpe)
|
239
|
-
|
240
|
-
qk *= sm_scale
|
241
|
-
|
242
|
-
if logit_cap > 0:
|
243
|
-
qk = logit_cap * tanh(qk / logit_cap)
|
244
|
-
|
234
|
+
final_mask = mask_m[:, None] & mask_n[None, :]
|
245
235
|
if USE_CUSTOM_MASK:
|
246
236
|
custom_mask = tl.load(
|
247
237
|
mask_ptr
|
@@ -254,34 +244,84 @@ def _fwd_kernel(
|
|
254
244
|
other=0,
|
255
245
|
)
|
256
246
|
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
257
|
-
|
247
|
+
final_mask &= custom_mask
|
258
248
|
elif IS_CAUSAL:
|
259
249
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
260
250
|
start_n + offs_n[None, :]
|
261
251
|
)
|
262
252
|
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
263
|
-
|
253
|
+
final_mask &= mask_causual
|
264
254
|
else:
|
265
255
|
mask_non_causal = mask_m[:, None] & mask_n[None, :]
|
266
|
-
|
256
|
+
final_mask &= mask_non_causal
|
257
|
+
|
258
|
+
if SLIDING_WINDOW_SIZE > 0:
|
259
|
+
# Add mask where q_id <= kv_id + sliding_window_size
|
260
|
+
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
|
261
|
+
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
|
262
|
+
)
|
263
|
+
final_mask &= window_mask
|
267
264
|
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
deno = deno * re_scale + tl.sum(p, 1)
|
265
|
+
SKIP_TILE = False
|
266
|
+
if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
|
267
|
+
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
|
272
268
|
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
269
|
+
if not SKIP_TILE:
|
270
|
+
# load k in transposed way
|
271
|
+
offs_k = (
|
272
|
+
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
273
|
+
+ cur_kv_head * stride_kh
|
274
|
+
+ offs_d[:, None]
|
275
|
+
)
|
276
|
+
k = tl.load(
|
277
|
+
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
278
|
+
)
|
283
279
|
|
284
|
-
|
280
|
+
qk = tl.dot(q, k, out_dtype=tl.float32)
|
281
|
+
if BLOCK_DPE > 0:
|
282
|
+
offs_kpe = (
|
283
|
+
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
284
|
+
+ cur_kv_head * stride_kh
|
285
|
+
+ offs_dpe[:, None]
|
286
|
+
)
|
287
|
+
kpe = tl.load(
|
288
|
+
K_Extend + offs_kpe,
|
289
|
+
mask=mask_n[None, :],
|
290
|
+
other=0.0,
|
291
|
+
)
|
292
|
+
qk += tl.dot(qpe, kpe)
|
293
|
+
|
294
|
+
qk *= sm_scale
|
295
|
+
|
296
|
+
if logit_cap > 0:
|
297
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
298
|
+
|
299
|
+
qk = tl.where(final_mask, qk, float("-inf"))
|
300
|
+
|
301
|
+
row_max = tl.max(qk, 1)
|
302
|
+
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
303
|
+
n_e_max = tl.maximum(row_max_fixed, e_max)
|
304
|
+
|
305
|
+
re_scale = tl.exp(e_max - n_e_max)
|
306
|
+
p = tl.exp(qk - n_e_max[:, None])
|
307
|
+
deno = deno * re_scale + tl.sum(p, 1)
|
308
|
+
|
309
|
+
offs_v = (
|
310
|
+
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
311
|
+
+ cur_kv_head * stride_vh
|
312
|
+
+ offs_dv[None, :]
|
313
|
+
)
|
314
|
+
v = tl.load(
|
315
|
+
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
316
|
+
)
|
317
|
+
p = p.to(v.dtype)
|
318
|
+
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
319
|
+
|
320
|
+
e_max = n_e_max
|
321
|
+
|
322
|
+
if HAS_SINK:
|
323
|
+
cur_sink = tl.load(sink_ptr + cur_head)
|
324
|
+
deno += tl.exp(cur_sink - e_max)
|
285
325
|
|
286
326
|
offs_o = (
|
287
327
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
@@ -321,6 +361,7 @@ def extend_attention_fwd(
|
|
321
361
|
logit_cap=0.0,
|
322
362
|
skip_prefix_custom_mask=True,
|
323
363
|
sliding_window_size=-1,
|
364
|
+
sinks=None,
|
324
365
|
):
|
325
366
|
"""
|
326
367
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
@@ -386,6 +427,8 @@ def extend_attention_fwd(
|
|
386
427
|
# Skip custom mask for prefix part
|
387
428
|
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
388
429
|
|
430
|
+
HAS_SINK = sinks is not None
|
431
|
+
|
389
432
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
390
433
|
num_stages = 1
|
391
434
|
|
@@ -405,6 +448,7 @@ def extend_attention_fwd(
|
|
405
448
|
kv_indices,
|
406
449
|
custom_mask,
|
407
450
|
mask_indptr,
|
451
|
+
sinks,
|
408
452
|
sm_scale,
|
409
453
|
kv_group_num,
|
410
454
|
q_extend.stride(0),
|
@@ -431,6 +475,7 @@ def extend_attention_fwd(
|
|
431
475
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
432
476
|
IS_CAUSAL=is_causal,
|
433
477
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
478
|
+
HAS_SINK=HAS_SINK,
|
434
479
|
STORE_TRANSPOSE=_is_hip,
|
435
480
|
num_warps=num_warps,
|
436
481
|
num_stages=num_stages,
|
@@ -0,0 +1,332 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support attention backend for TRTLLM MHA kernels from flashinfer.
|
5
|
+
The kernel supports sm100 only, with sliding window and attention sink features.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from typing import TYPE_CHECKING, Optional
|
10
|
+
|
11
|
+
import torch
|
12
|
+
|
13
|
+
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
14
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
15
|
+
from sglang.srt.utils import is_flashinfer_available
|
16
|
+
|
17
|
+
if is_flashinfer_available():
|
18
|
+
import flashinfer
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
22
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
23
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
24
|
+
|
25
|
+
# Constants
|
26
|
+
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
27
|
+
|
28
|
+
# Reuse this workspace buffer across all TRTLLM MHA wrappers
|
29
|
+
global_workspace_buffer = None
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class TRTLLMMHAMetadata:
|
34
|
+
# Sequence lengths for the forward batch
|
35
|
+
cache_seqlens_int32: torch.Tensor = None
|
36
|
+
# Maximum sequence length for query
|
37
|
+
max_seq_len_q: int = 1
|
38
|
+
# Maximum sequence length for key
|
39
|
+
max_seq_len_k: int = 0
|
40
|
+
# Cumulative sequence lengths for `query
|
41
|
+
cu_seqlens_q: torch.Tensor = None
|
42
|
+
# Cumulative sequence lengths for key
|
43
|
+
cu_seqlens_k: torch.Tensor = None
|
44
|
+
# Page table, the index of KV Cache Tables/Blocks
|
45
|
+
page_table: torch.Tensor = None
|
46
|
+
|
47
|
+
|
48
|
+
class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
49
|
+
"""TRTLLM MHA attention kernel from flashinfer."""
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
model_runner: ModelRunner,
|
54
|
+
skip_prefill: bool = False,
|
55
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
56
|
+
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
57
|
+
):
|
58
|
+
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
|
59
|
+
|
60
|
+
config = model_runner.model_config
|
61
|
+
|
62
|
+
# MHA-specific dimensions
|
63
|
+
self.max_context_len = model_runner.model_config.context_len
|
64
|
+
self.hidden_size = config.hidden_size
|
65
|
+
|
66
|
+
# Runtime parameters
|
67
|
+
self.data_type = model_runner.kv_cache_dtype
|
68
|
+
self.q_data_type = model_runner.dtype
|
69
|
+
self.page_size = model_runner.page_size
|
70
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
71
|
+
self.device = model_runner.device
|
72
|
+
|
73
|
+
# Workspace allocation
|
74
|
+
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
75
|
+
# Allocate buffers
|
76
|
+
global global_workspace_buffer
|
77
|
+
if global_workspace_buffer is None:
|
78
|
+
global_workspace_buffer = torch.empty(
|
79
|
+
self.workspace_size,
|
80
|
+
dtype=torch.uint8,
|
81
|
+
device=model_runner.device,
|
82
|
+
)
|
83
|
+
self.workspace_buffer = global_workspace_buffer
|
84
|
+
|
85
|
+
# CUDA graph state
|
86
|
+
self.decode_cuda_graph_metadata = {}
|
87
|
+
|
88
|
+
# Forward metadata
|
89
|
+
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
|
90
|
+
|
91
|
+
def init_cuda_graph_state(
|
92
|
+
self,
|
93
|
+
max_bs: int,
|
94
|
+
max_num_tokens: int,
|
95
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
96
|
+
):
|
97
|
+
"""Initialize CUDA graph state for TRTLLM MHA."""
|
98
|
+
self.decode_cuda_graph_metadata = {
|
99
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
100
|
+
"page_table": torch.zeros(
|
101
|
+
max_bs,
|
102
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
103
|
+
dtype=torch.int32,
|
104
|
+
device=self.device,
|
105
|
+
),
|
106
|
+
"strided_indices": torch.arange(
|
107
|
+
0, self.max_context_len, self.page_size, device=self.device
|
108
|
+
),
|
109
|
+
}
|
110
|
+
|
111
|
+
def init_forward_metadata_capture_cuda_graph(
|
112
|
+
self,
|
113
|
+
bs: int,
|
114
|
+
num_tokens: int,
|
115
|
+
req_pool_indices: torch.Tensor,
|
116
|
+
seq_lens: torch.Tensor,
|
117
|
+
encoder_lens: Optional[torch.Tensor],
|
118
|
+
forward_mode: ForwardMode,
|
119
|
+
spec_info: Optional[SpecInfo],
|
120
|
+
):
|
121
|
+
"""Initialize metadata for CUDA graph capture."""
|
122
|
+
metadata = TRTLLMMHAMetadata()
|
123
|
+
|
124
|
+
# Get sequence information
|
125
|
+
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
|
126
|
+
|
127
|
+
# Precompute maximum sequence length
|
128
|
+
metadata.max_seq_len_k = self.max_context_len
|
129
|
+
|
130
|
+
# Precompute page table
|
131
|
+
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
132
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
133
|
+
self.forward_metadata = metadata
|
134
|
+
|
135
|
+
def init_forward_metadata_replay_cuda_graph(
|
136
|
+
self,
|
137
|
+
bs: int,
|
138
|
+
req_pool_indices: torch.Tensor,
|
139
|
+
seq_lens: torch.Tensor,
|
140
|
+
seq_lens_sum: int,
|
141
|
+
encoder_lens: Optional[torch.Tensor],
|
142
|
+
forward_mode: ForwardMode,
|
143
|
+
spec_info: Optional[SpecInfo],
|
144
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
145
|
+
):
|
146
|
+
"""Replay CUDA graph with new inputs."""
|
147
|
+
seq_lens = seq_lens[:bs]
|
148
|
+
seq_lens_cpu = seq_lens_cpu[:bs]
|
149
|
+
req_pool_indices = req_pool_indices[:bs]
|
150
|
+
device = seq_lens.device
|
151
|
+
metadata = None
|
152
|
+
|
153
|
+
# Normal Decode
|
154
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
155
|
+
max_len = seq_lens_cpu.max().item()
|
156
|
+
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
157
|
+
metadata.max_seq_len_k = self.max_context_len
|
158
|
+
|
159
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
160
|
+
page_indices = self.req_to_token[
|
161
|
+
req_pool_indices[:, None],
|
162
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :],
|
163
|
+
]
|
164
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
165
|
+
self.forward_metadata = metadata
|
166
|
+
|
167
|
+
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
168
|
+
"""Get the fill value for sequence lengths in CUDA graph."""
|
169
|
+
return 1
|
170
|
+
|
171
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
172
|
+
"""Initialize the metadata for a forward pass."""
|
173
|
+
|
174
|
+
metadata = TRTLLMMHAMetadata()
|
175
|
+
seqlens_in_batch = forward_batch.seq_lens
|
176
|
+
batch_size = forward_batch.batch_size
|
177
|
+
device = seqlens_in_batch.device
|
178
|
+
|
179
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
180
|
+
# Normal Decode
|
181
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
182
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
183
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
184
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
185
|
+
]
|
186
|
+
else:
|
187
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
188
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
189
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
190
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
191
|
+
)
|
192
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
193
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
194
|
+
]
|
195
|
+
|
196
|
+
if any(forward_batch.extend_prefix_lens_cpu):
|
197
|
+
extend_seq_lens = forward_batch.extend_seq_lens
|
198
|
+
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
199
|
+
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
200
|
+
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
201
|
+
)
|
202
|
+
else:
|
203
|
+
metadata.max_seq_len_q = metadata.max_seq_len_k
|
204
|
+
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
205
|
+
|
206
|
+
# Convert the page table to a strided format
|
207
|
+
if self.page_size > 1:
|
208
|
+
self.strided_indices = torch.arange(
|
209
|
+
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
210
|
+
)
|
211
|
+
metadata.page_table = (
|
212
|
+
metadata.page_table[:, self.strided_indices] // self.page_size
|
213
|
+
)
|
214
|
+
|
215
|
+
self.forward_metadata = metadata
|
216
|
+
|
217
|
+
def forward_decode(
|
218
|
+
self,
|
219
|
+
q: torch.Tensor,
|
220
|
+
k: torch.Tensor,
|
221
|
+
v: torch.Tensor,
|
222
|
+
layer: RadixAttention,
|
223
|
+
forward_batch: ForwardBatch,
|
224
|
+
save_kv_cache: bool = True,
|
225
|
+
**kwargs,
|
226
|
+
) -> torch.Tensor:
|
227
|
+
"""Run forward for decode using TRTLLM MHA kernel."""
|
228
|
+
cache_loc = forward_batch.out_cache_loc
|
229
|
+
if save_kv_cache and k is not None:
|
230
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
231
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
232
|
+
)
|
233
|
+
|
234
|
+
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
235
|
+
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
236
|
+
# shape conversion:
|
237
|
+
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
|
238
|
+
k_cache = k_cache.view(
|
239
|
+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
240
|
+
).permute(0, 2, 1, 3)
|
241
|
+
v_cache = v_cache.view(
|
242
|
+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
243
|
+
).permute(0, 2, 1, 3)
|
244
|
+
kv_cache = (k_cache, v_cache)
|
245
|
+
|
246
|
+
# TODO: add support for quantization
|
247
|
+
q_scale = 1.0
|
248
|
+
k_scale = (
|
249
|
+
layer.k_scale_float
|
250
|
+
if getattr(layer, "k_scale_float", None) is not None
|
251
|
+
else 1.0
|
252
|
+
)
|
253
|
+
bmm1_scale = q_scale * k_scale * layer.scaling
|
254
|
+
bmm2_scale = 1.0
|
255
|
+
# sink: additional value per head in the denominator of the softmax.
|
256
|
+
attention_sink = kwargs.get("sinks", None)
|
257
|
+
|
258
|
+
# Call TRT-LLM kernel
|
259
|
+
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
|
260
|
+
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
261
|
+
query=q,
|
262
|
+
kv_cache=kv_cache,
|
263
|
+
workspace_buffer=self.workspace_buffer,
|
264
|
+
block_tables=self.forward_metadata.page_table,
|
265
|
+
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
266
|
+
max_seq_len=self.forward_metadata.max_seq_len_k,
|
267
|
+
bmm1_scale=bmm1_scale,
|
268
|
+
bmm2_scale=bmm2_scale,
|
269
|
+
window_left=layer.sliding_window_size,
|
270
|
+
# TODO: add attention_sink operation or nvfp4 scale factor if needed
|
271
|
+
sinks=attention_sink,
|
272
|
+
)
|
273
|
+
|
274
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
275
|
+
|
276
|
+
def forward_extend(
|
277
|
+
self,
|
278
|
+
q: torch.Tensor,
|
279
|
+
k: torch.Tensor,
|
280
|
+
v: torch.Tensor,
|
281
|
+
layer: RadixAttention,
|
282
|
+
forward_batch: ForwardBatch,
|
283
|
+
save_kv_cache=True,
|
284
|
+
**kwargs,
|
285
|
+
):
|
286
|
+
cache_loc = forward_batch.out_cache_loc
|
287
|
+
if save_kv_cache and k is not None:
|
288
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
289
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
290
|
+
)
|
291
|
+
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
292
|
+
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
|
293
|
+
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
294
|
+
k_cache = k_cache.view(
|
295
|
+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
296
|
+
).permute(0, 2, 1, 3)
|
297
|
+
v_cache = v_cache.view(
|
298
|
+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
299
|
+
).permute(0, 2, 1, 3)
|
300
|
+
kv_cache = (k_cache, v_cache)
|
301
|
+
|
302
|
+
# sink: additional value per head in the denominator of the softmax.
|
303
|
+
attention_sink = kwargs.get("sinks", None)
|
304
|
+
# TODO: add support for quantization
|
305
|
+
q_scale = 1.0
|
306
|
+
k_scale = (
|
307
|
+
layer.k_scale_float
|
308
|
+
if getattr(layer, "k_scale_float", None) is not None
|
309
|
+
else 1.0
|
310
|
+
)
|
311
|
+
bmm1_scale = q_scale * k_scale * layer.scaling
|
312
|
+
bmm2_scale = 1.0
|
313
|
+
|
314
|
+
o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
315
|
+
query=q,
|
316
|
+
kv_cache=kv_cache,
|
317
|
+
workspace_buffer=self.workspace_buffer,
|
318
|
+
block_tables=self.forward_metadata.page_table,
|
319
|
+
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
320
|
+
max_q_len=self.forward_metadata.max_seq_len_q,
|
321
|
+
max_kv_len=self.forward_metadata.max_seq_len_k,
|
322
|
+
bmm1_scale=bmm1_scale,
|
323
|
+
bmm2_scale=bmm2_scale,
|
324
|
+
batch_size=forward_batch.batch_size,
|
325
|
+
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
|
326
|
+
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
|
327
|
+
window_left=layer.sliding_window_size,
|
328
|
+
# TODO: add attention_sink operation or nvfp4 scale factor if needed
|
329
|
+
sinks=attention_sink,
|
330
|
+
)
|
331
|
+
|
332
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|