sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -308,7 +308,7 @@ class FlashAttentionBackend(AttentionBackend):
308
308
  ), "Sliding window and cross attention are not supported together"
309
309
 
310
310
  self.forward_metadata: FlashAttentionMetadata = None
311
- # extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
311
+ # extra metadata for handling speculative decoding topk > 1, extended draft decode and verify
312
312
  self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
313
313
  self.max_context_len = model_runner.model_config.context_len
314
314
  self.device = model_runner.device
@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend):
338
338
  """Initialize forward metadata hence all layers in the forward pass can reuse it."""
339
339
  metadata = FlashAttentionMetadata()
340
340
  seqlens_in_batch = forward_batch.seq_lens
341
- batch_size = len(seqlens_in_batch)
341
+ batch_size = forward_batch.batch_size
342
342
  device = seqlens_in_batch.device
343
343
 
344
344
  if forward_batch.forward_mode.is_decode_or_idle():
@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend):
913
913
  # Use precomputed metadata across all layers
914
914
  metadata = self.forward_metadata
915
915
  local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
916
- use_local_attention = (
917
- self.attention_chunk_size is not None and local_attn_metadata is not None
916
+ use_local_attn = (
917
+ self.attention_chunk_size is not None
918
+ and local_attn_metadata is not None
919
+ and (hasattr(layer, "use_irope") and layer.use_irope)
918
920
  )
919
921
  # We do cascade attention for Draft Decode with topk > 1
920
922
  use_cascade_attn = self.topk > 1
@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend):
970
972
  k_descale=k_descale,
971
973
  v_descale=v_descale,
972
974
  )
973
- elif use_local_attention:
975
+ elif use_local_attn:
974
976
  # Use chunked (local) attention batching for self-attention
975
977
  o = flash_attn_with_kvcache(
976
978
  q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend):
979
981
  page_table=local_attn_metadata.local_block_table,
980
982
  cache_seqlens=local_attn_metadata.local_seqused_k,
981
983
  cu_seqlens_q=local_attn_metadata.local_query_start_loc,
982
- cu_seqlens_k_new=metadata.cu_seqlens_k,
984
+ cu_seqlens_k_new=None,
983
985
  max_seqlen_q=local_attn_metadata.local_max_query_len,
984
986
  softmax_scale=layer.scaling,
985
987
  causal=True,
@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend):
1127
1129
  This creates fixed-size tensors that will be reused during CUDA graph replay
1128
1130
  to avoid memory allocations.
1129
1131
  """
1130
-
1131
1132
  # This is being used by normal decode and draft decode when topk == 1
1132
1133
  self.decode_cuda_graph_metadata = {
1133
1134
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend):
1154
1155
  ),
1155
1156
  }
1156
1157
 
1158
+ # Only allocate local attention buffers if local attention is enabled
1159
+ # This prevents OOM errors when local attention is not being used
1160
+ if self.attention_chunk_size is not None:
1161
+ # Estimate maximum sizes for local attention metadata
1162
+ max_seq_len = self.max_context_len
1163
+ page_size = self.page_size or 1
1164
+ attn_chunk_size = self.attention_chunk_size
1165
+ max_virtual_batches = max_bs * (
1166
+ (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
1167
+ )
1168
+ max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
1169
+ max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
1170
+
1171
+ self.decode_cuda_graph_local_attn_metadata = {
1172
+ "local_query_start_loc": torch.zeros(
1173
+ max_virtual_batches + 1, dtype=torch.int32, device=self.device
1174
+ ),
1175
+ "local_seqused_k": torch.zeros(
1176
+ max_virtual_batches, dtype=torch.int32, device=self.device
1177
+ ),
1178
+ "local_block_table": torch.zeros(
1179
+ max_virtual_batches,
1180
+ max_blocks_per_seq * max_pages_per_block,
1181
+ dtype=torch.int32,
1182
+ device=self.device,
1183
+ ),
1184
+ }
1185
+
1157
1186
  # This is used by draft decode's first half of metadata when topk > 1
1158
1187
  if self.topk > 1:
1159
1188
  self.draft_decode_metadata_topk_normal = {
@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend):
1405
1434
  )
1406
1435
  self.decode_cuda_graph_metadata[bs] = metadata
1407
1436
 
1437
+ if self.attention_chunk_size is not None:
1438
+ metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1439
+ local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
1440
+ "local_query_start_loc"
1441
+ ],
1442
+ local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
1443
+ "local_seqused_k"
1444
+ ],
1445
+ local_block_table=self.decode_cuda_graph_local_attn_metadata[
1446
+ "local_block_table"
1447
+ ],
1448
+ local_max_query_len=1,
1449
+ local_max_seq_len=1,
1450
+ )
1451
+
1408
1452
  elif forward_mode.is_target_verify():
1409
1453
  if self.topk <= 1:
1410
1454
  metadata.cache_seqlens_int32 = self.target_verify_metadata[
@@ -1525,12 +1569,9 @@ class FlashAttentionBackend(AttentionBackend):
1525
1569
  metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1526
1570
  self.speculative_step_id + 1
1527
1571
  )
1528
- metadata.cu_seqlens_k.copy_(
1529
- torch.nn.functional.pad(
1530
- torch.cumsum(
1531
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1532
- ),
1533
- (1, 0),
1572
+ metadata.cu_seqlens_k[1:].copy_(
1573
+ torch.cumsum(
1574
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1534
1575
  )
1535
1576
  )
1536
1577
 
@@ -1554,12 +1595,9 @@ class FlashAttentionBackend(AttentionBackend):
1554
1595
  # metadata.max_seq_len_q = self.topk, already set in capture
1555
1596
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1556
1597
  # metadata.cu_seqlens_q already set in capture
1557
- metadata.cu_seqlens_k.copy_(
1558
- torch.nn.functional.pad(
1559
- torch.cumsum(
1560
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1561
- ),
1562
- (1, 0),
1598
+ metadata.cu_seqlens_k[1:].copy_(
1599
+ torch.cumsum(
1600
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1563
1601
  )
1564
1602
  )
1565
1603
 
@@ -1578,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend):
1578
1616
  metadata_expand.page_table[: cache_loc.shape[0]].copy_(
1579
1617
  cache_loc[:, :decode_length].contiguous().to(torch.int32)
1580
1618
  )
1581
- # TODO: we need to test this part for llama 4 eagle case
1582
- self._init_local_attn_metadata(metadata, device)
1619
+ # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
1583
1620
  else:
1584
1621
  metadata = self.decode_cuda_graph_metadata[bs]
1585
1622
  # Normal Decode
@@ -1587,8 +1624,9 @@ class FlashAttentionBackend(AttentionBackend):
1587
1624
  metadata.max_seq_len_k = max_len
1588
1625
 
1589
1626
  metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
1590
- metadata.cu_seqlens_k = torch.nn.functional.pad(
1591
- torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
1627
+ # Optimize cumulative sequence length calculation
1628
+ metadata.cu_seqlens_k[1:].copy_(
1629
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
1592
1630
  )
1593
1631
 
1594
1632
  max_seq_pages = (
@@ -1604,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend):
1604
1642
  metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1605
1643
  metadata.page_table[:, max_seq_pages:].fill_(0)
1606
1644
 
1607
- self._init_local_attn_metadata(metadata, device)
1645
+ self._update_local_attn_metadata_for_replay(metadata, bs)
1608
1646
  elif forward_mode.is_target_verify():
1609
1647
  if self.topk <= 1:
1610
1648
  metadata = self.target_verify_metadata[bs]
@@ -1615,13 +1653,8 @@ class FlashAttentionBackend(AttentionBackend):
1615
1653
  metadata.max_seq_len_k = (
1616
1654
  seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1617
1655
  )
1618
- metadata.cu_seqlens_k.copy_(
1619
- torch.nn.functional.pad(
1620
- torch.cumsum(
1621
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1622
- ),
1623
- (1, 0),
1624
- )
1656
+ metadata.cu_seqlens_k[1:].copy_(
1657
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
1625
1658
  )
1626
1659
  max_seq_pages = (
1627
1660
  metadata.max_seq_len_k + self.page_size - 1
@@ -1640,13 +1673,8 @@ class FlashAttentionBackend(AttentionBackend):
1640
1673
  # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
1641
1674
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1642
1675
  # metadata.cu_seqlens_q already set in capture
1643
- metadata.cu_seqlens_k.copy_(
1644
- torch.nn.functional.pad(
1645
- torch.cumsum(
1646
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1647
- ),
1648
- (1, 0),
1649
- )
1676
+ metadata.cu_seqlens_k[1:].copy_(
1677
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
1650
1678
  )
1651
1679
  page_table = self.req_to_token[
1652
1680
  req_pool_indices, : metadata.max_seq_len_k
@@ -1704,14 +1732,11 @@ class FlashAttentionBackend(AttentionBackend):
1704
1732
  metadata_expand.cache_seqlens_int32.copy_(
1705
1733
  mask.sum(dim=1).to(torch.int32)
1706
1734
  )
1707
- metadata_expand.cu_seqlens_k.copy_(
1708
- torch.nn.functional.pad(
1709
- torch.cumsum(
1710
- metadata_expand.cache_seqlens_int32,
1711
- dim=0,
1712
- dtype=torch.int32,
1713
- ),
1714
- (1, 0),
1735
+ metadata_expand.cu_seqlens_k[1:].copy_(
1736
+ torch.cumsum(
1737
+ metadata_expand.cache_seqlens_int32,
1738
+ dim=0,
1739
+ dtype=torch.int32,
1715
1740
  )
1716
1741
  )
1717
1742
  metadata_expand.max_seq_len_k = (
@@ -1722,11 +1747,8 @@ class FlashAttentionBackend(AttentionBackend):
1722
1747
  # Only support encoder size 1 for now
1723
1748
  metadata.encoder_max_seq_len_k = encoder_lens[0]
1724
1749
  metadata.encoder_lens_int32.copy_(encoder_lens[:1])
1725
- metadata.encoder_cu_seqlens_k.copy_(
1726
- torch.nn.functional.pad(
1727
- torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
1728
- (1, 0),
1729
- )
1750
+ metadata.encoder_cu_seqlens_k[1:].copy_(
1751
+ torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)
1730
1752
  )
1731
1753
 
1732
1754
  metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
@@ -1776,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend):
1776
1798
  page_table,
1777
1799
  self.page_size,
1778
1800
  )
1801
+
1779
1802
  local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1780
1803
  local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
1781
1804
  local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
@@ -1785,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend):
1785
1808
  )
1786
1809
  metadata.local_attn_metadata = local_metadata
1787
1810
 
1811
+ def _update_local_attn_metadata_for_replay(
1812
+ self, metadata: FlashAttentionMetadata, bs: int
1813
+ ):
1814
+ """Update preallocated local attention metadata in-place before CUDA graph replay."""
1815
+ if self.attention_chunk_size is None:
1816
+ return
1817
+
1818
+ # Access preallocated buffers
1819
+ local_q_buf = self.decode_cuda_graph_local_attn_metadata[
1820
+ "local_query_start_loc"
1821
+ ]
1822
+ local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
1823
+ local_block_buf = self.decode_cuda_graph_local_attn_metadata[
1824
+ "local_block_table"
1825
+ ]
1826
+ cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]
1827
+
1828
+ # Create a modified version for local attention that only processes the last token
1829
+ # This mimics the normal decode pattern
1830
+ cu_seqlens_q = torch.arange(
1831
+ bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
1832
+ )
1833
+ seqlens = metadata.cache_seqlens_int32[:bs]
1834
+ # Slice the page_table to match the batch size and actual sequence length
1835
+ # This serves three important purposes:
1836
+ # 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
1837
+ # 2. Limits the sequence length to prevent processing padding tokens or garbage values
1838
+ # 3. Prevents zeros in the block table which can cause garbage output during replay
1839
+ #
1840
+ # Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
1841
+ # beyond the actual sequence length, leading to incorrect attention calculations
1842
+ max_seq_len = int(seqlens.max().item())
1843
+ sliced_page_table = metadata.page_table[:bs, :max_seq_len]
1844
+
1845
+ cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
1846
+ seqlens_np = seqlens.cpu().numpy()
1847
+ (
1848
+ seqlens_q_local_np,
1849
+ cu_seqlens_q_local_np,
1850
+ seqlens_k_local_np,
1851
+ block_table_local,
1852
+ ) = make_local_attention_virtual_batches(
1853
+ self.attention_chunk_size,
1854
+ cu_seqlens_q_np,
1855
+ seqlens_np,
1856
+ sliced_page_table,
1857
+ self.page_size,
1858
+ )
1859
+
1860
+ # Convert back to tensors
1861
+ device = local_q_buf.device
1862
+ cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
1863
+ seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
1864
+ block_table_local = block_table_local.to(device)
1865
+ # Get sizes
1866
+ q_len = cu_seqlens_q_local.shape[0]
1867
+ k_len = seqlens_k_local.shape[0]
1868
+ b0, b1 = block_table_local.shape
1869
+
1870
+ # In-place updates into preallocated tensors and zero out the unused space
1871
+ local_q_buf[:q_len].copy_(cu_seqlens_q_local)
1872
+ local_q_buf[q_len:].fill_(0)
1873
+ local_k_buf[:k_len].copy_(seqlens_k_local)
1874
+ local_k_buf[k_len:].fill_(0)
1875
+ local_block_buf[:b0, :b1].copy_(block_table_local)
1876
+ local_block_buf[b0:, :].fill_(0)
1877
+ local_block_buf[:b0, b1:].fill_(0)
1878
+
1879
+ if metadata.local_attn_metadata is not None:
1880
+ lam = metadata.local_attn_metadata
1881
+ lam.local_max_query_len = int(seqlens_q_local_np.max())
1882
+ lam.local_max_seq_len = int(seqlens_k_local_np.max())
1883
+
1788
1884
 
1789
1885
  class FlashAttentionMultiStepBackend:
1790
1886
 
@@ -16,8 +16,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
16
16
  import torch
17
17
 
18
18
  if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
19
- import torch._dynamo
19
+ import logging
20
20
 
21
+ torch._logging.set_logs(dynamo=logging.ERROR)
21
22
  torch._dynamo.config.suppress_errors = True
22
23
 
23
24
  from sglang.global_config import global_config
@@ -107,6 +108,7 @@ class FlashInferAttnBackend(AttentionBackend):
107
108
  if (
108
109
  "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
109
110
  or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
111
+ or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
110
112
  ):
111
113
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
112
114
 
@@ -416,6 +418,7 @@ class FlashInferAttnBackend(AttentionBackend):
416
418
 
417
419
  logits_soft_cap = layer.logit_cap
418
420
 
421
+ q = q.contiguous()
419
422
  if not self.forward_metadata.use_ragged:
420
423
  if k is not None:
421
424
  assert v is not None
@@ -425,7 +428,7 @@ class FlashInferAttnBackend(AttentionBackend):
425
428
  )
426
429
 
427
430
  o = prefill_wrapper_paged.forward(
428
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
431
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
429
432
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
430
433
  causal=not layer.is_cross_attention,
431
434
  sm_scale=layer.scaling,
@@ -435,20 +438,27 @@ class FlashInferAttnBackend(AttentionBackend):
435
438
  v_scale=layer.v_scale,
436
439
  )
437
440
  else:
438
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
439
- q.view(-1, layer.tp_q_head_num, layer.head_dim),
440
- k.view(-1, layer.tp_k_head_num, layer.head_dim),
441
- v.view(-1, layer.tp_v_head_num, layer.head_dim),
442
- causal=True,
443
- sm_scale=layer.scaling,
444
- logits_soft_cap=logits_soft_cap,
445
- )
446
-
447
441
  if self.forward_metadata.extend_no_prefix:
448
- o = o1
442
+ o = self.prefill_wrapper_ragged.forward(
443
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
444
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
445
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
446
+ causal=True,
447
+ sm_scale=layer.scaling,
448
+ logits_soft_cap=logits_soft_cap,
449
+ )
450
+
449
451
  else:
452
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
453
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
454
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
455
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
456
+ causal=True,
457
+ sm_scale=layer.scaling,
458
+ logits_soft_cap=logits_soft_cap,
459
+ )
450
460
  o2, s2 = prefill_wrapper_paged.forward_return_lse(
451
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
461
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
452
462
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
453
463
  causal=False,
454
464
  sm_scale=layer.scaling,
@@ -18,8 +18,9 @@ import torch
18
18
  import triton
19
19
 
20
20
  if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
21
- import torch._dynamo
21
+ import logging
22
22
 
23
+ torch._logging.set_logs(dynamo=logging.ERROR)
23
24
  torch._dynamo.config.suppress_errors = True
24
25
 
25
26
  from sglang.global_config import global_config
@@ -338,23 +339,39 @@ class FlashInferMLAAttnBackend(AttentionBackend):
338
339
  layer: RadixAttention,
339
340
  forward_batch: ForwardBatch,
340
341
  save_kv_cache: bool = True,
342
+ q_rope: Optional[torch.Tensor] = None,
343
+ k_rope: Optional[torch.Tensor] = None,
341
344
  ):
342
345
 
343
346
  cache_loc = forward_batch.out_cache_loc
344
347
  logits_soft_cap = layer.logit_cap
345
348
  prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
346
- qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
347
349
  k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
348
350
 
349
351
  # Save kv cache
350
352
  if save_kv_cache and k is not None:
351
353
  assert v is not None
352
354
  if save_kv_cache:
353
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
355
+ if k_rope is not None:
356
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
357
+ layer, cache_loc, k, k_rope
358
+ )
359
+ else:
360
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
361
+ if q_rope is not None:
362
+ q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
363
+ q_rope = q_rope.view(
364
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
365
+ )
354
366
 
355
367
  if self.forward_metadata.use_ragged:
356
368
  # ragged prefill
357
- o, _ = self.prefill_wrapper_ragged.forward_return_lse(
369
+ if q_rope is not None:
370
+ q = torch.cat([q, q_rope], dim=-1)
371
+ qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
372
+ if k_rope is not None:
373
+ k = torch.cat([k, k_rope], dim=-1)
374
+ o = self.prefill_wrapper_ragged.forward(
358
375
  qall,
359
376
  k.view(-1, layer.tp_k_head_num, layer.head_dim),
360
377
  v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
@@ -364,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
364
381
  )
365
382
  else:
366
383
  # mla paged prefill
384
+ if q_rope is None:
385
+ qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
386
+ q, q_rope = (
387
+ qall[:, :, : layer.v_head_dim],
388
+ qall[:, :, layer.v_head_dim :],
389
+ )
390
+ o = q.new_empty(q.shape)
367
391
  o = prefill_wrapper_paged.run(
368
- qall[:, :, : layer.v_head_dim],
369
- qall[:, :, layer.v_head_dim :],
392
+ q,
393
+ q_rope,
370
394
  k_buf[:, :, : layer.v_head_dim],
371
395
  k_buf[:, :, layer.v_head_dim :],
396
+ out=o,
372
397
  )
373
398
 
374
399
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -381,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
381
406
  layer: RadixAttention,
382
407
  forward_batch: ForwardBatch,
383
408
  save_kv_cache: bool = True,
409
+ # For multi-head latent attention
410
+ q_rope: Optional[torch.Tensor] = None,
411
+ k_rope: Optional[torch.Tensor] = None,
384
412
  ):
385
413
  decode_wrapper = self.forward_metadata.decode_wrapper
386
414
  cache_loc = forward_batch.out_cache_loc
@@ -388,23 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend):
388
416
  if k is not None:
389
417
  assert v is not None
390
418
  if save_kv_cache:
391
- forward_batch.token_to_kv_pool.set_kv_buffer(
392
- layer,
393
- cache_loc,
394
- k,
395
- v,
396
- )
419
+ if k_rope is not None:
420
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
421
+ layer,
422
+ cache_loc,
423
+ k,
424
+ k_rope,
425
+ )
426
+ else:
427
+ forward_batch.token_to_kv_pool.set_kv_buffer(
428
+ layer,
429
+ cache_loc,
430
+ k,
431
+ v,
432
+ )
397
433
 
398
434
  # Reshape inputs
399
- reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
435
+ if q_rope is not None:
436
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
437
+ q_rope = q_rope.view(
438
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
439
+ )
440
+ else:
441
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
442
+ q_nope = reshaped_q[:, :, : layer.v_head_dim]
443
+ q_rope = reshaped_q[:, :, layer.v_head_dim :]
444
+
400
445
  k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
401
446
 
447
+ o = q_nope.new_empty(q_nope.shape)
402
448
  # Direct call to run without the wrapper
403
449
  o = decode_wrapper.run(
404
- reshaped_q[:, :, : layer.v_head_dim],
405
- reshaped_q[:, :, layer.v_head_dim :],
450
+ q_nope,
451
+ q_rope,
406
452
  k_buffer[:, :, : layer.v_head_dim],
407
453
  k_buffer[:, :, layer.v_head_dim :],
454
+ out=o,
408
455
  )
409
456
 
410
457
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -0,0 +1,46 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from sgl_kernel import merge_state_v2
5
+
6
+ from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
7
+ from sglang.srt.utils import is_cuda
8
+
9
+ _is_cuda = is_cuda()
10
+
11
+
12
+ # Automatically fallback to the Triton kernel in some cases
13
+ # (e.g., for AMD GPUs, when the head dimension is not a multiple
14
+ # of 4 or 8, and in FP8 precision)
15
+ def _supported_dtypes(o: torch.Tensor) -> bool:
16
+ return o.dtype in [torch.float32, torch.half, torch.bfloat16]
17
+
18
+
19
+ def _supported_headdim(o: torch.Tensor) -> bool:
20
+ headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
21
+ if o.dtype == torch.float32:
22
+ return headdim % 4 == 0
23
+ return headdim % 8 == 0
24
+
25
+
26
+ def merge_state(
27
+ prefix_output: torch.Tensor,
28
+ prefix_lse: torch.Tensor,
29
+ suffix_output: torch.Tensor,
30
+ suffix_lse: torch.Tensor,
31
+ output: Optional[torch.Tensor] = None,
32
+ output_lse: Optional[torch.Tensor] = None,
33
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
34
+ if (
35
+ _is_cuda
36
+ and _supported_dtypes(prefix_output)
37
+ and _supported_headdim(prefix_output)
38
+ ):
39
+ return merge_state_v2(
40
+ prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
41
+ )
42
+ else:
43
+ # Fallback to Triton kernel
44
+ return merge_state_triton(
45
+ prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
46
+ )
@@ -919,7 +919,7 @@ def _fwd_kernel(
919
919
 
920
920
  e_max = n_e_max
921
921
 
922
- # stage 2: compute the trianlge part
922
+ # stage 2: compute the triangle part
923
923
 
924
924
  cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
925
925
  for start_n in range(0, cur_block_m_end, BLOCK_N):