sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -206,6 +206,8 @@ def _quantize_k_cache_fast_kernel(
206
206
 
207
207
 
208
208
  if __name__ == "__main__":
209
+ import dequant_k_cache
210
+
209
211
  for num_blocks, block_size in [
210
212
  (1, 1),
211
213
  (10, 64),
@@ -217,21 +219,9 @@ if __name__ == "__main__":
217
219
  dtype=torch.bfloat16,
218
220
  device="cuda",
219
221
  )
220
- # temp debug
221
- # input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
222
222
 
223
223
  ref_quant = _quantize_k_cache_slow(input_k_cache)
224
224
  actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
225
- # print(f"{input_k_cache=}")
226
- # print(f"{ref_quant=}")
227
- # print(f"{actual_quant=}")
228
- # print(f"{ref_quant == actual_quant=}")
229
- # print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
230
- # print(f"{ref_quant.view(torch.bfloat16)=}")
231
- # print(f"{actual_quant.view(torch.bfloat16)=}")
232
- # assert torch.all(ref_quant == actual_quant)
233
-
234
- import dequant_k_cache
235
225
 
236
226
  ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
237
227
  ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
@@ -252,4 +242,46 @@ if __name__ == "__main__":
252
242
  ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
253
243
  )
254
244
 
245
+ # test dequant_k_cache_paged
246
+ page_table_1 = torch.arange(
247
+ num_blocks * block_size, dtype=torch.int32, device="cuda"
248
+ )
249
+ actual_dequant_paged = dequant_k_cache.dequantize_k_cache_paged(
250
+ actual_quant, page_table_1
251
+ ).reshape(actual_actual_dequant.shape)
252
+ print(f"{torch.mean(actual_actual_dequant - actual_dequant_paged)=}")
253
+ torch.testing.assert_close(
254
+ ref_ref_dequant, actual_dequant_paged, atol=0.2, rtol=0.2
255
+ )
256
+
255
257
  print("Passed")
