sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
262
262
  )
263
263
 
264
264
  def init_cuda_graph_state(
265
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
265
+ self,
266
+ max_bs: int,
267
+ max_num_tokens: int,
268
+ kv_indices_buf: Optional[torch.Tensor] = None,
266
269
  ):
267
270
  if kv_indices_buf is None:
268
271
  cuda_graph_kv_indices = torch.zeros(
269
- (max_bs * self.max_context_len,),
272
+ (max_num_tokens * self.max_context_len,),
270
273
  dtype=torch.int32,
271
274
  device="cuda",
272
275
  )
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
285
288
 
286
289
  if not self.skip_prefill:
287
290
  self.cuda_graph_custom_mask = torch.zeros(
288
- (max_bs * self.max_context_len),
291
+ (max_num_tokens * self.max_context_len),
289
292
  dtype=torch.uint8,
290
293
  device="cuda",
291
294
  )
@@ -440,7 +443,7 @@ class FlashInferAttnBackend(AttentionBackend):
440
443
  raise ValueError("Invalid forward mode")
441
444
 
442
445
  def get_cuda_graph_seq_len_fill_value(self):
443
- return 0
446
+ return 1
444
447
 
445
448
  def forward_extend(
446
449
  self,
@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
1096
1099
 
1097
1100
  self.common_template(forward_batch, kv_indices, call_fn)
1098
1101
 
1099
- def init_cuda_graph_state(self, max_bs: int):
1102
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
1100
1103
  self.cuda_graph_kv_indices = torch.zeros(
1101
1104
  (self.speculative_num_steps, max_bs * self.max_context_len),
1102
1105
  dtype=torch.int32,
@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
1105
1108
 
1106
1109
  for i in range(self.speculative_num_steps):
1107
1110
  self.attn_backends[i].init_cuda_graph_state(
1108
- max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
1111
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1109
1112
  )
1110
1113
 
1111
1114
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
199
199
  )
200
200
 
201
201
  def init_cuda_graph_state(
202
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
202
+ self,
203
+ max_bs: int,
204
+ max_num_tokens: int,
205
+ kv_indices_buf: Optional[torch.Tensor] = None,
203
206
  ):
204
207
  if kv_indices_buf is None:
205
208
  cuda_graph_kv_indices = torch.zeros(
@@ -364,7 +367,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
364
367
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
365
368
 
366
369
  def get_cuda_graph_seq_len_fill_value(self):
367
- return 0
370
+ return 1
368
371
 
369
372
  def forward_extend(
370
373
  self,
@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
852
855
 
853
856
  self.common_template(forward_batch, kv_indices, call_fn)
854
857
 
855
- def init_cuda_graph_state(self, max_bs: int):
858
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
856
859
  self.cuda_graph_kv_indices = torch.zeros(
857
860
  (self.speculative_num_steps, max_bs * self.max_context_len),
858
861
  dtype=torch.int32,
@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
861
864
 
862
865
  for i in range(self.speculative_num_steps):
863
866
  self.attn_backends[i].init_cuda_graph_state(
864
- max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
867
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
865
868
  )
866
869
 
867
870
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
148
148
  def init_cuda_graph_state(
149
149
  self,
150
150
  max_bs: int,
151
+ max_num_tokens: int,
151
152
  block_kv_indices: Optional[torch.Tensor] = None,
152
153
  ):
153
154
  if block_kv_indices is None:
@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
502
503
 
503
504
  self.common_template(forward_batch, call_fn)
504
505
 
505
- def init_cuda_graph_state(self, max_bs: int):
506
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
506
507
  for i in range(self.speculative_num_steps):
507
- self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
508
+ self.attn_backends[i].init_cuda_graph_state(
509
+ max_bs, max_num_tokens, block_kv_indices=None
510
+ )
508
511
 
509
512
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
510
513
  def call_fn(i, forward_batch):
@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
32
32
  if forward_batch_child.batch_size > 0:
33
33
  child.init_forward_metadata(forward_batch=forward_batch_child)
34
34
 
35
- def init_cuda_graph_state(self, max_bs: int):
36
- self.primary.init_cuda_graph_state(max_bs=max_bs)
35
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
36
+ self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
37
37
  for item in self.children:
38
38
  # TODO for children, maybe can provide *smaller* max_bs to optimize
39
- item.init_cuda_graph_state(max_bs=max_bs)
39
+ item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
40
40
 
41
41
  def init_forward_metadata_capture_cuda_graph(
42
42
  self,
@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
261
261
  num_kv_splits = None
262
262
  attn_logits = None
263
263
  attn_lse = None
264
+
264
265
  elif forward_batch.forward_mode.is_draft_extend():
265
266
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
266
267
  spec_info.generate_attn_arg_prefill(
@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
335
336
  )
336
337
 
337
338
  def init_cuda_graph_state(
338
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
339
+ self,
340
+ max_bs: int,
341
+ max_num_tokens: int,
342
+ kv_indices_buf: Optional[torch.Tensor] = None,
339
343
  ):
340
344
  self.cuda_graph_attn_logits = torch.zeros(
341
- (max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
345
+ (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
342
346
  dtype=torch.float32,
343
347
  device=self.device,
344
348
  )
345
349
  self.cuda_graph_attn_lse = torch.zeros(
346
- (max_bs, self.num_head, self.max_kv_splits),
350
+ (max_num_tokens, self.num_head, self.max_kv_splits),
347
351
  dtype=torch.float32,
348
352
  device=self.device,
349
353
  )
350
354
  self.cuda_graph_num_kv_splits = torch.full(
351
- (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
355
+ (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
352
356
  )
353
357
  if kv_indices_buf is None:
354
358
  self.cuda_graph_kv_indices = torch.zeros(
355
- (max_bs * self.max_context_len),
359
+ (max_num_tokens * self.max_context_len),
356
360
  dtype=torch.int32,
357
361
  device=self.device,
358
362
  )
@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
361
365
 
362
366
  if not self.skip_prefill:
363
367
  self.cuda_graph_custom_mask = torch.zeros(
364
- (max_bs * self.max_context_len),
368
+ (max_num_tokens * self.max_context_len),
365
369
  dtype=torch.uint8,
366
370
  device=self.device,
367
371
  )
@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
369
373
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
370
374
  if kv_indices_buf is None:
371
375
  self.cuda_graph_window_kv_indices = torch.zeros(
372
- (max_bs * self.sliding_window_size),
376
+ (max_num_tokens * self.sliding_window_size),
373
377
  dtype=torch.int32,
374
378
  device=self.device,
375
379
  )
@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
377
381
  self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
378
382
 
379
383
  self.cuda_graph_window_num_kv_splits = torch.full(
380
- (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
384
+ (max_num_tokens,),
385
+ self.max_kv_splits,
386
+ dtype=torch.int32,
387
+ device=self.device,
381
388
  )
382
389
 
383
390
  def init_forward_metadata_capture_cuda_graph(
@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
458
465
  )
459
466
 
460
467
  custom_mask = self.cuda_graph_custom_mask
468
+ custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
461
469
  seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
462
470
  mask_indptr = self.mask_indptr[: bs + 1]
463
471
  mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend:
821
829
 
822
830
  self.common_template(forward_batch, kv_indices, call_fn)
823
831
 
824
- def init_cuda_graph_state(self, max_bs: int):
832
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
825
833
  self.cuda_graph_kv_indices = torch.zeros(
826
- (self.speculative_num_steps, max_bs * self.max_context_len),
834
+ (self.speculative_num_steps, max_num_tokens * self.max_context_len),
827
835
  dtype=torch.int32,
828
836
  device=self.device,
829
837
  )
830
838
  for i in range(self.speculative_num_steps):
831
839
  self.attn_backends[i].init_cuda_graph_state(
832
- max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
840
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
833
841
  )
834
842
 
835
843
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
28
28
  attn_tp_reduce_scatter,
29
29
  dp_gather_partial,
30
30
  dp_scatter,
31
+ get_attention_dp_size,
31
32
  get_attention_tp_rank,
32
33
  get_attention_tp_size,
33
- get_local_attention_dp_size,
34
34
  )
35
35
  from sglang.srt.managers.schedule_batch import global_server_args_dict
36
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -229,7 +229,7 @@ class CommunicateContext:
229
229
  process_group_sizes: Dict[ScatterMode, int]
230
230
  attn_tp_rank: int
231
231
  attn_tp_size: int
232
- local_attn_dp_size: int
232
+ attn_dp_size: int
233
233
  tp_size: int
234
234
 
235
235
  def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
@@ -239,7 +239,7 @@ class CommunicateContext:
239
239
  def init_new(cls):
240
240
  attn_tp_rank = get_attention_tp_rank()
241
241
  attn_tp_size = get_attention_tp_size()
242
- local_attn_dp_size = get_local_attention_dp_size()
242
+ attn_dp_size = get_attention_dp_size()
243
243
  tp_size = get_tensor_model_parallel_world_size()
244
244
  process_group_sizes = {
245
245
  ScatterMode.SCATTERED: 1,
@@ -251,7 +251,7 @@ class CommunicateContext:
251
251
  process_group_sizes=process_group_sizes,
252
252
  attn_tp_rank=attn_tp_rank,
253
253
  attn_tp_size=attn_tp_size,
254
- local_attn_dp_size=local_attn_dp_size,
254
+ attn_dp_size=attn_dp_size,
255
255
  tp_size=tp_size,
256
256
  )
257
257
 
@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
385
385
  attn_tp_all_gather(
386
386
  list(residual.tensor_split(context.attn_tp_size)), local_residual
387
387
  )
388
- if context.local_attn_dp_size != 1:
388
+ if context.attn_dp_size != 1:
389
389
  if context.attn_tp_rank == 0:
390
390
  hidden_states += residual
391
391
  hidden_states, local_hidden_states = (
@@ -165,7 +165,8 @@ def disable_dp_size():
165
165
 
166
166
 
167
167
  def get_dp_local_info(forward_batch: ForwardBatch):
168
- dp_rank = get_local_attention_dp_rank()
168
+ # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
169
+ dp_rank = get_attention_dp_rank()
169
170
 
170
171
  if forward_batch.dp_local_start_pos is None:
171
172
  cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
@@ -238,6 +239,10 @@ def _dp_gather(
238
239
  assert (
239
240
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
240
241
  ), "aliasing between global_tokens and local_tokens not allowed"
242
+ if forward_batch.forward_mode.is_draft_extend():
243
+ shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
244
+ local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
245
+
241
246
  memcpy_triton(
242
247
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
243
248
  )
@@ -288,6 +293,10 @@ def dp_scatter(
288
293
  assert (
289
294
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
290
295
  ), "aliasing between local_tokens and global_tokens not allowed"
296
+ if forward_batch.forward_mode.is_draft_extend():
297
+ shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
298
+ local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
299
+
291
300
  memcpy_triton(
292
301
  local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
293
302
  )
@@ -301,4 +310,4 @@ def attn_tp_reduce_scatter(
301
310
 
302
311
 
303
312
  def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
304
- return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
313
+ return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
@@ -20,11 +20,21 @@ import torch
20
20
  import torch.nn as nn
21
21
 
22
22
  from sglang.srt.custom_op import CustomOp
23
- from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
23
+ from sglang.srt.utils import (
24
+ cpu_has_amx_support,
25
+ get_bool_env_var,
26
+ is_cpu,
27
+ is_cuda,
28
+ is_hip,
29
+ is_npu,
30
+ )
24
31
 
25
32
  _is_cuda = is_cuda()
26
33
  _is_hip = is_hip()
34
+ _is_npu = is_npu()
27
35
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
36
+ _is_cpu_amx_available = cpu_has_amx_support()
37
+ _is_cpu = is_cpu()
28
38
 
29
39
  if _is_cuda:
30
40
  from sgl_kernel import (
@@ -121,6 +131,23 @@ class RMSNorm(CustomOp):
121
131
  else:
122
132
  return x, residual
123
133
 
134
+ def forward_cpu(
135
+ self,
136
+ x: torch.Tensor,
137
+ residual: Optional[torch.Tensor] = None,
138
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
139
+ if _is_cpu_amx_available:
140
+ if residual is not None:
141
+ torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
142
+ x, residual, self.weight.data, self.variance_epsilon
143
+ )
144
+ return x, residual
145
+ return torch.ops.sgl_kernel.rmsnorm_cpu(
146
+ x, self.weight.data, self.variance_epsilon
147
+ )
148
+ else:
149
+ return self.forward_native(x, residual)
150
+
124
151
 
125
152
  class GemmaRMSNorm(CustomOp):
126
153
  def __init__(
@@ -187,7 +214,7 @@ class Gemma3RMSNorm(nn.Module):
187
214
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
188
215
 
189
216
 
190
- if not (_is_cuda or _is_hip):
217
+ if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
191
218
  logger.info(
192
219
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
193
220
  )
@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import (
30
30
  attn_tp_all_gather,
31
31
  dp_gather_replicate,
32
32
  dp_scatter,
33
+ get_attention_dp_rank,
33
34
  get_attention_dp_size,
34
35
  get_attention_tp_size,
35
- get_local_attention_dp_rank,
36
36
  get_local_attention_dp_size,
37
37
  )
38
38
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -171,7 +171,7 @@ class LogitsMetadata:
171
171
  return
172
172
 
173
173
  cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
174
- dp_rank = get_local_attention_dp_rank()
174
+ dp_rank = get_attention_dp_rank()
175
175
  if dp_rank == 0:
176
176
  dp_local_start_pos = torch.zeros_like(
177
177
  self.global_num_tokens_for_logprob_gpu[0]
@@ -478,11 +478,13 @@ def post_reorder_triton_kernel(
478
478
  end_expert_id,
479
479
  topk,
480
480
  hidden_size,
481
+ dst_start,
481
482
  BLOCK_SIZE: tl.constexpr,
482
483
  ):
483
484
  InDtype = down_output_ptr.dtype.element_ty
484
485
 
485
- src_idx = tl.program_id(0)
486
+ src_idx_int32 = tl.program_id(0)
487
+ src_idx = src_idx_int32.to(tl.int64)
486
488
  src2dst_ptr = src2dst_ptr + src_idx * topk
487
489
  topk_ids_ptr = topk_ids_ptr + src_idx * topk
488
490
  topk_weights_ptr = topk_weights_ptr + src_idx * topk
@@ -501,7 +503,9 @@ def post_reorder_triton_kernel(
501
503
  expert_id = tl.load(topk_ids_ptr + idx)
502
504
  if expert_id >= start_expert_id and expert_id <= end_expert_id:
503
505
  computed = True
504
- dst_idx = tl.load(src2dst_ptr + idx)
506
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
507
+ dst_idx = dst_idx_int32.to(tl.int64)
508
+ dst_idx = dst_idx - dst_start
505
509
  weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
506
510
  load_ptr = down_output_ptr + dst_idx * hidden_size
507
511
  in_data = tl.load(load_ptr + offset, mask=mask)
@@ -1086,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor):
1086
1090
  BLOCK_SIZE_K=BLOCK_SIZE_K,
1087
1091
  )
1088
1092
  return output.t()[:m]
1093
+
1094
+
1095
+ @triton.jit
1096
+ def compute_masked_m_triton_kernel(seg_indptr, masked_m):
1097
+ expert_id = tl.program_id(0)
1098
+ start = tl.load(seg_indptr + expert_id)
1099
+ end = tl.load(seg_indptr + expert_id + 1)
1100
+ tl.store(masked_m + expert_id, (end - start))
1101
+
1102
+
1103
+ @triton.jit
1104
+ def deepgemm_compute_src2dst_triton_kernel(
1105
+ topk_ids,
1106
+ reorder_ids,
1107
+ seg_indptr,
1108
+ src2dst,
1109
+ m_max,
1110
+ num_toks,
1111
+ BLOCK_SIZE: tl.constexpr,
1112
+ ):
1113
+ pid = tl.program_id(axis=0)
1114
+ dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1115
+ mask = dst_id < num_toks
1116
+ src_id = tl.load(reorder_ids + dst_id, mask=mask)
1117
+ expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
1118
+ expert_dst_start = tl.load(seg_indptr + expert_id)
1119
+ expert_dst_offset = dst_id - expert_dst_start
1120
+ dst_id = expert_id * m_max + expert_dst_offset
1121
+ tl.store(src2dst + src_id, dst_id, mask=mask)
1122
+
1123
+
1124
+ @triton.jit
1125
+ def fill_gateup_input_triton_kernel(
1126
+ input_ptr,
1127
+ scale_ptr,
1128
+ gateup_input_ptr,
1129
+ gateup_input_scale_ptr,
1130
+ src2dst_ptr,
1131
+ topk_ids_ptr,
1132
+ start_expert_id,
1133
+ end_expert_id,
1134
+ topk,
1135
+ m_max,
1136
+ hidden_size,
1137
+ scale_size,
1138
+ BLOCK_SIZE: tl.constexpr,
1139
+ ):
1140
+
1141
+ src_idx_int32 = tl.program_id(0)
1142
+ src_idx = src_idx_int32.to(tl.int64)
1143
+ src2dst_ptr = src2dst_ptr + src_idx * topk
1144
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
1145
+ src_ptr = input_ptr + src_idx * hidden_size
1146
+ scale_src_ptr = scale_ptr + src_idx * scale_size
1147
+
1148
+ vec = tl.arange(0, BLOCK_SIZE)
1149
+ for idx in range(topk):
1150
+ expert_id = tl.load(topk_ids_ptr + idx)
1151
+ if expert_id >= start_expert_id and expert_id <= end_expert_id:
1152
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
1153
+ dst_idx = dst_idx_int32.to(tl.int64)
1154
+ dst_idx = dst_idx - start_expert_id * m_max
1155
+ dst_ptr = gateup_input_ptr + dst_idx * hidden_size
1156
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
1157
+ offset = start_offset + vec
1158
+ mask = offset < hidden_size
1159
+ in_data = tl.load(src_ptr + offset, mask=mask)
1160
+ tl.store(dst_ptr + offset, in_data, mask=mask)
1161
+ scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
1162
+ for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
1163
+ offset = start_offset + vec
1164
+ mask = offset < scale_size
1165
+ in_scale = tl.load(scale_src_ptr + offset, mask=mask)
1166
+ tl.store(scale_dst_ptr + offset, in_scale, mask=mask)
1167
+
1168
+
1169
+ def moe_ep_deepgemm_preprocess(
1170
+ topk_ids: torch.Tensor,
1171
+ num_experts: int,
1172
+ hidden_states: torch.Tensor,
1173
+ top_k: int,
1174
+ start_expert_id,
1175
+ end_expert_id,
1176
+ block_shape,
1177
+ output_dtype: torch.dtype = torch.float8_e4m3fn,
1178
+ ):
1179
+ reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
1180
+ seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
1181
+ src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
1182
+ masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
1183
+
1184
+ compute_seg_indptr_triton_kernel[(num_experts,)](
1185
+ reorder_topk_ids, seg_indptr, topk_ids.numel()
1186
+ )
1187
+
1188
+ grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
1189
+ compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
1190
+
1191
+ # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
1192
+ m_max = (hidden_states.size(0) + 255) // 256 * 256
1193
+ expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
1194
+ gateup_input = torch.empty(
1195
+ (int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
1196
+ device=hidden_states.device,
1197
+ dtype=output_dtype,
1198
+ )
1199
+
1200
+ deepgemm_compute_src2dst_triton_kernel[grid](
1201
+ topk_ids,
1202
+ reorder_ids,
1203
+ seg_indptr,
1204
+ src2dst,
1205
+ m_max,
1206
+ topk_ids.numel(),
1207
+ BLOCK_SIZE=256,
1208
+ )
1209
+
1210
+ if block_shape is None:
1211
+ block_shape = [128, 128]
1212
+ assert len(block_shape) == 2
1213
+ block_n, block_k = block_shape[0], block_shape[1]
1214
+ hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
1215
+
1216
+ gateup_input_scale = torch.empty(
1217
+ (gateup_input.size(0), gateup_input.size(1), scale.size(1)),
1218
+ device=hidden_states.device,
1219
+ dtype=scale.dtype,
1220
+ )
1221
+
1222
+ fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
1223
+ hidden_states,
1224
+ scale,
1225
+ gateup_input,
1226
+ gateup_input_scale,
1227
+ src2dst,
1228
+ topk_ids,
1229
+ start_expert_id,
1230
+ end_expert_id,
1231
+ top_k,
1232
+ m_max,
1233
+ hidden_states.size(1),
1234
+ scale.size(1),
1235
+ BLOCK_SIZE=1024,
1236
+ )
1237
+
1238
+ return (
1239
+ m_max,
1240
+ masked_m[start_expert_id : (end_expert_id + 1)],
1241
+ expected_m,
1242
+ src2dst,
1243
+ gateup_input,
1244
+ gateup_input_scale,
1245
+ )