sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend):
338
338
  """Initialize forward metadata hence all layers in the forward pass can reuse it."""
339
339
  metadata = FlashAttentionMetadata()
340
340
  seqlens_in_batch = forward_batch.seq_lens
341
- batch_size = len(seqlens_in_batch)
341
+ batch_size = forward_batch.batch_size
342
342
  device = seqlens_in_batch.device
343
343
 
344
344
  if forward_batch.forward_mode.is_decode_or_idle():
@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend):
913
913
  # Use precomputed metadata across all layers
914
914
  metadata = self.forward_metadata
915
915
  local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
916
- use_local_attention = (
917
- self.attention_chunk_size is not None and local_attn_metadata is not None
916
+ use_local_attn = (
917
+ self.attention_chunk_size is not None
918
+ and local_attn_metadata is not None
919
+ and (hasattr(layer, "use_irope") and layer.use_irope)
918
920
  )
919
921
  # We do cascade attention for Draft Decode with topk > 1
920
922
  use_cascade_attn = self.topk > 1
@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend):
970
972
  k_descale=k_descale,
971
973
  v_descale=v_descale,
972
974
  )
973
- elif use_local_attention:
975
+ elif use_local_attn:
974
976
  # Use chunked (local) attention batching for self-attention
975
977
  o = flash_attn_with_kvcache(
976
978
  q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend):
979
981
  page_table=local_attn_metadata.local_block_table,
980
982
  cache_seqlens=local_attn_metadata.local_seqused_k,
981
983
  cu_seqlens_q=local_attn_metadata.local_query_start_loc,
982
- cu_seqlens_k_new=metadata.cu_seqlens_k,
984
+ cu_seqlens_k_new=None,
983
985
  max_seqlen_q=local_attn_metadata.local_max_query_len,
984
986
  softmax_scale=layer.scaling,
985
987
  causal=True,
@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend):
1127
1129
  This creates fixed-size tensors that will be reused during CUDA graph replay
1128
1130
  to avoid memory allocations.
