sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +9 -7
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +1 -0
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +48 -43
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +7 -2
  20. sglang/srt/disaggregation/fake/conn.py +1 -1
  21. sglang/srt/disaggregation/mooncake/conn.py +227 -120
  22. sglang/srt/disaggregation/nixl/conn.py +1 -0
  23. sglang/srt/disaggregation/prefill.py +7 -4
  24. sglang/srt/disaggregation/utils.py +7 -1
  25. sglang/srt/entrypoints/engine.py +17 -2
  26. sglang/srt/entrypoints/http_server.py +17 -2
  27. sglang/srt/function_call_parser.py +2 -2
  28. sglang/srt/layers/attention/flashattention_backend.py +1 -1
  29. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  30. sglang/srt/layers/attention/utils.py +4 -2
  31. sglang/srt/layers/dp_attention.py +71 -21
  32. sglang/srt/layers/layernorm.py +1 -1
  33. sglang/srt/layers/logits_processor.py +46 -11
  34. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  35. sglang/srt/layers/moe/ep_moe/layer.py +1 -1
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  37. sglang/srt/layers/moe/topk.py +1 -1
  38. sglang/srt/layers/quantization/__init__.py +1 -1
  39. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  40. sglang/srt/layers/quantization/deep_gemm.py +72 -71
  41. sglang/srt/layers/quantization/fp8.py +2 -2
  42. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  43. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  44. sglang/srt/layers/sampler.py +0 -4
  45. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  46. sglang/srt/lora/lora_manager.py +1 -1
  47. sglang/srt/lora/mem_pool.py +4 -4
  48. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  49. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  50. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  51. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  52. sglang/srt/lora/utils.py +1 -1
  53. sglang/srt/managers/data_parallel_controller.py +3 -3
  54. sglang/srt/managers/detokenizer_manager.py +21 -8
  55. sglang/srt/managers/io_struct.py +3 -1
  56. sglang/srt/managers/mm_utils.py +1 -1
  57. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  58. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  59. sglang/srt/managers/schedule_batch.py +76 -24
  60. sglang/srt/managers/schedule_policy.py +0 -3
  61. sglang/srt/managers/scheduler.py +113 -88
  62. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  63. sglang/srt/managers/tokenizer_manager.py +133 -34
  64. sglang/srt/managers/tp_worker.py +12 -9
  65. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  66. sglang/srt/mem_cache/memory_pool.py +2 -0
  67. sglang/srt/metrics/collector.py +312 -37
  68. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  69. sglang/srt/model_executor/forward_batch_info.py +1 -1
  70. sglang/srt/model_executor/model_runner.py +19 -14
  71. sglang/srt/models/deepseek_janus_pro.py +2 -2
  72. sglang/srt/models/deepseek_v2.py +23 -20
  73. sglang/srt/models/llama.py +2 -0
  74. sglang/srt/models/llama4.py +5 -6
  75. sglang/srt/models/llava.py +248 -5
  76. sglang/srt/models/mixtral.py +98 -34
  77. sglang/srt/models/pixtral.py +467 -0
  78. sglang/srt/models/roberta.py +1 -1
  79. sglang/srt/models/torch_native_llama.py +1 -1
  80. sglang/srt/openai_api/adapter.py +30 -4
  81. sglang/srt/openai_api/protocol.py +0 -8
  82. sglang/srt/reasoning_parser.py +3 -3
  83. sglang/srt/sampling/custom_logit_processor.py +18 -3
  84. sglang/srt/sampling/sampling_batch_info.py +4 -56
  85. sglang/srt/sampling/sampling_params.py +2 -2
  86. sglang/srt/server_args.py +34 -4
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +7 -7
  89. sglang/srt/speculative/eagle_worker.py +22 -19
  90. sglang/srt/utils.py +6 -5
  91. sglang/test/few_shot_gsm8k.py +2 -2
  92. sglang/test/few_shot_gsm8k_engine.py +2 -2
  93. sglang/test/run_eval.py +2 -2
  94. sglang/test/runners.py +8 -1
  95. sglang/test/send_one.py +13 -3
  96. sglang/test/simple_eval_common.py +1 -1
  97. sglang/test/simple_eval_humaneval.py +1 -1
  98. sglang/test/test_programs.py +5 -5
  99. sglang/test/test_utils.py +89 -14
  100. sglang/utils.py +1 -1
  101. sglang/version.py +1 -1
  102. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
  103. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
  104. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  105. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
  106. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -175,6 +175,32 @@ def guess_chat_template_name_from_model_path(model_path):
