sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -175,6 +175,32 @@ def guess_chat_template_name_from_model_path(model_path):
|
|
175
175
|
)
|
176
176
|
|
177
177
|
|
178
|
+
def _validate_prompt(prompt: str):
|
179
|
+
"""Validate that the prompt is not empty or whitespace only."""
|
180
|
+
is_invalid = False
|
181
|
+
|
182
|
+
# Check for empty/whitespace string
|
183
|
+
if isinstance(prompt, str):
|
184
|
+
is_invalid = not prompt.strip()
|
185
|
+
# Check for various invalid list cases: [], [""], [" "], [[]]
|
186
|
+
elif isinstance(prompt, list):
|
187
|
+
is_invalid = not prompt or (
|
188
|
+
len(prompt) == 1
|
189
|
+
and (
|
190
|
+
(isinstance(prompt[0], str) and not prompt[0].strip())
|
191
|
+
or (isinstance(prompt[0], list) and not prompt[0])
|
192
|
+
)
|
193
|
+
)
|
194
|
+
|
195
|
+
if is_invalid:
|
196
|
+
raise HTTPException(
|
197
|
+
status_code=400,
|
198
|
+
detail="Input cannot be empty or contain only whitespace.",
|
199
|
+
)
|
200
|
+
|
201
|
+
return prompt
|
202
|
+
|
203
|
+
|
178
204
|
async def v1_files_create(
|
179
205
|
file: UploadFile, purpose: str, file_storage_path: str = None
|
180
206
|
):
|
@@ -529,7 +555,6 @@ def v1_generate_request(
|
|
529
555
|
"temperature": request.temperature,
|
530
556
|
"max_new_tokens": request.max_tokens,
|
531
557
|
"min_new_tokens": request.min_tokens,
|
532
|
-
"thinking_budget": request.thinking_budget,
|
533
558
|
"stop": request.stop,
|
534
559
|
"stop_token_ids": request.stop_token_ids,
|
535
560
|
"top_p": request.top_p,
|
@@ -591,7 +616,7 @@ def v1_generate_response(
|
|
591
616
|
echo = False
|
592
617
|
|
593
618
|
if (not isinstance(request, list)) and request.echo:
|
594
|
-
# TODO: handle the case
|
619
|
+
# TODO: handle the case prompt is token ids
|
595
620
|
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
|
596
621
|
# for the case of multiple str prompts
|
597
622
|
prompts = request.prompt
|
@@ -647,7 +672,7 @@ def v1_generate_response(
|
|
647
672
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
648
673
|
|
649
674
|
if to_file:
|
650
|
-
# to make the
|
675
|
+
# to make the choice data json serializable
|
651
676
|
choice_data = {
|
652
677
|
"index": 0,
|
653
678
|
"text": text,
|
@@ -1102,7 +1127,6 @@ def v1_chat_generate_request(
|
|
1102
1127
|
"temperature": request.temperature,
|
1103
1128
|
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
|
1104
1129
|
"min_new_tokens": request.min_tokens,
|
1105
|
-
"thinking_budget": request.thinking_budget,
|
1106
1130
|
"stop": stop,
|
1107
1131
|
"stop_token_ids": request.stop_token_ids,
|
1108
1132
|
"top_p": request.top_p,
|
@@ -1755,6 +1779,8 @@ def v1_embedding_request(all_requests, tokenizer_manager):
|
|
1755
1779
|
|
1756
1780
|
for request in all_requests:
|
1757
1781
|
prompt = request.input
|
1782
|
+
# Check for empty/whitespace string
|
1783
|
+
prompt = _validate_prompt(request.input)
|
1758
1784
|
assert (
|
1759
1785
|
type(prompt) is first_prompt_type
|
1760
1786
|
), "All prompts must be of the same type in file input settings"
|
@@ -172,7 +172,6 @@ class CompletionRequest(BaseModel):
|
|
172
172
|
top_k: int = -1
|
173
173
|
min_p: float = 0.0
|
174
174
|
min_tokens: int = 0
|
175
|
-
thinking_budget: Optional[int] = None
|
176
175
|
json_schema: Optional[str] = None
|
177
176
|
regex: Optional[str] = None
|
178
177
|
ebnf: Optional[str] = None
|
@@ -351,13 +350,6 @@ class ChatCompletionRequest(BaseModel):
|
|
351
350
|
description="The maximum number of completion tokens for a chat completion request, "
|
352
351
|
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
|
353
352
|
)
|
354
|
-
thinking_budget: Optional[int] = Field(
|
355
|
-
default=None,
|
356
|
-
description="The maximum number of reasoning tokens that can be generated for a request. "
|
357
|
-
"This setting of does not affect the thinking process of models. "
|
358
|
-
"If the number of tokens generated by the model's thinking process exceeds thinking_budget, "
|
359
|
-
"the reasoning content will be truncated and the final response content will be generated immediately.",
|
360
|
-
)
|
361
353
|
n: int = 1
|
362
354
|
presence_penalty: float = 0.0
|
363
355
|
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
|
|
32
32
|
One-time parsing: Detects and parses reasoning sections in the provided text.
|
33
33
|
Returns both reasoning content and normal text separately.
|
34
34
|
"""
|
35
|
-
text = text.replace(self.think_start_token, "")
|
35
|
+
text = text.replace(self.think_start_token, "").strip()
|
36
36
|
if self.think_end_token not in text:
|
37
37
|
# Assume reasoning was truncated before `</think>` token
|
38
38
|
return StreamingParseResult(reasoning_text=text)
|
@@ -73,7 +73,7 @@ class BaseReasoningFormatDetector:
|
|
73
73
|
normal_text = current_text[end_idx + len(self.think_end_token) :]
|
74
74
|
|
75
75
|
return StreamingParseResult(
|
76
|
-
normal_text=normal_text, reasoning_text=reasoning_text
|
76
|
+
normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
|
77
77
|
)
|
78
78
|
|
79
79
|
# Continue with reasoning content
|
@@ -147,7 +147,7 @@ class ReasoningParser:
|
|
147
147
|
|
148
148
|
Args:
|
149
149
|
model_type (str): Type of model to parse reasoning from
|
150
|
-
stream_reasoning (bool): If
|
150
|
+
stream_reasoning (bool): If False, accumulates reasoning content until complete.
|
151
151
|
If True, streams reasoning content as it arrives.
|
152
152
|
"""
|
153
153
|
|
@@ -28,11 +28,26 @@ class CustomLogitProcessor(ABC):
|
|
28
28
|
"""Define the callable behavior."""
|
29
29
|
raise NotImplementedError
|
30
30
|
|
31
|
-
|
31
|
+
@classmethod
|
32
|
+
def to_str(cls) -> str:
|
32
33
|
"""Serialize the callable function to a JSON-compatible string."""
|
33
|
-
return json.dumps({"callable": dill.dumps(
|
34
|
+
return json.dumps({"callable": dill.dumps(cls).hex()})
|
34
35
|
|
35
36
|
@classmethod
|
36
37
|
def from_str(cls, json_str: str):
|
37
38
|
"""Deserialize a callable function from a JSON string."""
|
38
|
-
return _cache_from_str(json_str)
|
39
|
+
return _cache_from_str(json_str)()
|
40
|
+
|
41
|
+
|
42
|
+
class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
|
43
|
+
def __call__(
|
44
|
+
self,
|
45
|
+
logits: torch.Tensor,
|
46
|
+
custom_param_list: Optional[List[Dict[str, Any]]] = None,
|
47
|
+
) -> torch.Tensor:
|
48
|
+
disallowed_token_ids = custom_param_list[0]["token_ids"]
|
49
|
+
assert all(
|
50
|
+
disallowed_token_ids == c["token_ids"] for c in custom_param_list
|
51
|
+
), f"{custom_param_list=}"
|
52
|
+
logits[..., disallowed_token_ids] = -float("inf")
|
53
|
+
return logits
|
@@ -30,13 +30,8 @@ class SamplingBatchInfo:
|
|
30
30
|
# Whether any request needs min_p sampling
|
31
31
|
need_min_p_sampling: bool
|
32
32
|
|
33
|
-
# Use thinking_budget to truncate thinking
|
34
|
-
num_thinking_tokens: Optional[torch.Tensor] = None
|
35
|
-
think_end_ids: Optional[torch.Tensor] = None
|
36
|
-
thinking_budgets: Optional[torch.Tensor] = None
|
37
|
-
|
38
33
|
# Masking tensors for grammar-guided structured outputs
|
39
|
-
vocab_size: int
|
34
|
+
vocab_size: int
|
40
35
|
grammars: Optional[List] = None
|
41
36
|
vocab_mask: Optional[torch.Tensor] = None
|
42
37
|
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
@@ -81,22 +76,7 @@ class SamplingBatchInfo:
|
|
81
76
|
min_ps = torch.tensor(
|
82
77
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
83
78
|
).to(device, non_blocking=True)
|
84
|
-
|
85
|
-
think_end_ids = torch.tensor(
|
86
|
-
[getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
|
87
|
-
dtype=torch.int64,
|
88
|
-
).to(device, non_blocking=True)
|
89
|
-
num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
|
90
|
-
device, non_blocking=True
|
91
|
-
)
|
92
|
-
thinking_budgets = torch.tensor(
|
93
|
-
[r.sampling_params.thinking_budget or -1 for r in reqs],
|
94
|
-
dtype=torch.int64,
|
95
|
-
).to(device, non_blocking=True)
|
96
|
-
else:
|
97
|
-
think_end_ids = None
|
98
|
-
num_thinking_tokens = None
|
99
|
-
thinking_budgets = None
|
79
|
+
|
100
80
|
# Check if any request has custom logit processor
|
101
81
|
has_custom_logit_processor = (
|
102
82
|
batch.enable_custom_logit_processor # check the flag first.
|
@@ -152,9 +132,6 @@ class SamplingBatchInfo:
|
|
152
132
|
top_ps=top_ps,
|
153
133
|
top_ks=top_ks,
|
154
134
|
min_ps=min_ps,
|
155
|
-
think_end_ids=think_end_ids,
|
156
|
-
num_thinking_tokens=num_thinking_tokens,
|
157
|
-
thinking_budgets=thinking_budgets,
|
158
135
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
159
136
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
160
137
|
vocab_size=vocab_size,
|
@@ -169,35 +146,6 @@ class SamplingBatchInfo:
|
|
169
146
|
def __len__(self):
|
170
147
|
return len(self.temperatures)
|
171
148
|
|
172
|
-
def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
|
173
|
-
has_budget = self.thinking_budgets > 0
|
174
|
-
if not has_budget.any():
|
175
|
-
return
|
176
|
-
torch.where(
|
177
|
-
has_budget,
|
178
|
-
self.num_thinking_tokens + 1,
|
179
|
-
self.num_thinking_tokens,
|
180
|
-
out=self.num_thinking_tokens,
|
181
|
-
)
|
182
|
-
should_stop = has_budget & (
|
183
|
-
self.num_thinking_tokens - 1 > self.thinking_budgets
|
184
|
-
)
|
185
|
-
next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
|
186
|
-
batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
|
187
|
-
if len(batch_indices) > 0:
|
188
|
-
end_token_indices = self.think_end_ids[batch_indices]
|
189
|
-
next_token_logits[batch_indices, end_token_indices] = 0.0
|
190
|
-
|
191
|
-
def update_thinking_budgets(self, next_token_ids: torch.Tensor):
|
192
|
-
if not torch.any(self.thinking_budgets > 0):
|
193
|
-
return
|
194
|
-
torch.where(
|
195
|
-
next_token_ids == self.think_end_ids,
|
196
|
-
torch.tensor(-1, device=self.thinking_budgets.device),
|
197
|
-
self.thinking_budgets,
|
198
|
-
out=self.thinking_budgets,
|
199
|
-
)
|
200
|
-
|
201
149
|
def update_regex_vocab_mask(self):
|
202
150
|
if not self.grammars:
|
203
151
|
self.vocab_mask = None
|
@@ -346,7 +294,7 @@ class SamplingBatchInfo:
|
|
346
294
|
# Set the flag to True if any of the two has custom logit processor
|
347
295
|
self.has_custom_logit_processor = True
|
348
296
|
|
349
|
-
# Note:
|
297
|
+
# Note: because the __len()__ operator is defined on the temperatures tensor,
|
350
298
|
# please make sure any merge operation with len(self) or len(other) is done before
|
351
299
|
# the merge operation of the temperatures tensor below.
|
352
300
|
for item in [
|
@@ -359,5 +307,5 @@ class SamplingBatchInfo:
|
|
359
307
|
other_val = getattr(other, item, None)
|
360
308
|
setattr(self, item, torch.cat([self_val, other_val]))
|
361
309
|
|
362
|
-
self.is_all_greedy
|
310
|
+
self.is_all_greedy &= other.is_all_greedy
|
363
311
|
self.need_min_p_sampling |= other.need_min_p_sampling
|
@@ -30,7 +30,6 @@ class SamplingParams:
|
|
30
30
|
def __init__(
|
31
31
|
self,
|
32
32
|
max_new_tokens: int = 128,
|
33
|
-
thinking_budget: Optional[int] = None,
|
34
33
|
stop: Optional[Union[str, List[str]]] = None,
|
35
34
|
stop_token_ids: Optional[List[int]] = None,
|
36
35
|
temperature: float = 1.0,
|
@@ -51,6 +50,7 @@ class SamplingParams:
|
|
51
50
|
spaces_between_special_tokens: bool = True,
|
52
51
|
no_stop_trim: bool = False,
|
53
52
|
custom_params: Optional[Dict[str, Any]] = None,
|
53
|
+
stream_interval: Optional[int] = None,
|
54
54
|
) -> None:
|
55
55
|
self.max_new_tokens = max_new_tokens
|
56
56
|
self.stop_strs = stop
|
@@ -58,7 +58,6 @@ class SamplingParams:
|
|
58
58
|
self.stop_token_ids = set(stop_token_ids)
|
59
59
|
else:
|
60
60
|
self.stop_token_ids = None
|
61
|
-
self.thinking_budget = thinking_budget
|
62
61
|
self.temperature = temperature
|
63
62
|
self.top_p = top_p
|
64
63
|
self.top_k = top_k
|
@@ -77,6 +76,7 @@ class SamplingParams:
|
|
77
76
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
78
77
|
self.no_stop_trim = no_stop_trim
|
79
78
|
self.custom_params = custom_params
|
79
|
+
self.stream_interval = stream_interval
|
80
80
|
|
81
81
|
# Process some special cases
|
82
82
|
if 0 <= self.temperature < _SAMPLING_EPS:
|
sglang/srt/server_args.py
CHANGED
@@ -98,6 +98,7 @@ class ServerArgs:
|
|
98
98
|
show_time_cost: bool = False
|
99
99
|
enable_metrics: bool = False
|
100
100
|
decode_log_interval: int = 40
|
101
|
+
enable_request_time_stats_logging: bool = False
|
101
102
|
|
102
103
|
# API related
|
103
104
|
api_key: Optional[str] = None
|
@@ -159,6 +160,7 @@ class ServerArgs:
|
|
159
160
|
disable_overlap_schedule: bool = False
|
160
161
|
enable_mixed_chunk: bool = False
|
161
162
|
enable_dp_attention: bool = False
|
163
|
+
enable_dp_lm_head: bool = False
|
162
164
|
enable_ep_moe: bool = False
|
163
165
|
enable_deepep_moe: bool = False
|
164
166
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
@@ -305,6 +307,12 @@ class ServerArgs:
|
|
305
307
|
if self.grammar_backend is None:
|
306
308
|
self.grammar_backend = "xgrammar"
|
307
309
|
|
310
|
+
if self.pp_size > 1:
|
311
|
+
self.disable_overlap_schedule = True
|
312
|
+
logger.warning(
|
313
|
+
"Overlap scheduler is disabled because of using pipeline parallelism."
|
314
|
+
)
|
315
|
+
|
308
316
|
# Data parallelism attention
|
309
317
|
if self.enable_dp_attention:
|
310
318
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
@@ -317,6 +325,11 @@ class ServerArgs:
|
|
317
325
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
318
326
|
)
|
319
327
|
|
328
|
+
if self.enable_dp_lm_head:
|
329
|
+
assert (
|
330
|
+
self.enable_dp_attention
|
331
|
+
), "Please enable dp attention when setting enable_dp_attention. "
|
332
|
+
|
320
333
|
# DeepEP MoE
|
321
334
|
self.enable_sp_layernorm = False
|
322
335
|
if self.enable_deepep_moe:
|
@@ -335,6 +348,12 @@ class ServerArgs:
|
|
335
348
|
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
336
349
|
)
|
337
350
|
|
351
|
+
if self.pp_size > 1:
|
352
|
+
self.disable_overlap_schedule = True
|
353
|
+
logger.warning(
|
354
|
+
"Pipeline parallelism is incompatible with overlap schedule."
|
355
|
+
)
|
356
|
+
|
338
357
|
# Speculative Decoding
|
339
358
|
if self.speculative_algorithm == "NEXTN":
|
340
359
|
# NEXTN shares the same implementation of EAGLE
|
@@ -767,6 +786,12 @@ class ServerArgs:
|
|
767
786
|
default=ServerArgs.decode_log_interval,
|
768
787
|
help="The log interval of decode batch.",
|
769
788
|
)
|
789
|
+
parser.add_argument(
|
790
|
+
"--enable-request-time-stats-logging",
|
791
|
+
action="store_true",
|
792
|
+
default=ServerArgs.enable_request_time_stats_logging,
|
793
|
+
help="Enable per request time stats logging",
|
794
|
+
)
|
770
795
|
|
771
796
|
# API related
|
772
797
|
parser.add_argument(
|
@@ -825,7 +850,7 @@ class ServerArgs:
|
|
825
850
|
# Multi-node distributed serving
|
826
851
|
parser.add_argument(
|
827
852
|
"--dist-init-addr",
|
828
|
-
"--nccl-init-addr", # For backward
|
853
|
+
"--nccl-init-addr", # For backward compatibility. This will be removed in the future.
|
829
854
|
type=str,
|
830
855
|
help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
|
831
856
|
)
|
@@ -1049,6 +1074,11 @@ class ServerArgs:
|
|
1049
1074
|
action="store_true",
|
1050
1075
|
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
|
1051
1076
|
)
|
1077
|
+
parser.add_argument(
|
1078
|
+
"--enable-dp-lm-head",
|
1079
|
+
action="store_true",
|
1080
|
+
help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
|
1081
|
+
)
|
1052
1082
|
parser.add_argument(
|
1053
1083
|
"--enable-ep-moe",
|
1054
1084
|
action="store_true",
|
@@ -1069,7 +1099,7 @@ class ServerArgs:
|
|
1069
1099
|
"--cuda-graph-max-bs",
|
1070
1100
|
type=int,
|
1071
1101
|
default=ServerArgs.cuda_graph_max_bs,
|
1072
|
-
help="Set the maximum batch size for cuda graph.",
|
1102
|
+
help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
|
1073
1103
|
)
|
1074
1104
|
parser.add_argument(
|
1075
1105
|
"--cuda-graph-bs",
|
@@ -1096,7 +1126,7 @@ class ServerArgs:
|
|
1096
1126
|
parser.add_argument(
|
1097
1127
|
"--triton-attention-reduce-in-fp32",
|
1098
1128
|
action="store_true",
|
1099
|
-
help="Cast the
|
1129
|
+
help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
|
1100
1130
|
"This only affects Triton attention kernels.",
|
1101
1131
|
)
|
1102
1132
|
parser.add_argument(
|
@@ -1188,7 +1218,7 @@ class ServerArgs:
|
|
1188
1218
|
type=int,
|
1189
1219
|
default=0,
|
1190
1220
|
help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
|
1191
|
-
"set it to tp_size can get best optimized
|
1221
|
+
"set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with n_share_experts_fusion automatically set to the TP size.",
|
1192
1222
|
)
|
1193
1223
|
parser.add_argument(
|
1194
1224
|
"--disable-chunked-prefix-cache",
|
@@ -82,12 +82,12 @@ class EAGLEDraftCudaGraphRunner:
|
|
82
82
|
self.capture()
|
83
83
|
except RuntimeError as e:
|
84
84
|
raise Exception(
|
85
|
-
f"Capture
|
85
|
+
f"Capture CUDA graph failed: {e}\n"
|
86
86
|
"Possible solutions:\n"
|
87
87
|
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
88
88
|
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
89
89
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
90
|
-
"4. disable
|
90
|
+
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
91
91
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
92
92
|
)
|
93
93
|
|
@@ -149,7 +149,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
149
149
|
|
150
150
|
# Run and capture
|
151
151
|
def run_once():
|
152
|
-
# Backup two
|
152
|
+
# Backup two fields, which will be modified in-place in `draft_forward`.
|
153
153
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
154
154
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
155
155
|
|
@@ -167,12 +167,12 @@ class EagleVerifyOutput:
|
|
167
167
|
draft_input: EagleDraftInput
|
168
168
|
# Logit outputs from target worker
|
169
169
|
logits_output: LogitsProcessorOutput
|
170
|
-
#
|
170
|
+
# Accepted token ids including the bonus token
|
171
171
|
verified_id: torch.Tensor
|
172
|
-
#
|
172
|
+
# Accepted token length per sequence in a batch in CPU.
|
173
173
|
accept_length_per_req_cpu: List[int]
|
174
|
-
#
|
175
|
-
|
174
|
+
# Accepted indices from logits_output.next_token_logits
|
175
|
+
accepted_indices: torch.Tensor
|
176
176
|
|
177
177
|
|
178
178
|
@dataclass
|
@@ -316,7 +316,7 @@ class EagleVerifyInput:
|
|
316
316
|
|
317
317
|
This API updates values inside logits_output based on the accepted
|
318
318
|
tokens. I.e., logits_output.next_token_logits only contains
|
319
|
-
|
319
|
+
accepted token logits.
|
320
320
|
"""
|
321
321
|
bs = self.retrive_index.shape[0]
|
322
322
|
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
@@ -493,7 +493,7 @@ class EagleVerifyInput:
|
|
493
493
|
logits_output=logits_output,
|
494
494
|
verified_id=verified_id,
|
495
495
|
accept_length_per_req_cpu=accept_length_cpu,
|
496
|
-
|
496
|
+
accepted_indices=accept_index,
|
497
497
|
)
|
498
498
|
else:
|
499
499
|
assign_req_to_token_pool[(bs,)](
|
@@ -539,7 +539,7 @@ class EagleVerifyInput:
|
|
539
539
|
logits_output=logits_output,
|
540
540
|
verified_id=verified_id,
|
541
541
|
accept_length_per_req_cpu=accept_length_cpu,
|
542
|
-
|
542
|
+
accepted_indices=accept_index,
|
543
543
|
)
|
544
544
|
|
545
545
|
|
@@ -201,7 +201,7 @@ class EAGLEWorker(TpModelWorker):
|
|
201
201
|
self.has_prefill_wrapper_verify = False
|
202
202
|
else:
|
203
203
|
raise ValueError(
|
204
|
-
f"EAGLE is not
|
204
|
+
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
205
205
|
)
|
206
206
|
|
207
207
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
@@ -245,14 +245,14 @@ class EAGLEWorker(TpModelWorker):
|
|
245
245
|
Args:
|
246
246
|
batch: The batch to run forward. The state of the batch is modified as it runs.
|
247
247
|
Returns:
|
248
|
-
A tuple of the final logit output of the target model, next tokens
|
249
|
-
the batch id (used for overlap schedule), and number of
|
248
|
+
A tuple of the final logit output of the target model, next tokens accepted,
|
249
|
+
the batch id (used for overlap schedule), and number of accepted tokens.
|
250
250
|
"""
|
251
251
|
if batch.forward_mode.is_decode():
|
252
252
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
253
253
|
spec_info = self.draft(batch)
|
254
|
-
logits_output, verify_output, model_worker_batch =
|
255
|
-
batch, spec_info
|
254
|
+
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
255
|
+
self.verify(batch, spec_info)
|
256
256
|
)
|
257
257
|
|
258
258
|
# If it is None, it means all requests are finished
|
@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker):
|
|
264
264
|
verify_output.verified_id,
|
265
265
|
model_worker_batch.bid,
|
266
266
|
sum(verify_output.accept_length_per_req_cpu),
|
267
|
+
can_run_cuda_graph,
|
267
268
|
)
|
268
269
|
elif batch.forward_mode.is_idle():
|
269
270
|
model_worker_batch = batch.get_model_worker_batch()
|
270
|
-
logits_output, next_token_ids =
|
271
|
-
model_worker_batch
|
271
|
+
logits_output, next_token_ids, _ = (
|
272
|
+
self.target_worker.forward_batch_generation(model_worker_batch)
|
272
273
|
)
|
273
274
|
|
274
|
-
return logits_output, next_token_ids, model_worker_batch.bid, 0
|
275
|
+
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
275
276
|
else:
|
276
277
|
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
277
278
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
278
279
|
self.forward_draft_extend(
|
279
280
|
batch, logits_output.hidden_states, next_token_ids
|
280
281
|
)
|
281
|
-
return logits_output, next_token_ids, bid, 0
|
282
|
+
return logits_output, next_token_ids, bid, 0, False
|
282
283
|
|
283
284
|
def forward_target_extend(
|
284
285
|
self, batch: ScheduleBatch
|
@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker):
|
|
297
298
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
298
299
|
model_worker_batch = batch.get_model_worker_batch()
|
299
300
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
300
|
-
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
301
|
+
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
301
302
|
model_worker_batch
|
302
303
|
)
|
303
304
|
return logits_output, next_token_ids, model_worker_batch.bid
|
@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker):
|
|
478
479
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
479
480
|
batch.spec_info = spec_info
|
480
481
|
model_worker_batch = batch.get_model_worker_batch()
|
481
|
-
logits_output, _ =
|
482
|
-
|
482
|
+
logits_output, _, can_run_cuda_graph = (
|
483
|
+
self.target_worker.forward_batch_generation(
|
484
|
+
model_worker_batch, skip_sample=True
|
485
|
+
)
|
483
486
|
)
|
484
487
|
self._detect_nan_if_needed(logits_output)
|
485
488
|
spec_info.hidden_states = logits_output.hidden_states
|
@@ -491,11 +494,11 @@ class EAGLEWorker(TpModelWorker):
|
|
491
494
|
)
|
492
495
|
|
493
496
|
# Post process based on verified outputs.
|
494
|
-
# Pick indices that we care (
|
497
|
+
# Pick indices that we care (accepted)
|
495
498
|
logits_output.next_token_logits = logits_output.next_token_logits[
|
496
|
-
res.
|
499
|
+
res.accepted_indices
|
497
500
|
]
|
498
|
-
logits_output.hidden_states = logits_output.hidden_states[res.
|
501
|
+
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
499
502
|
|
500
503
|
# Prepare the batch for the next draft forwards.
|
501
504
|
batch.forward_mode = ForwardMode.DECODE
|
@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker):
|
|
504
507
|
if batch.return_logprob:
|
505
508
|
self.add_logprob_values(batch, res, logits_output)
|
506
509
|
|
507
|
-
return logits_output, res, model_worker_batch
|
510
|
+
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
508
511
|
|
509
512
|
def add_logprob_values(
|
510
513
|
self,
|
@@ -590,14 +593,14 @@ class EAGLEWorker(TpModelWorker):
|
|
590
593
|
model_worker_batch, self.draft_model_runner
|
591
594
|
)
|
592
595
|
forward_batch.return_logprob = False
|
593
|
-
logits_output = self.draft_model_runner.forward(forward_batch)
|
596
|
+
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
594
597
|
self._detect_nan_if_needed(logits_output)
|
595
598
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
596
599
|
assert forward_batch.spec_info is batch.spec_info
|
597
600
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
598
601
|
|
599
602
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
600
|
-
# Backup
|
603
|
+
# Backup fields that will be modified in-place
|
601
604
|
seq_lens_backup = batch.seq_lens.clone()
|
602
605
|
req_pool_indices_backup = batch.req_pool_indices
|
603
606
|
accept_length_backup = batch.spec_info.accept_length
|
@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker):
|
|
617
620
|
)
|
618
621
|
|
619
622
|
# Run
|
620
|
-
logits_output = self.draft_model_runner.forward(forward_batch)
|
623
|
+
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
621
624
|
|
622
625
|
self._detect_nan_if_needed(logits_output)
|
623
626
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
sglang/srt/utils.py
CHANGED
@@ -282,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|
282
282
|
return wrapper
|
283
283
|
|
284
284
|
|
285
|
-
def get_available_gpu_memory(
|
285
|
+
def get_available_gpu_memory(
|
286
|
+
device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
|
287
|
+
):
|
286
288
|
"""
|
287
289
|
Get available memory for cuda:gpu_id device.
|
288
290
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
@@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
|
|
344
346
|
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
|
345
347
|
|
346
348
|
if distributed:
|
347
|
-
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
|
348
|
-
|
349
|
+
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
|
350
|
+
torch.distributed.all_reduce(
|
351
|
+
tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
|
349
352
|
)
|
350
|
-
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
351
353
|
free_gpu_memory = tensor.item()
|
352
354
|
|
353
355
|
return free_gpu_memory / (1 << 30)
|
@@ -2076,7 +2078,6 @@ def is_fa3_default_architecture(hf_config):
|
|
2076
2078
|
"Llama4ForConditionalGeneration",
|
2077
2079
|
"LlamaForCausalLM",
|
2078
2080
|
"MistralForCausalLM",
|
2079
|
-
"MixtralForCausalLM",
|
2080
2081
|
"Gemma2ForCausalLM",
|
2081
2082
|
"Gemma3ForConditionalGeneration",
|
2082
2083
|
"Qwen3ForCausalLM",
|
sglang/test/few_shot_gsm8k.py
CHANGED
@@ -90,7 +90,7 @@ def run_eval(args):
|
|
90
90
|
#####################################
|
91
91
|
|
92
92
|
# Run requests
|
93
|
-
tic = time.
|
93
|
+
tic = time.perf_counter()
|
94
94
|
states = few_shot_gsm8k.run_batch(
|
95
95
|
arguments,
|
96
96
|
temperature=args.temperature if hasattr(args, "temperature") else 0,
|
@@ -99,7 +99,7 @@ def run_eval(args):
|
|
99
99
|
return_logprob=getattr(args, "return_logprob", None),
|
100
100
|
logprob_start_len=getattr(args, "logprob_start_len", None),
|
101
101
|
)
|
102
|
-
latency = time.
|
102
|
+
latency = time.perf_counter() - tic
|
103
103
|
|
104
104
|
preds = []
|
105
105
|
for i in range(len(states)):
|
@@ -89,7 +89,7 @@ def run_eval(args):
|
|
89
89
|
}
|
90
90
|
|
91
91
|
# Run requests
|
92
|
-
tic = time.
|
92
|
+
tic = time.perf_counter()
|
93
93
|
|
94
94
|
loop = asyncio.get_event_loop()
|
95
95
|
|
@@ -98,7 +98,7 @@ def run_eval(args):
|
|
98
98
|
)
|
99
99
|
|
100
100
|
# End requests
|
101
|
-
latency = time.
|
101
|
+
latency = time.perf_counter() - tic
|
102
102
|
|
103
103
|
# Shutdown the engine
|
104
104
|
engine.shutdown()
|