1129
1131
  """
1130
-
1131
1132
  # This is being used by normal decode and draft decode when topk == 1
1132
1133
  self.decode_cuda_graph_metadata = {
1133
1134
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend):
1154
1155
  ),
1155
1156
  }
1156
1157
 
1158
+ # Only allocate local attention buffers if local attention is enabled
1159
+ # This prevents OOM errors when local attention is not being used
1160
+ if self.attention_chunk_size is not None:
1161
+ # Estimate maximum sizes for local attention metadata
1162
+ max_seq_len = self.max_context_len
1163
+ page_size = self.page_size or 1
1164
+ attn_chunk_size = self.attention_chunk_size
1165
+ max_virtual_batches = max_bs * (
1166
+ (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
1167
+ )
1168
+ max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
1169
+ max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
1170
+
1171
+ self.decode_cuda_graph_local_attn_metadata = {
1172
+ "local_query_start_loc": torch.zeros(
1173
+ max_virtual_batches + 1, dtype=torch.int32, device=self.device
1174
+ ),
1175
+ "local_seqused_k": torch.zeros(
1176
+ max_virtual_batches, dtype=torch.int32, device=self.device
1177
+ ),
1178
+ "local_block_table": torch.zeros(
1179
+ max_virtual_batches,
1180
+ max_blocks_per_seq * max_pages_per_block,
1181
+ dtype=torch.int32,
1182
+ device=self.device,
1183
+ ),
1184
+ }
1185
+
1157
1186
  # This is used by draft decode's first half of metadata when topk > 1
1158
1187
  if self.topk > 1:
1159
1188
  self.draft_decode_metadata_topk_normal = {
@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend):
1405
1434
  )
1406
1435
  self.decode_cuda_graph_metadata[bs] = metadata
1407
1436
 
1437
+ if self.attention_chunk_size is not None:
1438
+ metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1439
+ local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
1440
+ "local_query_start_loc"
1441
+ ],
1442
+ local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
1443
+ "local_seqused_k"
1444
+ ],
1445
+ local_block_table=self.decode_cuda_graph_local_attn_metadata[
1446
+ "local_block_table"
1447
+ ],
1448
+ local_max_query_len=1,
1449
+ local_max_seq_len=1,
1450
+ )
1451
+
1408
1452
  elif forward_mode.is_target_verify():
1409
1453
  if self.topk <= 1:
1410
1454
  metadata.cache_seqlens_int32 = self.target_verify_metadata[
@@ -1525,12 +1569,9 @@ class FlashAttentionBackend(AttentionBackend):
1525
1569
  metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1526
1570
  self.speculative_step_id + 1
1527
1571
  )
1528
- metadata.cu_seqlens_k.copy_(
1529
- torch.nn.functional.pad(
1530
- torch.cumsum(
1531
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1532
- ),
1533
- (1, 0),
1572
+ metadata.cu_seqlens_k[1:].copy_(
1573
+ torch.cumsum(
1574
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1534
1575
  )
1535
1576
  )
1536
1577
 
@@ -1554,12 +1595,9 @@ class FlashAttentionBackend(AttentionBackend):
1554
1595
  # metadata.max_seq_len_q = self.topk, already set in capture
1555
1596
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1556
1597
  # metadata.cu_seqlens_q already set in capture
1557
- metadata.cu_seqlens_k.copy_(
1558
- torch.nn.functional.pad(
1559
- torch.cumsum(
1560
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1561
- ),
1562
- (1, 0),
1598
+ metadata.cu_seqlens_k[1:].copy_(
1599
+ torch.cumsum(
1600
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1563
1601
  )
1564
1602
  )
1565
1603
 
@@ -1578,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend):
1578
1616
  metadata_expand.page_table[: cache_loc.shape[0]].copy_(
1579
1617
  cache_loc[:, :decode_length].contiguous().to(torch.int32)
1580
1618
  )
1581
- # TODO: we need to test this part for llama 4 eagle case
1582
- self._init_local_attn_metadata(metadata, device)
1619
+ # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
1583
1620
  else:
1584
1621
  metadata = self.decode_cuda_graph_metadata[bs]
1585
1622
  # Normal Decode
@@ -1587,8 +1624,9 @@ class FlashAttentionBackend(AttentionBackend):
1587
1624
  metadata.max_seq_len_k = max_len
1588
1625
 
1589
1626
  metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
1590
- metadata.cu_seqlens_k = torch.nn.functional.pad(
1591
- torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
1627
+ # Optimize cumulative sequence length calculation
1628
+ metadata.cu_seqlens_k[1:].copy_(
1629
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
1592
1630
  )
1593
1631
 
1594
1632
  max_seq_pages = (
@@ -1604,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend):
1604
1642
  metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1605
1643
  metadata.page_table[:, max_seq_pages:].fill_(0)
1606
1644
 
1607
- self._init_local_attn_metadata(metadata, device)
1645
+ self._update_local_attn_metadata_for_replay(metadata, bs)
1608
1646
  elif forward_mode.is_target_verify():
1609
1647
  if self.topk <= 1:
1610
1648
  metadata = self.target_verify_metadata[bs]
@@ -1615,13 +1653,8 @@ class FlashAttentionBackend(AttentionBackend):
1615
1653
  metadata.max_seq_len_k = (
1616
1654
  seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1617
1655
  )
1618
- metadata.cu_seqlens_k.copy_(
1619
- torch.nn.functional.pad(
1620
- torch.cumsum(
1621
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1622
- ),
1623
- (1, 0),
1624
- )
1656
+ metadata.cu_seqlens_k[1:].copy_(
1657
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
1625
1658
  )
1626
1659
  max_seq_pages = (
1627
1660
  metadata.max_seq_len_k + self.page_size - 1
@@ -1640,13 +1673,8 @@ class FlashAttentionBackend(AttentionBackend):
1640
1673
  # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
1641
1674
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1642
1675
  # metadata.cu_seqlens_q already set in capture
1643
- metadata.cu_seqlens_k.copy_(
1644
- torch.nn.functional.pad(
1645
- torch.cumsum(
1646
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1647
- ),
1648
- (1, 0),
1649
- )
1676
+ metadata.cu_seqlens_k[1:].copy_(
1677
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
1650
1678
  )
1651
1679
  page_table = self.req_to_token[
1652
1680
  req_pool_indices, : metadata.max_seq_len_k
@@ -1704,14 +1732,11 @@ class FlashAttentionBackend(AttentionBackend):
1704
1732
  metadata_expand.cache_seqlens_int32.copy_(
1705
1733
  mask.sum(dim=1).to(torch.int32)
1706
1734
  )
1707
- metadata_expand.cu_seqlens_k.copy_(
1708
- torch.nn.functional.pad(
1709
- torch.cumsum(
1710
- metadata_expand.cache_seqlens_int32,
1711
- dim=0,
1712
- dtype=torch.int32,
1713
- ),
1714
- (1, 0),
1735
+ metadata_expand.cu_seqlens_k[1:].copy_(
1736
+ torch.cumsum(
1737
+ metadata_expand.cache_seqlens_int32,
1738
+ dim=0,
1739
+ dtype=torch.int32,
1715
1740
  )
1716
1741
  )
1717
1742
  metadata_expand.max_seq_len_k = (
@@ -1722,11 +1747,8 @@ class FlashAttentionBackend(AttentionBackend):
1722
1747
  # Only support encoder size 1 for now
1723
1748
  metadata.encoder_max_seq_len_k = encoder_lens[0]
1724
1749
  metadata.encoder_lens_int32.copy_(encoder_lens[:1])
1725
- metadata.encoder_cu_seqlens_k.copy_(
1726
- torch.nn.functional.pad(
1727
- torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
1728
- (1, 0),
1729
- )
1750
+ metadata.encoder_cu_seqlens_k[1:].copy_(
1751
+ torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)
1730
1752
  )
1731
1753
 
1732
1754
  metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
@@ -1776,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend):
1776
1798
  page_table,
1777
1799
  self.page_size,
1778
1800
  )
1801
+
1779
1802
  local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1780
1803
  local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
1781
1804
  local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
@@ -1785,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend):
1785
1808
  )
1786
1809
  metadata.local_attn_metadata = local_metadata
1787
1810
 
1811
+ def _update_local_attn_metadata_for_replay(
1812
+ self, metadata: FlashAttentionMetadata, bs: int
1813
+ ):
1814
+ """Update preallocated local attention metadata in-place before CUDA graph replay."""
1815
+ if self.attention_chunk_size is None:
1816
+ return
1817
+
1818
+ # Access preallocated buffers
1819
+ local_q_buf = self.decode_cuda_graph_local_attn_metadata[
1820
+ "local_query_start_loc"
1821
+ ]
1822
+ local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
1823
+ local_block_buf = self.decode_cuda_graph_local_attn_metadata[
1824
+ "local_block_table"
1825
+ ]
1826
+ cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]
1827
+
1828
+ # Create a modified version for local attention that only processes the last token
1829
+ # This mimics the normal decode pattern
1830
+ cu_seqlens_q = torch.arange(
1831
+ bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
1832
+ )
1833
+ seqlens = metadata.cache_seqlens_int32[:bs]
1834
+ # Slice the page_table to match the batch size and actual sequence length
1835
+ # This serves three important purposes:
1836
+ # 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
1837
+ # 2. Limits the sequence length to prevent processing padding tokens or garbage values
1838
+ # 3. Prevents zeros in the block table which can cause garbage output during replay
1839
+ #
1840
+ # Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
1841
+ # beyond the actual sequence length, leading to incorrect attention calculations
1842
+ max_seq_len = int(seqlens.max().item())
1843
+ sliced_page_table = metadata.page_table[:bs, :max_seq_len]
1844
+
1845
+ cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
1846
+ seqlens_np = seqlens.cpu().numpy()
1847
+ (
1848
+ seqlens_q_local_np,
1849
+ cu_seqlens_q_local_np,
1850
+ seqlens_k_local_np,
1851
+ block_table_local,
1852
+ ) = make_local_attention_virtual_batches(
1853
+ self.attention_chunk_size,
1854
+ cu_seqlens_q_np,
1855
+ seqlens_np,
1856
+ sliced_page_table,
1857
+ self.page_size,
1858
+ )
1859
+
1860
+ # Convert back to tensors
1861
+ device = local_q_buf.device
1862
+ cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
1863
+ seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
1864
+ block_table_local = block_table_local.to(device)
1865
+ # Get sizes
1866
+ q_len = cu_seqlens_q_local.shape[0]
1867
+ k_len = seqlens_k_local.shape[0]
1868
+ b0, b1 = block_table_local.shape
1869
+
1870
+ # In-place updates into preallocated tensors and zero out the unused space
1871
+ local_q_buf[:q_len].copy_(cu_seqlens_q_local)
1872
+ local_q_buf[q_len:].fill_(0)
1873
+ local_k_buf[:k_len].copy_(seqlens_k_local)
1874
+ local_k_buf[k_len:].fill_(0)
1875
+ local_block_buf[:b0, :b1].copy_(block_table_local)
1876
+ local_block_buf[b0:, :].fill_(0)
1877
+ local_block_buf[:b0, b1:].fill_(0)
1878
+
1879
+ if metadata.local_attn_metadata is not None:
1880
+ lam = metadata.local_attn_metadata
1881
+ lam.local_max_query_len = int(seqlens_q_local_np.max())
1882
+ lam.local_max_seq_len = int(seqlens_k_local_np.max())
1883
+
1788
1884
 
1789
1885
  class FlashAttentionMultiStepBackend:
1790
1886
 
@@ -15,6 +15,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
15
15
 
16
16
  import torch
17
17
 
18
+ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
19
+ import logging
20
+
21
+ torch._logging.set_logs(dynamo=logging.ERROR)
22
+ torch._dynamo.config.suppress_errors = True
23
+
18
24
  from sglang.global_config import global_config
19
25
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
26
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
@@ -82,8 +88,6 @@ class FlashInferAttnBackend(AttentionBackend):
82
88
  self.max_context_len = model_runner.model_config.context_len
83
89
  self.skip_prefill = skip_prefill
84
90
  self.is_multimodal = model_runner.model_config.is_multimodal
85
- self.kv_cache_dtype = model_runner.kv_cache_dtype
86
- self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
87
91
 
88
92
  assert not (
89
93
  model_runner.sliding_window_size is not None
@@ -104,6 +108,7 @@ class FlashInferAttnBackend(AttentionBackend):
104
108
  if (
105
109
  "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
106
110
  or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
111
+ or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
107
112
  ):
108
113
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
109
114
 
@@ -268,6 +273,12 @@ class FlashInferAttnBackend(AttentionBackend):
268
273
  cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
269
274
  ]
270
275
 
276
+ # Ensure tensors are properly allocated
277
+ for i in range(self.num_wrappers):
278
+ # Force allocation by performing a small operation
279
+ if len(self.cuda_graph_kv_indices[i]) > 0:
280
+ self.cuda_graph_kv_indices[i][0] = 0
281
+
271
282
  if not self.skip_prefill:
272
283
  self.cuda_graph_custom_mask = torch.zeros(
273
284
  (max_bs * self.max_context_len),
@@ -396,8 +407,6 @@ class FlashInferAttnBackend(AttentionBackend):
396
407
  forward_batch: ForwardBatch,
397
408
  save_kv_cache=True,
398
409
  ):
399
- k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
400
- v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
401
410
  prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
402
411
  self._get_wrapper_idx(layer)
403
412
  ]
@@ -409,39 +418,47 @@ class FlashInferAttnBackend(AttentionBackend):
409
418
 
410
419
  logits_soft_cap = layer.logit_cap
411
420
 
421
+ q = q.contiguous()
412
422
  if not self.forward_metadata.use_ragged:
413
423
  if k is not None:
414
424
  assert v is not None
415
425
  if save_kv_cache:
416
426
  forward_batch.token_to_kv_pool.set_kv_buffer(
417
- layer, cache_loc, k, v, k_scale, v_scale
427
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
418
428
  )
419
429
 
420
430
  o = prefill_wrapper_paged.forward(
421
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
431
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
422
432
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
423
433
  causal=not layer.is_cross_attention,
424
434
  sm_scale=layer.scaling,
425
435
  window_left=layer.sliding_window_size,
426
436
  logits_soft_cap=logits_soft_cap,
427
- k_scale=k_scale,
428
- v_scale=v_scale,
437
+ k_scale=layer.k_scale,
438
+ v_scale=layer.v_scale,
429
439
  )
430
440
  else:
431
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
432
- q.view(-1, layer.tp_q_head_num, layer.head_dim),
433
- k.view(-1, layer.tp_k_head_num, layer.head_dim),
434
- v.view(-1, layer.tp_v_head_num, layer.head_dim),
435
- causal=True,
436
- sm_scale=layer.scaling,
437
- logits_soft_cap=logits_soft_cap,
438
- )
439
-
440
441
  if self.forward_metadata.extend_no_prefix:
441
- o = o1
442
+ o = self.prefill_wrapper_ragged.forward(
443
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
444
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
445
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
446
+ causal=True,
447
+ sm_scale=layer.scaling,
448
+ logits_soft_cap=logits_soft_cap,
449
+ )
450
+
442
451
  else:
452
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
453
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
454
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
455
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
456
+ causal=True,
457
+ sm_scale=layer.scaling,
458
+ logits_soft_cap=logits_soft_cap,
459
+ )
443
460
  o2, s2 = prefill_wrapper_paged.forward_return_lse(
444
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
461
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
445
462
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
446
463
  causal=False,
447
464
  sm_scale=layer.scaling,
@@ -452,7 +469,7 @@ class FlashInferAttnBackend(AttentionBackend):
452
469
 
453
470
  if save_kv_cache:
454
471
  forward_batch.token_to_kv_pool.set_kv_buffer(
455
- layer, cache_loc, k, v, k_scale, v_scale
472
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
456
473
  )
457
474
 
458
475
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -466,8 +483,6 @@ class FlashInferAttnBackend(AttentionBackend):
466
483
  forward_batch: ForwardBatch,
467
484
  save_kv_cache=True,
468
485
  ):
469
- k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
470
- v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
471
486
  decode_wrapper = self.forward_metadata.decode_wrappers[
472
487
  self._get_wrapper_idx(layer)
473
488
  ]
@@ -481,16 +496,17 @@ class FlashInferAttnBackend(AttentionBackend):
481
496
  assert v is not None
482
497
  if save_kv_cache:
483
498
  forward_batch.token_to_kv_pool.set_kv_buffer(
484
- layer, cache_loc, k, v, k_scale, v_scale
499
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
485
500
  )
486
501
 
502
+ # Call the wrapped function
487
503
  o = decode_wrapper.forward(
488
504
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
489
505
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
490
506
  sm_scale=layer.scaling,
491
507
  logits_soft_cap=layer.logit_cap,
492
- k_scale=k_scale,
493
- v_scale=v_scale,
508
+ k_scale=layer.k_scale,
509
+ v_scale=layer.v_scale,
494
510
  )
495
511
 
496
512
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -1146,8 +1162,9 @@ def fast_decode_plan(
1146
1162
  pos_encoding_mode: str = "NONE",
1147
1163
  window_left: int = -1,
1148
1164
  logits_soft_cap: Optional[float] = None,
1149
- data_type: Union[str, torch.dtype] = "float16",
1150
1165
  q_data_type: Optional[Union[str, torch.dtype]] = None,
1166
+ kv_data_type: Optional[Union[str, torch.dtype]] = None,
1167
+ data_type: Optional[Union[str, torch.dtype]] = None,
1151
1168
  sm_scale: Optional[float] = None,
1152
1169
  rope_scale: Optional[float] = None,
1153
1170
  rope_theta: Optional[float] = None,
@@ -1163,6 +1180,18 @@ def fast_decode_plan(
1163
1180
  if logits_soft_cap is None:
1164
1181
  logits_soft_cap = 0.0
1165
1182
 
1183
+ # Handle data types consistently
1184
+ if data_type is not None:
1185
+ if q_data_type is None:
1186
+ q_data_type = data_type
1187
+ if kv_data_type is None:
1188
+ kv_data_type = data_type
1189
+ elif q_data_type is None:
1190
+ q_data_type = "float16"
1191
+
1192
+ if kv_data_type is None:
1193
+ kv_data_type = q_data_type
1194
+
1166
1195
  if self.use_tensor_cores:
1167
1196
  qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
1168
1197
 
@@ -1178,36 +1207,33 @@ def fast_decode_plan(
1178
1207
  raise ValueError(
1179
1208
  "The size of indices should be less than or equal to the allocated buffer"
1180
1209
  )
1181
- # Skip these copies because we directly write to them during prepartion
1182
- # self._paged_kv_indptr_buf.copy_(indptr)
1183
- # self._paged_kv_indices_buf[: len(indices)] = indices
1184
- # self._paged_kv_last_page_len_buf.copy_(last_page_len)
1185
1210
  else:
1186
1211
  self._paged_kv_indptr_buf = indptr
1187
1212
  self._paged_kv_indices_buf = indices
1188
1213
  self._paged_kv_last_page_len_buf = last_page_len
1189
- self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
1190
-
1191
- # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
1192
- if not q_data_type:
1193
- q_data_type = data_type
1194
-
1195
- if not hasattr(self, "empty_q_data"):
1196
- self.empty_q_data = torch.empty(
1197
- 0,
1198
- dtype=(
1199
- getattr(torch, q_data_type)
1200
- if isinstance(q_data_type, str)
1201
- else q_data_type
1202
- ),
1203
- )
1204
- self.empty_kv_cache = torch.empty(
1205
- 0,
1206
- dtype=(
1207
- getattr(torch, data_type) if isinstance(data_type, str) else data_type
1208
- ),
1209
- )
1210
- self.last_page_len = torch.ones(32768, dtype=torch.int32)
1214
+ if self.use_tensor_cores:
1215
+ self._qo_indptr_buf = qo_indptr_host.to(
1216
+ self.device, non_blocking=non_blocking
1217
+ )
1218
+
1219
+ # Create empty tensors for dtype info if needed
1220
+ empty_q_data = torch.empty(
1221
+ 0,
1222
+ dtype=(
1223
+ getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
1224
+ ),
1225
+ device=self.device,
1226
+ )
1227
+
1228
+ empty_kv_cache = torch.empty(
1229
+ 0,
1230
+ dtype=(
1231
+ getattr(torch, kv_data_type)
1232
+ if isinstance(kv_data_type, str)
1233
+ else kv_data_type
1234
+ ),
1235
+ device=self.device,
1236
+ )
1211
1237
 
1212
1238
  indptr_host = (
1213
1239
  global_override_indptr_cpu
@@ -1215,48 +1241,57 @@ def fast_decode_plan(
1215
1241
  else indptr.cpu()
1216
1242
  )
1217
1243
 
1218
- if self.use_tensor_cores:
1219
- kv_lens_arr_host = get_seq_lens(
1220
- indptr_host, self.last_page_len[:batch_size], page_size
1221
- )
1222
-
1223
- self._plan_info = self._cached_module.plan(
1224
- self._float_workspace_buffer,
1225
- self._int_workspace_buffer,
1226
- self._pin_memory_int_workspace_buffer,
1227
- qo_indptr_host,
1228
- indptr_host,
1229
- kv_lens_arr_host,
1230
- batch_size, # total_num_rows
1231
- batch_size,
1232
- num_qo_heads,
1233
- num_kv_heads,
1234
- page_size,
1235
- self.is_cuda_graph_enabled,
1236
- head_dim,
1237
- head_dim,
1238
- False, # causal
1239
- torch.cuda.current_stream().cuda_stream,
1240
- )
1241
- else:
1242
- self._plan_info = self._cached_module.plan(
1243
- self._float_workspace_buffer,
1244
- self._int_workspace_buffer,
1245
- self._pin_memory_int_workspace_buffer,
1246
- indptr_host,
1247
- batch_size,
1248
- num_qo_heads,
1249
- num_kv_heads,
1250
- page_size,
1251
- self.is_cuda_graph_enabled,
1252
- window_left,
1253
- logits_soft_cap,
1254
- head_dim,
1255
- head_dim,
1256
- self.empty_q_data,
1257
- self.empty_kv_cache,
1258
- torch.cuda.current_stream().cuda_stream,
1259
- )
1244
+ with torch.cuda.device(self.device):
1245
+
1246
+ if self.use_tensor_cores:
1247
+ # ALSO convert last_page_len to CPU
1248
+ last_page_len_host = last_page_len.cpu()
1249
+
1250
+ kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
1251
+
1252
+ try:
1253
+ # Make sure we pass exactly 15 arguments for tensor core version
1254
+ self._plan_info = self._cached_module.plan(
1255
+ self._float_workspace_buffer,
1256
+ self._int_workspace_buffer,
1257
+ self._pin_memory_int_workspace_buffer,
1258
+ qo_indptr_host,
1259
+ indptr_host,
1260
+ kv_lens_arr_host,
1261
+ batch_size, # total_num_rows
1262
+ batch_size,
1263
+ num_qo_heads,
1264
+ num_kv_heads,
1265
+ page_size,
1266
+ self.is_cuda_graph_enabled,
1267
+ head_dim,
1268
+ head_dim,
1269
+ False, # causal
1270
+ )
1271
+ except Exception as e:
1272
+ raise RuntimeError(f"Error in standard plan: {e}")
1273
+ else:
1274
+ try:
1275
+ # Make sure we pass exactly 15 arguments for standard version
1276
+ self._plan_info = self._cached_module.plan(
1277
+ self._float_workspace_buffer,
1278
+ self._int_workspace_buffer,
1279
+ self._pin_memory_int_workspace_buffer,
1280
+ indptr_host,
1281
+ batch_size,
1282
+ num_qo_heads,
1283
+ num_kv_heads,
1284
+ page_size,
1285
+ self.is_cuda_graph_enabled,
1286
+ window_left,
1287
+ logits_soft_cap,
1288
+ head_dim,
1289
+ head_dim,
1290
+ empty_q_data,
1291
+ empty_kv_cache,
1292
+ )
1293
+ except Exception as e:
1294
+ raise RuntimeError(f"Error in standard plan: {e}")
1260
1295
 
1261
1296
  self._pos_encoding_mode = pos_encoding_mode
1262
1297
  self._window_left = window_left