sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Union
|
3
|
+
|
4
|
+
from fastapi import Request
|
5
|
+
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
7
|
+
ErrorResponse,
|
8
|
+
ScoringRequest,
|
9
|
+
ScoringResponse,
|
10
|
+
)
|
11
|
+
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class OpenAIServingScore(OpenAIServingBase):
|
17
|
+
"""Handler for /v1/score requests"""
|
18
|
+
|
19
|
+
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
|
20
|
+
# to another module in the future.
|
21
|
+
|
22
|
+
def _request_id_prefix(self) -> str:
|
23
|
+
return "score-"
|
24
|
+
|
25
|
+
def _convert_to_internal_request(
|
26
|
+
self,
|
27
|
+
request: ScoringRequest,
|
28
|
+
) -> tuple[ScoringRequest, ScoringRequest]:
|
29
|
+
"""Convert OpenAI scoring request to internal format"""
|
30
|
+
# For scoring, we pass the request directly as the tokenizer_manager
|
31
|
+
# has a specialized score_request method that doesn't use GenerateReqInput
|
32
|
+
|
33
|
+
return request, request
|
34
|
+
|
35
|
+
async def _handle_non_streaming_request(
|
36
|
+
self,
|
37
|
+
adapted_request: ScoringRequest,
|
38
|
+
request: ScoringRequest,
|
39
|
+
raw_request: Request,
|
40
|
+
) -> Union[ScoringResponse, ErrorResponse]:
|
41
|
+
"""Handle the scoring request"""
|
42
|
+
try:
|
43
|
+
# Use tokenizer_manager's score_request method directly
|
44
|
+
scores = await self.tokenizer_manager.score_request(
|
45
|
+
query=request.query,
|
46
|
+
items=request.items,
|
47
|
+
label_token_ids=request.label_token_ids,
|
48
|
+
apply_softmax=request.apply_softmax,
|
49
|
+
item_first=request.item_first,
|
50
|
+
request=raw_request,
|
51
|
+
)
|
52
|
+
|
53
|
+
# Create response with just the scores, without usage info
|
54
|
+
response = ScoringResponse(
|
55
|
+
scores=scores,
|
56
|
+
model=request.model,
|
57
|
+
)
|
58
|
+
return response
|
59
|
+
|
60
|
+
except ValueError as e:
|
61
|
+
return self.create_error_response(str(e))
|
@@ -0,0 +1,81 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Dict, List, Mapping, Optional, final
|
4
|
+
|
5
|
+
from sglang.srt.entrypoints.openai.protocol import UsageInfo
|
6
|
+
|
7
|
+
|
8
|
+
@final
|
9
|
+
class UsageProcessor:
|
10
|
+
"""Stateless helpers that turn raw token counts into a UsageInfo."""
|
11
|
+
|
12
|
+
@staticmethod
|
13
|
+
def _details_if_cached(count: int) -> Optional[Dict[str, int]]:
|
14
|
+
"""Return {"cached_tokens": N} only when N > 0 (keeps JSON slim)."""
|
15
|
+
return {"cached_tokens": count} if count > 0 else None
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def calculate_response_usage(
|
19
|
+
responses: List[Dict[str, Any]],
|
20
|
+
n_choices: int = 1,
|
21
|
+
enable_cache_report: bool = False,
|
22
|
+
) -> UsageInfo:
|
23
|
+
completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses)
|
24
|
+
|
25
|
+
prompt_tokens = sum(
|
26
|
+
responses[i]["meta_info"]["prompt_tokens"]
|
27
|
+
for i in range(0, len(responses), n_choices)
|
28
|
+
)
|
29
|
+
|
30
|
+
cached_details = None
|
31
|
+
if enable_cache_report:
|
32
|
+
cached_total = sum(
|
33
|
+
r["meta_info"].get("cached_tokens", 0) for r in responses
|
34
|
+
)
|
35
|
+
cached_details = UsageProcessor._details_if_cached(cached_total)
|
36
|
+
|
37
|
+
return UsageProcessor.calculate_token_usage(
|
38
|
+
prompt_tokens=prompt_tokens,
|
39
|
+
completion_tokens=completion_tokens,
|
40
|
+
cached_tokens=cached_details,
|
41
|
+
)
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def calculate_streaming_usage(
|
45
|
+
prompt_tokens: Mapping[int, int],
|
46
|
+
completion_tokens: Mapping[int, int],
|
47
|
+
cached_tokens: Mapping[int, int],
|
48
|
+
n_choices: int,
|
49
|
+
enable_cache_report: bool = False,
|
50
|
+
) -> UsageInfo:
|
51
|
+
# index % n_choices == 0 marks the first choice of a prompt
|
52
|
+
total_prompt_tokens = sum(
|
53
|
+
tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0
|
54
|
+
)
|
55
|
+
total_completion_tokens = sum(completion_tokens.values())
|
56
|
+
|
57
|
+
cached_details = (
|
58
|
+
UsageProcessor._details_if_cached(sum(cached_tokens.values()))
|
59
|
+
if enable_cache_report
|
60
|
+
else None
|
61
|
+
)
|
62
|
+
|
63
|
+
return UsageProcessor.calculate_token_usage(
|
64
|
+
prompt_tokens=total_prompt_tokens,
|
65
|
+
completion_tokens=total_completion_tokens,
|
66
|
+
cached_tokens=cached_details,
|
67
|
+
)
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def calculate_token_usage(
|
71
|
+
prompt_tokens: int,
|
72
|
+
completion_tokens: int,
|
73
|
+
cached_tokens: Optional[Dict[str, int]] = None,
|
74
|
+
) -> UsageInfo:
|
75
|
+
"""Calculate token usage information"""
|
76
|
+
return UsageInfo(
|
77
|
+
prompt_tokens=prompt_tokens,
|
78
|
+
completion_tokens=completion_tokens,
|
79
|
+
total_tokens=prompt_tokens + completion_tokens,
|
80
|
+
prompt_tokens_details=cached_tokens,
|
81
|
+
)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
3
|
+
|
4
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
5
|
+
ChatCompletionRequest,
|
6
|
+
CompletionRequest,
|
7
|
+
LogProbs,
|
8
|
+
)
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
def to_openai_style_logprobs(
|
14
|
+
input_token_logprobs=None,
|
15
|
+
output_token_logprobs=None,
|
16
|
+
input_top_logprobs=None,
|
17
|
+
output_top_logprobs=None,
|
18
|
+
):
|
19
|
+
ret_logprobs = LogProbs()
|
20
|
+
|
21
|
+
def append_token_logprobs(token_logprobs):
|
22
|
+
for logprob, _, token_text in token_logprobs:
|
23
|
+
ret_logprobs.tokens.append(token_text)
|
24
|
+
ret_logprobs.token_logprobs.append(logprob)
|
25
|
+
|
26
|
+
# Not supported yet
|
27
|
+
ret_logprobs.text_offset.append(-1)
|
28
|
+
|
29
|
+
def append_top_logprobs(top_logprobs):
|
30
|
+
for tokens in top_logprobs:
|
31
|
+
if tokens is not None:
|
32
|
+
ret_logprobs.top_logprobs.append(
|
33
|
+
{token[2]: token[0] for token in tokens}
|
34
|
+
)
|
35
|
+
else:
|
36
|
+
ret_logprobs.top_logprobs.append(None)
|
37
|
+
|
38
|
+
if input_token_logprobs is not None:
|
39
|
+
append_token_logprobs(input_token_logprobs)
|
40
|
+
if output_token_logprobs is not None:
|
41
|
+
append_token_logprobs(output_token_logprobs)
|
42
|
+
if input_top_logprobs is not None:
|
43
|
+
append_top_logprobs(input_top_logprobs)
|
44
|
+
if output_top_logprobs is not None:
|
45
|
+
append_top_logprobs(output_top_logprobs)
|
46
|
+
|
47
|
+
return ret_logprobs
|
48
|
+
|
49
|
+
|
50
|
+
def process_hidden_states_from_ret(
|
51
|
+
ret_item: Dict[str, Any],
|
52
|
+
request: Union[
|
53
|
+
ChatCompletionRequest,
|
54
|
+
CompletionRequest,
|
55
|
+
],
|
56
|
+
) -> Optional[List]:
|
57
|
+
"""Process hidden states from a ret item in non-streaming response.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
ret_item: Response item containing meta_info
|
61
|
+
request: The original request object
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Processed hidden states for the last token, or None
|
65
|
+
"""
|
66
|
+
if not request.return_hidden_states:
|
67
|
+
return None
|
68
|
+
|
69
|
+
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
70
|
+
if hidden_states is not None:
|
71
|
+
hidden_states = hidden_states[-1] if len(hidden_states) > 1 else []
|
72
|
+
return hidden_states
|
@@ -0,0 +1 @@
|
|
1
|
+
from . import reader
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from tqdm import tqdm
|
6
|
+
|
7
|
+
from sglang.srt.managers.expert_distribution import (
|
8
|
+
_convert_global_physical_count_to_logical_count,
|
9
|
+
)
|
10
|
+
|
11
|
+
convert_global_physical_count_to_logical_count = (
|
12
|
+
_convert_global_physical_count_to_logical_count
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
def read_mode_per_pass(dir_data: Path):
|
17
|
+
"""Read data from ExpertDistributionRecorder when recorded with mode `per_pass`"""
|
18
|
+
|
19
|
+
# gpc := global_physical_count
|
20
|
+
gpc_of_forward_pass_and_rank = defaultdict(lambda: defaultdict())
|
21
|
+
for path in tqdm(list(dir_data.glob("*.pt"))):
|
22
|
+
data_pack = torch.load(path, weights_only=True)
|
23
|
+
last_physical_to_logical_map = data_pack["last_physical_to_logical_map"]
|
24
|
+
for record in data_pack["records"]:
|
25
|
+
forward_pass_id = record["forward_pass_id"]
|
26
|
+
rank = record["rank"]
|
27
|
+
assert (
|
28
|
+
gpc_of_forward_pass_and_rank[forward_pass_id].get(rank) is None
|
29
|
+
), f"Duplicated {forward_pass_id=} {rank=}"
|
30
|
+
gpc_of_forward_pass_and_rank[forward_pass_id][rank] = record[
|
31
|
+
"global_physical_count"
|
32
|
+
]
|
33
|
+
|
34
|
+
forward_pass_ids = sorted(gpc_of_forward_pass_and_rank.keys())
|
35
|
+
print(f"Make {forward_pass_ids=} into array")
|
36
|
+
|
37
|
+
items = []
|
38
|
+
for forward_pass_id, gpc_of_rank in sorted(gpc_of_forward_pass_and_rank.items()):
|
39
|
+
gpc_of_rank_tensor = torch.stack(
|
40
|
+
[gpc for rank, gpc in sorted(gpc_of_rank.items())]
|
41
|
+
).sum(dim=0)
|
42
|
+
items.append(gpc_of_rank_tensor)
|
43
|
+
|
44
|
+
gpc_of_forward_pass = torch.stack(items)
|
45
|
+
print(f"{gpc_of_forward_pass.shape=}")
|
46
|
+
|
47
|
+
return dict(
|
48
|
+
global_physical_count_of_forward_pass=gpc_of_forward_pass,
|
49
|
+
last_physical_to_logical_map=last_physical_to_logical_map,
|
50
|
+
forward_pass_ids=forward_pass_ids,
|
51
|
+
)
|
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
|
|
6
6
|
from partial_json_parser.core.exceptions import MalformedJSON
|
7
7
|
from partial_json_parser.core.options import Allow
|
8
8
|
|
9
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
9
10
|
from sglang.srt.function_call.core_types import (
|
10
11
|
StreamingParseResult,
|
11
12
|
ToolCallItem,
|
@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
|
|
16
17
|
_is_complete_json,
|
17
18
|
_partial_json_loads,
|
18
19
|
)
|
19
|
-
from sglang.srt.openai_api.protocol import Tool
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
22
22
|
|
@@ -111,11 +111,10 @@ class BaseFormatDetector(ABC):
|
|
111
111
|
# The current_text has tool_call if it is the start of a new tool call sequence
|
112
112
|
# or it is the start of a new tool call after a tool call separator, when there is a previous tool call
|
113
113
|
if not (
|
114
|
-
self.
|
115
|
-
or current_text.startswith("{")
|
114
|
+
self.has_tool_call(current_text)
|
116
115
|
or (
|
117
116
|
self.current_tool_id > 0
|
118
|
-
and current_text.startswith(self.tool_call_separator
|
117
|
+
and current_text.startswith(self.tool_call_separator)
|
119
118
|
)
|
120
119
|
):
|
121
120
|
# Only clear buffer if we're sure no tool call is starting
|
@@ -143,6 +142,10 @@ class BaseFormatDetector(ABC):
|
|
143
142
|
try:
|
144
143
|
if current_text.startswith(self.bot_token):
|
145
144
|
start_idx = len(self.bot_token)
|
145
|
+
elif self.current_tool_id > 0 and current_text.startswith(
|
146
|
+
self.tool_call_separator + self.bot_token
|
147
|
+
):
|
148
|
+
start_idx = len(self.tool_call_separator + self.bot_token)
|
146
149
|
elif self.current_tool_id > 0 and current_text.startswith(
|
147
150
|
self.tool_call_separator
|
148
151
|
):
|
@@ -3,6 +3,7 @@ import logging
|
|
3
3
|
import re
|
4
4
|
from typing import List
|
5
5
|
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
6
7
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
7
8
|
from sglang.srt.function_call.core_types import (
|
8
9
|
StreamingParseResult,
|
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
|
|
12
13
|
)
|
13
14
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
14
15
|
from sglang.srt.function_call.utils import _is_complete_json
|
15
|
-
from sglang.srt.openai_api.protocol import Tool
|
16
16
|
|
17
17
|
logger = logging.getLogger(__name__)
|
18
18
|
|
@@ -211,20 +211,74 @@ class EBNFComposer:
|
|
211
211
|
properties = params.get("properties", {})
|
212
212
|
required_props = set(params.get("required", []))
|
213
213
|
|
214
|
-
#
|
215
|
-
|
214
|
+
# The generated pattern ensures:
|
215
|
+
# 1. Required properties appear first, joined by commas
|
216
|
+
# 2. Optional properties are wrapped with comma included: ( "," ( "prop" : value )? )?
|
217
|
+
# 3. For multiple optional properties, we allow flexible ordering:
|
218
|
+
# - Each optional can be skipped entirely
|
219
|
+
# - They can appear in any combination
|
220
|
+
#
|
221
|
+
# Example patterns generated:
|
222
|
+
# - One required, one optional:
|
223
|
+
# "{" "location" ":" string ( "," ( "unit" ":" enum ) )? "}"
|
224
|
+
# Allows: {"location": "Paris"} or {"location": "Paris", "unit": "celsius"}
|
225
|
+
#
|
226
|
+
# - Multiple optional properties with flexible ordering:
|
227
|
+
# "{" "req" ":" string ( "," ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value ) )? "}"
|
228
|
+
# Allows: {"req": "x"}, {"req": "x", "opt1": "y"}, {"req": "x", "opt2": "z"},
|
229
|
+
# {"req": "x", "opt1": "y", "opt2": "z"}
|
230
|
+
#
|
231
|
+
# - All optional properties with flexible ordering:
|
232
|
+
# "{" ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value )? "}"
|
233
|
+
# Allows: {}, {"opt1": "x"}, {"opt2": "y"}, {"opt1": "x", "opt2": "y"}
|
234
|
+
|
235
|
+
prop_kv_pairs = {}
|
236
|
+
ordered_props = list(properties.keys())
|
237
|
+
|
216
238
|
for prop_name, prop_schema in properties.items():
|
217
239
|
value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
|
218
240
|
# Create key=value pair
|
219
241
|
pair = key_value_template.format(key=prop_name, valrule=value_rule)
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
#
|
227
|
-
|
242
|
+
prop_kv_pairs[prop_name] = pair
|
243
|
+
|
244
|
+
# Separate into required and optional while preserving order
|
245
|
+
required = [p for p in ordered_props if p in required_props]
|
246
|
+
optional = [p for p in ordered_props if p not in required_props]
|
247
|
+
|
248
|
+
# Build the combined rule
|
249
|
+
rule_parts = []
|
250
|
+
|
251
|
+
# Add required properties joined by commas
|
252
|
+
if required:
|
253
|
+
rule_parts.append(' "," '.join(prop_kv_pairs[k] for k in required))
|
254
|
+
|
255
|
+
# Add optional properties with flexible ordering
|
256
|
+
if optional:
|
257
|
+
# Build alternatives where any optional property can appear first
|
258
|
+
opt_alternatives = []
|
259
|
+
for i in range(len(optional)):
|
260
|
+
# Build pattern for optional[i] appearing first
|
261
|
+
opt_parts = []
|
262
|
+
for j in range(i, len(optional)):
|
263
|
+
if j == i:
|
264
|
+
opt_parts.append(prop_kv_pairs[optional[j]])
|
265
|
+
else:
|
266
|
+
opt_parts.append(f' ( "," {prop_kv_pairs[optional[j]]} )?')
|
267
|
+
opt_alternatives.append("".join(opt_parts))
|
268
|
+
|
269
|
+
# Wrap with appropriate comma handling based on whether we have required properties
|
270
|
+
if required:
|
271
|
+
# Required properties exist, so optional group needs outer comma
|
272
|
+
rule_parts.append(' ( "," ( ')
|
273
|
+
rule_parts.append(" | ".join(opt_alternatives))
|
274
|
+
rule_parts.append(" ) )?")
|
275
|
+
else:
|
276
|
+
# All properties are optional
|
277
|
+
rule_parts.append("( ")
|
278
|
+
rule_parts.append(" | ".join(opt_alternatives))
|
279
|
+
rule_parts.append(" )?")
|
280
|
+
|
281
|
+
combined_args = "".join(rule_parts)
|
228
282
|
arguments_rule = args_template.format(arg_rules=combined_args)
|
229
283
|
|
230
284
|
# Add the function call rule and its arguments rule
|
@@ -1,6 +1,12 @@
|
|
1
1
|
import logging
|
2
2
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
3
3
|
|
4
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
5
|
+
StructuralTagResponseFormat,
|
6
|
+
StructuresResponseFormat,
|
7
|
+
Tool,
|
8
|
+
ToolChoice,
|
9
|
+
)
|
4
10
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
5
11
|
from sglang.srt.function_call.core_types import ToolCallItem
|
6
12
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
|
|
8
14
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
9
15
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
10
16
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
11
|
-
from sglang.srt.openai_api.protocol import (
|
12
|
-
StructuralTagResponseFormat,
|
13
|
-
StructuresResponseFormat,
|
14
|
-
Tool,
|
15
|
-
ToolChoice,
|
16
|
-
)
|
17
17
|
|
18
18
|
logger = logging.getLogger(__name__)
|
19
19
|
|
@@ -2,6 +2,7 @@ import json
|
|
2
2
|
import logging
|
3
3
|
from typing import List
|
4
4
|
|
5
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
5
6
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
6
7
|
from sglang.srt.function_call.core_types import (
|
7
8
|
StreamingParseResult,
|
@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
|
|
9
10
|
_GetInfoFunc,
|
10
11
|
)
|
11
12
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
12
|
-
from sglang.srt.openai_api.protocol import Tool
|
13
13
|
|
14
14
|
logger = logging.getLogger(__name__)
|
15
15
|
|
@@ -3,6 +3,7 @@ import logging
|
|
3
3
|
import re
|
4
4
|
from typing import List
|
5
5
|
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
6
7
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
7
8
|
from sglang.srt.function_call.core_types import (
|
8
9
|
StreamingParseResult,
|
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
|
|
10
11
|
_GetInfoFunc,
|
11
12
|
)
|
12
13
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
13
|
-
from sglang.srt.openai_api.protocol import Tool
|
14
14
|
|
15
15
|
logger = logging.getLogger(__name__)
|
16
16
|
|
@@ -4,6 +4,7 @@ import logging
|
|
4
4
|
import re
|
5
5
|
from typing import List, Optional
|
6
6
|
|
7
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
7
8
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
8
9
|
from sglang.srt.function_call.core_types import (
|
9
10
|
StreamingParseResult,
|
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
|
|
12
13
|
_GetInfoFunc,
|
13
14
|
)
|
14
15
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
15
|
-
from sglang.srt.openai_api.protocol import Tool
|
16
16
|
|
17
17
|
logger = logging.getLogger(__name__)
|
18
18
|
|
@@ -3,6 +3,7 @@ import logging
|
|
3
3
|
import re
|
4
4
|
from typing import List
|
5
5
|
|
6
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
6
7
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
7
8
|
from sglang.srt.function_call.core_types import (
|
8
9
|
StreamingParseResult,
|
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
|
|
10
11
|
_GetInfoFunc,
|
11
12
|
)
|
12
13
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
13
|
-
from sglang.srt.openai_api.protocol import Tool
|
14
14
|
|
15
15
|
logger = logging.getLogger(__name__)
|
16
16
|
|
@@ -1,11 +1,12 @@
|
|
1
|
-
"""
|
2
|
-
|
1
|
+
"""Template utilities for Jinja template processing.
|
2
|
+
|
3
|
+
This module provides utilities for analyzing and processing Jinja chat templates,
|
4
|
+
including content format detection and message processing.
|
3
5
|
"""
|
4
6
|
|
5
7
|
import logging
|
6
|
-
from typing import Dict, List
|
7
8
|
|
8
|
-
import jinja2
|
9
|
+
import jinja2
|
9
10
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
10
11
|
|
11
12
|
logger = logging.getLogger(__name__)
|
@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
|
|
75
76
|
return None
|
76
77
|
|
77
78
|
|
78
|
-
def
|
79
|
+
def detect_jinja_template_content_format(chat_template: str) -> str:
|
79
80
|
"""
|
80
81
|
Detect whether a chat template expects 'string' or 'openai' content format.
|
81
82
|
|
sglang/srt/layers/activation.py
CHANGED
@@ -20,6 +20,7 @@ from typing import Optional
|
|
20
20
|
import torch
|
21
21
|
import torch.nn as nn
|
22
22
|
import torch.nn.functional as F
|
23
|
+
from transformers import PretrainedConfig
|
23
24
|
|
24
25
|
from sglang.srt.custom_op import CustomOp
|
25
26
|
from sglang.srt.distributed import (
|
@@ -28,9 +29,19 @@ from sglang.srt.distributed import (
|
|
28
29
|
get_tensor_model_parallel_world_size,
|
29
30
|
)
|
30
31
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
31
|
-
from sglang.srt.utils import
|
32
|
+
from sglang.srt.utils import (
|
33
|
+
cpu_has_amx_support,
|
34
|
+
is_cpu,
|
35
|
+
is_cuda,
|
36
|
+
is_npu,
|
37
|
+
set_weight_attrs,
|
38
|
+
)
|
39
|
+
from sglang.utils import resolve_obj_by_qualname
|
32
40
|
|
33
41
|
_is_cuda = is_cuda()
|
42
|
+
_is_npu = is_npu()
|
43
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
44
|
+
_is_cpu = is_cpu()
|
34
45
|
|
35
46
|
if _is_cuda:
|
36
47
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
@@ -50,6 +61,15 @@ class SiluAndMul(CustomOp):
|
|
50
61
|
silu_and_mul(x, out)
|
51
62
|
return out
|
52
63
|
|
64
|
+
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
65
|
+
if _is_cpu_amx_available:
|
66
|
+
d = x.shape[-1] // 2
|
67
|
+
output_shape = x.shape[:-1] + (d,)
|
68
|
+
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
69
|
+
return out
|
70
|
+
else:
|
71
|
+
return self.forward_native(x)
|
72
|
+
|
53
73
|
|
54
74
|
class GeluAndMul(CustomOp):
|
55
75
|
def __init__(self, approximate="tanh"):
|
@@ -165,8 +185,25 @@ def get_act_fn(
|
|
165
185
|
return act_fn
|
166
186
|
|
167
187
|
|
168
|
-
|
188
|
+
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
189
|
+
if (
|
190
|
+
hasattr(config, "sbert_ce_default_activation_function")
|
191
|
+
and config.sbert_ce_default_activation_function is not None
|
192
|
+
):
|
193
|
+
|
194
|
+
function_name = config.sbert_ce_default_activation_function
|
195
|
+
assert function_name.startswith("torch.nn.modules."), (
|
196
|
+
"Loading of activation functions is restricted to "
|
197
|
+
"torch.nn.modules for security reasons"
|
198
|
+
)
|
199
|
+
return resolve_obj_by_qualname(function_name)()
|
200
|
+
else:
|
201
|
+
# adapt bge-reranker
|
202
|
+
return nn.Identity()
|
203
|
+
|
204
|
+
|
205
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
169
206
|
logger.info(
|
170
|
-
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
207
|
+
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
171
208
|
)
|
172
209
|
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
|
|
324
324
|
)
|
325
325
|
|
326
326
|
def init_cuda_graph_state(
|
327
|
-
self,
|
327
|
+
self,
|
328
|
+
max_bs: int,
|
329
|
+
max_num_tokens: int,
|
330
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
328
331
|
):
|
329
332
|
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
|
330
333
|
if kv_indices_buf is None:
|
@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
338
341
|
|
339
342
|
if not self.skip_prefill:
|
340
343
|
self.cuda_graph_custom_mask = torch.zeros(
|
341
|
-
(
|
344
|
+
(max_num_tokens * self.max_context_len),
|
342
345
|
dtype=torch.uint8,
|
343
346
|
device=self.device,
|
344
347
|
)
|
@@ -717,6 +720,11 @@ class AiterIndicesUpdaterPrefill:
|
|
717
720
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
718
721
|
self.update = self.update_single_wrapper
|
719
722
|
|
723
|
+
# get the last index of the pool
|
724
|
+
self.pool_size = (
|
725
|
+
model_runner.token_to_kv_pool.size + model_runner.token_to_kv_pool.page_size
|
726
|
+
) - 1
|
727
|
+
|
720
728
|
self.kv_indices = None
|
721
729
|
self.max_q_len = 0
|
722
730
|
self.max_kv_len = 0
|
@@ -754,8 +762,16 @@ class AiterIndicesUpdaterPrefill:
|
|
754
762
|
# Normal extend
|
755
763
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
756
764
|
kv_indptr = kv_indptr[: bs + 1]
|
757
|
-
|
758
|
-
|
765
|
+
|
766
|
+
# (TODO: Kk) WA - CI test_moe_eval_accuracy_large.py
|
767
|
+
# mha_batch_prefill reads 128 data to do computatoin
|
768
|
+
# if real data is not long enough then original padding value 0 is used
|
769
|
+
# but the 0 location will be made nan (noqa) in cuda graph capture mode
|
770
|
+
# this will cause the output tensor value becomes nan
|
771
|
+
# WA is to assure that last index of pool not changed
|
772
|
+
kv_indices = torch.full(
|
773
|
+
(paged_kernel_lens_sum + 128,),
|
774
|
+
self.pool_size,
|
759
775
|
dtype=torch.int32,
|
760
776
|
device=req_pool_indices.device,
|
761
777
|
)
|
@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
|
|
19
19
|
"""Init the metadata for a forward pass."""
|
20
20
|
raise NotImplementedError()
|
21
21
|
|
22
|
-
def init_cuda_graph_state(self, max_bs: int):
|
22
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
23
23
|
"""Init the global shared states for cuda graph."""
|
24
24
|
raise NotImplementedError()
|
25
25
|
|