175
175
  )
176
176
 
177
177
 
178
+ def _validate_prompt(prompt: str):
179
+ """Validate that the prompt is not empty or whitespace only."""
180
+ is_invalid = False
181
+
182
+ # Check for empty/whitespace string
183
+ if isinstance(prompt, str):
184
+ is_invalid = not prompt.strip()
185
+ # Check for various invalid list cases: [], [""], [" "], [[]]
186
+ elif isinstance(prompt, list):
187
+ is_invalid = not prompt or (
188
+ len(prompt) == 1
189
+ and (
190
+ (isinstance(prompt[0], str) and not prompt[0].strip())
191
+ or (isinstance(prompt[0], list) and not prompt[0])
192
+ )
193
+ )
194
+
195
+ if is_invalid:
196
+ raise HTTPException(
197
+ status_code=400,
198
+ detail="Input cannot be empty or contain only whitespace.",
199
+ )
200
+
201
+ return prompt
202
+
203
+
178
204
  async def v1_files_create(
179
205
  file: UploadFile, purpose: str, file_storage_path: str = None
180
206
  ):
@@ -529,7 +555,6 @@ def v1_generate_request(
529
555
  "temperature": request.temperature,
530
556
  "max_new_tokens": request.max_tokens,
531
557
  "min_new_tokens": request.min_tokens,
532
- "thinking_budget": request.thinking_budget,
533
558
  "stop": request.stop,
534
559
  "stop_token_ids": request.stop_token_ids,
535
560
  "top_p": request.top_p,
@@ -591,7 +616,7 @@ def v1_generate_response(
591
616
  echo = False
592
617
 
593
618
  if (not isinstance(request, list)) and request.echo:
594
- # TODO: handle the case propmt is token ids
619
+ # TODO: handle the case prompt is token ids
595
620
  if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
596
621
  # for the case of multiple str prompts
597
622
  prompts = request.prompt
@@ -647,7 +672,7 @@ def v1_generate_response(
647
672
  finish_reason = ret_item["meta_info"]["finish_reason"]
648
673
 
649
674
  if to_file:
650
- # to make the choise data json serializable
675
+ # to make the choice data json serializable
651
676
  choice_data = {
652
677
  "index": 0,
653
678
  "text": text,
@@ -1102,7 +1127,6 @@ def v1_chat_generate_request(
1102
1127
  "temperature": request.temperature,
1103
1128
  "max_new_tokens": request.max_tokens or request.max_completion_tokens,
1104
1129
  "min_new_tokens": request.min_tokens,
1105
- "thinking_budget": request.thinking_budget,
1106
1130
  "stop": stop,
1107
1131
  "stop_token_ids": request.stop_token_ids,
1108
1132
  "top_p": request.top_p,
@@ -1755,6 +1779,8 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1755
1779
 
1756
1780
  for request in all_requests:
1757
1781
  prompt = request.input
1782
+ # Check for empty/whitespace string
1783
+ prompt = _validate_prompt(request.input)
1758
1784
  assert (
1759
1785
  type(prompt) is first_prompt_type
1760
1786
  ), "All prompts must be of the same type in file input settings"
@@ -172,7 +172,6 @@ class CompletionRequest(BaseModel):
172
172
  top_k: int = -1
173
173
  min_p: float = 0.0
174
174
  min_tokens: int = 0
175
- thinking_budget: Optional[int] = None
176
175
  json_schema: Optional[str] = None
177
176
  regex: Optional[str] = None
178
177
  ebnf: Optional[str] = None
@@ -351,13 +350,6 @@ class ChatCompletionRequest(BaseModel):
351
350
  description="The maximum number of completion tokens for a chat completion request, "
352
351
  "including visible output tokens and reasoning tokens. Input tokens are not included. ",
353
352
  )
354
- thinking_budget: Optional[int] = Field(
355
- default=None,
356
- description="The maximum number of reasoning tokens that can be generated for a request. "
357
- "This setting of does not affect the thinking process of models. "
358
- "If the number of tokens generated by the model's thinking process exceeds thinking_budget, "
359
- "the reasoning content will be truncated and the final response content will be generated immediately.",
360
- )
361
353
  n: int = 1
362
354
  presence_penalty: float = 0.0
363
355
  response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
32
32
  One-time parsing: Detects and parses reasoning sections in the provided text.
33
33
  Returns both reasoning content and normal text separately.
34
34
  """
35
- text = text.replace(self.think_start_token, "")
35
+ text = text.replace(self.think_start_token, "").strip()
36
36
  if self.think_end_token not in text:
37
37
  # Assume reasoning was truncated before `</think>` token
38
38
  return StreamingParseResult(reasoning_text=text)
@@ -73,7 +73,7 @@ class BaseReasoningFormatDetector:
73
73
  normal_text = current_text[end_idx + len(self.think_end_token) :]
74
74
 
75
75
  return StreamingParseResult(
76
- normal_text=normal_text, reasoning_text=reasoning_text
76
+ normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
77
77
  )
78
78
 
79
79
  # Continue with reasoning content
@@ -147,7 +147,7 @@ class ReasoningParser:
147
147
 
148
148
  Args:
149
149
  model_type (str): Type of model to parse reasoning from
150
- stream_reasoning (bool): If Flase, accumulates reasoning content until complete.
150
+ stream_reasoning (bool): If False, accumulates reasoning content until complete.
151
151
  If True, streams reasoning content as it arrives.
152
152
  """
153
153
 
@@ -28,11 +28,26 @@ class CustomLogitProcessor(ABC):
28
28
  """Define the callable behavior."""
29
29
  raise NotImplementedError
30
30
 
31
- def to_str(self) -> str:
31
+ @classmethod
32
+ def to_str(cls) -> str:
32
33
  """Serialize the callable function to a JSON-compatible string."""
33
- return json.dumps({"callable": dill.dumps(self).hex()})
34
+ return json.dumps({"callable": dill.dumps(cls).hex()})
34
35
 
35
36
  @classmethod
36
37
  def from_str(cls, json_str: str):
37
38
  """Deserialize a callable function from a JSON string."""
38
- return _cache_from_str(json_str)
39
+ return _cache_from_str(json_str)()
40
+
41
+
42
+ class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
43
+ def __call__(
44
+ self,
45
+ logits: torch.Tensor,
46
+ custom_param_list: Optional[List[Dict[str, Any]]] = None,
47
+ ) -> torch.Tensor:
48
+ disallowed_token_ids = custom_param_list[0]["token_ids"]
49
+ assert all(
50
+ disallowed_token_ids == c["token_ids"] for c in custom_param_list
51
+ ), f"{custom_param_list=}"
52
+ logits[..., disallowed_token_ids] = -float("inf")
53
+ return logits
@@ -30,13 +30,8 @@ class SamplingBatchInfo:
30
30
  # Whether any request needs min_p sampling
31
31
  need_min_p_sampling: bool
32
32
 
33
- # Use thinking_budget to truncate thinking
34
- num_thinking_tokens: Optional[torch.Tensor] = None
35
- think_end_ids: Optional[torch.Tensor] = None
36
- thinking_budgets: Optional[torch.Tensor] = None
37
-
38
33
  # Masking tensors for grammar-guided structured outputs
39
- vocab_size: int = 0
34
+ vocab_size: int
40
35
  grammars: Optional[List] = None
41
36
  vocab_mask: Optional[torch.Tensor] = None
42
37
  apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
@@ -81,22 +76,7 @@ class SamplingBatchInfo:
81
76
  min_ps = torch.tensor(
82
77
  [r.sampling_params.min_p for r in reqs], dtype=torch.float
83
78
  ).to(device, non_blocking=True)
84
- if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
85
- think_end_ids = torch.tensor(
86
- [getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
87
- dtype=torch.int64,
88
- ).to(device, non_blocking=True)
89
- num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
90
- device, non_blocking=True
91
- )
92
- thinking_budgets = torch.tensor(
93
- [r.sampling_params.thinking_budget or -1 for r in reqs],
94
- dtype=torch.int64,
95
- ).to(device, non_blocking=True)
96
- else:
97
- think_end_ids = None
98
- num_thinking_tokens = None
99
- thinking_budgets = None
79
+
100
80
  # Check if any request has custom logit processor
101
81
  has_custom_logit_processor = (
102
82
  batch.enable_custom_logit_processor # check the flag first.
@@ -152,9 +132,6 @@ class SamplingBatchInfo:
152
132
  top_ps=top_ps,
153
133
  top_ks=top_ks,
154
134
  min_ps=min_ps,
155
- think_end_ids=think_end_ids,
156
- num_thinking_tokens=num_thinking_tokens,
157
- thinking_budgets=thinking_budgets,
158
135
  is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
159
136
  need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
160
137
  vocab_size=vocab_size,
@@ -169,35 +146,6 @@ class SamplingBatchInfo:
169
146
  def __len__(self):
170
147
  return len(self.temperatures)
171
148
 
172
- def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
173
- has_budget = self.thinking_budgets > 0
174
- if not has_budget.any():
175
- return
176
- torch.where(
177
- has_budget,
178
- self.num_thinking_tokens + 1,
179
- self.num_thinking_tokens,
180
- out=self.num_thinking_tokens,
181
- )
182
- should_stop = has_budget & (
183
- self.num_thinking_tokens - 1 > self.thinking_budgets
184
- )
185
- next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
186
- batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
187
- if len(batch_indices) > 0:
188
- end_token_indices = self.think_end_ids[batch_indices]
189
- next_token_logits[batch_indices, end_token_indices] = 0.0
190
-
191
- def update_thinking_budgets(self, next_token_ids: torch.Tensor):
192
- if not torch.any(self.thinking_budgets > 0):
193
- return
194
- torch.where(
195
- next_token_ids == self.think_end_ids,
196
- torch.tensor(-1, device=self.thinking_budgets.device),
197
- self.thinking_budgets,
198
- out=self.thinking_budgets,
199
- )
200
-
201
149
  def update_regex_vocab_mask(self):
202
150
  if not self.grammars:
203
151
  self.vocab_mask = None
@@ -346,7 +294,7 @@ class SamplingBatchInfo:
346
294
  # Set the flag to True if any of the two has custom logit processor
347
295
  self.has_custom_logit_processor = True
348
296
 
349
- # Note: becasue the __len()__ operator is defined on the temperatures tensor,
297
+ # Note: because the __len()__ operator is defined on the temperatures tensor,
350
298
  # please make sure any merge operation with len(self) or len(other) is done before
351
299
  # the merge operation of the temperatures tensor below.
352
300
  for item in [
@@ -359,5 +307,5 @@ class SamplingBatchInfo:
359
307
  other_val = getattr(other, item, None)
360
308
  setattr(self, item, torch.cat([self_val, other_val]))
361
309
 
362
- self.is_all_greedy |= other.is_all_greedy
310
+ self.is_all_greedy &= other.is_all_greedy
363
311
  self.need_min_p_sampling |= other.need_min_p_sampling
@@ -30,7 +30,6 @@ class SamplingParams:
30
30
  def __init__(
31
31
  self,
32
32
  max_new_tokens: int = 128,
33
- thinking_budget: Optional[int] = None,
34
33
  stop: Optional[Union[str, List[str]]] = None,
35
34
  stop_token_ids: Optional[List[int]] = None,
36
35
  temperature: float = 1.0,
@@ -51,6 +50,7 @@ class SamplingParams:
51
50
  spaces_between_special_tokens: bool = True,
52
51
  no_stop_trim: bool = False,
53
52
  custom_params: Optional[Dict[str, Any]] = None,
53
+ stream_interval: Optional[int] = None,
54
54
  ) -> None:
55
55
  self.max_new_tokens = max_new_tokens
56
56
  self.stop_strs = stop
@@ -58,7 +58,6 @@ class SamplingParams:
58
58
  self.stop_token_ids = set(stop_token_ids)
59
59
  else:
60
60
  self.stop_token_ids = None
61
- self.thinking_budget = thinking_budget
62
61
  self.temperature = temperature
63
62
  self.top_p = top_p
64
63
  self.top_k = top_k
@@ -77,6 +76,7 @@ class SamplingParams:
77
76
  self.spaces_between_special_tokens = spaces_between_special_tokens
78
77
  self.no_stop_trim = no_stop_trim
79
78
  self.custom_params = custom_params
79
+ self.stream_interval = stream_interval
80
80
 
81
81
  # Process some special cases
82
82
  if 0 <= self.temperature < _SAMPLING_EPS:
sglang/srt/server_args.py CHANGED
@@ -98,6 +98,7 @@ class ServerArgs:
98
98
  show_time_cost: bool = False
99
99
  enable_metrics: bool = False
100
100
  decode_log_interval: int = 40
101
+ enable_request_time_stats_logging: bool = False
101
102
 
102
103
  # API related
103
104
  api_key: Optional[str] = None
@@ -159,6 +160,7 @@ class ServerArgs:
159
160
  disable_overlap_schedule: bool = False
160
161
  enable_mixed_chunk: bool = False
161
162
  enable_dp_attention: bool = False
163
+ enable_dp_lm_head: bool = False
162
164
  enable_ep_moe: bool = False
163
165
  enable_deepep_moe: bool = False
164
166
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
@@ -305,6 +307,12 @@ class ServerArgs:
305
307
  if self.grammar_backend is None:
306
308
  self.grammar_backend = "xgrammar"
307
309
 
310
+ if self.pp_size > 1:
311
+ self.disable_overlap_schedule = True
312
+ logger.warning(
313
+ "Overlap scheduler is disabled because of using pipeline parallelism."
314
+ )
315
+
308
316
  # Data parallelism attention
309
317
  if self.enable_dp_attention:
310
318
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
@@ -317,6 +325,11 @@ class ServerArgs:
317
325
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
318
326
  )
319
327
 
328
+ if self.enable_dp_lm_head:
329
+ assert (
330
+ self.enable_dp_attention
331
+ ), "Please enable dp attention when setting enable_dp_attention. "
332
+
320
333
  # DeepEP MoE
321
334
  self.enable_sp_layernorm = False
322
335
  if self.enable_deepep_moe:
@@ -335,6 +348,12 @@ class ServerArgs:
335
348
  f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
336
349
  )
337
350
 
351
+ if self.pp_size > 1:
352
+ self.disable_overlap_schedule = True
353
+ logger.warning(
354
+ "Pipeline parallelism is incompatible with overlap schedule."
355
+ )
356
+
338
357
  # Speculative Decoding
339
358
  if self.speculative_algorithm == "NEXTN":
340
359
  # NEXTN shares the same implementation of EAGLE
@@ -767,6 +786,12 @@ class ServerArgs:
767
786
  default=ServerArgs.decode_log_interval,
768
787
  help="The log interval of decode batch.",
769
788
  )
789
+ parser.add_argument(
790
+ "--enable-request-time-stats-logging",
791
+ action="store_true",
792
+ default=ServerArgs.enable_request_time_stats_logging,
793
+ help="Enable per request time stats logging",
794
+ )
770
795
 
771
796
  # API related
772
797
  parser.add_argument(
@@ -825,7 +850,7 @@ class ServerArgs:
825
850
  # Multi-node distributed serving
826
851
  parser.add_argument(
827
852
  "--dist-init-addr",
828
- "--nccl-init-addr", # For backward compatbility. This will be removed in the future.
853
+ "--nccl-init-addr", # For backward compatibility. This will be removed in the future.
829
854
  type=str,
830
855
  help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
831
856
  )
@@ -1049,6 +1074,11 @@ class ServerArgs:
1049
1074
  action="store_true",
1050
1075
  help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
1051
1076
  )
1077
+ parser.add_argument(
1078
+ "--enable-dp-lm-head",
1079
+ action="store_true",
1080
+ help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
1081
+ )
1052
1082
  parser.add_argument(
1053
1083
  "--enable-ep-moe",
1054
1084
  action="store_true",
@@ -1069,7 +1099,7 @@ class ServerArgs:
1069
1099
  "--cuda-graph-max-bs",
1070
1100
  type=int,
1071
1101
  default=ServerArgs.cuda_graph_max_bs,
1072
- help="Set the maximum batch size for cuda graph.",
1102
+ help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
1073
1103
  )
1074
1104
  parser.add_argument(
1075
1105
  "--cuda-graph-bs",
@@ -1096,7 +1126,7 @@ class ServerArgs:
1096
1126
  parser.add_argument(
1097
1127
  "--triton-attention-reduce-in-fp32",
1098
1128
  action="store_true",
1099
- help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
1129
+ help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
1100
1130
  "This only affects Triton attention kernels.",
1101
1131
  )
1102
1132
  parser.add_argument(
@@ -1188,7 +1218,7 @@ class ServerArgs:
1188
1218
  type=int,
1189
1219
  default=0,
1190
1220
  help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
1191
- "set it to tp_size can get best optimized performace.",
1221
+ "set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with n_share_experts_fusion automatically set to the TP size.",
1192
1222
  )
1193
1223
  parser.add_argument(
1194
1224
  "--disable-chunked-prefix-cache",
@@ -82,12 +82,12 @@ class EAGLEDraftCudaGraphRunner:
82
82
  self.capture()
83
83
  except RuntimeError as e:
84
84
  raise Exception(
85
- f"Capture cuda graph failed: {e}\n"
85
+ f"Capture CUDA graph failed: {e}\n"
86
86
  "Possible solutions:\n"
87
87
  "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
88
88
  "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
89
89
  "3. disable torch compile by not using --enable-torch-compile\n"
90
- "4. disable cuda graph by --disable-cuda-graph. (Not recommonded. Huge perf loss)\n"
90
+ "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
91
91
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
92
92
  )
93
93
 
@@ -149,7 +149,7 @@ class EAGLEDraftCudaGraphRunner:
149
149
 
150
150
  # Run and capture
151
151
  def run_once():
152
- # Backup two fileds, which will be modified in-place in `draft_forward`.
152
+ # Backup two fields, which will be modified in-place in `draft_forward`.
153
153
  output_cache_loc_backup = forward_batch.out_cache_loc
154
154
  hidden_states_backup = forward_batch.spec_info.hidden_states
155
155
 
@@ -167,12 +167,12 @@ class EagleVerifyOutput:
167
167
  draft_input: EagleDraftInput
168
168
  # Logit outputs from target worker
169
169
  logits_output: LogitsProcessorOutput
170
- # Accepeted token ids including the bonus token
170
+ # Accepted token ids including the bonus token
171
171
  verified_id: torch.Tensor
172
- # Accepeted token length per sequence in a batch in CPU.
172
+ # Accepted token length per sequence in a batch in CPU.
173
173
  accept_length_per_req_cpu: List[int]
174
- # Accepeted indices from logits_output.next_token_logits
175
- accepeted_indices: torch.Tensor
174
+ # Accepted indices from logits_output.next_token_logits
175
+ accepted_indices: torch.Tensor
176
176
 
177
177
 
178
178
  @dataclass
@@ -316,7 +316,7 @@ class EagleVerifyInput:
316
316
 
317
317
  This API updates values inside logits_output based on the accepted
318
318
  tokens. I.e., logits_output.next_token_logits only contains
319
- accepeted token logits.
319
+ accepted token logits.
320
320
  """
321
321
  bs = self.retrive_index.shape[0]
322
322
  candidates = self.draft_token.reshape(bs, self.draft_token_num)
@@ -493,7 +493,7 @@ class EagleVerifyInput:
493
493
  logits_output=logits_output,
494
494
  verified_id=verified_id,
495
495
  accept_length_per_req_cpu=accept_length_cpu,
496
- accepeted_indices=accept_index,
496
+ accepted_indices=accept_index,
497
497
  )
498
498
  else:
499
499
  assign_req_to_token_pool[(bs,)](
@@ -539,7 +539,7 @@ class EagleVerifyInput:
539
539
  logits_output=logits_output,
540
540
  verified_id=verified_id,
541
541
  accept_length_per_req_cpu=accept_length_cpu,
542
- accepeted_indices=accept_index,
542
+ accepted_indices=accept_index,
543
543
  )
544
544
 
545
545
 
@@ -201,7 +201,7 @@ class EAGLEWorker(TpModelWorker):
201
201
  self.has_prefill_wrapper_verify = False
202
202
  else:
203
203
  raise ValueError(
204
- f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
204
+ f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
205
205
  )
206
206
 
207
207
  self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
@@ -245,14 +245,14 @@ class EAGLEWorker(TpModelWorker):
245
245
  Args:
246
246
  batch: The batch to run forward. The state of the batch is modified as it runs.
247
247
  Returns:
248
- A tuple of the final logit output of the target model, next tokens accepeted,
249
- the batch id (used for overlap schedule), and number of accepeted tokens.
248
+ A tuple of the final logit output of the target model, next tokens accepted,
249
+ the batch id (used for overlap schedule), and number of accepted tokens.
250
250
  """
251
251
  if batch.forward_mode.is_decode():
252
252
  with self.draft_tp_context(self.draft_model_runner.tp_group):
253
253
  spec_info = self.draft(batch)
254
- logits_output, verify_output, model_worker_batch = self.verify(
255
- batch, spec_info
254
+ logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
255
+ self.verify(batch, spec_info)
256
256
  )
257
257
 
258
258
  # If it is None, it means all requests are finished
@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker):
264
264
  verify_output.verified_id,
265
265
  model_worker_batch.bid,
266
266
  sum(verify_output.accept_length_per_req_cpu),
267
+ can_run_cuda_graph,
267
268
  )
268
269
  elif batch.forward_mode.is_idle():
269
270
  model_worker_batch = batch.get_model_worker_batch()
270
- logits_output, next_token_ids = self.target_worker.forward_batch_generation(
271
- model_worker_batch
271
+ logits_output, next_token_ids, _ = (
272
+ self.target_worker.forward_batch_generation(model_worker_batch)
272
273
  )
273
274
 
274
- return logits_output, next_token_ids, model_worker_batch.bid, 0
275
+ return logits_output, next_token_ids, model_worker_batch.bid, 0, False
275
276
  else:
276
277
  logits_output, next_token_ids, bid = self.forward_target_extend(batch)
277
278
  with self.draft_tp_context(self.draft_model_runner.tp_group):
278
279
  self.forward_draft_extend(
279
280
  batch, logits_output.hidden_states, next_token_ids
280
281
  )
281
- return logits_output, next_token_ids, bid, 0
282
+ return logits_output, next_token_ids, bid, 0, False
282
283
 
283
284
  def forward_target_extend(
284
285
  self, batch: ScheduleBatch
@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker):
297
298
  # We need the full hidden states to prefill the KV cache of the draft model.
298
299
  model_worker_batch = batch.get_model_worker_batch()
299
300
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
300
- logits_output, next_token_ids = self.target_worker.forward_batch_generation(
301
+ logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
301
302
  model_worker_batch
302
303
  )
303
304
  return logits_output, next_token_ids, model_worker_batch.bid
@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker):
478
479
  batch.forward_mode = ForwardMode.TARGET_VERIFY
479
480
  batch.spec_info = spec_info
480
481
  model_worker_batch = batch.get_model_worker_batch()
481
- logits_output, _ = self.target_worker.forward_batch_generation(
482
- model_worker_batch, skip_sample=True
482
+ logits_output, _, can_run_cuda_graph = (
483
+ self.target_worker.forward_batch_generation(
484
+ model_worker_batch, skip_sample=True
485
+ )
483
486
  )
484
487
  self._detect_nan_if_needed(logits_output)
485
488
  spec_info.hidden_states = logits_output.hidden_states
@@ -491,11 +494,11 @@ class EAGLEWorker(TpModelWorker):
491
494
  )
492
495
 
493
496
  # Post process based on verified outputs.
494
- # Pick indices that we care (accepeted)
497
+ # Pick indices that we care (accepted)
495
498
  logits_output.next_token_logits = logits_output.next_token_logits[
496
- res.accepeted_indices
499
+ res.accepted_indices
497
500
  ]
498
- logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
501
+ logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
499
502
 
500
503
  # Prepare the batch for the next draft forwards.
501
504
  batch.forward_mode = ForwardMode.DECODE
@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker):
504
507
  if batch.return_logprob:
505
508
  self.add_logprob_values(batch, res, logits_output)
506
509
 
507
- return logits_output, res, model_worker_batch
510
+ return logits_output, res, model_worker_batch, can_run_cuda_graph
508
511
 
509
512
  def add_logprob_values(
510
513
  self,
@@ -590,14 +593,14 @@ class EAGLEWorker(TpModelWorker):
590
593
  model_worker_batch, self.draft_model_runner
591
594
  )
592
595
  forward_batch.return_logprob = False
593
- logits_output = self.draft_model_runner.forward(forward_batch)
596
+ logits_output, _ = self.draft_model_runner.forward(forward_batch)
594
597
  self._detect_nan_if_needed(logits_output)
595
598
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
596
599
  assert forward_batch.spec_info is batch.spec_info
597
600
  self.capture_for_decode(logits_output, forward_batch.spec_info)
598
601
 
599
602
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
600
- # Backup fileds that will be modified in-place
603
+ # Backup fields that will be modified in-place
601
604
  seq_lens_backup = batch.seq_lens.clone()
602
605
  req_pool_indices_backup = batch.req_pool_indices
603
606
  accept_length_backup = batch.spec_info.accept_length
@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker):
617
620
  )
