sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ def _fwd_kernel(
51
51
  kv_indices,
52
52
  mask_ptr,
53
53
  mask_indptr,
54
+ sink_ptr,
54
55
  sm_scale,
55
56
  kv_group_num,
56
57
  stride_qbs,
@@ -78,6 +79,7 @@ def _fwd_kernel(
78
79
  IS_CAUSAL: tl.constexpr,
79
80
  SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
80
81
  STORE_TRANSPOSE: tl.constexpr,
82
+ HAS_SINK: tl.constexpr,
81
83
  ):
82
84
  cur_seq = tl.program_id(0)
83
85
  cur_head = tl.program_id(1)
@@ -132,38 +134,6 @@ def _fwd_kernel(
132
134
  start_n = tl.multiple_of(start_n, BLOCK_N)
133
135
  mask_n = (start_n + offs_n) < cur_seq_len_prefix
134
136
 
135
- offs_kv_loc = tl.load(
136
- kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
137
- )
138
-
139
- # load k in transposed way
140
- offs_buf_k = (
141
- offs_kv_loc[None, :] * stride_buf_kbs
142
- + cur_kv_head * stride_buf_kh
143
- + offs_d[:, None]
144
- )
145
- k = tl.load(
146
- K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
147
- )
148
-
149
- qk = tl.dot(q.to(k.dtype), k)
150
- if BLOCK_DPE > 0:
151
- offs_kpe = (
152
- offs_kv_loc[None, :] * stride_buf_kbs
153
- + cur_kv_head * stride_buf_kh
154
- + offs_dpe[:, None]
155
- )
156
- kpe = tl.load(
157
- K_Buffer + offs_kpe,
158
- mask=mask_n[None, :],
159
- other=0.0,
160
- )
161
- qk += tl.dot(qpe.to(kpe.dtype), kpe)
162
- qk *= sm_scale
163
-
164
- if logit_cap > 0:
165
- qk = logit_cap * tanh(qk / logit_cap)
166
-
167
137
  final_mask = mask_m[:, None] & mask_n[None, :]
168
138
  if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
169
139
  custom_mask = tl.load(
@@ -178,29 +148,77 @@ def _fwd_kernel(
178
148
  final_mask &= custom_mask
179
149
  if SLIDING_WINDOW_SIZE > 0:
180
150
  # Add mask where q_id <= kv_id + sliding_window_size
181
- window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
182
- start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
183
- )
151
+ # q_id = prefix_len + cur_m, kv_id = cur_n
152
+ window_mask = (
153
+ cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
154
+ ) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
184
155
  final_mask &= window_mask
185
- qk = tl.where(final_mask, qk, float("-inf"))
186
156
 
187
- n_e_max = tl.maximum(tl.max(qk, 1), e_max)
188
- re_scale = tl.exp(e_max - n_e_max)
189
- p = tl.exp(qk - n_e_max[:, None])
190
- deno = deno * re_scale + tl.sum(p, 1)
157
+ SKIP_TILE = False
158
+ if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0:
159
+ SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
191
160
 
192
- offs_buf_v = (
193
- offs_kv_loc[:, None] * stride_buf_vbs
194
- + cur_kv_head * stride_buf_vh
195
- + offs_dv[None, :]
196
- )
197
- v = tl.load(
198
- V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
199
- )
200
- p = p.to(v.dtype)
201
- acc = acc * re_scale[:, None] + tl.dot(p, v)
161
+ if not SKIP_TILE:
162
+ offs_kv_loc = tl.load(
163
+ kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
164
+ mask=mask_n,
165
+ other=0,
166
+ )
202
167
 
203
- e_max = n_e_max
168
+ # load k in transposed way
169
+ offs_buf_k = (
170
+ offs_kv_loc[None, :] * stride_buf_kbs
171
+ + cur_kv_head * stride_buf_kh
172
+ + offs_d[:, None]
173
+ )
174
+ k = tl.load(
175
+ K_Buffer + offs_buf_k,
176
+ mask=(mask_n[None, :]) & (mask_d[:, None]),
177
+ other=0.0,
178
+ )
179
+
180
+ qk = tl.dot(q.to(k.dtype), k)
181
+ if BLOCK_DPE > 0:
182
+ offs_kpe = (
183
+ offs_kv_loc[None, :] * stride_buf_kbs
184
+ + cur_kv_head * stride_buf_kh
185
+ + offs_dpe[:, None]
186
+ )
187
+ kpe = tl.load(
188
+ K_Buffer + offs_kpe,
189
+ mask=mask_n[None, :],
190
+ other=0.0,
191
+ )
192
+ qk += tl.dot(qpe.to(kpe.dtype), kpe)
193
+ qk *= sm_scale
194
+
195
+ if logit_cap > 0:
196
+ qk = logit_cap * tanh(qk / logit_cap)
197
+
198
+ qk = tl.where(final_mask, qk, float("-inf"))
199
+
200
+ row_max = tl.max(qk, 1)
201
+ row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
202
+ n_e_max = tl.maximum(row_max_fixed, e_max)
203
+
204
+ re_scale = tl.exp(e_max - n_e_max)
205
+ p = tl.exp(qk - n_e_max[:, None])
206
+ deno = deno * re_scale + tl.sum(p, 1)
207
+
208
+ offs_buf_v = (
209
+ offs_kv_loc[:, None] * stride_buf_vbs
210
+ + cur_kv_head * stride_buf_vh
211
+ + offs_dv[None, :]
212
+ )
213
+ v = tl.load(
214
+ V_Buffer + offs_buf_v,
215
+ mask=mask_n[:, None] & mask_dv[None, :],
216
+ other=0.0,
217
+ )
218
+ p = p.to(v.dtype)
219
+ acc = acc * re_scale[:, None] + tl.dot(p, v)
220
+
221
+ e_max = n_e_max
204
222
 
205
223
  # stage 2: compute the triangle part
206
224
 
@@ -213,35 +231,7 @@ def _fwd_kernel(
213
231
  start_n = tl.multiple_of(start_n, BLOCK_N)
214
232
  mask_n = (start_n + offs_n) < cur_block_m_end
215
233
 
216
- # load k in transposed way
217
- offs_k = (
218
- (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
219
- + cur_kv_head * stride_kh
220
- + offs_d[:, None]
221
- )
222
- k = tl.load(
223
- K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
224
- )
225
-
226
- qk = tl.dot(q, k, out_dtype=tl.float32)
227
- if BLOCK_DPE > 0:
228
- offs_kpe = (
229
- (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
230
- + cur_kv_head * stride_kh
231
- + offs_dpe[:, None]
232
- )
233
- kpe = tl.load(
234
- K_Extend + offs_kpe,
235
- mask=mask_n[None, :],
236
- other=0.0,
237
- )
238
- qk += tl.dot(qpe, kpe)
239
-
240
- qk *= sm_scale
241
-
242
- if logit_cap > 0:
243
- qk = logit_cap * tanh(qk / logit_cap)
244
-
234
+ final_mask = mask_m[:, None] & mask_n[None, :]
245
235
  if USE_CUSTOM_MASK:
246
236
  custom_mask = tl.load(
247
237
  mask_ptr
@@ -254,34 +244,84 @@ def _fwd_kernel(
254
244
  other=0,
255
245
  )
256
246
  custom_mask &= mask_m[:, None] & mask_n[None, :]
257
- qk = tl.where(custom_mask, qk, float("-inf"))
247
+ final_mask &= custom_mask
258
248
  elif IS_CAUSAL:
259
249
  mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
260
250
  start_n + offs_n[None, :]
261
251
  )
262
252
  mask_causual &= mask_m[:, None] & mask_n[None, :]
263
- qk = tl.where(mask_causual, qk, float("-inf"))
253
+ final_mask &= mask_causual
264
254
  else:
265
255
  mask_non_causal = mask_m[:, None] & mask_n[None, :]
266
- qk = tl.where(mask_non_causal, qk, float("-inf"))
256
+ final_mask &= mask_non_causal
257
+
258
+ if SLIDING_WINDOW_SIZE > 0:
259
+ # Add mask where q_id <= kv_id + sliding_window_size
260
+ window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
261
+ start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
262
+ )
263
+ final_mask &= window_mask
267
264
 
268
- n_e_max = tl.maximum(tl.max(qk, 1), e_max)
269
- re_scale = tl.exp(e_max - n_e_max)
270
- p = tl.exp(qk - n_e_max[:, None])
271
- deno = deno * re_scale + tl.sum(p, 1)
265
+ SKIP_TILE = False
266
+ if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
267
+ SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
272
268
 
273
- offs_v = (
274
- (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
275
- + cur_kv_head * stride_vh
276
- + offs_dv[None, :]
277
- )
278
- v = tl.load(
279
- V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
280
- )
281
- p = p.to(v.dtype)
282
- acc = acc * re_scale[:, None] + tl.dot(p, v)
269
+ if not SKIP_TILE:
270
+ # load k in transposed way
271
+ offs_k = (
272
+ (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
273
+ + cur_kv_head * stride_kh
274
+ + offs_d[:, None]
275
+ )
276
+ k = tl.load(
277
+ K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
278
+ )
283
279
 
284
- e_max = n_e_max
280
+ qk = tl.dot(q, k, out_dtype=tl.float32)
281
+ if BLOCK_DPE > 0:
282
+ offs_kpe = (
283
+ (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
284
+ + cur_kv_head * stride_kh
285
+ + offs_dpe[:, None]
286
+ )
287
+ kpe = tl.load(
288
+ K_Extend + offs_kpe,
289
+ mask=mask_n[None, :],
290
+ other=0.0,
291
+ )
292
+ qk += tl.dot(qpe, kpe)
293
+
294
+ qk *= sm_scale
295
+
296
+ if logit_cap > 0:
297
+ qk = logit_cap * tanh(qk / logit_cap)
298
+
299
+ qk = tl.where(final_mask, qk, float("-inf"))
300
+
301
+ row_max = tl.max(qk, 1)
302
+ row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
303
+ n_e_max = tl.maximum(row_max_fixed, e_max)
304
+
305
+ re_scale = tl.exp(e_max - n_e_max)
306
+ p = tl.exp(qk - n_e_max[:, None])
307
+ deno = deno * re_scale + tl.sum(p, 1)
308
+
309
+ offs_v = (
310
+ (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
311
+ + cur_kv_head * stride_vh
312
+ + offs_dv[None, :]
313
+ )
314
+ v = tl.load(
315
+ V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
316
+ )
317
+ p = p.to(v.dtype)
318
+ acc = acc * re_scale[:, None] + tl.dot(p, v)
319
+
320
+ e_max = n_e_max
321
+
322
+ if HAS_SINK:
323
+ cur_sink = tl.load(sink_ptr + cur_head)
324
+ deno += tl.exp(cur_sink - e_max)
285
325
 
286
326
  offs_o = (
287
327
  (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
@@ -321,6 +361,7 @@ def extend_attention_fwd(
321
361
  logit_cap=0.0,
322
362
  skip_prefix_custom_mask=True,
323
363
  sliding_window_size=-1,
364
+ sinks=None,
324
365
  ):
325
366
  """
326
367
  q_extend, k_extend, v_extend, o_extend: contiguous tensors
@@ -386,6 +427,8 @@ def extend_attention_fwd(
386
427
  # Skip custom mask for prefix part
387
428
  SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
388
429
 
430
+ HAS_SINK = sinks is not None
431
+
389
432
  grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
390
433
  num_stages = 1
391
434
 
@@ -405,6 +448,7 @@ def extend_attention_fwd(
405
448
  kv_indices,
406
449
  custom_mask,
407
450
  mask_indptr,
451
+ sinks,
408
452
  sm_scale,
409
453
  kv_group_num,
410
454
  q_extend.stride(0),
@@ -431,6 +475,7 @@ def extend_attention_fwd(
431
475
  USE_CUSTOM_MASK=USE_CUSTOM_MASK,
432
476
  IS_CAUSAL=is_causal,
433
477
  SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
478
+ HAS_SINK=HAS_SINK,
434
479
  STORE_TRANSPOSE=_is_hip,
435
480
  num_warps=num_warps,
436
481
  num_stages=num_stages,
@@ -0,0 +1,332 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support attention backend for TRTLLM MHA kernels from flashinfer.
5
+ The kernel supports sm100 only, with sliding window and attention sink features.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, Optional
10
+
11
+ import torch
12
+
13
+ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
14
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
+ from sglang.srt.utils import is_flashinfer_available
16
+
17
+ if is_flashinfer_available():
18
+ import flashinfer
19
+
20
+ if TYPE_CHECKING:
21
+ from sglang.srt.layers.radix_attention import RadixAttention
22
+ from sglang.srt.model_executor.model_runner import ModelRunner
23
+ from sglang.srt.speculative.spec_info import SpecInfo
24
+
25
+ # Constants
26
+ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
27
+
28
+ # Reuse this workspace buffer across all TRTLLM MHA wrappers
29
+ global_workspace_buffer = None
30
+
31
+
32
+ @dataclass
33
+ class TRTLLMMHAMetadata:
34
+ # Sequence lengths for the forward batch
35
+ cache_seqlens_int32: torch.Tensor = None
36
+ # Maximum sequence length for query
37
+ max_seq_len_q: int = 1
38
+ # Maximum sequence length for key
39
+ max_seq_len_k: int = 0
40
+ # Cumulative sequence lengths for `query
41
+ cu_seqlens_q: torch.Tensor = None
42
+ # Cumulative sequence lengths for key
43
+ cu_seqlens_k: torch.Tensor = None
44
+ # Page table, the index of KV Cache Tables/Blocks
45
+ page_table: torch.Tensor = None
46
+
47
+
48
+ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
49
+ """TRTLLM MHA attention kernel from flashinfer."""
50
+
51
+ def __init__(
52
+ self,
53
+ model_runner: ModelRunner,
54
+ skip_prefill: bool = False,
55
+ kv_indptr_buf: Optional[torch.Tensor] = None,
56
+ q_indptr_decode_buf: Optional[torch.Tensor] = None,
57
+ ):
58
+ super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
59
+
60
+ config = model_runner.model_config
61
+
62
+ # MHA-specific dimensions
63
+ self.max_context_len = model_runner.model_config.context_len
64
+ self.hidden_size = config.hidden_size
65
+
66
+ # Runtime parameters
67
+ self.data_type = model_runner.kv_cache_dtype
68
+ self.q_data_type = model_runner.dtype
69
+ self.page_size = model_runner.page_size
70
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
71
+ self.device = model_runner.device
72
+
73
+ # Workspace allocation
74
+ self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
75
+ # Allocate buffers
76
+ global global_workspace_buffer
77
+ if global_workspace_buffer is None:
78
+ global_workspace_buffer = torch.empty(
79
+ self.workspace_size,
80
+ dtype=torch.uint8,
81
+ device=model_runner.device,
82
+ )
83
+ self.workspace_buffer = global_workspace_buffer
84
+
85
+ # CUDA graph state
86
+ self.decode_cuda_graph_metadata = {}
87
+
88
+ # Forward metadata
89
+ self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
90
+
91
+ def init_cuda_graph_state(
92
+ self,
93
+ max_bs: int,
94
+ max_num_tokens: int,
95
+ kv_indices_buf: Optional[torch.Tensor] = None,
96
+ ):
97
+ """Initialize CUDA graph state for TRTLLM MHA."""
98
+ self.decode_cuda_graph_metadata = {
99
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
100
+ "page_table": torch.zeros(
101
+ max_bs,
102
+ (self.max_context_len + self.page_size - 1) // self.page_size,
103
+ dtype=torch.int32,
104
+ device=self.device,
105
+ ),
106
+ "strided_indices": torch.arange(
107
+ 0, self.max_context_len, self.page_size, device=self.device
108
+ ),
109
+ }
110
+
111
+ def init_forward_metadata_capture_cuda_graph(
112
+ self,
113
+ bs: int,
114
+ num_tokens: int,
115
+ req_pool_indices: torch.Tensor,
116
+ seq_lens: torch.Tensor,
117
+ encoder_lens: Optional[torch.Tensor],
118
+ forward_mode: ForwardMode,
119
+ spec_info: Optional[SpecInfo],
120
+ ):
121
+ """Initialize metadata for CUDA graph capture."""
122
+ metadata = TRTLLMMHAMetadata()
123
+
124
+ # Get sequence information
125
+ metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
126
+
127
+ # Precompute maximum sequence length
128
+ metadata.max_seq_len_k = self.max_context_len
129
+
130
+ # Precompute page table
131
+ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
132
+ self.decode_cuda_graph_metadata[bs] = metadata
133
+ self.forward_metadata = metadata
134
+
135
+ def init_forward_metadata_replay_cuda_graph(
136
+ self,
137
+ bs: int,
138
+ req_pool_indices: torch.Tensor,
139
+ seq_lens: torch.Tensor,
140
+ seq_lens_sum: int,
141
+ encoder_lens: Optional[torch.Tensor],
142
+ forward_mode: ForwardMode,
143
+ spec_info: Optional[SpecInfo],
144
+ seq_lens_cpu: Optional[torch.Tensor],
145
+ ):
146
+ """Replay CUDA graph with new inputs."""
147
+ seq_lens = seq_lens[:bs]
148
+ seq_lens_cpu = seq_lens_cpu[:bs]
149
+ req_pool_indices = req_pool_indices[:bs]
150
+ device = seq_lens.device
151
+ metadata = None
152
+
153
+ # Normal Decode
154
+ metadata = self.decode_cuda_graph_metadata[bs]
155
+ max_len = seq_lens_cpu.max().item()
156
+ max_seq_pages = (max_len + self.page_size - 1) // self.page_size
157
+ metadata.max_seq_len_k = self.max_context_len
158
+
159
+ metadata.cache_seqlens_int32.copy_(seq_lens)
160
+ page_indices = self.req_to_token[
161
+ req_pool_indices[:, None],
162
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :],
163
+ ]
164
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
165
+ self.forward_metadata = metadata
166
+
167
+ def get_cuda_graph_seq_len_fill_value(self) -> int:
168
+ """Get the fill value for sequence lengths in CUDA graph."""
169
+ return 1
170
+
171
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
172
+ """Initialize the metadata for a forward pass."""
173
+
174
+ metadata = TRTLLMMHAMetadata()
175
+ seqlens_in_batch = forward_batch.seq_lens
176
+ batch_size = forward_batch.batch_size
177
+ device = seqlens_in_batch.device
178
+
179
+ if forward_batch.forward_mode.is_decode_or_idle():
180
+ # Normal Decode
181
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
182
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
183
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
184
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
185
+ ]
186
+ else:
187
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
188
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
189
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
190
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
191
+ )
192
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
193
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
194
+ ]
195
+
196
+ if any(forward_batch.extend_prefix_lens_cpu):
197
+ extend_seq_lens = forward_batch.extend_seq_lens
198
+ metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
199
+ metadata.cu_seqlens_q = torch.nn.functional.pad(
200
+ torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
201
+ )
202
+ else:
203
+ metadata.max_seq_len_q = metadata.max_seq_len_k
204
+ metadata.cu_seqlens_q = metadata.cu_seqlens_k
205
+
206
+ # Convert the page table to a strided format
207
+ if self.page_size > 1:
208
+ self.strided_indices = torch.arange(
209
+ 0, metadata.page_table.shape[1], self.page_size, device=self.device
210
+ )
211
+ metadata.page_table = (
212
+ metadata.page_table[:, self.strided_indices] // self.page_size
213
+ )
214
+
215
+ self.forward_metadata = metadata
216
+
217
+ def forward_decode(
218
+ self,
219
+ q: torch.Tensor,
220
+ k: torch.Tensor,
221
+ v: torch.Tensor,
222
+ layer: RadixAttention,
223
+ forward_batch: ForwardBatch,
224
+ save_kv_cache: bool = True,
225
+ **kwargs,
226
+ ) -> torch.Tensor:
227
+ """Run forward for decode using TRTLLM MHA kernel."""
228
+ cache_loc = forward_batch.out_cache_loc
229
+ if save_kv_cache and k is not None:
230
+ forward_batch.token_to_kv_pool.set_kv_buffer(
231
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
232
+ )
233
+
234
+ q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
235
+ k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
236
+ # shape conversion:
237
+ # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
238
+ k_cache = k_cache.view(
239
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
240
+ ).permute(0, 2, 1, 3)
241
+ v_cache = v_cache.view(
242
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
243
+ ).permute(0, 2, 1, 3)
244
+ kv_cache = (k_cache, v_cache)
245
+
246
+ # TODO: add support for quantization
247
+ q_scale = 1.0
248
+ k_scale = (
249
+ layer.k_scale_float
250
+ if getattr(layer, "k_scale_float", None) is not None
251
+ else 1.0
252
+ )
253
+ bmm1_scale = q_scale * k_scale * layer.scaling
254
+ bmm2_scale = 1.0
255
+ # sink: additional value per head in the denominator of the softmax.
256
+ attention_sink = kwargs.get("sinks", None)
257
+
258
+ # Call TRT-LLM kernel
259
+ # raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
260
+ o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
261
+ query=q,
262
+ kv_cache=kv_cache,
263
+ workspace_buffer=self.workspace_buffer,
264
+ block_tables=self.forward_metadata.page_table,
265
+ seq_lens=self.forward_metadata.cache_seqlens_int32,
266
+ max_seq_len=self.forward_metadata.max_seq_len_k,
267
+ bmm1_scale=bmm1_scale,
268
+ bmm2_scale=bmm2_scale,
269
+ window_left=layer.sliding_window_size,
270
+ # TODO: add attention_sink operation or nvfp4 scale factor if needed
271
+ sinks=attention_sink,
272
+ )
273
+
274
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
275
+
276
+ def forward_extend(
277
+ self,
278
+ q: torch.Tensor,
279
+ k: torch.Tensor,
280
+ v: torch.Tensor,
281
+ layer: RadixAttention,
282
+ forward_batch: ForwardBatch,
283
+ save_kv_cache=True,
284
+ **kwargs,
285
+ ):
286
+ cache_loc = forward_batch.out_cache_loc
287
+ if save_kv_cache and k is not None:
288
+ forward_batch.token_to_kv_pool.set_kv_buffer(
289
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
290
+ )
291
+ q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
292
+ # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
293
+ k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
294
+ k_cache = k_cache.view(
295
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
296
+ ).permute(0, 2, 1, 3)
297
+ v_cache = v_cache.view(
298
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
299
+ ).permute(0, 2, 1, 3)
300
+ kv_cache = (k_cache, v_cache)
301
+
302
+ # sink: additional value per head in the denominator of the softmax.
303
+ attention_sink = kwargs.get("sinks", None)
304
+ # TODO: add support for quantization
305
+ q_scale = 1.0
306
+ k_scale = (
307
+ layer.k_scale_float
308
+ if getattr(layer, "k_scale_float", None) is not None
309
+ else 1.0
310
+ )
311
+ bmm1_scale = q_scale * k_scale * layer.scaling
312
+ bmm2_scale = 1.0
313
+
314
+ o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
315
+ query=q,
316
+ kv_cache=kv_cache,
317
+ workspace_buffer=self.workspace_buffer,
318
+ block_tables=self.forward_metadata.page_table,
319
+ seq_lens=self.forward_metadata.cache_seqlens_int32,
320
+ max_q_len=self.forward_metadata.max_seq_len_q,
321
+ max_kv_len=self.forward_metadata.max_seq_len_k,
322
+ bmm1_scale=bmm1_scale,
323
+ bmm2_scale=bmm2_scale,
324
+ batch_size=forward_batch.batch_size,
325
+ cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
326
+ cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
327
+ window_left=layer.sliding_window_size,
328
+ # TODO: add attention_sink operation or nvfp4 scale factor if needed
329
+ sinks=attention_sink,
330
+ )
331
+
332
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)