sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -657,12 +657,16 @@ class FlashAttentionBackend(AttentionBackend):
657
657
  )
658
658
  k_descale, v_descale = None, None
659
659
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
660
- # has corresponding quantization method so that layer.k_scale is not None
661
- if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
662
- descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
663
- k_descale = layer.k_scale.expand(descale_shape)
664
- v_descale = layer.v_scale.expand(descale_shape)
660
+ # has corresponding quantization method so that layer.k_scale is not None,
661
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
662
+ if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
663
+ if layer.k_scale is not None:
664
+ descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
665
+ k_descale = layer.k_scale.expand(descale_shape)
666
+ v_descale = layer.v_scale.expand(descale_shape)
665
667
  q = q.to(self.kv_cache_dtype)
668
+ q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
669
+ k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
666
670
  causal = not layer.is_cross_attention
667
671
 
668
672
  # Check if we should use local attention
@@ -776,8 +780,8 @@ class FlashAttentionBackend(AttentionBackend):
776
780
 
777
781
  output, lse, *rest = flash_attn_varlen_func(
778
782
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
779
- k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
780
- v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
783
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
784
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
781
785
  cu_seqlens_q=metadata.cu_seqlens_q,
782
786
  cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
783
787
  max_seqlen_q=metadata.max_seq_len_q,
@@ -790,8 +794,8 @@ class FlashAttentionBackend(AttentionBackend):
790
794
  # MHA for extend part of sequence without attending prefix kv cache
791
795
  output, lse, *rest = flash_attn_varlen_func(
792
796
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
793
- k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
794
- v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
797
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
798
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
795
799
  cu_seqlens_q=metadata.cu_seqlens_q,
796
800
  cu_seqlens_k=metadata.cu_seqlens_q,
797
801
  max_seqlen_q=metadata.max_seq_len_q,
@@ -803,7 +807,9 @@ class FlashAttentionBackend(AttentionBackend):
803
807
  return output, lse
804
808
  else:
805
809
  # Do absorbed multi-latent attention
806
- kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
810
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
811
+ layer.layer_id
812
+ ).to(q.dtype)
807
813
  k_rope = kv_cache[:, :, layer.v_head_dim :]
808
814
  c_kv = kv_cache[:, :, : layer.v_head_dim]
809
815
  k_rope_cache = k_rope.view(
@@ -933,14 +939,16 @@ class FlashAttentionBackend(AttentionBackend):
933
939
 
934
940
  k_descale, v_descale = None, None
935
941
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
936
- # has corresponding quantization method so that layer.k_scale is not None
937
- if self.kv_cache_dtype_str != "auto":
942
+ # has corresponding quantization method so that layer.k_scale is not None,
943
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
944
+ if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
938
945
  if layer.k_scale is not None:
939
946
  descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
940
947
  k_descale = layer.k_scale.expand(descale_shape)
941
948
  v_descale = layer.v_scale.expand(descale_shape)
942
949
  q = q.to(self.kv_cache_dtype)
943
-
950
+ q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
951
+ k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
944
952
  if not self.use_mla:
945
953
  # Do multi-head attention
946
954
 
@@ -1048,7 +1056,9 @@ class FlashAttentionBackend(AttentionBackend):
1048
1056
  o = result
1049
1057
  else:
1050
1058
  # Do absorbed multi-latent attention
1051
- kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
1059
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
1060
+ q.dtype
1061
+ )
1052
1062
  k_rope = kv_cache[:, :, layer.v_head_dim :]
1053
1063
  c_kv = kv_cache[:, :, : layer.v_head_dim]
1054
1064
  k_rope_cache = k_rope.view(
@@ -1120,7 +1130,7 @@ class FlashAttentionBackend(AttentionBackend):
1120
1130
 
1121
1131
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
1122
1132
 
1123
- def init_cuda_graph_state(self, max_bs: int):
1133
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
1124
1134
  """Initialize CUDA graph state for the attention backend.
1125
1135
 
1126
1136
  Args:
@@ -1704,14 +1714,15 @@ class FlashAttentionBackend(AttentionBackend):
1704
1714
 
1705
1715
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1706
1716
  metadata_expand = self.target_verify_metadata_topk_expand[bs]
1717
+
1707
1718
  # metadata_expand.max_seq_len_q = 1, already set in capture
1708
1719
  # metadata_expand.cu_seqlens_q already set in capture
1709
-
1710
1720
  offsets = torch.arange(
1711
1721
  self.speculative_num_draft_tokens, device=device
1712
1722
  ).unsqueeze(
1713
1723
  0
1714
1724
  ) # shape: (1, self.speculative_num_draft_tokens)
1725
+
1715
1726
  cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
1716
1727
  cum_len = torch.nn.functional.pad(
1717
1728
  torch.cumsum(
@@ -1728,17 +1739,20 @@ class FlashAttentionBackend(AttentionBackend):
1728
1739
  ).view(1, -1)
1729
1740
  # avoid extracting padded seq indices which will be out of boundary
1730
1741
  mask_extraction_indices[
1731
- :, spec_info.positions.numel() * self.speculative_num_draft_tokens :
1742
+ :,
1743
+ spec_info.positions.numel() * self.speculative_num_draft_tokens :,
1732
1744
  ].fill_(0)
1733
-
1734
1745
  mask = spec_info.custom_mask[mask_extraction_indices].view(
1735
1746
  -1, self.speculative_num_draft_tokens
1736
1747
  ) # (bsz * draft_num, draft_num)
1748
+
1737
1749
  col_indices = offsets.expand(
1738
1750
  mask.shape[0], self.speculative_num_draft_tokens
1739
1751
  )
1740
1752
  keys = torch.where(
1741
- mask, col_indices, col_indices + self.speculative_num_draft_tokens
1753
+ mask,
1754
+ col_indices,
1755
+ col_indices + self.speculative_num_draft_tokens,
1742
1756
  )
1743
1757
  _, sort_order = torch.sort(keys, dim=1)
1744
1758
 
@@ -1747,6 +1761,7 @@ class FlashAttentionBackend(AttentionBackend):
1747
1761
  .gather(1, cols)
1748
1762
  .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
1749
1763
  ) # (bsz, draft_num)
1764
+
1750
1765
  metadata_expand.page_table.copy_(
1751
1766
  non_masked_page_table.gather(1, sort_order)
1752
1767
  )
@@ -1758,6 +1773,7 @@ class FlashAttentionBackend(AttentionBackend):
1758
1773
  dtype=torch.int32,
1759
1774
  )
1760
1775
  )
1776
+
1761
1777
  elif forward_mode.is_draft_extend():
1762
1778
  metadata = self.draft_extend_metadata[bs]
1763
1779
  metadata.cache_seqlens_int32.copy_(seq_lens)
@@ -1767,7 +1783,11 @@ class FlashAttentionBackend(AttentionBackend):
1767
1783
  torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
1768
1784
  )
1769
1785
  accept_length = spec_info.accept_length[:bs]
1770
- metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
1786
+ if spec_info.accept_length_cpu:
1787
+ metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
1788
+ else:
1789
+ metadata.max_seq_len_q = 1
1790
+
1771
1791
  metadata.cu_seqlens_q[1:].copy_(
1772
1792
  torch.cumsum(accept_length, dim=0, dtype=torch.int32)
1773
1793
  )
@@ -1807,7 +1827,7 @@ class FlashAttentionBackend(AttentionBackend):
1807
1827
 
1808
1828
  def get_cuda_graph_seq_len_fill_value(self):
1809
1829
  """Get the fill value for sequence length in CUDA graph."""
1810
- return 0
1830
+ return 1
1811
1831
 
1812
1832
  def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
1813
1833
  """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
@@ -1999,9 +2019,9 @@ class FlashAttentionMultiStepBackend:
1999
2019
  for i in range(self.speculative_num_steps - 1):
2000
2020
  self.attn_backends[i].init_forward_metadata(forward_batch)
2001
2021
 
2002
- def init_cuda_graph_state(self, max_bs: int):
2022
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
2003
2023
  for i in range(self.speculative_num_steps):
2004
- self.attn_backends[i].init_cuda_graph_state(max_bs)
2024
+ self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
2005
2025
 
2006
2026
  def init_forward_metadata_capture_cuda_graph(
2007
2027
  self,
@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
262
262
  )
263
263
 
264
264
  def init_cuda_graph_state(
265
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
265
+ self,
266
+ max_bs: int,
267
+ max_num_tokens: int,
268
+ kv_indices_buf: Optional[torch.Tensor] = None,
266
269
  ):
267
270
  if kv_indices_buf is None:
268
271
  cuda_graph_kv_indices = torch.zeros(
269
- (max_bs * self.max_context_len,),
272
+ (max_num_tokens * self.max_context_len,),
270
273
  dtype=torch.int32,
271
274
  device="cuda",
272
275
  )
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
285
288
 
286
289
  if not self.skip_prefill:
287
290
  self.cuda_graph_custom_mask = torch.zeros(
288
- (max_bs * self.max_context_len),
291
+ (max_num_tokens * self.max_context_len),
289
292
  dtype=torch.uint8,
290
293
  device="cuda",
291
294
  )
@@ -440,7 +443,7 @@ class FlashInferAttnBackend(AttentionBackend):
440
443
  raise ValueError("Invalid forward mode")
441
444
 
442
445
  def get_cuda_graph_seq_len_fill_value(self):
443
- return 0
446
+ return 1
444
447
 
445
448
  def forward_extend(
446
449
  self,
@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
1096
1099
 
1097
1100
  self.common_template(forward_batch, kv_indices, call_fn)
1098
1101
 
1099
- def init_cuda_graph_state(self, max_bs: int):
1102
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
1100
1103
  self.cuda_graph_kv_indices = torch.zeros(
1101
1104
  (self.speculative_num_steps, max_bs * self.max_context_len),
1102
1105
  dtype=torch.int32,
@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
1105
1108
 
1106
1109
  for i in range(self.speculative_num_steps):
1107
1110
  self.attn_backends[i].init_cuda_graph_state(
1108
- max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
1111
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1109
1112
  )
1110
1113
 
1111
1114
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
199
199
  )
200
200
 
201
201
  def init_cuda_graph_state(
202
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
202
+ self,
203
+ max_bs: int,
204
+ max_num_tokens: int,
205
+ kv_indices_buf: Optional[torch.Tensor] = None,
203
206
  ):
204
207
  if kv_indices_buf is None:
205
208
  cuda_graph_kv_indices = torch.zeros(
@@ -364,7 +367,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
364
367
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
365
368
 
366
369
  def get_cuda_graph_seq_len_fill_value(self):
367
- return 0
370
+ return 1
368
371
 
369
372
  def forward_extend(
370
373
  self,
@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
852
855
 
853
856
  self.common_template(forward_batch, kv_indices, call_fn)
854
857
 
855
- def init_cuda_graph_state(self, max_bs: int):
858
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
856
859
  self.cuda_graph_kv_indices = torch.zeros(
857
860
  (self.speculative_num_steps, max_bs * self.max_context_len),
858
861
  dtype=torch.int32,
@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
861
864
 
862
865
  for i in range(self.speculative_num_steps):
863
866
  self.attn_backends[i].init_cuda_graph_state(
864
- max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
867
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
865
868
  )
866
869
 
867
870
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
148
148
  def init_cuda_graph_state(
149
149
  self,
150
150
  max_bs: int,
151
+ max_num_tokens: int,
151
152
  block_kv_indices: Optional[torch.Tensor] = None,
152
153
  ):
153
154
  if block_kv_indices is None:
@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
502
503
 
503
504
  self.common_template(forward_batch, call_fn)
504
505
 
505
- def init_cuda_graph_state(self, max_bs: int):
506
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
506
507
  for i in range(self.speculative_num_steps):
507
- self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
508
+ self.attn_backends[i].init_cuda_graph_state(
509
+ max_bs, max_num_tokens, block_kv_indices=None
510
+ )
508
511
 
509
512
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
510
513
  def call_fn(i, forward_batch):
@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
32
32
  if forward_batch_child.batch_size > 0:
33
33
  child.init_forward_metadata(forward_batch=forward_batch_child)
34
34
 
35
- def init_cuda_graph_state(self, max_bs: int):
36
- self.primary.init_cuda_graph_state(max_bs=max_bs)
35
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
36
+ self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
37
37
  for item in self.children:
38
38
  # TODO for children, maybe can provide *smaller* max_bs to optimize
39
- item.init_cuda_graph_state(max_bs=max_bs)
39
+ item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
40
40
 
41
41
  def init_forward_metadata_capture_cuda_graph(
42
42
  self,
@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
261
261
  num_kv_splits = None
262
262
  attn_logits = None
263
263
  attn_lse = None
264
+
264
265
  elif forward_batch.forward_mode.is_draft_extend():
265
266
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
266
267
  spec_info.generate_attn_arg_prefill(
@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
335
336
  )
336
337
 
337
338
  def init_cuda_graph_state(
338
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
339
+ self,
340
+ max_bs: int,
341
+ max_num_tokens: int,
342
+ kv_indices_buf: Optional[torch.Tensor] = None,
339
343
  ):
340
344
  self.cuda_graph_attn_logits = torch.zeros(
341
- (max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
345
+ (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
342
346
  dtype=torch.float32,
343
347
  device=self.device,
344
348
  )
345
349
  self.cuda_graph_attn_lse = torch.zeros(
346
- (max_bs, self.num_head, self.max_kv_splits),
350
+ (max_num_tokens, self.num_head, self.max_kv_splits),
347
351
  dtype=torch.float32,
348
352
  device=self.device,
349
353
  )
350
354
  self.cuda_graph_num_kv_splits = torch.full(
351
- (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
355
+ (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
352
356
  )
353
357
  if kv_indices_buf is None:
354
358
  self.cuda_graph_kv_indices = torch.zeros(
355
- (max_bs * self.max_context_len),
359
+ (max_num_tokens * self.max_context_len),
356
360
  dtype=torch.int32,
357
361
  device=self.device,
358
362
  )
@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
361
365
 
362
366
  if not self.skip_prefill:
363
367
  self.cuda_graph_custom_mask = torch.zeros(
364
- (max_bs * self.max_context_len),
368
+ (max_num_tokens * self.max_context_len),
365
369
  dtype=torch.uint8,
366
370
  device=self.device,
367
371
  )
@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
369
373
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
370
374
  if kv_indices_buf is None:
371
375
  self.cuda_graph_window_kv_indices = torch.zeros(
372
- (max_bs * self.sliding_window_size),
376
+ (max_num_tokens * self.sliding_window_size),
373
377
  dtype=torch.int32,
374
378
  device=self.device,
375
379
  )
@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
377
381
  self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
378
382
 
379
383
  self.cuda_graph_window_num_kv_splits = torch.full(
380
- (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
384
+ (max_num_tokens,),
385
+ self.max_kv_splits,
386
+ dtype=torch.int32,
387
+ device=self.device,
381
388
  )
382
389
 
383
390
  def init_forward_metadata_capture_cuda_graph(
@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
458
465
  )
459
466
 
460
467
  custom_mask = self.cuda_graph_custom_mask
468
+ custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
461
469
  seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
462
470
  mask_indptr = self.mask_indptr[: bs + 1]
463
471
  mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend:
821
829
 
822
830
  self.common_template(forward_batch, kv_indices, call_fn)
823
831
 
824
- def init_cuda_graph_state(self, max_bs: int):
832
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
825
833
  self.cuda_graph_kv_indices = torch.zeros(
826
- (self.speculative_num_steps, max_bs * self.max_context_len),
834
+ (self.speculative_num_steps, max_num_tokens * self.max_context_len),
827
835
  dtype=torch.int32,
828
836
  device=self.device,
829
837
  )
830
838
  for i in range(self.speculative_num_steps):
831
839
  self.attn_backends[i].init_cuda_graph_state(
832
- max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
840
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
833
841
  )
834
842
 
835
843
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
28
28
  attn_tp_reduce_scatter,
29
29
  dp_gather_partial,
30
30
  dp_scatter,
31
+ get_attention_dp_size,
31
32
  get_attention_tp_rank,
32
33
  get_attention_tp_size,
33
- get_local_attention_dp_size,
34
34
  )
35
35
  from sglang.srt.managers.schedule_batch import global_server_args_dict
36
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -229,7 +229,7 @@ class CommunicateContext:
229
229
  process_group_sizes: Dict[ScatterMode, int]
230
230
  attn_tp_rank: int
231
231
  attn_tp_size: int
232
- local_attn_dp_size: int
232
+ attn_dp_size: int
233
233
  tp_size: int
234
234
 
235
235
  def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
@@ -239,7 +239,7 @@ class CommunicateContext:
239
239
  def init_new(cls):
240
240
  attn_tp_rank = get_attention_tp_rank()
241
241
  attn_tp_size = get_attention_tp_size()
242
- local_attn_dp_size = get_local_attention_dp_size()
242
+ attn_dp_size = get_attention_dp_size()
243
243
  tp_size = get_tensor_model_parallel_world_size()
244
244
  process_group_sizes = {
245
245
  ScatterMode.SCATTERED: 1,
@@ -251,7 +251,7 @@ class CommunicateContext:
251
251
  process_group_sizes=process_group_sizes,
252
252
  attn_tp_rank=attn_tp_rank,
253
253
  attn_tp_size=attn_tp_size,
254
- local_attn_dp_size=local_attn_dp_size,
254
+ attn_dp_size=attn_dp_size,
255
255
  tp_size=tp_size,
256
256
  )
257
257
 
@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
385
385
  attn_tp_all_gather(
386
386
  list(residual.tensor_split(context.attn_tp_size)), local_residual
387
387
  )
388
- if context.local_attn_dp_size != 1:
388
+ if context.attn_dp_size != 1:
389
389
  if context.attn_tp_rank == 0:
390
390
  hidden_states += residual
391
391
  hidden_states, local_hidden_states = (
@@ -165,7 +165,8 @@ def disable_dp_size():
165
165
 
166
166
 
167
167
  def get_dp_local_info(forward_batch: ForwardBatch):
168
- dp_rank = get_local_attention_dp_rank()
168
+ # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
169
+ dp_rank = get_attention_dp_rank()
169
170
 
170
171
  if forward_batch.dp_local_start_pos is None:
171
172
  cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
@@ -238,6 +239,10 @@ def _dp_gather(
238
239
  assert (
239
240
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
240
241
  ), "aliasing between global_tokens and local_tokens not allowed"
242
+ if forward_batch.forward_mode.is_draft_extend():
243
+ shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
244
+ local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
245
+
241
246
  memcpy_triton(
242
247
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
243
248
  )
@@ -288,6 +293,10 @@ def dp_scatter(
288
293
  assert (
289
294
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
290
295
  ), "aliasing between local_tokens and global_tokens not allowed"
296
+ if forward_batch.forward_mode.is_draft_extend():
297
+ shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
298
+ local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
299
+
291
300
  memcpy_triton(
292
301
  local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
293
302
  )
@@ -301,4 +310,4 @@ def attn_tp_reduce_scatter(
301
310
 
302
311
 
303
312
  def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
304
- return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
313
+ return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
@@ -20,11 +20,21 @@ import torch
20
20
  import torch.nn as nn
21
21
 
22
22
  from sglang.srt.custom_op import CustomOp
23
- from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
23
+ from sglang.srt.utils import (
24
+ cpu_has_amx_support,
25
+ get_bool_env_var,
26
+ is_cpu,
27
+ is_cuda,
28
+ is_hip,
29
+ is_npu,
30
+ )
24
31
 
25
32
  _is_cuda = is_cuda()
26
33
  _is_hip = is_hip()
34
+ _is_npu = is_npu()
27
35
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
36
+ _is_cpu_amx_available = cpu_has_amx_support()
37
+ _is_cpu = is_cpu()
28
38
 
29
39
  if _is_cuda:
30
40
  from sgl_kernel import (
@@ -42,6 +52,9 @@ elif _is_hip:
42
52
 
43
53
  logger = logging.getLogger(__name__)
44
54
 
55
+ if is_npu():
56
+ import torch_npu
57
+
45
58
 
46
59
  class RMSNorm(CustomOp):
47
60
  def __init__(
@@ -66,6 +79,18 @@ class RMSNorm(CustomOp):
66
79
  out = rmsnorm(x, self.weight.data, self.variance_epsilon)
67
80
  return out
68
81
 
82
+ def forward_npu(
83
+ self,
84
+ x: torch.Tensor,
85
+ residual: Optional[torch.Tensor] = None,
86
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
87
+ if residual is not None:
88
+ out, _, residual_out = torch_npu.npu_add_rms_norm(
89
+ residual, x, self.weight.data, self.variance_epsilon
90
+ )
91
+ return out, residual_out
92
+ return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
93
+
69
94
  def forward_aiter(
70
95
  self,
71
96
  x: torch.Tensor,
@@ -121,6 +146,23 @@ class RMSNorm(CustomOp):
121
146
  else:
122
147
  return x, residual
123
148
 
149
+ def forward_cpu(
150
+ self,
151
+ x: torch.Tensor,
152
+ residual: Optional[torch.Tensor] = None,
153
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
154
+ if _is_cpu_amx_available:
155
+ if residual is not None:
156
+ torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
157
+ x, residual, self.weight.data, self.variance_epsilon
158
+ )
159
+ return x, residual
160
+ return torch.ops.sgl_kernel.rmsnorm_cpu(
161
+ x, self.weight.data, self.variance_epsilon
162
+ )
163
+ else:
164
+ return self.forward_native(x, residual)
165
+
124
166
 
125
167
  class GemmaRMSNorm(CustomOp):
126
168
  def __init__(
@@ -187,7 +229,7 @@ class Gemma3RMSNorm(nn.Module):
187
229
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
188
230
 
189
231
 
190
- if not (_is_cuda or _is_hip):
232
+ if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
191
233
  logger.info(
192
234
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
193
235
  )
@@ -30,7 +30,12 @@ from sglang.srt.layers.quantization.base_config import (
30
30
  QuantizationConfig,
31
31
  QuantizeMethodBase,
32
32
  )
33
- from sglang.srt.utils import set_weight_attrs
33
+ from sglang.srt.utils import (
34
+ _process_weight_after_loading,
35
+ cpu_has_amx_support,
36
+ is_cpu,
37
+ set_weight_attrs,
38
+ )
34
39
 
35
40
  logger = logging.getLogger(__name__)
36
41
 
@@ -52,6 +57,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
52
57
  "IPEXAWQLinearMethod",
53
58
  ]
54
59
 
60
+ _is_cpu_amx_available = cpu_has_amx_support()
61
+ _is_cpu = is_cpu()
62
+
55
63
 
56
64
  def adjust_marlin_shard(param, shard_size, shard_offset):
57
65
  marlin_tile_size = getattr(param, "marlin_tile_size", None)
@@ -165,6 +173,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
165
173
  layer.register_parameter("weight", weight)
166
174
  set_weight_attrs(weight, extra_weight_attrs)
167
175
 
176
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
177
+ if _is_cpu and _is_cpu_amx_available:
178
+ _process_weight_after_loading(layer, ["weight"])
179
+
168
180
  def apply(
169
181
  self,
170
182
  layer: torch.nn.Module,
@@ -172,6 +184,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
172
184
  bias: Optional[torch.Tensor] = None,
173
185
  ) -> torch.Tensor:
174
186
 
187
+ if getattr(layer, "use_intel_amx_backend", False):
188
+ return torch.ops.sgl_kernel.weight_packed_linear(
189
+ x, layer.weight, bias, True # is_vnni
190
+ )
191
+
175
192
  return F.linear(x, layer.weight, bias)
176
193
 
177
194
 
@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import (
30
30
  attn_tp_all_gather,
31
31
  dp_gather_replicate,
32
32
  dp_scatter,
33
+ get_attention_dp_rank,
33
34
  get_attention_dp_size,
34
35
  get_attention_tp_size,
35
- get_local_attention_dp_rank,
36
36
  get_local_attention_dp_size,
37
37
  )
38
38
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -171,7 +171,7 @@ class LogitsMetadata:
171
171
  return
172
172
 
173
173
  cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
174
- dp_rank = get_local_attention_dp_rank()
174
+ dp_rank = get_attention_dp_rank()
175
175
  if dp_rank == 0:
176
176
  dp_local_start_pos = torch.zeros_like(
177
177
  self.global_num_tokens_for_logprob_gpu[0]
@@ -442,11 +442,20 @@ class LogitsProcessor(nn.Module):
442
442
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
443
443
 
444
444
  if hasattr(lm_head, "weight"):
445
- logits = torch.matmul(
446
- hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
447
- )
445
+ if getattr(lm_head, "use_intel_amx_backend", False):
446
+ logits = torch.ops.sgl_kernel.weight_packed_linear(
447
+ hidden_states.to(lm_head.weight.dtype),
448
+ lm_head.weight,
449
+ None, # bias
450
+ True, # is_vnni
451
+ )
452
+ else:
453
+ logits = torch.matmul(
454
+ hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
455
+ )
448
456
  else:
449
457
  # GGUF models
458
+ # TODO: use weight_packed_linear for GGUF models
450
459
  logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
451
460
 
452
461
  if self.logit_scale is not None: