sglang 0.5.0rc1__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. sglang/bench_one_batch.py +0 -1
  2. sglang/srt/configs/model_config.py +1 -0
  3. sglang/srt/disaggregation/decode.py +0 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/entrypoints/http_server.py +64 -0
  6. sglang/srt/entrypoints/openai/protocol.py +2 -0
  7. sglang/srt/entrypoints/openai/serving_chat.py +1 -0
  8. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  9. sglang/srt/layers/attention/flashinfer_backend.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  11. sglang/srt/layers/attention/triton_backend.py +24 -27
  12. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  13. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -3
  14. sglang/srt/layers/communicator.py +7 -7
  15. sglang/srt/layers/dp_attention.py +118 -27
  16. sglang/srt/layers/logits_processor.py +12 -18
  17. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/multimodal.py +156 -40
  29. sglang/srt/layers/quantization/__init__.py +5 -32
  30. sglang/srt/layers/quantization/awq.py +15 -16
  31. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  32. sglang/srt/layers/quantization/gptq.py +12 -17
  33. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  34. sglang/srt/layers/quantization/modelopt_quant.py +52 -30
  35. sglang/srt/layers/quantization/mxfp4.py +16 -2
  36. sglang/srt/layers/quantization/utils.py +52 -2
  37. sglang/srt/layers/sampler.py +5 -2
  38. sglang/srt/lora/layers.py +6 -2
  39. sglang/srt/managers/cache_controller.py +4 -1
  40. sglang/srt/managers/io_struct.py +14 -0
  41. sglang/srt/managers/schedule_batch.py +18 -39
  42. sglang/srt/managers/scheduler.py +3 -4
  43. sglang/srt/managers/tokenizer_manager.py +28 -18
  44. sglang/srt/mem_cache/allocator.py +8 -157
  45. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  46. sglang/srt/mem_cache/chunk_cache.py +1 -1
  47. sglang/srt/model_executor/cuda_graph_runner.py +8 -21
  48. sglang/srt/model_executor/forward_batch_info.py +8 -10
  49. sglang/srt/model_executor/model_runner.py +57 -53
  50. sglang/srt/models/deepseek_nextn.py +2 -1
  51. sglang/srt/models/deepseek_v2.py +5 -3
  52. sglang/srt/models/glm4_moe.py +2 -2
  53. sglang/srt/models/glm4_moe_nextn.py +2 -1
  54. sglang/srt/models/gpt_oss.py +7 -2
  55. sglang/srt/models/llama.py +10 -2
  56. sglang/srt/models/llama4.py +18 -5
  57. sglang/srt/models/qwen2.py +2 -2
  58. sglang/srt/models/qwen2_moe.py +20 -5
  59. sglang/srt/models/qwen3_classification.py +78 -0
  60. sglang/srt/models/qwen3_moe.py +18 -5
  61. sglang/srt/models/step3_vl.py +6 -2
  62. sglang/srt/operations.py +17 -2
  63. sglang/srt/sampling/sampling_batch_info.py +7 -4
  64. sglang/srt/server_args.py +33 -7
  65. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  66. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  67. sglang/srt/two_batch_overlap.py +4 -8
  68. sglang/test/test_marlin_moe.py +1 -1
  69. sglang/test/test_marlin_utils.py +1 -1
  70. sglang/version.py +1 -1
  71. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +5 -5
  72. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +75 -63
  73. sglang/srt/layers/quantization/scalar_type.py +0 -352
  74. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  75. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  76. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
34
34
  )
35
35
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
36
36
  from sglang.srt.layers.dp_attention import (
37
- DPPaddingMode,
37
+ DpPaddingMode,
38
38
  get_attention_tp_rank,
39
39
  get_attention_tp_size,
40
+ set_dp_buffer_len,
40
41
  )
41
42
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
42
43
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
@@ -349,30 +350,15 @@ class CudaGraphRunner:
349
350
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
350
351
  (self.dp_size,), dtype=torch.int32
351
352
  )
352
- self.gathered_buffer = torch.zeros(
353
- (
354
- self.max_num_token * self.dp_size,
355
- self.model_runner.model_config.hidden_size,
356
- ),
357
- dtype=self.model_runner.dtype,
358
- )
359
353
  else:
