sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/srt/entrypoints/engine.py
CHANGED
@@ -27,12 +27,16 @@ import signal
|
|
27
27
|
import threading
|
28
28
|
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
29
29
|
|
30
|
+
import zmq
|
31
|
+
import zmq.asyncio
|
32
|
+
|
30
33
|
# Fix a bug of Python threading
|
31
34
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
32
35
|
|
33
36
|
import torch
|
34
37
|
import uvloop
|
35
38
|
|
39
|
+
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
|
36
40
|
from sglang.srt.managers.data_parallel_controller import (
|
37
41
|
run_data_parallel_controller_process,
|
38
42
|
)
|
@@ -44,6 +48,8 @@ from sglang.srt.managers.io_struct import (
|
|
44
48
|
InitWeightsUpdateGroupReqInput,
|
45
49
|
ReleaseMemoryOccupationReqInput,
|
46
50
|
ResumeMemoryOccupationReqInput,
|
51
|
+
RpcReqInput,
|
52
|
+
RpcReqOutput,
|
47
53
|
UpdateWeightFromDiskReqInput,
|
48
54
|
UpdateWeightsFromDistributedReqInput,
|
49
55
|
UpdateWeightsFromTensorReqInput,
|
@@ -57,6 +63,7 @@ from sglang.srt.utils import (
|
|
57
63
|
MultiprocessingSerializer,
|
58
64
|
assert_pkg_version,
|
59
65
|
configure_logger,
|
66
|
+
get_zmq_socket,
|
60
67
|
kill_process_tree,
|
61
68
|
launch_dummy_health_check_server,
|
62
69
|
maybe_set_triton_cache_manager,
|
@@ -102,15 +109,25 @@ class Engine:
|
|
102
109
|
# Shutdown the subprocesses automatically when the program exits
|
103
110
|
atexit.register(self.shutdown)
|
104
111
|
|
112
|
+
# Allocate ports for inter-process communications
|
113
|
+
port_args = PortArgs.init_new(server_args)
|
114
|
+
logger.info(f"{server_args=}")
|
115
|
+
|
105
116
|
# Launch subprocesses
|
106
117
|
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
107
|
-
server_args=server_args
|
118
|
+
server_args=server_args,
|
119
|
+
port_args=port_args,
|
108
120
|
)
|
109
121
|
|
110
122
|
self.server_args = server_args
|
111
123
|
self.tokenizer_manager = tokenizer_manager
|
112
124
|
self.scheduler_info = scheduler_info
|
113
125
|
|
126
|
+
context = zmq.Context(2)
|
127
|
+
self.send_to_rpc = get_zmq_socket(
|
128
|
+
context, zmq.DEALER, port_args.rpc_ipc_name, True
|
129
|
+
)
|
130
|
+
|
114
131
|
def generate(
|
115
132
|
self,
|
116
133
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
@@ -232,6 +249,13 @@ class Engine:
|
|
232
249
|
"""Shutdown the engine"""
|
233
250
|
kill_process_tree(os.getpid(), include_parent=False)
|
234
251
|
|
252
|
+
def __enter__(self):
|
253
|
+
return self
|
254
|
+
|
255
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
256
|
+
self.shutdown()
|
257
|
+
return False
|
258
|
+
|
235
259
|
def start_profile(self):
|
236
260
|
loop = asyncio.get_event_loop()
|
237
261
|
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
@@ -296,7 +320,10 @@ class Engine:
|
|
296
320
|
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
|
297
321
|
to avoid duplicated operations such as clearing cache."""
|
298
322
|
obj = UpdateWeightsFromTensorReqInput(
|
299
|
-
serialized_named_tensors=
|
323
|
+
serialized_named_tensors=[
|
324
|
+
MultiprocessingSerializer.serialize(named_tensors)
|
325
|
+
for _ in range(self.server_args.tp_size)
|
326
|
+
],
|
300
327
|
load_format=load_format,
|
301
328
|
flush_cache=flush_cache,
|
302
329
|
)
|
@@ -350,6 +377,23 @@ class Engine:
|
|
350
377
|
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
351
378
|
)
|
352
379
|
|
380
|
+
"""
|
381
|
+
Execute an RPC call on all scheduler processes.
|
382
|
+
"""
|
383
|
+
|
384
|
+
def collective_rpc(self, method: str, **kwargs):
|
385
|
+
obj = RpcReqInput(method=method, parameters=kwargs)
|
386
|
+
self.send_to_rpc.send_pyobj(obj)
|
387
|
+
recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY)
|
388
|
+
assert isinstance(recv_req, RpcReqOutput)
|
389
|
+
assert recv_req.success, recv_req.message
|
390
|
+
|
391
|
+
def save_remote_model(self, **kwargs):
|
392
|
+
self.collective_rpc("save_remote_model", **kwargs)
|
393
|
+
|
394
|
+
def save_sharded_model(self, **kwargs):
|
395
|
+
self.collective_rpc("save_sharded_model", **kwargs)
|
396
|
+
|
353
397
|
|
354
398
|
def _set_envs_and_config(server_args: ServerArgs):
|
355
399
|
# Set global environments
|
@@ -408,7 +452,9 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
408
452
|
mp.set_start_method("spawn", force=True)
|
409
453
|
|
410
454
|
|
411
|
-
def _launch_subprocesses(
|
455
|
+
def _launch_subprocesses(
|
456
|
+
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
457
|
+
) -> Tuple[TokenizerManager, Dict]:
|
412
458
|
"""
|
413
459
|
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
414
460
|
"""
|
@@ -418,8 +464,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
|
418
464
|
_set_envs_and_config(server_args)
|
419
465
|
|
420
466
|
# Allocate ports for inter-process communications
|
421
|
-
port_args
|
422
|
-
|
467
|
+
if port_args is None:
|
468
|
+
port_args = PortArgs.init_new(server_args)
|
469
|
+
logger.info(f"{server_args=}")
|
423
470
|
|
424
471
|
# If using model from www.modelscope.cn, first download the model.
|
425
472
|
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
@@ -502,6 +549,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
|
502
549
|
tokenizer_manager, server_args.chat_template, server_args.model_path
|
503
550
|
)
|
504
551
|
|
552
|
+
if server_args.completion_template:
|
553
|
+
load_completion_template_for_openai_api(server_args.completion_template)
|
554
|
+
|
505
555
|
# Wait for the model to finish loading
|
506
556
|
scheduler_infos = []
|
507
557
|
for i in range(len(scheduler_pipe_readers)):
|
@@ -14,11 +14,12 @@
|
|
14
14
|
"""
|
15
15
|
The entry point of inference server. (SRT = SGLang Runtime)
|
16
16
|
|
17
|
-
This file implements HTTP APIs for the
|
17
|
+
This file implements HTTP APIs for the inference engine via fastapi.
|
18
18
|
"""
|
19
19
|
|
20
20
|
import asyncio
|
21
21
|
import dataclasses
|
22
|
+
import json
|
22
23
|
import logging
|
23
24
|
import multiprocessing as multiprocessing
|
24
25
|
import os
|
@@ -259,6 +260,29 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|
259
260
|
return _create_error_response(e)
|
260
261
|
|
261
262
|
|
263
|
+
@app.api_route("/generate_from_file", methods=["POST"])
|
264
|
+
async def generate_from_file_request(file: UploadFile, request: Request):
|
265
|
+
"""Handle a generate request, this is purely to work with input_embeds."""
|
266
|
+
content = await file.read()
|
267
|
+
input_embeds = json.loads(content.decode("utf-8"))
|
268
|
+
|
269
|
+
obj = GenerateReqInput(
|
270
|
+
input_embeds=input_embeds,
|
271
|
+
sampling_params={
|
272
|
+
"repetition_penalty": 1.2,
|
273
|
+
"temperature": 0.2,
|
274
|
+
"max_new_tokens": 512,
|
275
|
+
},
|
276
|
+
)
|
277
|
+
|
278
|
+
try:
|
279
|
+
ret = await _global_state.generate_request(obj, request).__anext__()
|
280
|
+
return ret
|
281
|
+
except ValueError as e:
|
282
|
+
logger.error(f"Error: {e}")
|
283
|
+
return _create_error_response(e)
|
284
|
+
|
285
|
+
|
262
286
|
@app.api_route("/encode", methods=["POST", "PUT"])
|
263
287
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
264
288
|
"""Handle an embedding request."""
|
@@ -283,7 +307,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
283
307
|
return _create_error_response(e)
|
284
308
|
|
285
309
|
|
286
|
-
@app.
|
310
|
+
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
287
311
|
async def flush_cache():
|
288
312
|
"""Flush the radix cache."""
|
289
313
|
_global_state.tokenizer_manager.flush_cache()
|
@@ -319,6 +343,36 @@ async def stop_profile_async():
|
|
319
343
|
)
|
320
344
|
|
321
345
|
|
346
|
+
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
|
347
|
+
async def start_expert_distribution_record_async():
|
348
|
+
"""Start recording the expert distribution. Clear the previous record if any."""
|
349
|
+
await _global_state.tokenizer_manager.start_expert_distribution_record()
|
350
|
+
return Response(
|
351
|
+
content="Start recording the expert distribution.\n",
|
352
|
+
status_code=200,
|
353
|
+
)
|
354
|
+
|
355
|
+
|
356
|
+
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
|
357
|
+
async def stop_expert_distribution_record_async():
|
358
|
+
"""Stop recording the expert distribution."""
|
359
|
+
await _global_state.tokenizer_manager.stop_expert_distribution_record()
|
360
|
+
return Response(
|
361
|
+
content="Stop recording the expert distribution.\n",
|
362
|
+
status_code=200,
|
363
|
+
)
|
364
|
+
|
365
|
+
|
366
|
+
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
|
367
|
+
async def dump_expert_distribution_record_async():
|
368
|
+
"""Dump expert distribution record."""
|
369
|
+
await _global_state.tokenizer_manager.dump_expert_distribution_record()
|
370
|
+
return Response(
|
371
|
+
content="Dump expert distribution record.\n",
|
372
|
+
status_code=200,
|
373
|
+
)
|
374
|
+
|
375
|
+
|
322
376
|
@app.post("/update_weights_from_disk")
|
323
377
|
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
324
378
|
"""Update the weights from disk inplace without re-launching the server."""
|
@@ -507,7 +561,13 @@ def available_models():
|
|
507
561
|
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
508
562
|
model_cards = []
|
509
563
|
for served_model_name in served_model_names:
|
510
|
-
model_cards.append(
|
564
|
+
model_cards.append(
|
565
|
+
ModelCard(
|
566
|
+
id=served_model_name,
|
567
|
+
root=served_model_name,
|
568
|
+
max_model_len=_global_state.tokenizer_manager.model_config.context_len,
|
569
|
+
)
|
570
|
+
)
|
511
571
|
return ModelList(data=model_cards)
|
512
572
|
|
513
573
|
|
@@ -706,9 +766,15 @@ def _wait_and_warmup(
|
|
706
766
|
},
|
707
767
|
}
|
708
768
|
if server_args.skip_tokenizer_init:
|
709
|
-
json_data["input_ids"] = [10, 11, 12]
|
769
|
+
json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)]
|
770
|
+
# TODO Workaround the bug that embedding errors for list of size 1
|
771
|
+
if server_args.dp_size == 1:
|
772
|
+
json_data["input_ids"] = json_data["input_ids"][0]
|
710
773
|
else:
|
711
|
-
json_data["text"] = "The capital city of France is"
|
774
|
+
json_data["text"] = ["The capital city of France is"] * server_args.dp_size
|
775
|
+
# TODO Workaround the bug that embedding errors for list of size 1
|
776
|
+
if server_args.dp_size == 1:
|
777
|
+
json_data["text"] = json_data["text"][0]
|
712
778
|
|
713
779
|
# Debug dumping
|
714
780
|
if server_args.debug_tensor_dump_input_file:
|
@@ -719,14 +785,13 @@ def _wait_and_warmup(
|
|
719
785
|
json_data["sampling_params"]["max_new_tokens"] = 0
|
720
786
|
|
721
787
|
try:
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
assert res.status_code == 200, f"{res}"
|
788
|
+
res = requests.post(
|
789
|
+
url + request_name,
|
790
|
+
json=json_data,
|
791
|
+
headers=headers,
|
792
|
+
timeout=600,
|
793
|
+
)
|
794
|
+
assert res.status_code == 200, f"{res}"
|
730
795
|
except Exception:
|
731
796
|
last_traceback = get_exception_traceback()
|
732
797
|
if pipe_finish_writer is not None:
|
@@ -19,6 +19,7 @@ import torch.distributed as dist
|
|
19
19
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
20
20
|
|
21
21
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
22
|
+
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
22
23
|
from sglang.srt.server import Engine
|
23
24
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
24
25
|
|
@@ -30,6 +31,7 @@ class VerlEngine:
|
|
30
31
|
nnodes: int = 1,
|
31
32
|
**kwargs,
|
32
33
|
):
|
34
|
+
monkey_patch_torch_reductions()
|
33
35
|
self._device_mesh_cpu = device_mesh_cpu
|
34
36
|
self._tp_rank = device_mesh_cpu.get_local_rank()
|
35
37
|
self._tp_size = device_mesh_cpu.size()
|
@@ -1,12 +1,21 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
import re
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from dataclasses import dataclass
|
4
6
|
from json import JSONDecodeError, JSONDecoder
|
5
|
-
from typing import Any, Dict, List, Optional, Tuple
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
|
6
8
|
|
7
9
|
import partial_json_parser
|
10
|
+
from partial_json_parser.core.exceptions import MalformedJSON
|
8
11
|
from partial_json_parser.core.options import Allow
|
9
|
-
from pydantic import BaseModel
|
12
|
+
from pydantic import BaseModel
|
13
|
+
|
14
|
+
from sglang.srt.openai_api.protocol import (
|
15
|
+
StructuralTagResponseFormat,
|
16
|
+
StructuresResponseFormat,
|
17
|
+
Tool,
|
18
|
+
)
|
10
19
|
|
11
20
|
logger = logging.getLogger(__name__)
|
12
21
|
|
@@ -19,14 +28,6 @@ TOOLS_TAG_LIST = [
|
|
19
28
|
]
|
20
29
|
|
21
30
|
|
22
|
-
class Function(BaseModel):
|
23
|
-
"""Function Tool Template."""
|
24
|
-
|
25
|
-
description: Optional[str] = Field(default=None, examples=[None])
|
26
|
-
name: Optional[str] = None
|
27
|
-
parameters: Optional[object] = None
|
28
|
-
|
29
|
-
|
30
31
|
class ToolCallItem(BaseModel):
|
31
32
|
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
32
33
|
|
@@ -74,7 +75,22 @@ class StreamingParseResult:
|
|
74
75
|
self.calls = calls or []
|
75
76
|
|
76
77
|
|
77
|
-
|
78
|
+
@dataclass
|
79
|
+
class StructureInfo:
|
80
|
+
begin: str
|
81
|
+
end: str
|
82
|
+
trigger: str
|
83
|
+
|
84
|
+
|
85
|
+
_GetInfoFunc = Callable[[str], StructureInfo]
|
86
|
+
"""
|
87
|
+
helper alias of function
|
88
|
+
ususally it is a function that takes a name string and returns a StructureInfo object,
|
89
|
+
which can be used to construct a structural_tag object
|
90
|
+
"""
|
91
|
+
|
92
|
+
|
93
|
+
class BaseFormatDetector(ABC):
|
78
94
|
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
79
95
|
|
80
96
|
def __init__(self):
|
@@ -90,26 +106,12 @@ class BaseFormatDetector:
|
|
90
106
|
self.bot_token = ""
|
91
107
|
self.eot_token = ""
|
92
108
|
|
93
|
-
def parse_base_json(self, action: Any, tools: List[
|
109
|
+
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
94
110
|
tool_indices = {
|
95
111
|
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
96
112
|
}
|
97
113
|
if not isinstance(action, list):
|
98
|
-
|
99
|
-
if not name or name not in tool_indices:
|
100
|
-
logger.warning(f"Model attempted to call undefined function: {name}")
|
101
|
-
return []
|
102
|
-
|
103
|
-
return [
|
104
|
-
ToolCallItem(
|
105
|
-
tool_index=tool_indices[name],
|
106
|
-
name=name,
|
107
|
-
parameters=json.dumps(
|
108
|
-
action.get("parameters") or action.get("arguments", {}),
|
109
|
-
ensure_ascii=False,
|
110
|
-
),
|
111
|
-
)
|
112
|
-
]
|
114
|
+
action = [action]
|
113
115
|
|
114
116
|
results = []
|
115
117
|
for act in action:
|
@@ -125,19 +127,22 @@ class BaseFormatDetector:
|
|
125
127
|
),
|
126
128
|
)
|
127
129
|
)
|
130
|
+
else:
|
131
|
+
logger.warning(f"Model attempted to call undefined function: {name}")
|
128
132
|
|
129
133
|
return results
|
130
134
|
|
131
|
-
|
135
|
+
@abstractmethod
|
136
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
132
137
|
"""
|
133
138
|
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
134
139
|
Note that leftover_text here represents "content that this parser will not consume further".
|
135
140
|
"""
|
136
141
|
action = json.loads(text)
|
137
|
-
return self.parse_base_json(action, tools)
|
142
|
+
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
138
143
|
|
139
144
|
def parse_streaming_increment(
|
140
|
-
self, new_text: str, tools: List[
|
145
|
+
self, new_text: str, tools: List[Tool]
|
141
146
|
) -> StreamingParseResult:
|
142
147
|
"""
|
143
148
|
Streaming incremental parsing with tool validation.
|
@@ -196,7 +201,7 @@ class BaseFormatDetector:
|
|
196
201
|
obj["arguments"] = obj["parameters"]
|
197
202
|
tool_call_arr.append(obj)
|
198
203
|
|
199
|
-
except
|
204
|
+
except MalformedJSON:
|
200
205
|
return StreamingParseResult()
|
201
206
|
|
202
207
|
if len(tool_call_arr) == 0:
|
@@ -285,7 +290,6 @@ class BaseFormatDetector:
|
|
285
290
|
calls=[
|
286
291
|
ToolCallItem(
|
287
292
|
tool_index=self.current_tool_id,
|
288
|
-
name="",
|
289
293
|
parameters=argument_diff,
|
290
294
|
)
|
291
295
|
],
|
@@ -302,6 +306,14 @@ class BaseFormatDetector:
|
|
302
306
|
logger.error(f"Error in parse_streaming_increment: {e}")
|
303
307
|
return StreamingParseResult()
|
304
308
|
|
309
|
+
@abstractmethod
|
310
|
+
def has_tool_call(self, text: str) -> bool:
|
311
|
+
raise NotImplementedError()
|
312
|
+
|
313
|
+
@abstractmethod
|
314
|
+
def structure_info(self) -> _GetInfoFunc:
|
315
|
+
raise NotImplementedError()
|
316
|
+
|
305
317
|
|
306
318
|
class Qwen25Detector(BaseFormatDetector):
|
307
319
|
"""
|
@@ -322,7 +334,7 @@ class Qwen25Detector(BaseFormatDetector):
|
|
322
334
|
"""Check if the text contains a Qwen 2.5 format tool call."""
|
323
335
|
return self.bot_token in text
|
324
336
|
|
325
|
-
def detect_and_parse(self, text: str, tools: List[
|
337
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
326
338
|
"""
|
327
339
|
One-time parsing: Detects and parses tool calls in the provided text.
|
328
340
|
|
@@ -330,15 +342,24 @@ class Qwen25Detector(BaseFormatDetector):
|
|
330
342
|
:param tools: List of available tools.
|
331
343
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
332
344
|
"""
|
333
|
-
|
334
|
-
|
335
|
-
|
345
|
+
idx = text.find(self.bot_token)
|
346
|
+
normal_text = text[:idx].strip() if idx != -1 else text
|
347
|
+
if self.bot_token not in text:
|
348
|
+
return StreamingParseResult(normal_text=normal_text, calls=[])
|
349
|
+
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
336
350
|
match_result_list = re.findall(pattern, text, re.DOTALL)
|
337
351
|
calls = []
|
338
352
|
for match_result in match_result_list:
|
339
353
|
match_result = json.loads(match_result)
|
340
354
|
calls.extend(self.parse_base_json(match_result, tools))
|
341
|
-
return calls
|
355
|
+
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
356
|
+
|
357
|
+
def structure_info(self) -> _GetInfoFunc:
|
358
|
+
return lambda name: StructureInfo(
|
359
|
+
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
360
|
+
end="}</tool_call>",
|
361
|
+
trigger="<tool_call>",
|
362
|
+
)
|
342
363
|
|
343
364
|
|
344
365
|
class MistralDetector(BaseFormatDetector):
|
@@ -374,7 +395,7 @@ class MistralDetector(BaseFormatDetector):
|
|
374
395
|
else:
|
375
396
|
return ""
|
376
397
|
|
377
|
-
def detect_and_parse(self, text: str, tools: List[
|
398
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
378
399
|
"""
|
379
400
|
One-time parsing: Detects and parses tool calls in the provided text.
|
380
401
|
|
@@ -382,6 +403,8 @@ class MistralDetector(BaseFormatDetector):
|
|
382
403
|
:param tools: List of available tools.
|
383
404
|
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
384
405
|
"""
|
406
|
+
idx = text.find(self.bot_token)
|
407
|
+
normal_text = text[:idx].strip() if idx != -1 else text
|
385
408
|
text = self._clean_text(text)
|
386
409
|
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
387
410
|
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
@@ -391,7 +414,14 @@ class MistralDetector(BaseFormatDetector):
|
|
391
414
|
function_call_arr = json.loads(raw_tool_call)
|
392
415
|
for match_result in function_call_arr:
|
393
416
|
calls.extend(self.parse_base_json(match_result, tools))
|
394
|
-
return calls
|
417
|
+
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
418
|
+
|
419
|
+
def structure_info(self) -> _GetInfoFunc:
|
420
|
+
return lambda name: StructureInfo(
|
421
|
+
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
422
|
+
end="}]",
|
423
|
+
trigger="[TOOL_CALLS]",
|
424
|
+
)
|
395
425
|
|
396
426
|
|
397
427
|
class Llama32Detector(BaseFormatDetector):
|
@@ -411,19 +441,18 @@ class Llama32Detector(BaseFormatDetector):
|
|
411
441
|
# prefix the output with the <|python_tag|> token
|
412
442
|
return "<|python_tag|>" in text or text.startswith("{")
|
413
443
|
|
414
|
-
def detect_and_parse(self, text: str, tools: List[
|
444
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
415
445
|
"""Parse function calls from text, handling multiple JSON objects."""
|
416
446
|
if "<|python_tag|>" not in text and not text.startswith("{"):
|
417
|
-
return []
|
447
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
418
448
|
|
419
449
|
if "<|python_tag|>" in text:
|
420
|
-
|
450
|
+
normal_text, action_text = text.split("<|python_tag|>")
|
421
451
|
else:
|
422
|
-
action_text = text
|
452
|
+
normal_text, action_text = "", text
|
423
453
|
|
424
454
|
# Split by semicolon and process each part
|
425
455
|
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
426
|
-
|
427
456
|
all_actions = []
|
428
457
|
for part in json_parts:
|
429
458
|
try:
|
@@ -434,12 +463,18 @@ class Llama32Detector(BaseFormatDetector):
|
|
434
463
|
logger.warning(f"Failed to parse JSON part: {part}")
|
435
464
|
logger.warning(f"JSON parse error: {str(e)}")
|
436
465
|
continue
|
437
|
-
|
466
|
+
calls = []
|
438
467
|
# Only process if we found valid JSON objects
|
439
468
|
if all_actions:
|
440
|
-
|
441
|
-
|
442
|
-
|
469
|
+
calls = self.parse_base_json(all_actions, tools)
|
470
|
+
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
471
|
+
|
472
|
+
def structure_info(self) -> _GetInfoFunc:
|
473
|
+
return lambda name: StructureInfo(
|
474
|
+
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
|
475
|
+
end="}",
|
476
|
+
trigger="<|python_tag|>",
|
477
|
+
)
|
443
478
|
|
444
479
|
|
445
480
|
class MultiFormatParser:
|
@@ -449,7 +484,9 @@ class MultiFormatParser:
|
|
449
484
|
"""
|
450
485
|
self.detectors = detectors
|
451
486
|
|
452
|
-
def parse_once(
|
487
|
+
def parse_once(
|
488
|
+
self, text: str, tools: List[Tool]
|
489
|
+
) -> Tuple[str, list[ToolCallItem]]:
|
453
490
|
"""
|
454
491
|
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
455
492
|
Return: (final_text, all_calls)
|
@@ -459,15 +496,19 @@ class MultiFormatParser:
|
|
459
496
|
final_calls = []
|
460
497
|
final_normal_text = text
|
461
498
|
for detector in self.detectors:
|
462
|
-
|
499
|
+
parsed_result = detector.detect_and_parse(text, tools)
|
500
|
+
tool_call_list = parsed_result.calls
|
463
501
|
if len(tool_call_list) > 0: # parsed successfully
|
464
502
|
final_calls = tool_call_list
|
503
|
+
final_normal_text = parsed_result.normal_text
|
465
504
|
break
|
466
505
|
|
467
506
|
# leftover_text is the normal text not consumed by any Detector
|
468
507
|
return final_normal_text, final_calls
|
469
508
|
|
470
|
-
def parse_streaming_increment(
|
509
|
+
def parse_streaming_increment(
|
510
|
+
self, new_text: str, tools: List[Tool]
|
511
|
+
) -> Tuple[str, list[ToolCallItem]]:
|
471
512
|
"""
|
472
513
|
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
473
514
|
and merge their produced normal_text/calls to return.
|
@@ -498,13 +539,13 @@ class FunctionCallParser:
|
|
498
539
|
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
499
540
|
"""
|
500
541
|
|
501
|
-
ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
|
542
|
+
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
|
502
543
|
"llama3": Llama32Detector,
|
503
544
|
"qwen25": Qwen25Detector,
|
504
545
|
"mistral": MistralDetector,
|
505
546
|
}
|
506
547
|
|
507
|
-
def __init__(self, tools: List[
|
548
|
+
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
508
549
|
detectors = []
|
509
550
|
if tool_call_parser:
|
510
551
|
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
@@ -532,7 +573,7 @@ class FunctionCallParser:
|
|
532
573
|
return True
|
533
574
|
return False
|
534
575
|
|
535
|
-
def parse_non_stream(self, full_text: str):
|
576
|
+
def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
|
536
577
|
"""
|
537
578
|
Non-streaming call: one-time parsing
|
538
579
|
"""
|
@@ -541,7 +582,7 @@ class FunctionCallParser:
|
|
541
582
|
)
|
542
583
|
return full_normal_text, calls
|
543
584
|
|
544
|
-
def parse_stream_chunk(self, chunk_text: str):
|
585
|
+
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
|
545
586
|
"""
|
546
587
|
Streaming call: incremental parsing
|
547
588
|
"""
|
@@ -549,3 +590,40 @@ class FunctionCallParser:
|
|
549
590
|
chunk_text, self.tools
|
550
591
|
)
|
551
592
|
return normal_text, calls
|
593
|
+
|
594
|
+
def structure_infos(self) -> List[_GetInfoFunc]:
|
595
|
+
"""
|
596
|
+
Returns a list of structure_info functions for each detector
|
597
|
+
"""
|
598
|
+
return [
|
599
|
+
detector.structure_info() for detector in self.multi_format_parser.detectors
|
600
|
+
]
|
601
|
+
|
602
|
+
def get_structure_tag(self) -> StructuralTagResponseFormat:
|
603
|
+
tool_structures: List[StructuresResponseFormat] = list()
|
604
|
+
tool_trigger_set: Set[str] = set()
|
605
|
+
|
606
|
+
for wrapper in self.structure_infos():
|
607
|
+
for tool in self.tools:
|
608
|
+
function = tool.function
|
609
|
+
name = function.name
|
610
|
+
assert name is not None
|
611
|
+
info = wrapper(name)
|
612
|
+
|
613
|
+
# accept all if not strict, otherwise only accept the schema
|
614
|
+
schema = function.parameters if function.strict else {}
|
615
|
+
|
616
|
+
tool_structures.append(
|
617
|
+
StructuresResponseFormat(
|
618
|
+
begin=info.begin,
|
619
|
+
schema=schema, # type: ignore
|
620
|
+
end=info.end,
|
621
|
+
)
|
622
|
+
)
|
623
|
+
tool_trigger_set.add(info.trigger)
|
624
|
+
|
625
|
+
return StructuralTagResponseFormat(
|
626
|
+
type="structural_tag",
|
627
|
+
structures=tool_structures,
|
628
|
+
triggers=list(tool_trigger_set),
|
629
|
+
)
|