sglang 0.5.0rc1__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. sglang/bench_one_batch.py +0 -1
  2. sglang/srt/configs/model_config.py +1 -0
  3. sglang/srt/disaggregation/decode.py +0 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/entrypoints/http_server.py +64 -0
  6. sglang/srt/entrypoints/openai/protocol.py +2 -0
  7. sglang/srt/entrypoints/openai/serving_chat.py +1 -0
  8. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  9. sglang/srt/layers/attention/flashinfer_backend.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  11. sglang/srt/layers/attention/triton_backend.py +24 -27
  12. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  13. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -3
  14. sglang/srt/layers/communicator.py +7 -7
  15. sglang/srt/layers/dp_attention.py +118 -27
  16. sglang/srt/layers/logits_processor.py +12 -18
  17. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/multimodal.py +156 -40
  29. sglang/srt/layers/quantization/__init__.py +5 -32
  30. sglang/srt/layers/quantization/awq.py +15 -16
  31. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  32. sglang/srt/layers/quantization/gptq.py +12 -17
  33. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  34. sglang/srt/layers/quantization/modelopt_quant.py +52 -30
  35. sglang/srt/layers/quantization/mxfp4.py +16 -2
  36. sglang/srt/layers/quantization/utils.py +52 -2
  37. sglang/srt/layers/sampler.py +5 -2
  38. sglang/srt/lora/layers.py +6 -2
  39. sglang/srt/managers/cache_controller.py +4 -1
  40. sglang/srt/managers/io_struct.py +14 -0
  41. sglang/srt/managers/schedule_batch.py +18 -39
  42. sglang/srt/managers/scheduler.py +3 -4
  43. sglang/srt/managers/tokenizer_manager.py +28 -18
  44. sglang/srt/mem_cache/allocator.py +8 -157
  45. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  46. sglang/srt/mem_cache/chunk_cache.py +1 -1
  47. sglang/srt/model_executor/cuda_graph_runner.py +8 -21
  48. sglang/srt/model_executor/forward_batch_info.py +8 -10
  49. sglang/srt/model_executor/model_runner.py +57 -53
  50. sglang/srt/models/deepseek_nextn.py +2 -1
  51. sglang/srt/models/deepseek_v2.py +5 -3
  52. sglang/srt/models/glm4_moe.py +2 -2
  53. sglang/srt/models/glm4_moe_nextn.py +2 -1
  54. sglang/srt/models/gpt_oss.py +7 -2
  55. sglang/srt/models/llama.py +10 -2
  56. sglang/srt/models/llama4.py +18 -5
  57. sglang/srt/models/qwen2.py +2 -2
  58. sglang/srt/models/qwen2_moe.py +20 -5
  59. sglang/srt/models/qwen3_classification.py +78 -0
  60. sglang/srt/models/qwen3_moe.py +18 -5
  61. sglang/srt/models/step3_vl.py +6 -2
  62. sglang/srt/operations.py +17 -2
  63. sglang/srt/sampling/sampling_batch_info.py +7 -4
  64. sglang/srt/server_args.py +33 -7
  65. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  66. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  67. sglang/srt/two_batch_overlap.py +4 -8
  68. sglang/test/test_marlin_moe.py +1 -1
  69. sglang/test/test_marlin_utils.py +1 -1
  70. sglang/version.py +1 -1
  71. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +5 -5
  72. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +75 -63
  73. sglang/srt/layers/quantization/scalar_type.py +0 -352
  74. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  75. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  76. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -267,7 +267,6 @@ def extend(reqs, model_runner):
267
267
  model_config=model_runner.model_config,
268
268
  enable_overlap=False,
269
269
  spec_algorithm=SpeculativeAlgorithm.NONE,
270
- enable_custom_logit_processor=False,
271
270
  )
272
271
  batch.prepare_for_extend()
273
272
  _maybe_prepare_mlp_sync_batch(batch, model_runner)
@@ -642,6 +642,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
642
642
  or "InternLM2ForRewardModel" in model_architectures
643
643
  or "Qwen2ForRewardModel" in model_architectures