360
354
  assert self.require_attn_tp_gather
361
355
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
362
356
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
363
357
  (1,), dtype=torch.int32
364
358
  )
365
- self.gathered_buffer = torch.zeros(
366
- (
367
- self.max_num_token,
368
- self.model_runner.model_config.hidden_size,
369
- ),
370
- dtype=self.model_runner.dtype,
371
- )
372
359
  else:
373
360
  self.global_num_tokens_gpu = None
374
361
  self.global_num_tokens_for_logprob_gpu = None
375
- self.gathered_buffer = None
376
362
 
377
363
  self.custom_mask = torch.ones(
378
364
  (
@@ -556,7 +542,7 @@ class CudaGraphRunner:
556
542
  device=input_ids.device,
557
543
  )
558
544
  )
559
- gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
545
+ global_dp_buffer_len = num_tokens * self.dp_size
560
546
  elif self.require_attn_tp_gather:
561
547
  self.global_num_tokens_gpu.copy_(
562
548
  torch.tensor(
@@ -572,9 +558,9 @@ class CudaGraphRunner:
572
558
  device=input_ids.device,
573
559
  )
574
560
  )
575
- gathered_buffer = self.gathered_buffer[:num_tokens]
561
+ global_dp_buffer_len = num_tokens
576
562
  else:
577
- gathered_buffer = None
563
+ global_dp_buffer_len = None
578
564
 
579
565
  spec_info = self.get_spec_info(num_tokens)
580
566
  if self.capture_hidden_mode != CaptureHiddenMode.FULL:
@@ -607,8 +593,8 @@ class CudaGraphRunner:
607
593
  positions=positions,
608
594
  global_num_tokens_gpu=self.global_num_tokens_gpu,
609
595
  global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
610
- dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
611
- gathered_buffer=gathered_buffer,
596
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
597
+ global_dp_buffer_len=global_dp_buffer_len,
612
598
  mrope_positions=mrope_positions,
613
599
  spec_algorithm=self.model_runner.spec_algorithm,
614
600
  spec_info=spec_info,
@@ -637,6 +623,7 @@ class CudaGraphRunner:
637
623
  def run_once():
638
624
  # Clean intermediate result cache for DP attention
639
625
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
626
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
640
627
 
641
628
  kwargs = {}