618
621
 
619
622
  # Run
620
- logits_output = self.draft_model_runner.forward(forward_batch)
623
+ logits_output, _ = self.draft_model_runner.forward(forward_batch)
621
624
 
622
625
  self._detect_nan_if_needed(logits_output)
623
626
  self.capture_for_decode(logits_output, forward_batch.spec_info)
sglang/srt/utils.py CHANGED
@@ -282,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
282
282
  return wrapper
283
283
 
284
284
 
285
- def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
285
+ def get_available_gpu_memory(
286
+ device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
287
+ ):
286
288
  """
287
289
  Get available memory for cuda:gpu_id device.
288
290
  When distributed is True, the available memory is the minimum available memory of all GPUs.
@@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
344
346
  free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
345
347
 
346
348
  if distributed:
347
- tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
348
- torch.device(device, gpu_id)
349
+ tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
350
+ torch.distributed.all_reduce(
351
+ tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
349
352
  )
350
- torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
351
353
  free_gpu_memory = tensor.item()
352
354
 
353
355
  return free_gpu_memory / (1 << 30)
@@ -2076,7 +2078,6 @@ def is_fa3_default_architecture(hf_config):
2076
2078
  "Llama4ForConditionalGeneration",
2077
2079
  "LlamaForCausalLM",
2078
2080
  "MistralForCausalLM",
2079
- "MixtralForCausalLM",
2080
2081
  "Gemma2ForCausalLM",
2081
2082
  "Gemma3ForConditionalGeneration",
2082
2083
  "Qwen3ForCausalLM",
@@ -90,7 +90,7 @@ def run_eval(args):
90
90
  #####################################
91
91
 
92
92
  # Run requests
93
- tic = time.time()
93
+ tic = time.perf_counter()
94
94
  states = few_shot_gsm8k.run_batch(
95
95
  arguments,
96
96
  temperature=args.temperature if hasattr(args, "temperature") else 0,
@@ -99,7 +99,7 @@ def run_eval(args):
99
99
  return_logprob=getattr(args, "return_logprob", None),
100
100
  logprob_start_len=getattr(args, "logprob_start_len", None),
101
101
  )
102
- latency = time.time() - tic
102
+ latency = time.perf_counter() - tic
103
103
 
104
104
  preds = []
105
105
  for i in range(len(states)):
@@ -89,7 +89,7 @@ def run_eval(args):
89
89
  }
90
90
 
91
91
  # Run requests
92
- tic = time.time()
92
+ tic = time.perf_counter()
93
93
 
94
94
  loop = asyncio.get_event_loop()
95
95
 
@@ -98,7 +98,7 @@ def run_eval(args):
98
98
  )
99
99
 
100
100
  # End requests
101
- latency = time.time() - tic
101
+ latency = time.perf_counter() - tic
102
102
 
103
103
  # Shutdown the engine
104
104
  engine.shutdown()