644
644
  or "Qwen2ForSequenceClassification" in model_architectures
645
+ or "Qwen3ForSequenceClassification" in model_architectures
645
646
  or "CLIPModel" in model_architectures
646
647
  or "BertModel" in model_architectures
647
648
  or "Contriever" in model_architectures
@@ -864,7 +864,6 @@ class SchedulerDisaggregationDecodeMixin:
864
864
  self.model_config,
865
865
  self.enable_overlap,
866
866
  self.spec_algorithm,
867
- self.server_args.enable_custom_logit_processor,
868
867
  )
869
868
 
870
869
  # construct fake completed prefill
@@ -647,7 +647,7 @@ def _set_envs_and_config(server_args: ServerArgs):
647
647
  if server_args.attention_backend == "flashinfer":
648
648
  assert_pkg_version(
649
649
  "flashinfer_python",
650
- "0.2.11.post1",
650
+ "0.2.11.post3",
651
651
  "Please uninstall the old version and "
652
652
  "reinstall the latest version by following the instructions "
653
653
  "at https://docs.flashinfer.ai/installation.html.",
@@ -655,7 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs):
655
655
  if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
656
656
  assert_pkg_version(
657
657
  "sgl-kernel",
658
- "0.3.4",
658
+ "0.3.5",
659
659
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
660
660
  )
661
661
 
@@ -88,6 +88,7 @@ from sglang.srt.managers.io_struct import (
88
88
  UpdateWeightFromDiskReqInput,
89
89
  UpdateWeightsFromDistributedReqInput,
90
90
  UpdateWeightsFromTensorReqInput,
91
+ UpdateWeightVersionReqInput,
91
92
  VertexGenerateReqInput,
92
93
  )
93
94
  from sglang.srt.managers.template_manager import TemplateManager
@@ -342,10 +343,19 @@ async def get_model_info():
342
343
  "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
343
344
  "is_generation": _global_state.tokenizer_manager.is_generation,
344
345
  "preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
346
+ "weight_version": _global_state.tokenizer_manager.server_args.weight_version,
345
347
  }
346
348
  return result
347
349
 
348
350
 
351
+ @app.get("/get_weight_version")
352
+ async def get_weight_version():
353
+ """Get the current weight version."""
354
+ return {
355
+ "weight_version": _global_state.tokenizer_manager.server_args.weight_version
356
+ }
357
+
358
+
349
359
  @app.get("/get_server_info")
350
360
  async def get_server_info():
351
361
  # Returns interna states per DP.
@@ -537,6 +547,12 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
537
547
  success, message, num_paused_requests = (
538
548
  await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
539
549
  )
550
+
551
+ # Update weight version if provided and weights update was successful
552
+ if success and obj.weight_version is not None:
553
+ _update_weight_version_if_provided(obj.weight_version)
554
+ message += f" Weight version updated to {obj.weight_version}."
555
+
540
556
  content = {
541
557
  "success": success,
542
558
  "message": message,
@@ -583,6 +599,12 @@ async def update_weights_from_tensor(
583
599
  success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(
584
600
  obj, request
585
601
  )
602
+
603
+ # Update weight version if provided and weights update was successful
604
+ if success and obj.weight_version is not None:
605
+ _update_weight_version_if_provided(obj.weight_version)
606
+ message += f" Weight version updated to {obj.weight_version}."
607
+
586
608
  content = {"success": success, "message": message}
587
609
  return ORJSONResponse(
588
610
  content, status_code=200 if success else HTTPStatus.BAD_REQUEST
@@ -599,6 +621,12 @@ async def update_weights_from_distributed(
599
621
  obj, request
600
622
  )
601
623
  )
624
+
625
+ # Update weight version if provided and weights update was successful
626
+ if success and obj.weight_version is not None:
627
+ _update_weight_version_if_provided(obj.weight_version)
628
+ message += f" Weight version updated to {obj.weight_version}."
629
+
602
630
  content = {"success": success, "message": message}
603
631
  if success:
604
632
  return ORJSONResponse(content, status_code=200)
@@ -606,6 +634,36 @@ async def update_weights_from_distributed(
606
634
  return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
607
635
 
608
636
 
637
+ @app.post("/update_weight_version")
638
+ async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
639
+ """Update the weight version. This operation requires no active requests."""
640
+ if obj.abort_all_requests:
641
+ _global_state.tokenizer_manager.abort_request(abort_all=True)
642
+
643
+ # Use a simple approach without the complex lock mechanism for now
644
+ # since weight_version update is a simple operation that doesn't affect model weights
645
+ try:
646
+ # Update the weight version in server args (the single source of truth)
647
+ _global_state.tokenizer_manager.server_args.weight_version = obj.new_version
648
+
649
+ return ORJSONResponse(
650
+ {
651
+ "success": True,
652
+ "message": f"Weight version updated to {obj.new_version}",
653
+ "new_version": obj.new_version,
654
+ },
655
+ status_code=HTTPStatus.OK,
656
+ )
657
+ except Exception as e:
658
+ return ORJSONResponse(
659
+ {
660
+ "success": False,
661
+ "message": f"Failed to update weight version: {str(e)}",
662
+ },
663
+ status_code=HTTPStatus.BAD_REQUEST,
664
+ )
665
+
666
+
609
667
  @app.api_route("/get_weights_by_name", methods=["GET", "POST"])
610
668
  async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
611
669
  """Get model parameter by name."""
@@ -966,6 +1024,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
966
1024
  return ORJSONResponse({"predictions": ret})
967
1025
 
968
1026
 
1027
+ def _update_weight_version_if_provided(weight_version: Optional[str]) -> None:
1028
+ """Update weight version if provided."""
1029
+ if weight_version is not None:
1030
+ _global_state.tokenizer_manager.server_args.weight_version = weight_version
1031
+
1032
+
969
1033
  def _create_error_response(e):
970
1034
  return ORJSONResponse(
971
1035
  {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
@@ -240,6 +240,7 @@ class CompletionResponse(BaseModel):
240
240
  model: str
241
241
  choices: List[CompletionResponseChoice]
242
242
  usage: UsageInfo
243
+ metadata: Optional[Dict[str, Any]] = None
243
244
 
244
245
 
245
246
  class CompletionResponseStreamChoice(BaseModel):
@@ -517,6 +518,7 @@ class ChatCompletionResponse(BaseModel):
517
518
  model: str
518
519
  choices: List[ChatCompletionResponseChoice]
519
520
  usage: UsageInfo
521
+ metadata: Optional[Dict[str, Any]] = None
520
522
 
521
523
 
522
524
  class DeltaMessage(BaseModel):
@@ -723,6 +723,7 @@ class OpenAIServingChat(OpenAIServingBase):
723
723
  model=request.model,
724
724
  choices=choices,
725
725
  usage=usage,
726
+ metadata={"weight_version": ret[0]["meta_info"]["weight_version"]},
726
727
  )
727
728
 
728
729
  def _process_logprobs_tokens(
@@ -373,6 +373,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
373
373
  created=created,
374
374
  choices=choices,
375
375
  usage=usage,
376
+ metadata={"weight_version": ret[0]["meta_info"]["weight_version"]},
376
377
  )
377
378
 
378
379
  def _get_echo_text(self, request: CompletionRequest, index: int) -> str:
@@ -122,6 +122,7 @@ class FlashInferAttnBackend(AttentionBackend):
122
122
  # Allocate buffers
123
123
  global global_workspace_buffer
124
124
  if global_workspace_buffer is None:
125
+ # different from flashinfer zero_init_global_workspace_buffer
125
126
  global_workspace_buffer = torch.empty(
126
127
  global_config.flashinfer_workspace_size,
127
128
  dtype=torch.uint8,
@@ -870,6 +871,8 @@ class FlashInferIndicesUpdaterPrefill:
870
871
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
871
872
  ):
872
873
  if use_ragged:
874
+ # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
875
+ # and forward_batch.extend_seq_lens_cpu
873
876
  paged_kernel_lens = prefix_lens
874
877
  paged_kernel_lens_sum = paged_kernel_lens.sum().item()
875
878
  else:
@@ -81,6 +81,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
81
81
  # Allocate buffers
82
82
  global global_workspace_buffer
83
83
  if global_workspace_buffer is None:
84
+ # different from flashinfer zero_init_global_workspace_buffer
84
85
  global_workspace_buffer = torch.empty(
85
86
  global_config.flashinfer_workspace_size,
86
87
  dtype=torch.uint8,
@@ -57,16 +57,36 @@ class TritonAttnBackend(AttentionBackend):
57
57
  self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
58
58
  self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
59
59
 
60
+ # Parse args
60
61
  self.skip_prefill = skip_prefill
61
-
62
62
  max_bs = model_runner.req_to_token_pool.size
63
+ self.sliding_window_size = model_runner.sliding_window_size
64
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
65
+ self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
66
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
67
+ self.speculative_num_steps = model_runner.server_args.speculative_num_steps
68
+ self.num_head = (
69
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
70
+ )
71
+ self.num_kv_head = model_runner.model_config.get_num_kv_heads(
72
+ get_attention_tp_size()
73
+ )
74
+ self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
75
+ self.max_context_len = model_runner.model_config.context_len
76
+ self.device = model_runner.device
77
+ self.device_core_count = get_device_core_count(model_runner.gpu_id)
78
+ self.static_kv_splits = get_bool_env_var(
79
+ "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
80
+ )
81
+ self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
63
82
 
83
+ # Check arguments
64
84
  assert not (
65
85
  model_runner.sliding_window_size is not None
66
86
  and model_runner.model_config.is_encoder_decoder
67
87
  ), "Sliding window and cross attention are not supported together"
68
- self.sliding_window_size = model_runner.sliding_window_size
69
88
 
89
+ # Initialize buffers
70
90
  # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
71
91
  if kv_indptr_buf is None:
72
92
  self.kv_indptr = torch.zeros(
@@ -87,9 +107,6 @@ class TritonAttnBackend(AttentionBackend):
87
107
  # When provided a buffer, create a clone for the second buffer
88
108
  self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
89
109
 
90
- self.req_to_token = model_runner.req_to_token_pool.req_to_token
91
- self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
92
-
93
110
  if not self.skip_prefill:
94
111
  self.qo_indptr = torch.zeros(
95
112
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
@@ -99,29 +116,9 @@ class TritonAttnBackend(AttentionBackend):
99
116
  (max_bs + 1,), dtype=torch.int64, device=model_runner.device
100
117
  )
101
118
 
102
- self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
103
- self.speculative_num_steps = model_runner.server_args.speculative_num_steps
104
-
105
- self.num_head = (
106
- model_runner.model_config.num_attention_heads // get_attention_tp_size()
107
- )
108
- self.num_kv_head = model_runner.model_config.get_num_kv_heads(
109
- get_attention_tp_size()
110
- )
111
-
112
- self.static_kv_splits = get_bool_env_var(
113
- "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
114
- )
115
- self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
116
- self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
117
-
119
+ # Initialize forward metadata
118
120
  self.forward_metadata: ForwardMetadata = None
119
121
 
120
- self.max_context_len = model_runner.model_config.context_len
121
-
122
- self.device = model_runner.device
123
- self.device_core_count = get_device_core_count(model_runner.gpu_id)
124
-
125
122
  def get_num_kv_splits(
126
123
  self,
127
124
  num_kv_splits: torch.Tensor,
@@ -333,7 +330,7 @@ class TritonAttnBackend(AttentionBackend):
333
330
  mask_indptr = None
334
331
  attn_logits = None
335
332
  attn_lse = None
336
- max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
333
+ max_extend_len = max(forward_batch.extend_seq_lens_cpu)
337
334
  num_kv_splits = None
338
335
 
339
336
  self.forward_metadata = ForwardMetadata(
@@ -23,10 +23,12 @@ if TYPE_CHECKING:
23
23
  from sglang.srt.speculative.spec_info import SpecInfo
24
24
 
25
25
  # Constants
26
- DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
26
+ DEFAULT_WORKSPACE_SIZE_MB = (
27
+ 512 # Memory workspace size in MB, todo(Yingyi): read from config
28
+ )
27
29
 
28
30
  # Reuse this workspace buffer across all TRTLLM MHA wrappers
29
- global_workspace_buffer = None
31
+ global_zero_init_workspace_buffer = None
30
32
 
31
33
 
32
34
  @dataclass
@@ -73,14 +75,14 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
73
75
  # Workspace allocation
74
76
  self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
75
77
  # Allocate buffers
76
- global global_workspace_buffer
77
- if global_workspace_buffer is None:
78
- global_workspace_buffer = torch.empty(
78
+ global global_zero_init_workspace_buffer
79
+ if global_zero_init_workspace_buffer is None:
80
+ global_zero_init_workspace_buffer = torch.zeros(
79
81
  self.workspace_size,
80
82
  dtype=torch.uint8,
81
83
  device=model_runner.device,
82
84
  )
83
- self.workspace_buffer = global_workspace_buffer
85
+ self.workspace_buffer = global_zero_init_workspace_buffer
84
86
 
85
87
  # CUDA graph state
86
88
  self.decode_cuda_graph_metadata = {}
@@ -39,6 +39,8 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
39
39
  # compute the LCM with other padding constraints.
40
40
  TRTLLM_BLOCK_CONSTRAINT = 128
41
41
 
42
+ global_zero_init_workspace_buffer = None
43
+
42
44
 
43
45
  @dataclass
44
46
  class TRTLLMMLADecodeMetadata:
@@ -83,9 +85,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
83
85
 
84
86
  # Workspace allocation
85
87
  self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
86
- self.workspace_buffer = torch.empty(
87
- self.workspace_size, dtype=torch.int8, device=self.device
88
- )
88
+ global global_zero_init_workspace_buffer
89
+ if global_zero_init_workspace_buffer is None:
90
+ global_zero_init_workspace_buffer = torch.zeros(
91
+ self.workspace_size,
92
+ dtype=torch.uint8,
93
+ device=model_runner.device,
94
+ )
95
+ self.workspace_buffer = global_zero_init_workspace_buffer
89
96
 
90
97
  # CUDA graph state
91
98
  self.decode_cuda_graph_metadata = {}
@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
32
32
  get_attention_dp_size,
33
33
  get_attention_tp_rank,
34
34
  get_attention_tp_size,
35
+ get_global_dp_buffer,
36
+ get_local_dp_buffer,
35
37
  )
36
38
  from sglang.srt.layers.utils import is_sm100_supported
37
39
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
319
321
  context: CommunicateContext,
320
322
  ) -> torch.Tensor:
321
323
  hidden_states, local_hidden_states = (
322
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
324
+ get_local_dp_buffer(),
323
325
  hidden_states,
324
326
  )
325
327
  attn_tp_all_gather_into_tensor(
@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
408
410
  ):
409
411
  if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
410
412
  residual, local_residual = (
411
- torch.empty_like(
412
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
413
- ),
413
+ get_local_dp_buffer(),
414
414
  residual,
415
415
  )
416
416
  attn_tp_all_gather_into_tensor(residual, local_residual)
@@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
424
424
  residual = hidden_states
425
425
  hidden_states = layernorm(hidden_states)
426
426
  hidden_states, local_hidden_states = (
427
- torch.empty_like(forward_batch.gathered_buffer),
427
+ get_global_dp_buffer(),
428
428
  hidden_states,
429
429
  )
430
430
  dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
@@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn:
548
548
  allow_reduce_scatter: bool = False,
549
549
  ):
550
550
  hidden_states, global_hidden_states = (
551
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
551
+ get_local_dp_buffer(),
552
552
  hidden_states,
553
553
  )
554
554
  if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
@@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn:
569
569
  hidden_states += residual
570
570
  residual = None
571
571
  hidden_states, local_hidden_states = (
572
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
572
+ get_local_dp_buffer(),
573
573
  hidden_states,
574
574
  )
575
575
  attn_tp_all_gather_into_tensor(
@@ -4,7 +4,7 @@ import functools
4
4
  import logging
5
5
  from contextlib import contextmanager
6
6
  from enum import IntEnum, auto
7
- from typing import TYPE_CHECKING, List, Tuple
7
+ from typing import TYPE_CHECKING, List, Optional, Tuple
8
8
 
9
9
  import torch
10
10
  import triton
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
20
 
21
+ if TYPE_CHECKING:
22
+ from sglang.srt.configs.model_config import ModelConfig
23
+ from sglang.srt.server_args import ServerArgs
24
+
21
25
  logger = logging.getLogger(__name__)
22
26
 
23
27
  if TYPE_CHECKING:
24
28
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
25
29
 
26
- _ATTN_TP_GROUP = None
27
- _ATTN_TP_RANK = None
28
- _ATTN_TP_SIZE = None
29
- _ATTN_DP_RANK = None
30
- _ATTN_DP_SIZE = None
31
- _LOCAL_ATTN_DP_SIZE = None
32
- _LOCAL_ATTN_DP_RANK = None
30
+ _ATTN_TP_GROUP: Optional[GroupCoordinator] = None
31
+ _ATTN_TP_RANK: Optional[int] = None
32
+ _ATTN_TP_SIZE: Optional[int] = None
33
+ _ATTN_DP_RANK: Optional[int] = None
34
+ _ATTN_DP_SIZE: Optional[int] = None
35
+ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
36
+ _LOCAL_ATTN_DP_RANK: Optional[int] = None
37
+ _ENABLE_DP_ATTENTION_FLAG: bool = False
33
38
 
34
39
 
35
- class DPPaddingMode(IntEnum):
40
+ class DpPaddingMode(IntEnum):
36
41
 
37
42
  # Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
38
43
  MAX_LEN = auto()
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
40
45
  SUM_LEN = auto()
41
46
 
42
47
  def is_max_len(self):
43
- return self == DPPaddingMode.MAX_LEN
48
+ return self == DpPaddingMode.MAX_LEN
44
49
 
45
50
  def is_sum_len(self):
46
- return self == DPPaddingMode.SUM_LEN
51
+ return self == DpPaddingMode.SUM_LEN
47
52
 
48
53
  @classmethod
49
- def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
54
+ def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
50
55
  # we choose the mode that minimizes the communication cost
51
56
  max_len = max(global_num_tokens)
52
57
  sum_len = sum(global_num_tokens)
@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
56
61
  return cls.SUM_LEN
57
62
 
58
63
  @classmethod
59
- def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
64
+ def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
60
65
  return cls.MAX_LEN
61
66
 
62
67
 
68
+ class _DpGatheredBufferWrapper:
69
+
70
+ _hidden_size: int
71
+ _dtype: torch.dtype
72
+ _device: torch.device
73
+ _global_dp_buffer_len: int
74
+ _local_dp_buffer_len: int
75
+
76
+ @classmethod
77
+ def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
78
+ cls._hidden_size = hidden_size
79
+ cls._dtype = dtype
80
+ cls._device = device
81
+
82
+ @classmethod
83
+ def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
84
+ cls._global_dp_buffer_len = global_dp_buffer_len
85
+ cls._local_dp_buffer_len = local_dp_buffer_len
86
+
87
+ @classmethod
88
+ def get_global_dp_buffer(cls) -> torch.Tensor:
89
+ return torch.empty(
90
+ (cls._global_dp_buffer_len, cls._hidden_size),
91
+ dtype=cls._dtype,
92
+ device=cls._device,
93
+ )
94
+
95
+ @classmethod
96
+ def get_local_dp_buffer(cls) -> torch.Tensor:
97
+ return torch.empty(
98
+ (cls._local_dp_buffer_len, cls._hidden_size),
99
+ dtype=cls._dtype,
100
+ device=cls._device,
101
+ )
102
+
103
+ @classmethod
104
+ def get_global_dp_buffer_len(cls) -> int:
105
+ return cls._global_dp_buffer_len
106
+
107
+ @classmethod
108
+ def get_local_dp_buffer_len(cls) -> int:
109
+ return cls._local_dp_buffer_len
110
+
111
+
112
+ def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
113
+ _DpGatheredBufferWrapper.set_dp_buffer_len(
114
+ global_dp_buffer_len, local_dp_buffer_len
115
+ )
116
+
117
+
118
+ def get_global_dp_buffer() -> torch.Tensor:
119
+ return _DpGatheredBufferWrapper.get_global_dp_buffer()
120
+
121
+
122
+ def get_local_dp_buffer() -> torch.Tensor:
123
+ return _DpGatheredBufferWrapper.get_local_dp_buffer()
124
+
125
+
126
+ def get_global_dp_buffer_len() -> int:
127
+ return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
128
+
129
+
130
+ def get_local_dp_buffer_len() -> int:
131
+ return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
132
+
133
+
63
134
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
64
135
  if not enable_dp_attention:
65
136
  return tp_rank, tp_size, 0
@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
89
160
 
90
161
 
91
162
  def initialize_dp_attention(
92
- enable_dp_attention: bool,
93
- tp_rank: int,
94
- tp_size: int,
95
- dp_size: int,
96
- moe_dense_tp_size: int,
97
- pp_size: int,
163
+ server_args: ServerArgs,
164
+ model_config: ModelConfig,
98
165
  ):
99
166
  global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
100
- global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
167
+ global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
101
168
 
102
169
  from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
103
170
 
171
+ enable_dp_attention = server_args.enable_dp_attention
172
+ tp_size = server_args.tp_size
173
+ dp_size = server_args.dp_size
174
+ moe_dense_tp_size = server_args.moe_dense_tp_size
175
+ pp_size = server_args.pp_size
176
+
177
+ tp_rank = get_tensor_model_parallel_rank()
178
+
179
+ _ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
180
+
104
181
  _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
105
182
  enable_dp_attention, tp_rank, tp_size, dp_size
106
183
  )
@@ -135,38 +212,48 @@ def initialize_dp_attention(
135
212
  group_name="attention_tp",
136
213
  )
137
214
 
215
+ _DpGatheredBufferWrapper.set_metadata(
216
+ hidden_size=model_config.hidden_size,
217
+ dtype=model_config.dtype,
218
+ device=torch.device("cuda"),
219
+ )
138
220
 
139
- def get_attention_tp_group():
221
+
222
+ def is_dp_attention_enabled() -> bool:
223
+ return _ENABLE_DP_ATTENTION_FLAG
224
+
225
+
226
+ def get_attention_tp_group() -> GroupCoordinator:
140
227
  assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
141
228
  return _ATTN_TP_GROUP
142
229
 
143
230
 
144
- def get_attention_tp_rank():
231
+ def get_attention_tp_rank() -> int:
145
232
  assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
146
233
  return _ATTN_TP_RANK
147
234
 
148
235
 
149
- def get_attention_tp_size():
236
+ def get_attention_tp_size() -> int:
150
237
  assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
151
238
  return _ATTN_TP_SIZE
152
239
 
153
240
 
154
- def get_attention_dp_rank():
241
+ def get_attention_dp_rank() -> int:
155
242
  assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
156
243
  return _ATTN_DP_RANK
157
244
 
158
245
 
159
- def get_attention_dp_size():
246
+ def get_attention_dp_size() -> int:
160
247
  assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
161
248
  return _ATTN_DP_SIZE
162
249
 
163
250
 
164
- def get_local_attention_dp_rank():
251
+ def get_local_attention_dp_rank() -> int:
165
252
  assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
166
253
  return _LOCAL_ATTN_DP_RANK
167
254
 
168
255
 
169
- def get_local_attention_dp_size():
256
+ def get_local_attention_dp_size() -> int:
170
257
  assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
171
258
  return _LOCAL_ATTN_DP_SIZE
172
259
 
@@ -292,6 +379,10 @@ def _dp_gather_via_all_gather(
292
379
  forward_batch: ForwardBatch,
293
380
  is_partial: bool,
294
381
  ):
382
+ if get_attention_tp_size() == 1:
383
+ get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
384
+ return
385
+
295
386
  if not is_partial:
296
387
  if get_attention_tp_rank() != 0:
297
388
  local_tokens.fill_(0)