sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -107,6 +107,8 @@ from sglang.version import __version__
|
|
107
107
|
logger = logging.getLogger(__name__)
|
108
108
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
109
109
|
|
110
|
+
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
111
|
+
|
110
112
|
|
111
113
|
# Store global states
|
112
114
|
@dataclasses.dataclass
|
@@ -212,9 +214,6 @@ async def validate_json_request(raw_request: Request):
|
|
212
214
|
)
|
213
215
|
|
214
216
|
|
215
|
-
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
216
|
-
|
217
|
-
|
218
217
|
##### Native API endpoints #####
|
219
218
|
|
220
219
|
|
@@ -807,6 +806,24 @@ async def retrieve_model(model: str):
|
|
807
806
|
)
|
808
807
|
|
809
808
|
|
809
|
+
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
810
|
+
async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
811
|
+
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
812
|
+
return await raw_request.app.state.openai_serving_score.handle_request(
|
813
|
+
request, raw_request
|
814
|
+
)
|
815
|
+
|
816
|
+
|
817
|
+
@app.api_route(
|
818
|
+
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
819
|
+
)
|
820
|
+
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
821
|
+
"""Endpoint for reranking documents based on query relevance."""
|
822
|
+
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
823
|
+
request, raw_request
|
824
|
+
)
|
825
|
+
|
826
|
+
|
810
827
|
## SageMaker API
|
811
828
|
@app.get("/ping")
|
812
829
|
async def sagemaker_health() -> Response:
|
@@ -852,24 +869,6 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
|
852
869
|
return ORJSONResponse({"predictions": ret})
|
853
870
|
|
854
871
|
|
855
|
-
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
856
|
-
async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
857
|
-
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
858
|
-
return await raw_request.app.state.openai_serving_score.handle_request(
|
859
|
-
request, raw_request
|
860
|
-
)
|
861
|
-
|
862
|
-
|
863
|
-
@app.api_route(
|
864
|
-
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
865
|
-
)
|
866
|
-
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
867
|
-
"""Endpoint for reranking documents based on query relevance."""
|
868
|
-
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
869
|
-
request, raw_request
|
870
|
-
)
|
871
|
-
|
872
|
-
|
873
872
|
def _create_error_response(e):
|
874
873
|
return ORJSONResponse(
|
875
874
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
@@ -916,15 +915,6 @@ def launch_server(
|
|
916
915
|
add_prometheus_middleware(app)
|
917
916
|
enable_func_timer()
|
918
917
|
|
919
|
-
image_token_text = None
|
920
|
-
if (
|
921
|
-
tokenizer_manager.image_token_id is not None
|
922
|
-
and not server_args.skip_tokenizer_init
|
923
|
-
):
|
924
|
-
image_token_text = tokenizer_manager.tokenizer.decode(
|
925
|
-
[tokenizer_manager.image_token_id]
|
926
|
-
)
|
927
|
-
|
928
918
|
# Send a warmup request - we will create the thread launch it
|
929
919
|
# in the lifespan after all other warmups have fired.
|
930
920
|
warmup_thread = threading.Thread(
|
@@ -932,7 +922,6 @@ def launch_server(
|
|
932
922
|
args=(
|
933
923
|
server_args,
|
934
924
|
pipe_finish_writer,
|
935
|
-
image_token_text,
|
936
925
|
launch_callback,
|
937
926
|
),
|
938
927
|
)
|
@@ -1066,7 +1055,6 @@ def _execute_server_warmup(
|
|
1066
1055
|
def _wait_and_warmup(
|
1067
1056
|
server_args: ServerArgs,
|
1068
1057
|
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
1069
|
-
image_token_text: str,
|
1070
1058
|
launch_callback: Optional[Callable[[], None]] = None,
|
1071
1059
|
):
|
1072
1060
|
if not server_args.skip_server_warmup:
|
@@ -192,9 +192,9 @@ class CompletionRequest(BaseModel):
|
|
192
192
|
session_params: Optional[Dict] = None
|
193
193
|
|
194
194
|
# For PD disaggregation
|
195
|
-
bootstrap_host: Optional[str] = None
|
196
|
-
bootstrap_port: Optional[int] = None
|
197
|
-
bootstrap_room: Optional[int] = None
|
195
|
+
bootstrap_host: Optional[Union[List[str], str]] = None
|
196
|
+
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
197
|
+
bootstrap_room: Optional[Union[List[int], int]] = None
|
198
198
|
|
199
199
|
# For request id
|
200
200
|
rid: Optional[Union[List[str], str]] = None
|
@@ -55,6 +55,20 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
55
55
|
def _request_id_prefix(self) -> str:
|
56
56
|
return "chatcmpl-"
|
57
57
|
|
58
|
+
def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]:
|
59
|
+
"""Validate that the input is valid."""
|
60
|
+
if not request.messages:
|
61
|
+
return "Messages cannot be empty."
|
62
|
+
|
63
|
+
if (
|
64
|
+
isinstance(request.tool_choice, str)
|
65
|
+
and request.tool_choice.lower() == "required"
|
66
|
+
and not request.tools
|
67
|
+
):
|
68
|
+
return "Tools cannot be empty if tool choice is set to required."
|
69
|
+
|
70
|
+
return None
|
71
|
+
|
58
72
|
def _convert_to_internal_request(
|
59
73
|
self,
|
60
74
|
request: ChatCompletionRequest,
|
@@ -113,12 +127,12 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
113
127
|
request.skip_special_tokens = False
|
114
128
|
if not isinstance(request.tool_choice, str):
|
115
129
|
tools = [
|
116
|
-
item.
|
130
|
+
item.model_dump()
|
117
131
|
for item in request.tools
|
118
132
|
if item.function.name == request.tool_choice.function.name
|
119
133
|
]
|
120
134
|
else:
|
121
|
-
tools = [item.
|
135
|
+
tools = [item.model_dump() for item in request.tools]
|
122
136
|
|
123
137
|
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
124
138
|
parser = FunctionCallParser(request.tools, tool_call_parser)
|
@@ -164,6 +178,25 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
164
178
|
audio_data,
|
165
179
|
modalities,
|
166
180
|
)
|
181
|
+
|
182
|
+
if "tool_calls" in processed_msg and isinstance(
|
183
|
+
processed_msg.get("tool_calls"), list
|
184
|
+
):
|
185
|
+
for call in processed_msg["tool_calls"]:
|
186
|
+
try:
|
187
|
+
if "arguments" in call["function"] and isinstance(
|
188
|
+
call["function"]["arguments"], str
|
189
|
+
):
|
190
|
+
call["function"]["arguments"] = json.loads(
|
191
|
+
call["function"]["arguments"]
|
192
|
+
)
|
193
|
+
except json.JSONDecodeError as e:
|
194
|
+
# Log a warning or error if JSON parsing fails for arguments
|
195
|
+
logger.warning(
|
196
|
+
f"Failed to parse tool call arguments as JSON: {e}"
|
197
|
+
)
|
198
|
+
# Decide whether to continue or raise the exception based on desired behavior
|
199
|
+
continue # Or raise e if strict parsing is required
|
167
200
|
openai_compatible_messages.append(processed_msg)
|
168
201
|
|
169
202
|
# Handle assistant prefix for continue_final_message
|
@@ -465,7 +498,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
465
498
|
|
466
499
|
# Handle tool calls
|
467
500
|
if request.tool_choice != "none" and request.tools:
|
468
|
-
async for
|
501
|
+
async for (
|
502
|
+
chunk,
|
503
|
+
tool_call_finish_reason_type,
|
504
|
+
) in self._process_tool_call_stream(
|
469
505
|
index,
|
470
506
|
delta,
|
471
507
|
parser_dict,
|
@@ -473,7 +509,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
473
509
|
request,
|
474
510
|
finish_reason_type,
|
475
511
|
):
|
476
|
-
|
512
|
+
if chunk:
|
513
|
+
yield chunk
|
514
|
+
finish_reason_type = tool_call_finish_reason_type
|
515
|
+
|
477
516
|
else:
|
478
517
|
# Regular content
|
479
518
|
if delta or not (
|
@@ -846,7 +885,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
846
885
|
choices=[choice_data],
|
847
886
|
model=request.model,
|
848
887
|
)
|
849
|
-
yield f"data: {chunk.model_dump_json()}\n\n"
|
888
|
+
yield f"data: {chunk.model_dump_json()}\n\n", finish_reason_type
|
850
889
|
|
851
890
|
# Yield tool calls
|
852
891
|
for call_item in calls:
|
@@ -901,4 +940,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
901
940
|
choices=[choice_data],
|
902
941
|
model=request.model,
|
903
942
|
)
|
904
|
-
yield f"data: {chunk.model_dump_json()}\n\n"
|
943
|
+
yield f"data: {chunk.model_dump_json()}\n\n", finish_reason_type
|
944
|
+
|
945
|
+
if finish_reason_type == "stop":
|
946
|
+
yield None, "tool_calls"
|
@@ -66,7 +66,7 @@ def transform_select_experts_inputs(
|
|
66
66
|
info: Optional[ExpertLocationDispatchInfo],
|
67
67
|
):
|
68
68
|
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
|
69
|
-
router_logits
|
69
|
+
router_logits.uniform_(5, 10)
|
70
70
|
if correction_bias is not None:
|
71
71
|
correction_bias = torch.zeros_like(correction_bias)
|
72
72
|
return router_logits, correction_bias
|
@@ -25,23 +25,49 @@ class BaseFormatDetector(ABC):
|
|
25
25
|
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
26
26
|
|
27
27
|
def __init__(self):
|
28
|
-
#
|
28
|
+
# Streaming state management
|
29
|
+
# Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks
|
29
30
|
self._buffer = ""
|
30
|
-
#
|
31
|
+
# Stores complete tool call info (name and arguments) for each tool being parsed.
|
32
|
+
# Used by serving layer for completion handling when streaming ends.
|
33
|
+
# Format: [{"name": str, "arguments": dict}, ...]
|
31
34
|
self.prev_tool_call_arr: List[Dict] = []
|
35
|
+
# Index of currently streaming tool call. Starts at -1 (no active tool),
|
36
|
+
# increments as each tool completes. Tracks which tool's arguments are streaming.
|
32
37
|
self.current_tool_id: int = -1
|
38
|
+
# Flag for whether current tool's name has been sent to client.
|
39
|
+
# Tool names sent first with empty parameters, then arguments stream incrementally.
|
33
40
|
self.current_tool_name_sent: bool = False
|
34
|
-
|
35
|
-
|
36
|
-
|
41
|
+
# Tracks raw JSON string content streamed to client for each tool's arguments.
|
42
|
+
# Critical for serving layer to calculate remaining content when streaming ends.
|
43
|
+
# Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72']
|
44
|
+
self.streamed_args_for_tool: List[str] = []
|
45
|
+
|
46
|
+
# Token configuration (override in subclasses)
|
37
47
|
self.bot_token = ""
|
38
48
|
self.eot_token = ""
|
39
49
|
self.tool_call_separator = ", "
|
40
50
|
|
41
|
-
def
|
42
|
-
|
51
|
+
def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]:
|
52
|
+
"""
|
53
|
+
Get a mapping of tool names to their indices in the tools list.
|
54
|
+
|
55
|
+
This utility method creates a dictionary mapping function names to their
|
56
|
+
indices in the tools list, which is commonly needed for tool validation
|
57
|
+
and ToolCallItem creation.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
tools: List of available tools
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Dictionary mapping tool names to their indices
|
64
|
+
"""
|
65
|
+
return {
|
43
66
|
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
44
67
|
}
|
68
|
+
|
69
|
+
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
70
|
+
tool_indices = self._get_tool_indices(tools)
|
45
71
|
if not isinstance(action, list):
|
46
72
|
action = [action]
|
47
73
|
|
@@ -130,11 +156,7 @@ class BaseFormatDetector(ABC):
|
|
130
156
|
|
131
157
|
# Build tool indices if not already built
|
132
158
|
if not hasattr(self, "_tool_indices"):
|
133
|
-
self._tool_indices =
|
134
|
-
tool.function.name: i
|
135
|
-
for i, tool in enumerate(tools)
|
136
|
-
if tool.function and tool.function.name
|
137
|
-
}
|
159
|
+
self._tool_indices = self._get_tool_indices(tools)
|
138
160
|
|
139
161
|
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
140
162
|
|
@@ -294,12 +316,52 @@ class BaseFormatDetector(ABC):
|
|
294
316
|
|
295
317
|
@abstractmethod
|
296
318
|
def has_tool_call(self, text: str) -> bool:
|
319
|
+
"""
|
320
|
+
Check if the given text contains function call markers specific to this format.
|
321
|
+
"""
|
297
322
|
raise NotImplementedError()
|
298
323
|
|
324
|
+
def supports_structural_tag(self) -> bool:
|
325
|
+
"""Return True if this detector supports structural tag format."""
|
326
|
+
return True
|
327
|
+
|
299
328
|
@abstractmethod
|
300
329
|
def structure_info(self) -> _GetInfoFunc:
|
330
|
+
"""
|
331
|
+
Return a function that creates StructureInfo for constrained generation.
|
332
|
+
|
333
|
+
The returned function takes a tool name and returns a StructureInfo object
|
334
|
+
containing the begin/end patterns and trigger tokens needed for constrained
|
335
|
+
generation of function calls in this format.
|
336
|
+
|
337
|
+
Returns:
|
338
|
+
A function that takes a tool name (str) and returns StructureInfo
|
339
|
+
"""
|
301
340
|
raise NotImplementedError()
|
302
341
|
|
303
342
|
@abstractmethod
|
304
343
|
def build_ebnf(self, tools: List[Tool]) -> str:
|
344
|
+
"""
|
345
|
+
Build an EBNF grammar for constrained generation of function calls.
|
346
|
+
|
347
|
+
This method generates an Extended Backus-Naur Form (EBNF) grammar that
|
348
|
+
constrains the model's output to valid function calls in this format.
|
349
|
+
The grammar should include all available tools and their parameter schemas.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
tools: List of available tools/functions that can be called
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
A string containing the EBNF grammar for this function call format
|
356
|
+
|
357
|
+
The EBNF grammar should:
|
358
|
+
- Define the overall structure of function calls in this format
|
359
|
+
- Include all tool names from the provided tools list
|
360
|
+
- Define valid JSON structures for function arguments
|
361
|
+
- Handle multiple function calls if the format supports them
|
362
|
+
|
363
|
+
Note:
|
364
|
+
Most implementations use EBNFComposer.build_ebnf() utility with
|
365
|
+
format-specific parameters rather than writing EBNF from scratch.
|
366
|
+
"""
|
305
367
|
raise NotImplementedError()
|
@@ -19,9 +19,28 @@ logger = logging.getLogger(__name__)
|
|
19
19
|
|
20
20
|
class DeepSeekV3Detector(BaseFormatDetector):
|
21
21
|
"""
|
22
|
-
Detector for DeepSeek
|
23
|
-
|
24
|
-
|
22
|
+
Detector for DeepSeek V3 model function call format.
|
23
|
+
|
24
|
+
The DeepSeek V3 format uses special Unicode tokens to delimit function calls
|
25
|
+
with JSON code blocks for arguments.
|
26
|
+
|
27
|
+
Format Structure:
|
28
|
+
```
|
29
|
+
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{function_name}\n```json\n{json_arguments}\n```<|tool▁calls▁end|><|end▁of▁sentence|>
|
30
|
+
```
|
31
|
+
Examples:
|
32
|
+
```
|
33
|
+
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
|
34
|
+
```
|
35
|
+
|
36
|
+
Key Components:
|
37
|
+
- Tool Calls Section: Wrapped between `<|tool▁calls▁begin|>` and `<|tool▁calls▁end|>`
|
38
|
+
- Individual Tool Call: Wrapped between `<|tool▁call▁begin|>` and `<|tool▁call▁end|>`
|
39
|
+
- Function Declaration: `function<|tool▁sep|>{function_name}`
|
40
|
+
- Arguments: JSON code block between ````json` and ````
|
41
|
+
- Supports multiple tool calls
|
42
|
+
|
43
|
+
Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default
|
25
44
|
"""
|
26
45
|
|
27
46
|
def __init__(self):
|
@@ -89,16 +108,12 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
|
89
108
|
return StreamingParseResult(normal_text=new_text)
|
90
109
|
|
91
110
|
if not hasattr(self, "_tool_indices"):
|
92
|
-
self._tool_indices =
|
93
|
-
tool.function.name: i
|
94
|
-
for i, tool in enumerate(tools)
|
95
|
-
if tool.function and tool.function.name
|
96
|
-
}
|
111
|
+
self._tool_indices = self._get_tool_indices(tools)
|
97
112
|
|
98
113
|
calls: list[ToolCallItem] = []
|
99
114
|
try:
|
100
115
|
partial_match = re.search(
|
101
|
-
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
|
116
|
+
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```.*",
|
102
117
|
string=current_text,
|
103
118
|
flags=re.DOTALL,
|
104
119
|
)
|
@@ -127,7 +142,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
|
127
142
|
)
|
128
143
|
)
|
129
144
|
self.current_tool_name_sent = True
|
130
|
-
# Store the tool call info for
|
145
|
+
# Store the tool call info for serving layer completions endpoint
|
131
146
|
self.prev_tool_call_arr[self.current_tool_id] = {
|
132
147
|
"name": func_name,
|
133
148
|
"arguments": {},
|
@@ -153,7 +168,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
|
|
153
168
|
] += argument_diff
|
154
169
|
|
155
170
|
if _is_complete_json(func_args_raw):
|
156
|
-
# Update the stored arguments
|
171
|
+
# Update the stored arguments
|
157
172
|
try:
|
158
173
|
parsed_args = json.loads(func_args_raw)
|
159
174
|
self.prev_tool_call_arr[self.current_tool_id][
|