sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Union
|
3
|
+
|
4
|
+
from fastapi import Request
|
5
|
+
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
7
|
+
ErrorResponse,
|
8
|
+
ScoringRequest,
|
9
|
+
ScoringResponse,
|
10
|
+
)
|
11
|
+
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class OpenAIServingScore(OpenAIServingBase):
|
17
|
+
"""Handler for /v1/score requests"""
|
18
|
+
|
19
|
+
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
|
20
|
+
# to another module in the future.
|
21
|
+
|
22
|
+
def _request_id_prefix(self) -> str:
|
23
|
+
return "score-"
|
24
|
+
|
25
|
+
def _convert_to_internal_request(
|
26
|
+
self,
|
27
|
+
request: ScoringRequest,
|
28
|
+
) -> tuple[ScoringRequest, ScoringRequest]:
|
29
|
+
"""Convert OpenAI scoring request to internal format"""
|
30
|
+
# For scoring, we pass the request directly as the tokenizer_manager
|
31
|
+
# has a specialized score_request method that doesn't use GenerateReqInput
|
32
|
+
|
33
|
+
return request, request
|
34
|
+
|
35
|
+
async def _handle_non_streaming_request(
|
36
|
+
self,
|
37
|
+
adapted_request: ScoringRequest,
|
38
|
+
request: ScoringRequest,
|
39
|
+
raw_request: Request,
|
40
|
+
) -> Union[ScoringResponse, ErrorResponse]:
|
41
|
+
"""Handle the scoring request"""
|
42
|
+
try:
|
43
|
+
# Use tokenizer_manager's score_request method directly
|
44
|
+
scores = await self.tokenizer_manager.score_request(
|
45
|
+
query=request.query,
|
46
|
+
items=request.items,
|
47
|
+
label_token_ids=request.label_token_ids,
|
48
|
+
apply_softmax=request.apply_softmax,
|
49
|
+
item_first=request.item_first,
|
50
|
+
request=raw_request,
|
51
|
+
)
|
52
|
+
|
53
|
+
# Create response with just the scores, without usage info
|
54
|
+
response = ScoringResponse(
|
55
|
+
scores=scores,
|
56
|
+
model=request.model,
|
57
|
+
)
|
58
|
+
return response
|
59
|
+
|
60
|
+
except ValueError as e:
|
61
|
+
return self.create_error_response(str(e))
|
@@ -0,0 +1,81 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Dict, List, Mapping, Optional, final
|
4
|
+
|
5
|
+
from sglang.srt.entrypoints.openai.protocol import UsageInfo
|
6
|
+
|
7
|
+
|
8
|
+
@final
|
9
|
+
class UsageProcessor:
|
10
|
+
"""Stateless helpers that turn raw token counts into a UsageInfo."""
|
11
|
+
|
12
|
+
@staticmethod
|
13
|
+
def _details_if_cached(count: int) -> Optional[Dict[str, int]]:
|
14
|
+
"""Return {"cached_tokens": N} only when N > 0 (keeps JSON slim)."""
|
15
|
+
return {"cached_tokens": count} if count > 0 else None
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def calculate_response_usage(
|
19
|
+
responses: List[Dict[str, Any]],
|
20
|
+
n_choices: int = 1,
|
21
|
+
enable_cache_report: bool = False,
|
22
|
+
) -> UsageInfo:
|
23
|
+
completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses)
|
24
|
+
|
25
|
+
prompt_tokens = sum(
|
26
|
+
responses[i]["meta_info"]["prompt_tokens"]
|
27
|
+
for i in range(0, len(responses), n_choices)
|
28
|
+
)
|
29
|
+
|
30
|
+
cached_details = None
|
31
|
+
if enable_cache_report:
|
32
|
+
cached_total = sum(
|
33
|
+
r["meta_info"].get("cached_tokens", 0) for r in responses
|
34
|
+
)
|
35
|
+
cached_details = UsageProcessor._details_if_cached(cached_total)
|
36
|
+
|
37
|
+
return UsageProcessor.calculate_token_usage(
|
38
|
+
prompt_tokens=prompt_tokens,
|
39
|
+
completion_tokens=completion_tokens,
|
40
|
+
cached_tokens=cached_details,
|
41
|
+
)
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def calculate_streaming_usage(
|
45
|
+
prompt_tokens: Mapping[int, int],
|
46
|
+
completion_tokens: Mapping[int, int],
|
47
|
+
cached_tokens: Mapping[int, int],
|
48
|
+
n_choices: int,
|
49
|
+
enable_cache_report: bool = False,
|
50
|
+
) -> UsageInfo:
|
51
|
+
# index % n_choices == 0 marks the first choice of a prompt
|
52
|
+
total_prompt_tokens = sum(
|
53
|
+
tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0
|
54
|
+
)
|
55
|
+
total_completion_tokens = sum(completion_tokens.values())
|
56
|
+
|
57
|
+
cached_details = (
|
58
|
+
UsageProcessor._details_if_cached(sum(cached_tokens.values()))
|
59
|
+
if enable_cache_report
|
60
|
+
else None
|
61
|
+
)
|
62
|
+
|
63
|
+
return UsageProcessor.calculate_token_usage(
|
64
|
+
prompt_tokens=total_prompt_tokens,
|
65
|
+
completion_tokens=total_completion_tokens,
|
66
|
+
cached_tokens=cached_details,
|
67
|
+
)
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def calculate_token_usage(
|
71
|
+
prompt_tokens: int,
|
72
|
+
completion_tokens: int,
|
73
|
+
cached_tokens: Optional[Dict[str, int]] = None,
|
74
|
+
) -> UsageInfo:
|
75
|
+
"""Calculate token usage information"""
|
76
|
+
return UsageInfo(
|
77
|
+
prompt_tokens=prompt_tokens,
|
78
|
+
completion_tokens=completion_tokens,
|
79
|
+
total_tokens=prompt_tokens + completion_tokens,
|
80
|
+
prompt_tokens_details=cached_tokens,
|
81
|
+
)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
3
|
+
|
4
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
5
|
+
ChatCompletionRequest,
|
6
|
+
CompletionRequest,
|
7
|
+
LogProbs,
|
8
|
+
)
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
def to_openai_style_logprobs(
|
14
|
+
input_token_logprobs=None,
|
15
|
+
output_token_logprobs=None,
|
16
|
+
input_top_logprobs=None,
|
17
|
+
output_top_logprobs=None,
|
18
|
+
):
|
19
|
+
ret_logprobs = LogProbs()
|
20
|
+
|
21
|
+
def append_token_logprobs(token_logprobs):
|
22
|
+
for logprob, _, token_text in token_logprobs:
|
23
|
+
ret_logprobs.tokens.append(token_text)
|
24
|
+
ret_logprobs.token_logprobs.append(logprob)
|
25
|
+
|
26
|
+
# Not supported yet
|
27
|
+
ret_logprobs.text_offset.append(-1)
|
28
|
+
|
29
|
+
def append_top_logprobs(top_logprobs):
|
30
|
+
for tokens in top_logprobs:
|
31
|
+
if tokens is not None:
|
32
|
+
ret_logprobs.top_logprobs.append(
|
33
|
+
{token[2]: token[0] for token in tokens}
|
34
|
+
)
|
35
|
+
else:
|
36
|
+
ret_logprobs.top_logprobs.append(None)
|
37
|
+
|
38
|
+
if input_token_logprobs is not None:
|
39
|
+
append_token_logprobs(input_token_logprobs)
|
40
|
+
if output_token_logprobs is not None:
|
41
|
+
append_token_logprobs(output_token_logprobs)
|
42
|
+
if input_top_logprobs is not None:
|
43
|
+
append_top_logprobs(input_top_logprobs)
|
44
|
+
if output_top_logprobs is not None:
|
45
|
+
append_top_logprobs(output_top_logprobs)
|
46
|
+
|
47
|
+
return ret_logprobs
|
48
|
+
|
49
|
+
|
50
|
+
def process_hidden_states_from_ret(
|
51
|
+
ret_item: Dict[str, Any],
|
52
|
+
request: Union[
|
53
|
+
ChatCompletionRequest,
|
54
|
+
CompletionRequest,
|
55
|
+
],
|
56
|
+
) -> Optional[List]:
|
57
|
+
"""Process hidden states from a ret item in non-streaming response.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
ret_item: Response item containing meta_info
|
61
|
+
request: The original request object
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Processed hidden states for the last token, or None
|
65
|
+
"""
|
66
|
+
if not request.return_hidden_states:
|
67
|
+
return None
|
68
|
+
|
69
|
+
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
70
|
+
if hidden_states is not None:
|
71
|
+
hidden_states = hidden_states[-1] if len(hidden_states) > 1 else []
|
72
|
+
return hidden_states
|
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
|
|
6
6
|
from partial_json_parser.core.exceptions import MalformedJSON
|
7
7
|
from partial_json_parser.core.options import Allow
|
8
8
|
|
9
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
9
10
|
from sglang.srt.function_call.core_types import (
|
10
11
|
StreamingParseResult,
|
11
12
|
ToolCallItem,
|
@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
|
|
16
17
|
_is_complete_json,
|
17
18
|
_partial_json_loads,
|
18
19
|
)
|
19
|
-
from sglang.srt.openai_api.protocol import Tool
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
22
22
|
|
@@ -111,11 +111,10 @@ class BaseFormatDetector(ABC):
|
|
111
111
|
# The current_text has tool_call if it is the start of a new tool call sequence
|
112
112
|
# or it is the start of a new tool call after a tool call separator, when there is a previous tool call
|
113
113
|
if not (
|
114
|
-
self.
|
115
|
-
or current_text.startswith("{")
|
114
|
+
self.has_tool_call(current_text)
|
116
115
|
or (
|
117
116
|
self.current_tool_id > 0
|
118
|
-
and current_text.startswith(self.tool_call_separator
|
117
|
+
and current_text.startswith(self.tool_call_separator)
|
119
118
|
)
|
120
119
|
):
|
121
120
|
# Only clear buffer if we're sure no tool call is starting
|
@@ -143,6 +142,10 @@ class BaseFormatDetector(ABC):
|
|
143
142
|
try:
|
144
143
|
if current_text.startswith(self.bot_token):
|
145
144
|
start_idx = len(self.bot_token)
|
145
|
+
elif self.current_tool_id > 0 and current_text.startswith(
|
146
|
+
self.tool_call_separator + self.bot_token
|
147
|
+
):
|
148
|
+
start_idx = len(self.tool_call_separator + self.bot_token)
|
146
149
|
elif self.current_tool_id > 0 and current_text.startswith(
|
147
150
|
self.tool_call_separator
|
148
151
|
):
|
@@ -3,6 +3,7 @@ import logging
|
|
3
3
|
import re
|
4
4
|
from typing import List
|
5
5
|
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
6
7
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
7
8
|
from sglang.srt.function_call.core_types import (
|
8
9
|
StreamingParseResult,
|
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
|
|
12
13
|
)
|
13
14
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
14
15
|
from sglang.srt.function_call.utils import _is_complete_json
|
15
|
-
from sglang.srt.openai_api.protocol import Tool
|
16
16
|
|
17
17
|
logger = logging.getLogger(__name__)
|
18
18
|
|
@@ -211,20 +211,74 @@ class EBNFComposer:
|
|
211
211
|
properties = params.get("properties", {})
|
212
212
|
required_props = set(params.get("required", []))
|
213
213
|
|
214
|
-
#
|
215
|
-
|
214
|
+
# The generated pattern ensures:
|
215
|
+
# 1. Required properties appear first, joined by commas
|
216
|
+
# 2. Optional properties are wrapped with comma included: ( "," ( "prop" : value )? )?
|
217
|
+
# 3. For multiple optional properties, we allow flexible ordering:
|
218
|
+
# - Each optional can be skipped entirely
|
219
|
+
# - They can appear in any combination
|
220
|
+
#
|
221
|
+
# Example patterns generated:
|
222
|
+
# - One required, one optional:
|
223
|
+
# "{" "location" ":" string ( "," ( "unit" ":" enum ) )? "}"
|
224
|
+
# Allows: {"location": "Paris"} or {"location": "Paris", "unit": "celsius"}
|
225
|
+
#
|
226
|
+
# - Multiple optional properties with flexible ordering:
|
227
|
+
# "{" "req" ":" string ( "," ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value ) )? "}"
|
228
|
+
# Allows: {"req": "x"}, {"req": "x", "opt1": "y"}, {"req": "x", "opt2": "z"},
|
229
|
+
# {"req": "x", "opt1": "y", "opt2": "z"}
|
230
|
+
#
|
231
|
+
# - All optional properties with flexible ordering:
|
232
|
+
# "{" ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value )? "}"
|
233
|
+
# Allows: {}, {"opt1": "x"}, {"opt2": "y"}, {"opt1": "x", "opt2": "y"}
|
234
|
+
|
235
|
+
prop_kv_pairs = {}
|
236
|
+
ordered_props = list(properties.keys())
|
237
|
+
|
216
238
|
for prop_name, prop_schema in properties.items():
|
217
239
|
value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
|
218
240
|
# Create key=value pair
|
219
241
|
pair = key_value_template.format(key=prop_name, valrule=value_rule)
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
#
|
227
|
-
|
242
|
+
prop_kv_pairs[prop_name] = pair
|
243
|
+
|
244
|
+
# Separate into required and optional while preserving order
|
245
|
+
required = [p for p in ordered_props if p in required_props]
|
246
|
+
optional = [p for p in ordered_props if p not in required_props]
|
247
|
+
|
248
|
+
# Build the combined rule
|
249
|
+
rule_parts = []
|
250
|
+
|
251
|
+
# Add required properties joined by commas
|
252
|
+
if required:
|
253
|
+
rule_parts.append(' "," '.join(prop_kv_pairs[k] for k in required))
|
254
|
+
|
255
|
+
# Add optional properties with flexible ordering
|
256
|
+
if optional:
|
257
|
+
# Build alternatives where any optional property can appear first
|
258
|
+
opt_alternatives = []
|
259
|
+
for i in range(len(optional)):
|
260
|
+
# Build pattern for optional[i] appearing first
|
261
|
+
opt_parts = []
|
262
|
+
for j in range(i, len(optional)):
|
263
|
+
if j == i:
|
264
|
+
opt_parts.append(prop_kv_pairs[optional[j]])
|
265
|
+
else:
|
266
|
+
opt_parts.append(f' ( "," {prop_kv_pairs[optional[j]]} )?')
|
267
|
+
opt_alternatives.append("".join(opt_parts))
|
268
|
+
|
269
|
+
# Wrap with appropriate comma handling based on whether we have required properties
|
270
|
+
if required:
|
271
|
+
# Required properties exist, so optional group needs outer comma
|
272
|
+
rule_parts.append(' ( "," ( ')
|
273
|
+
rule_parts.append(" | ".join(opt_alternatives))
|
274
|
+
rule_parts.append(" ) )?")
|
275
|
+
else:
|
276
|
+
# All properties are optional
|
277
|
+
rule_parts.append("( ")
|
278
|
+
rule_parts.append(" | ".join(opt_alternatives))
|
279
|
+
rule_parts.append(" )?")
|
280
|
+
|
281
|
+
combined_args = "".join(rule_parts)
|
228
282
|
arguments_rule = args_template.format(arg_rules=combined_args)
|
229
283
|
|
230
284
|
# Add the function call rule and its arguments rule
|
@@ -1,6 +1,12 @@
|
|
1
1
|
import logging
|
2
2
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
3
3
|
|
4
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
5
|
+
StructuralTagResponseFormat,
|
6
|
+
StructuresResponseFormat,
|
7
|
+
Tool,
|
8
|
+
ToolChoice,
|
9
|
+
)
|
4
10
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
5
11
|
from sglang.srt.function_call.core_types import ToolCallItem
|
6
12
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
|
|
8
14
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
9
15
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
10
16
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
11
|
-
from sglang.srt.openai_api.protocol import (
|
12
|
-
StructuralTagResponseFormat,
|
13
|
-
StructuresResponseFormat,
|
14
|
-
Tool,
|
15
|
-
ToolChoice,
|
16
|
-
)
|
17
17
|
|
18
18
|
logger = logging.getLogger(__name__)
|
19
19
|
|
@@ -2,6 +2,7 @@ import json
|
|
2
2
|
import logging
|
3
3
|
from typing import List
|
4
4
|
|
5
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
5
6
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
6
7
|
from sglang.srt.function_call.core_types import (
|
7
8
|
StreamingParseResult,
|
@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
|
|
9
10
|
_GetInfoFunc,
|
10
11
|
)
|
11
12
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
12
|
-
from sglang.srt.openai_api.protocol import Tool
|
13
13
|
|
14
14
|
logger = logging.getLogger(__name__)
|
15
15
|
|
@@ -3,6 +3,7 @@ import logging
|
|
3
3
|
import re
|
4
4
|
from typing import List
|
5
5
|
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
6
7
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
7
8
|
from sglang.srt.function_call.core_types import (
|
8
9
|
StreamingParseResult,
|
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
|
|
10
11
|
_GetInfoFunc,
|
11
12
|
)
|
12
13
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
13
|
-
from sglang.srt.openai_api.protocol import Tool
|
14
14
|
|
15
15
|
logger = logging.getLogger(__name__)
|
16
16
|
|
@@ -4,6 +4,7 @@ import logging
|
|
4
4
|
import re
|
5
5
|
from typing import List, Optional
|
6
6
|
|
7
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
7
8
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
8
9
|
from sglang.srt.function_call.core_types import (
|
9
10
|
StreamingParseResult,
|
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
|
|
12
13
|
_GetInfoFunc,
|
13
14
|
)
|
14
15
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
15
|
-
from sglang.srt.openai_api.protocol import Tool
|
16
16
|
|
17
17
|
logger = logging.getLogger(__name__)
|
18
18
|
|
@@ -3,6 +3,7 @@ import logging
|
|
3
3
|
import re
|
4
4
|
from typing import List
|
5
5
|
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
6
7
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
7
8
|
from sglang.srt.function_call.core_types import (
|
8
9
|
StreamingParseResult,
|
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
|
|
10
11
|
_GetInfoFunc,
|
11
12
|
)
|
12
13
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
13
|
-
from sglang.srt.openai_api.protocol import Tool
|
14
14
|
|
15
15
|
logger = logging.getLogger(__name__)
|
16
16
|
|
@@ -1,11 +1,12 @@
|
|
1
|
-
"""
|
2
|
-
|
1
|
+
"""Template utilities for Jinja template processing.
|
2
|
+
|
3
|
+
This module provides utilities for analyzing and processing Jinja chat templates,
|
4
|
+
including content format detection and message processing.
|
3
5
|
"""
|
4
6
|
|
5
7
|
import logging
|
6
|
-
from typing import Dict, List
|
7
8
|
|
8
|
-
import jinja2
|
9
|
+
import jinja2
|
9
10
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
10
11
|
|
11
12
|
logger = logging.getLogger(__name__)
|
@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
|
|
75
76
|
return None
|
76
77
|
|
77
78
|
|
78
|
-
def
|
79
|
+
def detect_jinja_template_content_format(chat_template: str) -> str:
|
79
80
|
"""
|
80
81
|
Detect whether a chat template expects 'string' or 'openai' content format.
|
81
82
|
|
sglang/srt/layers/activation.py
CHANGED
@@ -29,10 +29,19 @@ from sglang.srt.distributed import (
|
|
29
29
|
get_tensor_model_parallel_world_size,
|
30
30
|
)
|
31
31
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
32
|
-
from sglang.srt.utils import
|
32
|
+
from sglang.srt.utils import (
|
33
|
+
cpu_has_amx_support,
|
34
|
+
is_cpu,
|
35
|
+
is_cuda,
|
36
|
+
is_npu,
|
37
|
+
set_weight_attrs,
|
38
|
+
)
|
33
39
|
from sglang.utils import resolve_obj_by_qualname
|
34
40
|
|
35
41
|
_is_cuda = is_cuda()
|
42
|
+
_is_npu = is_npu()
|
43
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
44
|
+
_is_cpu = is_cpu()
|
36
45
|
|
37
46
|
if _is_cuda:
|
38
47
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
@@ -52,6 +61,15 @@ class SiluAndMul(CustomOp):
|
|
52
61
|
silu_and_mul(x, out)
|
53
62
|
return out
|
54
63
|
|
64
|
+
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
65
|
+
if _is_cpu_amx_available:
|
66
|
+
d = x.shape[-1] // 2
|
67
|
+
output_shape = x.shape[:-1] + (d,)
|
68
|
+
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
69
|
+
return out
|
70
|
+
else:
|
71
|
+
return self.forward_native(x)
|
72
|
+
|
55
73
|
|
56
74
|
class GeluAndMul(CustomOp):
|
57
75
|
def __init__(self, approximate="tanh"):
|
@@ -184,8 +202,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
|
184
202
|
return nn.Identity()
|
185
203
|
|
186
204
|
|
187
|
-
if not _is_cuda:
|
205
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
188
206
|
logger.info(
|
189
|
-
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
207
|
+
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
190
208
|
)
|
191
209
|
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
|
|
324
324
|
)
|
325
325
|
|
326
326
|
def init_cuda_graph_state(
|
327
|
-
self,
|
327
|
+
self,
|
328
|
+
max_bs: int,
|
329
|
+
max_num_tokens: int,
|
330
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
328
331
|
):
|
329
332
|
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
|
330
333
|
if kv_indices_buf is None:
|
@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
338
341
|
|
339
342
|
if not self.skip_prefill:
|
340
343
|
self.cuda_graph_custom_mask = torch.zeros(
|
341
|
-
(
|
344
|
+
(max_num_tokens * self.max_context_len),
|
342
345
|
dtype=torch.uint8,
|
343
346
|
device=self.device,
|
344
347
|
)
|
@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
|
|
19
19
|
"""Init the metadata for a forward pass."""
|
20
20
|
raise NotImplementedError()
|
21
21
|
|
22
|
-
def init_cuda_graph_state(self, max_bs: int):
|
22
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
23
23
|
"""Init the global shared states for cuda graph."""
|
24
24
|
raise NotImplementedError()
|
25
25
|
|
@@ -1120,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1120
1120
|
|
1121
1121
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
1122
1122
|
|
1123
|
-
def init_cuda_graph_state(self, max_bs: int):
|
1123
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
1124
1124
|
"""Initialize CUDA graph state for the attention backend.
|
1125
1125
|
|
1126
1126
|
Args:
|
@@ -1704,14 +1704,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1704
1704
|
|
1705
1705
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1706
1706
|
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
1707
|
+
|
1707
1708
|
# metadata_expand.max_seq_len_q = 1, already set in capture
|
1708
1709
|
# metadata_expand.cu_seqlens_q already set in capture
|
1709
|
-
|
1710
1710
|
offsets = torch.arange(
|
1711
1711
|
self.speculative_num_draft_tokens, device=device
|
1712
1712
|
).unsqueeze(
|
1713
1713
|
0
|
1714
1714
|
) # shape: (1, self.speculative_num_draft_tokens)
|
1715
|
+
|
1715
1716
|
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
1716
1717
|
cum_len = torch.nn.functional.pad(
|
1717
1718
|
torch.cumsum(
|
@@ -1728,17 +1729,20 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1728
1729
|
).view(1, -1)
|
1729
1730
|
# avoid extracting padded seq indices which will be out of boundary
|
1730
1731
|
mask_extraction_indices[
|
1731
|
-
:,
|
1732
|
+
:,
|
1733
|
+
spec_info.positions.numel() * self.speculative_num_draft_tokens :,
|
1732
1734
|
].fill_(0)
|
1733
|
-
|
1734
1735
|
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
1735
1736
|
-1, self.speculative_num_draft_tokens
|
1736
1737
|
) # (bsz * draft_num, draft_num)
|
1738
|
+
|
1737
1739
|
col_indices = offsets.expand(
|
1738
1740
|
mask.shape[0], self.speculative_num_draft_tokens
|
1739
1741
|
)
|
1740
1742
|
keys = torch.where(
|
1741
|
-
mask,
|
1743
|
+
mask,
|
1744
|
+
col_indices,
|
1745
|
+
col_indices + self.speculative_num_draft_tokens,
|
1742
1746
|
)
|
1743
1747
|
_, sort_order = torch.sort(keys, dim=1)
|
1744
1748
|
|
@@ -1747,6 +1751,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1747
1751
|
.gather(1, cols)
|
1748
1752
|
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
1749
1753
|
) # (bsz, draft_num)
|
1754
|
+
|
1750
1755
|
metadata_expand.page_table.copy_(
|
1751
1756
|
non_masked_page_table.gather(1, sort_order)
|
1752
1757
|
)
|
@@ -1758,6 +1763,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1758
1763
|
dtype=torch.int32,
|
1759
1764
|
)
|
1760
1765
|
)
|
1766
|
+
|
1761
1767
|
elif forward_mode.is_draft_extend():
|
1762
1768
|
metadata = self.draft_extend_metadata[bs]
|
1763
1769
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
@@ -1767,7 +1773,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1767
1773
|
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
1768
1774
|
)
|
1769
1775
|
accept_length = spec_info.accept_length[:bs]
|
1770
|
-
|
1776
|
+
if spec_info.accept_length_cpu:
|
1777
|
+
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
1778
|
+
else:
|
1779
|
+
metadata.max_seq_len_q = 1
|
1780
|
+
|
1771
1781
|
metadata.cu_seqlens_q[1:].copy_(
|
1772
1782
|
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
1773
1783
|
)
|
@@ -1807,7 +1817,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1807
1817
|
|
1808
1818
|
def get_cuda_graph_seq_len_fill_value(self):
|
1809
1819
|
"""Get the fill value for sequence length in CUDA graph."""
|
1810
|
-
return
|
1820
|
+
return 1
|
1811
1821
|
|
1812
1822
|
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
1813
1823
|
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
@@ -1999,9 +2009,9 @@ class FlashAttentionMultiStepBackend:
|
|
1999
2009
|
for i in range(self.speculative_num_steps - 1):
|
2000
2010
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
2001
2011
|
|
2002
|
-
def init_cuda_graph_state(self, max_bs: int):
|
2012
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
2003
2013
|
for i in range(self.speculative_num_steps):
|
2004
|
-
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
2014
|
+
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
2005
2015
|
|
2006
2016
|
def init_forward_metadata_capture_cuda_graph(
|
2007
2017
|
self,
|