258
+ print("Do benchmark...")
259
+
260
+ for num_blocks, block_size in [
261
+ (1, 64),
262
+ (64, 64),
263
+ (128, 64),
264
+ (256, 64),
265
+ (512, 64),
266
+ (1024, 64),
267
+ (2048, 64),
268
+ ]:
269
+ dim_nope_and_rope = 512 + 64
270
+
271
+ input_k_cache = torch.randn(
272
+ (num_blocks, block_size, 1, dim_nope_and_rope),
273
+ dtype=torch.bfloat16,
274
+ device="cuda",
275
+ )
276
+
277
+ actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
278
+
279
+ page_table_1 = torch.arange(
280
+ num_blocks * block_size, dtype=torch.int32, device="cuda"
281
+ )
282
+
283
+ def run_ans():
284
+ return dequant_k_cache.dequantize_k_cache_paged(actual_quant, page_table_1)
285
+
286
+ ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore
287
+ print(f"seq_kv: {num_blocks * block_size}, time: {ans_time * 1e6: 4.0f} us")
@@ -103,7 +103,7 @@ def transform_index_page_table_decode_ref(
103
103
  result = torch.empty_like(topk_indices, dtype=torch.int32)
104
104
  assert result.shape == topk_indices.shape
105
105
  torch.gather(
106
- page_table,
106
+ page_table.to(result.dtype),
107
107
  dim=1,
108
108
  index=topk_indices.clamp(min=0),
109
109
  out=result,
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from enum import IntEnum, auto
4
5
  from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
5
6
 
6
7
  import torch
7
8
 
8
9
  from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
9
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
+ from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache_paged
10
12
  from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
11
13
  from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
12
14
  from sglang.srt.layers.attention.nsa.transform_index import (
@@ -98,11 +100,27 @@ class NSAMetadata:
98
100
  nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
99
101
 
100
102
  flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
103
+ # The sum of sequence lengths for key, prefill only
104
+ seq_lens_sum: Optional[int] = None
105
+ # The flattened 1D page table with shape (seq_lens_sum,), prefill only
106
+ # this table is always with page_size = 1
107
+ page_table_1_flattened: Optional[torch.Tensor] = None
108
+ # The offset of topk indices in ragged kv, prefill only
109
+ # shape: (seq_lens_sum,)
110
+ topk_indices_offset: Optional[torch.Tensor] = None
111
+
112
+
113
+ class TopkTransformMethod(IntEnum):
114
+ # Transform topk indices to indices to the page table (page_size = 1)
115
+ PAGED = auto()
116
+ # Transform topk indices to indices to ragged kv (non-paged)
117
+ RAGGED = auto()
101
118
 
102
119
 
103
120
  @dataclass(frozen=True)
104
121
  class NSAIndexerMetadata(BaseIndexerMetadata):
105
122
  attn_metadata: NSAMetadata
123
+ topk_transform_method: TopkTransformMethod
106
124
 
107
125
  def get_seqlens_int32(self) -> torch.Tensor:
108
126
  return self.attn_metadata.cache_seqlens_int32
@@ -118,23 +136,36 @@ class NSAIndexerMetadata(BaseIndexerMetadata):
118
136
  logits: torch.Tensor,
119
137
  topk: int,
120
138
  ) -> torch.Tensor:
121
- from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
139
+ from sgl_kernel import (
140
+ fast_topk_transform_fused,
141
+ fast_topk_transform_ragged_fused,
142
+ fast_topk_v2,
143
+ )
122
144
 
123
145
  if not NSA_FUSE_TOPK:
124
146
  return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
125
-
126
- # NOTE(dark): if fused, we return a transformed page table directly
127
- return fast_topk_transform_fused(
128
- score=logits,
129
- lengths=self.get_seqlens_expanded(),
130
- page_table_size_1=self.attn_metadata.page_table_1,
131
- cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
132
- topk=topk,
133
- )
147
+ elif self.topk_transform_method == TopkTransformMethod.PAGED:
148
+ # NOTE(dark): if fused, we return a transformed page table directly
149
+ return fast_topk_transform_fused(
150
+ score=logits,
151
+ lengths=self.get_seqlens_expanded(),
152
+ page_table_size_1=self.attn_metadata.page_table_1,
153
+ cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
154
+ topk=topk,
155
+ )
156
+ elif self.topk_transform_method == TopkTransformMethod.RAGGED:
157
+ return fast_topk_transform_ragged_fused(
158
+ score=logits,
159
+ lengths=self.get_seqlens_expanded(),
160
+ topk_indices_offset=self.attn_metadata.topk_indices_offset,
161
+ topk=topk,
162
+ )
163
+ else:
164
+ assert False, f"Unsupported {self.topk_transform_method = }"
134
165
 
135
166
 
136
167
  def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
137
- assert seqlens.dtype == torch.int32 and seqlens.is_cuda
168
+ assert seqlens.dtype == torch.int32
138
169
  return torch.nn.functional.pad(
139
170
  torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
140
171
  )
@@ -181,6 +212,7 @@ class NativeSparseAttnBackend(AttentionBackend):
181
212
  global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
182
213
  NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
183
214
  NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
215
+ self.enable_auto_select_prefill_impl = NSA_PREFILL_IMPL == "flashmla_auto"
184
216
 
185
217
  self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
186
218
 
@@ -231,10 +263,16 @@ class NativeSparseAttnBackend(AttentionBackend):
231
263
  cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
232
264
  assert forward_batch.seq_lens_cpu is not None
233
265
  max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)
266
+ # [b, max_seqlen_k]
234
267
  page_table = forward_batch.req_to_token_pool.req_to_token[
235
268
  forward_batch.req_pool_indices, :max_seqlen_k
236
269
  ]
237
270
 
271
+ page_table_1_flattened = None
272
+ topk_indices_offset = None
273
+ self.set_nsa_prefill_impl(forward_batch)
274
+ topk_transform_method = self.get_topk_transform_method()
275
+
238
276
  if forward_batch.forward_mode.is_decode_or_idle():
239
277
  extend_seq_lens_cpu = [1] * batch_size
240
278
  max_seqlen_q = 1
@@ -295,6 +333,7 @@ class NativeSparseAttnBackend(AttentionBackend):
295
333
  else:
296
334
  max_seqlen_q = max_seqlen_k
297
335
  cu_seqlens_q = cu_seqlens_k
336
+
298
337
  seqlens_expanded = torch.cat(
299
338
  [
300
339
  torch.arange(
@@ -310,6 +349,24 @@ class NativeSparseAttnBackend(AttentionBackend):
310
349
  )
311
350
  ]
312
351
  )
352
+
353
+ if topk_transform_method == TopkTransformMethod.RAGGED:
354
+ page_table_1_flattened = torch.cat(
355
+ [
356
+ page_table[i, :kv_len]
357
+ for i, kv_len in enumerate(
358
+ forward_batch.seq_lens_cpu.tolist(),
359
+ )
360
+ ]
361
+ )
362
+ assert (
363
+ page_table_1_flattened.shape[0] == forward_batch.seq_lens_sum
364
+ ), f"{page_table_1_flattened.shape[0] = } must be the same as {forward_batch.seq_lens_sum = }"
365
+
366
+ topk_indices_offset = torch.repeat_interleave(
367
+ cu_seqlens_k[:-1],
368
+ forward_batch.extend_seq_lens,
369
+ )
313
370
  else:
314
371
  assert False, f"Unsupported {forward_batch.forward_mode = }"
315
372
 
@@ -328,7 +385,9 @@ class NativeSparseAttnBackend(AttentionBackend):
328
385
  max_seq_len_k=max_seqlen_k,
329
386
  cu_seqlens_q=cu_seqlens_q,
330
387
  cu_seqlens_k=cu_seqlens_k,
388
+ seq_lens_sum=forward_batch.seq_lens_sum,
331
389
  page_table_1=page_table,
390
+ page_table_1_flattened=page_table_1_flattened,
332
391
  flashmla_metadata=(
333
392
  self._compute_flashmla_metadata(
334
393
  cache_seqlens=nsa_cache_seqlens_int32,
@@ -344,6 +403,7 @@ class NativeSparseAttnBackend(AttentionBackend):
344
403
  nsa_extend_seq_lens_list=extend_seq_lens_cpu,
345
404
  real_page_table=self._transform_table_1_to_real(page_table),
346
405
  nsa_max_seqlen_q=1,
406
+ topk_indices_offset=topk_indices_offset,
347
407
  )
348
408
 
349
409
  self.forward_metadata = metadata
@@ -396,6 +456,8 @@ class NativeSparseAttnBackend(AttentionBackend):
396
456
  forward_mode: ForwardMode,
397
457
  spec_info: Optional[SpecInput],
398
458
  ):
459
+ self.set_nsa_prefill_impl(forward_batch=None)
460
+
399
461
  """Initialize forward metadata for capturing CUDA graph."""
400
462
  if forward_mode.is_decode_or_idle():
401
463
  # Normal Decode
@@ -586,6 +648,8 @@ class NativeSparseAttnBackend(AttentionBackend):
586
648
  """Initialize forward metadata for replaying CUDA graph."""
587
649
  assert seq_lens_cpu is not None
588
650
 
651
+ self.set_nsa_prefill_impl(forward_batch=None)
652
+
589
653
  seq_lens = seq_lens[:bs]
590
654
  seq_lens_cpu = seq_lens_cpu[:bs]
591
655
  req_pool_indices = req_pool_indices[:bs]
@@ -780,17 +844,31 @@ class NativeSparseAttnBackend(AttentionBackend):
780
844
  q_rope = q_all[:, :, layer.v_head_dim :]
781
845
 
782
846
  # NOTE(dark): here, we use page size = 1
783
-
847
+ topk_transform_method = self.get_topk_transform_method()
784
848
  if NSA_FUSE_TOPK:
785
849
  page_table_1 = topk_indices
786
850
  else:
787
- assert metadata.nsa_extend_seq_lens_list is not None
788
- page_table_1 = transform_index_page_table_prefill(
789
- page_table=metadata.page_table_1,
790
- topk_indices=topk_indices,
791
- extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
792
- page_size=1,
793
- )
851
+ if topk_transform_method == TopkTransformMethod.RAGGED:
852
+ topk_indices_offset = metadata.topk_indices_offset
853
+ assert topk_indices_offset is not None
854
+ mask = topk_indices != -1
855
+ topk_indices_offset = (
856
+ topk_indices_offset.unsqueeze(1)
857
+ if topk_indices_offset.ndim == 1
858
+ else topk_indices_offset
859
+ )
860
+ topk_indices = torch.where(
861
+ mask, topk_indices + topk_indices_offset, topk_indices
862
+ )
863
+ elif topk_transform_method == TopkTransformMethod.PAGED:
864
+ assert metadata.nsa_extend_seq_lens_list is not None
865
+ page_table_1 = transform_index_page_table_prefill(
866
+ page_table=metadata.page_table_1,
867
+ topk_indices=topk_indices,
868
+ extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
869
+ page_size=1,
870
+ )
871
+
794
872
  if NSA_PREFILL_IMPL == "tilelang":
795
873
  if q_rope is not None:
796
874
  q_all = torch.cat([q_nope, q_rope], dim=-1)
@@ -804,6 +882,22 @@ class NativeSparseAttnBackend(AttentionBackend):
804
882
  elif NSA_PREFILL_IMPL == "flashmla_sparse":
805
883
  if q_rope is not None:
806
884
  q_all = torch.cat([q_nope, q_rope], dim=-1)
885
+
886
+ # NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 has no effect here,
887
+ # because the flashmla_sparse kernel doesn't support fp8 compute
888
+ if topk_transform_method == TopkTransformMethod.RAGGED:
889
+ if any(forward_batch.extend_prefix_lens_cpu):
890
+ page_table_1_flattened = (
891
+ self.forward_metadata.page_table_1_flattened
892
+ )
893
+ assert page_table_1_flattened is not None
894
+ kv_cache = dequantize_k_cache_paged(
895
+ kv_cache, page_table_1_flattened
896
+ )
897
+ else:
898
+ kv_cache = torch.cat([k, k_rope], dim=-1)
899
+ page_table_1 = topk_indices
900
+
807
901
  return self._forward_flashmla_sparse(
808
902
  q_all=q_all,
809
903
  kv_cache=kv_cache,
@@ -1004,7 +1098,7 @@ class NativeSparseAttnBackend(AttentionBackend):
1004
1098
  page_table_1: torch.Tensor,
1005
1099
  sm_scale: float,
1006
1100
  ) -> torch.Tensor:
1007
- from flash_mla import flash_mla_sparse_fwd
1101
+ from sgl_kernel.flash_mla import flash_mla_sparse_fwd
1008
1102
 
1009
1103
  o, _, _ = flash_mla_sparse_fwd(
1010
1104
  q=q_all,
@@ -1025,7 +1119,7 @@ class NativeSparseAttnBackend(AttentionBackend):
1025
1119
  metadata: NSAMetadata,
1026
1120
  page_table_1,
1027
1121
  ) -> torch.Tensor:
1028
- from flash_mla import flash_mla_with_kvcache
1122
+ from sgl_kernel.flash_mla import flash_mla_with_kvcache
1029
1123
 
1030
1124
  cache_seqlens = metadata.nsa_cache_seqlens_int32
1031
1125
 
@@ -1121,13 +1215,53 @@ class NativeSparseAttnBackend(AttentionBackend):
1121
1215
  """Get the fill value for sequence length in CUDA graph."""
1122
1216
  return 1
1123
1217
 
1218
+ def set_nsa_prefill_impl(self, forward_batch: Optional[ForwardBatch] = None) -> str:
1219
+ from sglang.srt.utils import is_blackwell
1220
+
1221
+ global NSA_PREFILL_IMPL
1222
+ if self.enable_auto_select_prefill_impl:
1223
+ if self.nsa_kv_cache_store_fp8:
1224
+ if (
1225
+ is_blackwell()
1226
+ and forward_batch is not None
1227
+ and forward_batch.forward_mode == ForwardMode.EXTEND
1228
+ ):
1229
+ total_kv_tokens = forward_batch.seq_lens_sum
1230
+ total_q_tokens = forward_batch.extend_num_tokens
1231
+ # Heuristic based on benchmarking flashmla_kv vs flashmla_sparse + dequantize_k_cache_paged
1232
+ if total_kv_tokens < total_q_tokens * 512:
1233
+ NSA_PREFILL_IMPL = "flashmla_sparse"
1234
+ return
1235
+ NSA_PREFILL_IMPL = "flashmla_kv"
1236
+ else:
1237
+ # bf16 kv cache
1238
+ NSA_PREFILL_IMPL = "flashmla_sparse"
1239
+
1240
+ def get_topk_transform_method(self) -> TopkTransformMethod:
1241
+ """
1242
+ NSA_FUSE_TOPK controls whether to fuse the topk transform into the topk kernel.
1243
+ This method is used to select the topk transform method which can be fused or unfused.
1244
+ """
1245
+ if (
1246
+ # disable for MTP
1247
+ self.nsa_kv_cache_store_fp8
1248
+ and NSA_PREFILL_IMPL == "flashmla_sparse"
1249
+ ):
1250
+ topk_transform_method = TopkTransformMethod.RAGGED
1251
+ else:
1252
+ topk_transform_method = TopkTransformMethod.PAGED
1253
+ return topk_transform_method
1254
+
1124
1255
  def get_indexer_metadata(
1125
1256
  self, layer_id: int, forward_batch: ForwardBatch
1126
1257
  ) -> NSAIndexerMetadata:
1127
- return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
1258
+ return NSAIndexerMetadata(
1259
+ attn_metadata=self.forward_metadata,
1260
+ topk_transform_method=self.get_topk_transform_method(),
1261
+ )
1128
1262
 
1129
1263
  def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
1130
- from flash_mla import get_mla_metadata
1264
+ from sgl_kernel.flash_mla import get_mla_metadata
1131
1265
 
1132
1266
  flashmla_metadata, num_splits = get_mla_metadata(
1133
1267
  cache_seqlens=cache_seqlens,
@@ -92,7 +92,10 @@ class TritonAttnBackend(AttentionBackend):
92
92
  self.num_kv_head = model_runner.model_config.get_num_kv_heads(
93
93
  get_attention_tp_size()
94
94
  )
95
- if model_runner.hybrid_gdn_config is not None:
95
+ if (
96
+ model_runner.hybrid_gdn_config is not None
97
+ or model_runner.kimi_linear_config is not None
98
+ ):
96
99
  # For hybrid linear models, layer_id = 0 may not be full attention
97
100
  self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
98
101
  else:
@@ -488,10 +488,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
488
488
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
489
489
  ]
490
490
 
491
- if (
492
- any(forward_batch.extend_prefix_lens_cpu)
493
- or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
494
- ):
491
+ if any(
492
+ forward_batch.extend_prefix_lens_cpu
493
+ ) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
495
494
  extend_seq_lens = forward_batch.extend_seq_lens
496
495
  metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
497
496
  metadata.cu_seqlens_q = torch.nn.functional.pad(
@@ -529,6 +528,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
529
528
  layer, cache_loc, k, v, layer.k_scale, layer.v_scale
530
529
  )
531
530
 
531
+ if self.data_type == torch.float8_e4m3fn:
532
+ q = q.to(torch.float8_e4m3fn)
532
533
  q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
533
534
  k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
534
535
  # shape conversion:
@@ -567,6 +568,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
567
568
  window_left=layer.sliding_window_size,
568
569
  # TODO: add attention_sink operation or nvfp4 scale factor if needed
569
570
  sinks=attention_sink,
571
+ out_dtype=self.q_data_type, # model_runner.dtype
570
572
  )