642
629
  if (
@@ -40,9 +40,10 @@ import triton.language as tl
40
40
 
41
41
  from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
42
42
  from sglang.srt.layers.dp_attention import (
43
- DPPaddingMode,
43
+ DpPaddingMode,
44
44
  get_attention_dp_rank,
45
45
  get_attention_tp_size,
46
+ set_dp_buffer_len,
46
47
  )
47
48
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
48
49
  from sglang.srt.utils import (
@@ -274,13 +275,13 @@ class ForwardBatch:
274
275
  global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
275
276
  global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
276
277
  # The padding mode for DP attention
277
- dp_padding_mode: Optional[DPPaddingMode] = None
278
+ dp_padding_mode: Optional[DpPaddingMode] = None
278
279
  # for extend, local start pos and num tokens is different in logits processor
279
280
  # this will be computed in get_dp_local_info
280
281
  # this will be recomputed in LogitsMetadata.from_forward_batch
281
282
  dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
282
283
  dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
283
- gathered_buffer: Optional[torch.Tensor] = None
284
+ global_dp_buffer_len: Optional[int] = None
284
285
  is_extend_in_batch: bool = False
285
286
  can_run_dp_cuda_graph: bool = False
286
287
  global_forward_mode: Optional[ForwardMode] = None
@@ -628,7 +629,7 @@ class ForwardBatch:
628
629
  (global_num_tokens[i] - 1) // attn_tp_size + 1
629
630
  ) * attn_tp_size
630
631
 
631
- dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
632
+ dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
632
633
  self.dp_padding_mode = dp_padding_mode
633
634
 
634
635
  if dp_padding_mode.is_max_len():
@@ -642,17 +643,14 @@ class ForwardBatch:
642
643
  else:
643
644
  buffer_len = sum(global_num_tokens)
644
645
 
645
- self.gathered_buffer = torch.zeros(
646
- (buffer_len, model_runner.model_config.hidden_size),
647
- dtype=model_runner.dtype,
648
- device=model_runner.device,
649
- )
650
-
651
646
  if len(global_num_tokens) > 1:
652
647
  num_tokens = global_num_tokens[get_attention_dp_rank()]
653
648
  else:
654
649
  num_tokens = global_num_tokens[0]
655
650
 
651
+ self.global_dp_buffer_len = buffer_len
652
+ set_dp_buffer_len(buffer_len, num_tokens)
653
+
656
654
  bs = self.batch_size
657
655
 
658
656
  if self.forward_mode.is_decode():
@@ -75,12 +75,12 @@ from sglang.srt.managers.schedule_batch import (
75
75
  global_server_args_dict,
76
76
  )
77
77
  from sglang.srt.mem_cache.allocator import (
78
- AscendPagedTokenToKVPoolAllocator,
79
78
  BaseTokenToKVPoolAllocator,
80
79
  PagedTokenToKVPoolAllocator,
81
80
  SWATokenToKVPoolAllocator,
82
81
  TokenToKVPoolAllocator,
83
82
  )
83
+ from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
84
84
  from sglang.srt.mem_cache.memory_pool import (
85
85
  AscendMLAPagedTokenToKVPool,
86
86
  AscendTokenToKVPool,
@@ -176,10 +176,6 @@ class ModelRunner:
176
176
  self.mem_fraction_static = mem_fraction_static
177
177
  self.device = server_args.device
178
178
  self.gpu_id = gpu_id
179
-
180
- # Apply the rank zero filter to logger
181
- if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
182
- logger.addFilter(RankZeroFilter(tp_rank == 0))
183
179
  self.tp_rank = tp_rank
184
180
  self.tp_size = tp_size
185
181
  self.moe_ep_rank = moe_ep_rank
@@ -205,15 +201,17 @@ class ModelRunner:
205
201
  self.is_hybrid = model_config.is_hybrid
206
202
  self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
207
203
  self.attention_chunk_size = model_config.attention_chunk_size
208
-
209
204
  self.forward_pass_id = 0
210
205
 
211
- # Model-specific adjustment
212
- self.model_specific_adjustment()
213
-
206
+ # Apply the rank zero filter to logger
207
+ if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
208
+ logger.addFilter(RankZeroFilter(tp_rank == 0))
214
209
  if server_args.show_time_cost:
215
210
  enable_show_time_cost()
216
211
 
212
+ # Model-specific adjustment
213
+ self.model_specific_adjustment()
214
+
217
215
  # Global vars
218
216
  global_server_args_dict.update(
219
217
  {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
@@ -221,8 +219,6 @@ class ModelRunner:
221
219
  # TODO it is indeed not a "server args"
222
220
  "use_mla_backend": self.use_mla_backend,
223
221
  "speculative_algorithm": self.spec_algorithm,
224
- }
225
- | {
226
222
  "moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
227
223
  "deepep_mode": DeepEPMode(server_args.deepep_mode),
228
224
  }
@@ -242,13 +238,15 @@ class ModelRunner:
242
238
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
243
239
  deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
244
240
 
245
- # If it is a draft model, tp_group can be different
241
+ # Initialize the model runner
246
242
  self.initialize(min_per_gpu_memory)
247
243
 
248
- # temporary cached values
244
+ # Temporary cached values
249
245
  self.support_pp = (
250
246
  "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
251
247
  )
248
+
249
+ # For weight updates
252
250
  self._model_update_group = {}
253
251
 
254
252
  def initialize(self, min_per_gpu_memory: float):
@@ -277,6 +275,7 @@ class ModelRunner:
277
275
  )
278
276
  )
279
277
 
278
+ # Expert parallelism
280
279
  self.eplb_manager = (
281
280
  EPLBManager(self)
282
281
  if self.server_args.enable_eplb and (not self.is_draft_worker)
@@ -604,12 +603,8 @@ class ModelRunner:
604
603
  duplicate_tp_group=self.server_args.enable_pdmux,
605
604
  )
606
605
  initialize_dp_attention(
607
- enable_dp_attention=self.server_args.enable_dp_attention,
608
- tp_rank=self.tp_rank,
609
- tp_size=self.tp_size,
610
- dp_size=self.server_args.dp_size,
611
- moe_dense_tp_size=self.server_args.moe_dense_tp_size,
612
- pp_size=self.server_args.pp_size,
606
+ server_args=self.server_args,
607
+ model_config=self.model_config,
613
608
  )
614
609
 
615
610
  min_per_gpu_memory = get_available_gpu_memory(
@@ -1160,6 +1155,7 @@ class ModelRunner:
1160
1155
  max_num_reqs: Optional[int] = None,
1161
1156
  max_total_tokens: Optional[int] = None,
1162
1157
  ):
1158
+ # Determine the kv cache dtype
1163
1159
  if self.server_args.kv_cache_dtype == "auto":
1164
1160
  self.kv_cache_dtype = self.dtype
1165
1161
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
@@ -1178,6 +1174,8 @@ class ModelRunner:
1178
1174
  )
1179
1175
 
1180
1176
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
1177
+ if SGLANG_CI_SMALL_KV_SIZE:
1178
+ self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
1181
1179
 
1182
1180
  if max_num_reqs is None:
1183
1181
  max_num_reqs = min(
@@ -1190,9 +1188,6 @@ class ModelRunner:
1190
1188
  4096,
1191
1189
  )
1192
1190
 
1193
- if SGLANG_CI_SMALL_KV_SIZE:
1194
- self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
1195
-
1196
1191
  if not self.spec_algorithm.is_none():
1197
1192
  if self.is_draft_worker:
1198
1193
  self.max_total_num_tokens = self.server_args.draft_runner_cache_size
@@ -1239,6 +1234,7 @@ class ModelRunner:
1239
1234
  "Not enough memory. Please try to increase --mem-fraction-static."
1240
1235
  )
1241
1236
 
1237
+ # Initialize req_to_token_pool
1242
1238
  if self.req_to_token_pool is None:
1243
1239
  if self.server_args.disaggregation_mode == "decode":
1244
1240
  from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
@@ -1264,6 +1260,7 @@ class ModelRunner:
1264
1260
  # Draft worker shares req_to_token_pool with the target worker.
1265
1261
  assert self.is_draft_worker
1266
1262
 
1263
+ # Initialize token_to_kv_pool
1267
1264
  if self.server_args.attention_backend == "ascend":
1268
1265
  if self.use_mla_backend:
1269
1266
  self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
@@ -1349,44 +1346,52 @@ class ModelRunner:
1349
1346
  end_layer=self.end_layer,
1350
1347
  )
1351
1348
 
1349
+ # Initialize token_to_kv_pool_allocator
1352
1350
  need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1351
+ max_num_extend_tokens = (
1352
+ self.server_args.chunked_prefill_size
1353
+ if self.server_args.chunked_prefill_size > 0
1354
+ else self.server_args.max_prefill_tokens
1355
+ )
1353
1356
  if self.token_to_kv_pool_allocator is None:
1354
- if self.page_size == 1:
1355
- if self.is_hybrid:
1356
- self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
1357
- self.full_max_total_num_tokens,
1358
- self.swa_max_total_num_tokens,
1359
- dtype=self.kv_cache_dtype,
1360
- device=self.device,
1361
- kvcache=self.token_to_kv_pool,
1362
- need_sort=need_sort,
1363
- )
1364
- else:
1365
- self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
1366
- self.max_total_num_tokens,
1367
- dtype=self.kv_cache_dtype,
1368
- device=self.device,
1369
- kvcache=self.token_to_kv_pool,
1370
- need_sort=need_sort,
1371
- )
1357
+ if self.server_args.attention_backend == "ascend":
1358
+ self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1359
+ self.max_total_num_tokens,
1360
+ page_size=self.page_size,
1361
+ dtype=self.kv_cache_dtype,
1362
+ device=self.device,
1363
+ kvcache=self.token_to_kv_pool,
1364
+ need_sort=need_sort,
1365
+ )
1372
1366
  else:
1373
- if not _is_npu:
1374
- self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1375
- self.max_total_num_tokens,
1376
- page_size=self.page_size,
1377
- dtype=self.kv_cache_dtype,
1378
- device=self.device,
1379
- kvcache=self.token_to_kv_pool,
1380
- need_sort=need_sort,
1381
- )
1367
+ if self.page_size == 1:
1368
+ if self.is_hybrid:
1369
+ self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
1370
+ self.full_max_total_num_tokens,
1371
+ self.swa_max_total_num_tokens,
1372
+ dtype=self.kv_cache_dtype,
1373
+ device=self.device,
1374
+ kvcache=self.token_to_kv_pool,
1375
+ need_sort=need_sort,
1376
+ )
1377
+ else:
1378
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
1379
+ self.max_total_num_tokens,
1380
+ dtype=self.kv_cache_dtype,
1381
+ device=self.device,
1382
+ kvcache=self.token_to_kv_pool,
1383
+ need_sort=need_sort,
1384
+ )
1382
1385
  else:
1383
- self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1386
+ assert not self.is_hybrid
1387
+ self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1384
1388
  self.max_total_num_tokens,
1385
1389
  page_size=self.page_size,
1386
1390
  dtype=self.kv_cache_dtype,
1387
1391
  device=self.device,
1388
1392
  kvcache=self.token_to_kv_pool,
1389
1393
  need_sort=need_sort,
1394
+ max_num_extend_tokens=max_num_extend_tokens,
1390
1395
  )
1391
1396
  else:
1392
1397
  assert self.is_draft_worker
@@ -1554,15 +1559,13 @@ class ModelRunner:
1554
1559
  )
1555
1560
 
1556
1561
  return TRTLLMHAAttnBackend(self)
1557
-
1558
1562
  elif backend_str == "intel_amx":
1559
1563
  from sglang.srt.layers.attention.intel_amx_backend import (
1560
1564
  IntelAMXAttnBackend,
1561
1565
  )
1562
1566
 
1563
- logger.info(f"Intel AMX attention backend is enabled.")
1564
1567
  return IntelAMXAttnBackend(self)
1565
- elif self.server_args.attention_backend == "dual_chunk_flash_attn":
1568
+ elif backend_str == "dual_chunk_flash_attn":
1566
1569
  from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
1567
1570
  DualChunkFlashAttentionBackend,
1568
1571
  )
@@ -1606,6 +1609,7 @@ class ModelRunner:
1606
1609
  f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1607
1610
  )
1608
1611
  self.cuda_graph_runner = CudaGraphRunner(self)
1612
+
1609
1613
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1610
1614
  self.cuda_graph_mem_usage = before_mem - after_mem
1611
1615
  logger.info(
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
25
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
25
26
  from sglang.srt.layers.layernorm import RMSNorm
26
27
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
28
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -56,7 +57,7 @@ class DeepseekModelNextN(nn.Module):
56
57
  self.embed_tokens = VocabParallelEmbedding(
57
58
  config.vocab_size,
58
59
  config.hidden_size,
59
- enable_tp=not global_server_args_dict["enable_dp_attention"],
60
+ enable_tp=not is_dp_attention_enabled(),
60
61
  prefix=add_prefix("embed_tokens", prefix),
61
62
  )
62
63
 
@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
51
51
  get_attention_tp_rank,
52
52
  get_attention_tp_size,
53
53
  get_local_attention_dp_size,
54
+ is_dp_attention_enabled,
54
55
  )
55
56
  from sglang.srt.layers.layernorm import RMSNorm