571
573
 
572
574
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -586,6 +588,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
586
588
  forward_batch.token_to_kv_pool.set_kv_buffer(
587
589
  layer, cache_loc, k, v, layer.k_scale, layer.v_scale
588
590
  )
591
+
592
+ if self.data_type == torch.float8_e4m3fn:
593
+ q = q.to(torch.float8_e4m3fn)
589
594
  q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
590
595
  # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
591
596
  k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
@@ -625,6 +630,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
625
630
  window_left=layer.sliding_window_size,
626
631
  # TODO: add attention_sink operation or nvfp4 scale factor if needed
627
632
  sinks=attention_sink,
633
+ out_dtype=self.q_data_type, # model_runner.dtype
628
634
  )
629
635
 
630
636
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -944,8 +944,16 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
944
944
  metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
945
945
  )
946
946
  else:
947
- seq_lens = forward_batch.seq_lens.to(torch.int32)
948
- max_seq_len = metadata.max_seq_len_k
947
+ # forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
948
+ # To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
949
+ # This will ensure queries align with kvs correctly when calling
950
+ # flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
951
+ seq_lens = (
952
+ forward_batch.seq_lens
953
+ - metadata.seq_lens_q
954
+ + metadata.max_seq_len_q
955
+ ).to(torch.int32)
956
+ max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
949
957
  # Check if we're in CUDA graph mode (buffers are pre-allocated)
950
958
  if self.padded_q_buffer is not None:
951
959
  # Use pre-allocated buffer for CUDA graph compatibility
@@ -15,7 +15,7 @@
15
15
  from dataclasses import dataclass
16
16
  from enum import Enum, auto
17
17
  from functools import partial
18
- from typing import Dict, Optional
18
+ from typing import Dict, List, Optional
19
19
 
20
20
  import torch
21
21
 
@@ -216,6 +216,28 @@ class LayerCommunicator:
216
216
  get_global_server_args().speculative_algorithm
217
217
  )
218
218
 
219
+ def prepare_attn_and_capture_last_layer_outputs(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ residual: torch.Tensor,
223
+ forward_batch: ForwardBatch,
224
+ captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
225
+ ):
226
+ hidden_states, residual = self.prepare_attn(
227
+ hidden_states, residual, forward_batch
228
+ )
229
+ if captured_last_layer_outputs is not None:
230
+ gathered_last_layer_output = self._communicate_simple_fn(
231
+ hidden_states=residual,
232
+ forward_batch=forward_batch,
233
+ context=self._context,
234
+ )
235
+ if gathered_last_layer_output is residual:
236
+ # Clone to avoid modifying the original residual by Custom RMSNorm inplace operation
237
+ gathered_last_layer_output = residual.clone()
238
+ captured_last_layer_outputs.append(gathered_last_layer_output)
239
+ return hidden_states, residual
240
+
219
241
  def prepare_attn(
220
242
  self,
221
243
  hidden_states: torch.Tensor,
@@ -20,7 +20,12 @@ import torch
20
20
  import torch.nn as nn
21
21
  from packaging.version import Version
22
22
 
23
+ from sglang.srt.batch_invariant_ops import (
24
+ is_batch_invariant_mode_enabled,
25
+ rms_norm_batch_invariant,
26
+ )
23
27
  from sglang.srt.custom_op import CustomOp
28
+ from sglang.srt.server_args import get_global_server_args
24
29
  from sglang.srt.utils import (
25
30
  cpu_has_amx_support,
26
31
  get_bool_env_var,
@@ -90,8 +95,6 @@ class RMSNorm(CustomOp):
90
95
  )
91
96
  if _use_aiter:
92
97
  self._forward_method = self.forward_aiter
93
- if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
94
- self._forward_method = self.forward_native
95
98
 
96
99
  def forward_cuda(
97
100
  self,
@@ -100,6 +103,17 @@ class RMSNorm(CustomOp):
100
103
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
101
104
  if self.variance_size_override is not None:
102
105
  return self.forward_native(x, residual)
106
+ if is_batch_invariant_mode_enabled():
107
+ if (
108
+ residual is not None
109
+ or get_global_server_args().rl_on_policy_target == "fsdp"
110
+ ):
111
+ return self.forward_native(x, residual)
112
+ return rms_norm_batch_invariant(
113
+ x,
114
+ self.weight.data,
115
+ self.variance_epsilon,
116
+ )
103
117
  if residual is not None:
104
118
  fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
105
119
  return x, residual
@@ -38,7 +38,6 @@ from sglang.srt.layers.dp_attention import (
38
38
  get_dp_device,
39
39
  get_dp_dtype,
40
40
  get_dp_hidden_size,
41
- get_local_attention_dp_size,
42
41
  )
43
42
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
44
43
  from sglang.srt.model_executor.forward_batch_info import (
@@ -47,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import (
47
46
  ForwardMode,
48
47
  )
49
48
  from sglang.srt.server_args import get_global_server_args
50
- from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
49
+ from sglang.srt.utils import is_npu, use_intel_amx_backend
51
50
 
52
51
  logger = logging.getLogger(__name__)
53
52
 
@@ -135,10 +134,7 @@ class LogitsMetadata:
135
134
  @classmethod
136
135
  def from_forward_batch(cls, forward_batch: ForwardBatch):
137
136
  if (
138
- (
139
- forward_batch.forward_mode.is_extend()
140
- or forward_batch.forward_mode.is_split_prefill()
141
- )
137
+ forward_batch.forward_mode.is_extend()
142
138
  and forward_batch.return_logprob
143
139
  and not forward_batch.forward_mode.is_target_verify()
144
140
  ):
@@ -252,10 +248,6 @@ class LogitsProcessor(nn.Module):
252
248
  ):
253
249
  self.final_logit_softcapping = None
254
250
 
255
- self.debug_tensor_dump_output_folder = (
256
- get_global_server_args().debug_tensor_dump_output_folder
257
- )
258
-
259
251
  def compute_logprobs_for_multi_item_scoring(
260
252
  self,
261
253
  input_ids,
@@ -389,8 +381,8 @@ class LogitsProcessor(nn.Module):
389
381
  input_logprob_indices = None
390
382
  elif (
391
383
  logits_metadata.forward_mode.is_extend()
392
- or logits_metadata.forward_mode.is_split_prefill()
393
- ) and not logits_metadata.extend_return_logprob:
384
+ and not logits_metadata.extend_return_logprob
385
+ ):
394
386
  # Prefill without input logprobs.
395
387
  if logits_metadata.padded_static_len < 0:
396
388
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
@@ -463,14 +455,6 @@ class LogitsProcessor(nn.Module):
463
455
  logits[sample_indices] if sample_indices is not None else logits
464
456
  )
465
457
 
466
- if self.debug_tensor_dump_output_folder:
467
- assert (
468
- not self.do_tensor_parallel_all_gather
469
- or get_local_attention_dp_size() == 1
470
- ), "dp attention + sharded lm_head doesn't support full logits"
471
- full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
472
- dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
473
-
474
458
  hidden_states_to_store: Optional[torch.Tensor] = None
475
459
  if logits_metadata.capture_hidden_mode.need_capture():
476
460
  if logits_metadata.capture_hidden_mode.is_full():
@@ -131,23 +131,6 @@ class DeepEPMoE(FusedMoE):
131
131
  )
132
132
  # the last one is invalid rank_id
133
133
  self.expert_mask[:-1] = 1
134
- elif not _is_npu:
135
- self.w13_weight_fp8 = (
136
- self.w13_weight,
137
- (
138
- self.w13_weight_scale_inv
139
- if self.use_block_quant or self.use_w4afp8
140
- else self.w13_weight_scale
141
- ),
142
- )
143
- self.w2_weight_fp8 = (
144
- self.w2_weight,
145
- (
146
- self.w2_weight_scale_inv
147
- if self.use_block_quant or self.use_w4afp8
148
- else self.w2_weight_scale
149
- ),
150
- )
151
134
 
152
135
  def forward(
153
136
  self,
@@ -235,7 +218,6 @@ class DeepEPMoE(FusedMoE):
235
218
  hidden_states=output,
236
219
  topk_ids=dispatch_output.topk_ids,
237
220
  topk_weights=dispatch_output.topk_weights,
238
- overlap_args=down_gemm_overlap_args,
239
221
  )
240
222
 
241
223
  def combine(