56
57
  from sglang.srt.layers.linear import (
@@ -1797,7 +1798,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1797
1798
  rope_theta = getattr(config, "rope_theta", 10000)
1798
1799
  rope_scaling = getattr(config, "rope_scaling", None)
1799
1800
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1800
- self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1801
1801
  self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
1802
1802
  self.layer_id = layer_id
1803
1803
  self.is_nextn = is_nextn
@@ -1917,7 +1917,9 @@ class DeepseekV2DecoderLayer(nn.Module):
1917
1917
 
1918
1918
  should_allreduce_fusion = (
1919
1919
  self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
1920
- and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
1920
+ and not (
1921
+ is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
1922
+ )
1921
1923
  and not self.is_nextn
1922
1924
  )
1923
1925
 
@@ -2047,7 +2049,7 @@ class DeepseekV2Model(nn.Module):
2047
2049
  self.embed_tokens = VocabParallelEmbedding(
2048
2050
  config.vocab_size,
2049
2051
  config.hidden_size,
2050
- enable_tp=not global_server_args_dict["enable_dp_attention"],
2052
+ enable_tp=not is_dp_attention_enabled(),
2051
2053
  )
2052
2054
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
2053
2055
  self.layers = nn.ModuleList(
@@ -40,6 +40,7 @@ from sglang.srt.layers.dp_attention import (
40
40
  get_attention_tp_rank,
41
41
  get_attention_tp_size,
42
42
  get_local_attention_dp_size,
43
+ is_dp_attention_enabled,
43
44
  )
44
45
  from sglang.srt.layers.layernorm import RMSNorm
45
46
  from sglang.srt.layers.linear import (
@@ -634,7 +635,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
634
635
  )
635
636
  rms_norm_eps = config.rms_norm_eps
636
637
  attention_bias = config.attention_bias
637
- self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
638
638
  self.layer_id = layer_id
639
639
  self.self_attn = Glm4MoeAttention(
640
640
  hidden_size=self.hidden_size,
@@ -744,7 +744,7 @@ class Glm4MoeModel(DeepseekV2Model):
744
744
  self.embed_tokens = VocabParallelEmbedding(
745
745
  config.vocab_size,
746
746
  config.hidden_size,
747
- enable_tp=not global_server_args_dict["enable_dp_attention"],
747
+ enable_tp=not is_dp_attention_enabled(),
748
748
  )
749
749
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
750
750
  self.layers = nn.ModuleList(
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
25
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
25
26
  from sglang.srt.layers.layernorm import RMSNorm
26
27
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
28
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module):
56
57
  self.embed_tokens = VocabParallelEmbedding(
57
58
  config.vocab_size,
58
59
  config.hidden_size,
59
- enable_tp=not global_server_args_dict["enable_dp_attention"],
60
+ enable_tp=not is_dp_attention_enabled(),
60
61
  prefix=add_prefix("embed_tokens", prefix),
61
62
  )
62
63
 
@@ -41,6 +41,7 @@ from sglang.srt.layers.dp_attention import (
41
41
  get_attention_tp_rank,
42
42
  get_attention_tp_size,
43
43
  get_local_attention_dp_size,
44
+ is_dp_attention_enabled,
44
45
  )
45
46
  from sglang.srt.layers.layernorm import RMSNorm
46
47
  from sglang.srt.layers.linear import (
@@ -293,8 +294,12 @@ class GptOssAttention(nn.Module):
293
294
  prefix=add_prefix("qkv_proj", prefix),
294
295
  )
295
296
 
297
+ # Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
298
+ # others can use bfloat16
299
+ attn_backend = global_server_args_dict.get("attention_backend")
300
+ sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
296
301
  self.sinks = nn.Parameter(
297
- torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
302
+ torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
298
303
  )
299
304
 
300
305
  self.o_proj = RowParallelLinear(
@@ -561,7 +566,7 @@ class GptOssModel(nn.Module):
561
566
  self.embed_tokens = VocabParallelEmbedding(
562
567
  config.vocab_size,
563
568
  config.hidden_size,
564
- enable_tp=not global_server_args_dict["enable_dp_attention"],
569
+ enable_tp=not is_dp_attention_enabled(),
565
570
  prefix=add_prefix("embed_tokens", prefix),
566
571
  )
567
572
  else:
@@ -91,10 +91,18 @@ class LlamaMLP(nn.Module):
91
91
  )
92
92
  self.act_fn = SiluAndMul()
93
93
 
94
- def forward(self, x, forward_batch=None):
94
+ def forward(
95
+ self,
96
+ x,
97
+ forward_batch=None,
98
+ use_reduce_scatter: bool = False,
99
+ ):
95
100
  gate_up, _ = self.gate_up_proj(x)
96
101
  x = self.act_fn(gate_up)
97
- x, _ = self.down_proj(x)
102
+ x, _ = self.down_proj(
103
+ x,
104
+ skip_all_reduce=use_reduce_scatter,
105
+ )
98
106
  return x
99
107
 
100
108
 
@@ -32,6 +32,7 @@ from sglang.srt.layers.dp_attention import (
32
32
  get_attention_tp_rank,
33
33
  get_attention_tp_size,
34
34
  get_local_attention_dp_size,
35
+ is_dp_attention_enabled,
35
36
  )
36
37
  from sglang.srt.layers.layernorm import RMSNorm
37
38
  from sglang.srt.layers.linear import (
@@ -45,7 +46,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
46
  from sglang.srt.layers.radix_attention import RadixAttention
46
47
  from sglang.srt.layers.rotary_embedding import get_rope
47
48
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
- from sglang.srt.managers.schedule_batch import global_server_args_dict
49
49
  from sglang.srt.model_executor.forward_batch_info import (
50
50
  ForwardBatch,
51
51
  ForwardMode,
@@ -131,14 +131,19 @@ class Llama4MoE(nn.Module):
131
131
  reduce_results=False, # We need to do scatter before reduce
132
132
  )
133
133
 
134
- def forward(self, hidden_states, forward_batch: ForwardBatch):
134
+ def forward(
135
+ self,
136
+ hidden_states,
137
+ forward_batch: ForwardBatch,
138
+ use_reduce_scatter: bool = False,
139
+ ):
135
140
  shared_out, routed_out = self._forward_core(
136
141
  hidden_states, forward_batch.forward_mode
137
142
  )
138
143
 
139
144
  out_aD = routed_out + shared_out
140
145
 
141
- if self.tp_size > 1:
146
+ if self.tp_size > 1 and not use_reduce_scatter:
142
147
  out_aD = tensor_model_parallel_all_reduce(out_aD)
143
148
 
144
149
  return out_aD
@@ -412,6 +417,7 @@ class Llama4DecoderLayer(nn.Module):
412
417
  layer_scatter_modes=self.layer_scatter_modes,
413
418
  input_layernorm=self.input_layernorm,
414
419
  post_attention_layernorm=self.post_attention_layernorm,
420
+ allow_reduce_scatter=True,
415
421
  )
416
422
 
417
423
  def _is_moe_layer(self, layer_id: int) -> bool:
@@ -441,8 +447,15 @@ class Llama4DecoderLayer(nn.Module):
441
447
  hidden_states, residual, forward_batch
442
448
  )
443
449
 
450
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
451
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
452
+ forward_batch
453
+ )
454
+
444
455
  # Fully Connected
445
- hidden_states = self.feed_forward(hidden_states, forward_batch)
456
+ hidden_states = self.feed_forward(
457
+ hidden_states, forward_batch, use_reduce_scatter
458
+ )
446
459
  hidden_states, residual = self.layer_communicator.postprocess_layer(
447
460
  hidden_states, residual, forward_batch
448
461
  )
@@ -466,7 +479,7 @@ class Llama4Model(nn.Module):
466
479
  config.hidden_size,
467
480
  quant_config=quant_config,
468
481
  prefix=add_prefix("embed_tokens", prefix),
469
- enable_tp=not global_server_args_dict["enable_dp_attention"],
482
+ enable_tp=not is_dp_attention_enabled(),
470
483
  )
471
484
  self.layers = make_layers(
472
485
  config.num_hidden_layers,
@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
27
27
  get_tensor_model_parallel_world_size,
28
28
  )
29
29
  from sglang.srt.layers.activation import SiluAndMul
30
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
30
31
  from sglang.srt.layers.layernorm import RMSNorm
31
32
  from sglang.srt.layers.linear import (
32
33
  MergedColumnParallelLinear,
@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
44
  ParallelLMHead,
44
45
  VocabParallelEmbedding,
45
46
  )
46
- from sglang.srt.managers.schedule_batch import global_server_args_dict
47
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
48
48
  from sglang.srt.model_loader.weight_utils import (
49
49
  default_weight_loader,
@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module):
273
273
  config.vocab_size,
274
274
  config.hidden_size,
275
275
  quant_config=quant_config,
276
- enable_tp=not global_server_args_dict["enable_dp_attention"],
276
+ enable_tp=not is_dp_attention_enabled(),
277
277
  prefix=add_prefix("embed_tokens", prefix),
278
278
  )
279
279
  else: