sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
sglang/srt/models/xverse_moe.py
CHANGED
@@ -18,25 +18,25 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
18
18
|
import torch
|
19
19
|
from torch import nn
|
20
20
|
from transformers import PretrainedConfig
|
21
|
-
|
21
|
+
|
22
|
+
from sglang.srt.distributed import (
|
22
23
|
get_tensor_model_parallel_rank,
|
23
24
|
get_tensor_model_parallel_world_size,
|
24
25
|
tensor_model_parallel_all_reduce,
|
25
26
|
)
|
26
|
-
from
|
27
|
-
from
|
28
|
-
from
|
27
|
+
from sglang.srt.layers.activation import SiluAndMul
|
28
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
29
|
+
from sglang.srt.layers.linear import (
|
29
30
|
MergedColumnParallelLinear,
|
30
31
|
QKVParallelLinear,
|
31
32
|
ReplicatedLinear,
|
32
33
|
RowParallelLinear,
|
33
34
|
)
|
34
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
35
|
-
|
36
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
36
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
38
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
40
40
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
41
41
|
ParallelLMHead,
|
42
42
|
VocabParallelEmbedding,
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -20,7 +20,7 @@ import os
|
|
20
20
|
import time
|
21
21
|
import uuid
|
22
22
|
from http import HTTPStatus
|
23
|
-
from typing import Dict, List
|
23
|
+
from typing import Dict, List, Optional
|
24
24
|
|
25
25
|
from fastapi import HTTPException, Request, UploadFile
|
26
26
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
|
|
40
40
|
generate_chat_conv,
|
41
41
|
register_conv_template,
|
42
42
|
)
|
43
|
+
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
43
44
|
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
44
45
|
from sglang.srt.openai_api.protocol import (
|
45
46
|
BatchRequest,
|
@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
|
|
71
72
|
TopLogprob,
|
72
73
|
UsageInfo,
|
73
74
|
)
|
74
|
-
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
|
75
75
|
from sglang.utils import get_exception_traceback
|
76
76
|
|
77
77
|
logger = logging.getLogger(__name__)
|
@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
309
309
|
ret,
|
310
310
|
to_file=True,
|
311
311
|
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
312
|
+
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
312
313
|
)
|
313
314
|
else:
|
314
315
|
responses = v1_generate_response(
|
@@ -877,9 +878,6 @@ def v1_chat_generate_request(
|
|
877
878
|
tools = None
|
878
879
|
if request.tools and request.tool_choice != "none":
|
879
880
|
request.skip_special_tokens = False
|
880
|
-
if request.stream:
|
881
|
-
logger.warning("Streaming is not supported with tools.")
|
882
|
-
request.stream = False
|
883
881
|
if not isinstance(request.tool_choice, str):
|
884
882
|
tools = [
|
885
883
|
item.function.model_dump()
|
@@ -908,12 +906,26 @@ def v1_chat_generate_request(
|
|
908
906
|
openai_compatible_messages = openai_compatible_messages[:-1]
|
909
907
|
else:
|
910
908
|
assistant_prefix = None
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
909
|
+
|
910
|
+
try:
|
911
|
+
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
912
|
+
openai_compatible_messages,
|
913
|
+
tokenize=True,
|
914
|
+
add_generation_prompt=True,
|
915
|
+
tools=tools,
|
916
|
+
)
|
917
|
+
except:
|
918
|
+
# This except branch will be triggered when the chosen model
|
919
|
+
# has a different tools input format that is not compatiable
|
920
|
+
# with openAI's apply_chat_template tool_call format, like Mistral.
|
921
|
+
tools = [t if "function" in t else {"function": t} for t in tools]
|
922
|
+
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
923
|
+
openai_compatible_messages,
|
924
|
+
tokenize=True,
|
925
|
+
add_generation_prompt=True,
|
926
|
+
tools=tools,
|
927
|
+
)
|
928
|
+
|
917
929
|
if assistant_prefix:
|
918
930
|
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
919
931
|
stop = request.stop
|
@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
|
|
1005
1017
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
1006
1018
|
|
1007
1019
|
|
1008
|
-
def v1_chat_generate_response(
|
1020
|
+
def v1_chat_generate_response(
|
1021
|
+
request, ret, to_file=False, cache_report=False, tool_call_parser=None
|
1022
|
+
):
|
1009
1023
|
choices = []
|
1010
1024
|
|
1011
1025
|
for idx, ret_item in enumerate(ret):
|
@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
1066
1080
|
if finish_reason == "stop":
|
1067
1081
|
finish_reason = "tool_calls"
|
1068
1082
|
try:
|
1069
|
-
|
1083
|
+
parser = FunctionCallParser(tools, tool_call_parser)
|
1084
|
+
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
1070
1085
|
tool_calls = [
|
1071
1086
|
ToolCall(
|
1072
|
-
id=str(call_info
|
1087
|
+
id=str(call_info.tool_index),
|
1073
1088
|
function=FunctionResponse(
|
1074
|
-
name=call_info
|
1089
|
+
name=call_info.name, arguments=call_info.parameters
|
1075
1090
|
),
|
1076
1091
|
)
|
1077
1092
|
for call_info in call_info_list
|
@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1172
1187
|
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
1173
1188
|
|
1174
1189
|
if adapted_request.stream:
|
1190
|
+
parser_dict = {}
|
1175
1191
|
|
1176
1192
|
async def generate_stream_resp():
|
1177
1193
|
is_firsts = {}
|
@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1184
1200
|
adapted_request, raw_request
|
1185
1201
|
):
|
1186
1202
|
index = content.get("index", 0)
|
1203
|
+
text = content["text"]
|
1187
1204
|
|
1188
1205
|
is_first = is_firsts.get(index, True)
|
1189
1206
|
stream_buffer = stream_buffers.get(index, "")
|
@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1263
1280
|
|
1264
1281
|
text = content["text"]
|
1265
1282
|
delta = text[len(stream_buffer) :]
|
1266
|
-
|
1267
|
-
choice_data = ChatCompletionResponseStreamChoice(
|
1268
|
-
index=index,
|
1269
|
-
delta=DeltaMessage(content=delta),
|
1270
|
-
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
1271
|
-
matched_stop=(
|
1272
|
-
finish_reason["matched"]
|
1273
|
-
if finish_reason and "matched" in finish_reason
|
1274
|
-
else None
|
1275
|
-
),
|
1276
|
-
logprobs=choice_logprobs,
|
1277
|
-
)
|
1278
|
-
chunk = ChatCompletionStreamResponse(
|
1279
|
-
id=content["meta_info"]["id"],
|
1280
|
-
choices=[choice_data],
|
1281
|
-
model=request.model,
|
1282
|
-
)
|
1283
|
+
new_stream_buffer = stream_buffer + delta
|
1283
1284
|
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1285
|
+
if request.tool_choice != "none" and request.tools:
|
1286
|
+
if index not in parser_dict:
|
1287
|
+
parser_dict[index] = FunctionCallParser(
|
1288
|
+
tools=request.tools,
|
1289
|
+
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
1290
|
+
)
|
1291
|
+
parser = parser_dict[index]
|
1292
|
+
|
1293
|
+
# parse_increment => returns (normal_text, calls)
|
1294
|
+
normal_text, calls = parser.parse_stream_chunk(delta)
|
1295
|
+
|
1296
|
+
# 1) if there's normal_text, output it as normal content
|
1297
|
+
if normal_text:
|
1298
|
+
choice_data = ChatCompletionResponseStreamChoice(
|
1299
|
+
index=index,
|
1300
|
+
delta=DeltaMessage(content=normal_text),
|
1301
|
+
finish_reason=(
|
1302
|
+
finish_reason["type"] if finish_reason else ""
|
1303
|
+
),
|
1304
|
+
)
|
1305
|
+
chunk = ChatCompletionStreamResponse(
|
1306
|
+
id=content["meta_info"]["id"],
|
1307
|
+
choices=[choice_data],
|
1308
|
+
model=request.model,
|
1309
|
+
)
|
1310
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
1311
|
+
|
1312
|
+
# 2) if we found calls, we output them as separate chunk(s)
|
1313
|
+
for call_item in calls:
|
1314
|
+
# transform call_item -> FunctionResponse + ToolCall
|
1315
|
+
|
1316
|
+
if (
|
1317
|
+
content["meta_info"]["finish_reason"]
|
1318
|
+
and content["meta_info"]["finish_reason"]["type"]
|
1319
|
+
== "stop"
|
1320
|
+
):
|
1321
|
+
latest_delta_len = 0
|
1322
|
+
if isinstance(call_item.parameters, str):
|
1323
|
+
latest_delta_len = len(call_item.parameters)
|
1324
|
+
|
1325
|
+
expected_call = json.dumps(
|
1326
|
+
parser.multi_format_parser.detectors[0]
|
1327
|
+
.prev_tool_call_arr[index]
|
1328
|
+
.get("arguments", {}),
|
1329
|
+
ensure_ascii=False,
|
1330
|
+
)
|
1331
|
+
actual_call = parser.multi_format_parser.detectors[
|
1332
|
+
0
|
1333
|
+
].streamed_args_for_tool[index]
|
1334
|
+
if latest_delta_len > 0:
|
1335
|
+
actual_call = actual_call[:-latest_delta_len]
|
1336
|
+
remaining_call = expected_call.replace(
|
1337
|
+
actual_call, "", 1
|
1338
|
+
)
|
1339
|
+
call_item.parameters = remaining_call
|
1340
|
+
|
1341
|
+
tool_call = ToolCall(
|
1342
|
+
id=str(call_item.tool_index),
|
1343
|
+
function=FunctionResponse(
|
1344
|
+
name=call_item.name,
|
1345
|
+
arguments=call_item.parameters,
|
1346
|
+
),
|
1347
|
+
)
|
1348
|
+
choice_data = ChatCompletionResponseStreamChoice(
|
1349
|
+
index=index,
|
1350
|
+
delta=DeltaMessage(
|
1351
|
+
role="assistant", tool_calls=[tool_call]
|
1352
|
+
),
|
1353
|
+
finish_reason="tool_call",
|
1354
|
+
)
|
1355
|
+
chunk = ChatCompletionStreamResponse(
|
1356
|
+
id=content["meta_info"]["id"],
|
1357
|
+
choices=[choice_data],
|
1358
|
+
model=request.model,
|
1359
|
+
)
|
1360
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
1287
1361
|
|
1288
|
-
|
1362
|
+
stream_buffers[index] = new_stream_buffer
|
1363
|
+
is_firsts[index] = is_first
|
1364
|
+
|
1365
|
+
else:
|
1366
|
+
# No tool calls => just treat this as normal text
|
1367
|
+
choice_data = ChatCompletionResponseStreamChoice(
|
1368
|
+
index=index,
|
1369
|
+
delta=DeltaMessage(content=delta),
|
1370
|
+
finish_reason=(
|
1371
|
+
finish_reason["type"] if finish_reason else ""
|
1372
|
+
),
|
1373
|
+
matched_stop=(
|
1374
|
+
finish_reason["matched"]
|
1375
|
+
if finish_reason and "matched" in finish_reason
|
1376
|
+
else None
|
1377
|
+
),
|
1378
|
+
logprobs=choice_logprobs,
|
1379
|
+
)
|
1380
|
+
chunk = ChatCompletionStreamResponse(
|
1381
|
+
id=content["meta_info"]["id"],
|
1382
|
+
choices=[choice_data],
|
1383
|
+
model=request.model,
|
1384
|
+
)
|
1385
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
1386
|
+
stream_buffers[index] = new_stream_buffer
|
1387
|
+
is_firsts[index] = is_first
|
1289
1388
|
if request.stream_options and request.stream_options.include_usage:
|
1290
1389
|
total_prompt_tokens = sum(
|
1291
1390
|
tokens
|
@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1333
1432
|
ret = [ret]
|
1334
1433
|
|
1335
1434
|
response = v1_chat_generate_response(
|
1336
|
-
request,
|
1435
|
+
request,
|
1436
|
+
ret,
|
1437
|
+
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
1438
|
+
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
1337
1439
|
)
|
1338
1440
|
|
1339
1441
|
return response
|
@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
|
|
180
180
|
ignore_eos: bool = False
|
181
181
|
skip_special_tokens: bool = True
|
182
182
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
183
|
+
session_params: Optional[Dict] = None
|
183
184
|
|
184
185
|
|
185
186
|
class CompletionResponseChoice(BaseModel):
|
@@ -261,7 +262,7 @@ class Function(BaseModel):
|
|
261
262
|
"""Function descriptions."""
|
262
263
|
|
263
264
|
description: Optional[str] = Field(default=None, examples=[None])
|
264
|
-
name: str
|
265
|
+
name: Optional[str] = None
|
265
266
|
parameters: Optional[object] = None
|
266
267
|
|
267
268
|
|
@@ -275,7 +276,7 @@ class Tool(BaseModel):
|
|
275
276
|
class ToolChoiceFuncName(BaseModel):
|
276
277
|
"""The name of tool choice function."""
|
277
278
|
|
278
|
-
name: str
|
279
|
+
name: Optional[str] = None
|
279
280
|
|
280
281
|
|
281
282
|
class ToolChoice(BaseModel):
|
@@ -322,13 +323,14 @@ class ChatCompletionRequest(BaseModel):
|
|
322
323
|
ignore_eos: bool = False
|
323
324
|
skip_special_tokens: bool = True
|
324
325
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
326
|
+
session_params: Optional[Dict] = None
|
325
327
|
|
326
328
|
|
327
329
|
class FunctionResponse(BaseModel):
|
328
330
|
"""Function response."""
|
329
331
|
|
330
|
-
name: str
|
331
|
-
arguments: str
|
332
|
+
name: Optional[str] = None
|
333
|
+
arguments: Optional[str] = None
|
332
334
|
|
333
335
|
|
334
336
|
class ToolCall(BaseModel):
|
@@ -365,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
|
|
365
367
|
class DeltaMessage(BaseModel):
|
366
368
|
role: Optional[str] = None
|
367
369
|
content: Optional[str] = None
|
370
|
+
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
368
371
|
|
369
372
|
|
370
373
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import json
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from functools import lru_cache
|
4
|
+
from typing import Any, Dict, List, Optional
|
5
|
+
|
6
|
+
import dill
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
@lru_cache(maxsize=None)
|
11
|
+
def _cache_from_str(json_str: str):
|
12
|
+
"""Deserialize a json string to a Callable object.
|
13
|
+
This function is cached to avoid redundant deserialization.
|
14
|
+
"""
|
15
|
+
data = json.loads(json_str)
|
16
|
+
return dill.loads(bytes.fromhex(data["callable"]))
|
17
|
+
|
18
|
+
|
19
|
+
class CustomLogitProcessor(ABC):
|
20
|
+
"""Abstract base class for callable functions."""
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def __call__(
|
24
|
+
self,
|
25
|
+
logits: torch.Tensor,
|
26
|
+
custom_param_list: Optional[List[Dict[str, Any]]] = None,
|
27
|
+
) -> torch.Tensor:
|
28
|
+
"""Define the callable behavior."""
|
29
|
+
raise NotImplementedError
|
30
|
+
|
31
|
+
def to_str(self) -> str:
|
32
|
+
"""Serialize the callable function to a JSON-compatible string."""
|
33
|
+
return json.dumps({"callable": dill.dumps(self).hex()})
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def from_str(cls, json_str: str):
|
37
|
+
"""Deserialize a callable function from a JSON string."""
|
38
|
+
return _cache_from_str(json_str)
|
@@ -3,11 +3,16 @@ from typing import List
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
-
from sglang.srt.utils import
|
6
|
+
from sglang.srt.utils import get_compiler_backend
|
7
7
|
|
8
|
-
|
9
|
-
|
10
|
-
|
8
|
+
|
9
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
10
|
+
def apply_scaling_penalties(logits, scaling_penalties):
|
11
|
+
logits[:] = torch.where(
|
12
|
+
logits > 0,
|
13
|
+
logits / scaling_penalties,
|
14
|
+
logits * scaling_penalties,
|
15
|
+
)
|
11
16
|
|
12
17
|
|
13
18
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
@@ -61,16 +66,8 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
61
66
|
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
62
67
|
|
63
68
|
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
64
|
-
|
65
|
-
|
66
|
-
logits, self.cumulated_repetition_penalties
|
67
|
-
)
|
68
|
-
else:
|
69
|
-
return torch.where(
|
70
|
-
logits > 0,
|
71
|
-
logits / self.cumulated_repetition_penalties,
|
72
|
-
logits * self.cumulated_repetition_penalties,
|
73
|
-
)
|
69
|
+
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
|
70
|
+
return logits
|
74
71
|
|
75
72
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
76
73
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
@@ -3,17 +3,15 @@ from __future__ import annotations
|
|
3
3
|
import dataclasses
|
4
4
|
import logging
|
5
5
|
import threading
|
6
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
|
-
from sglang.srt.utils import is_cuda_available
|
11
|
-
|
12
|
-
is_cuda = is_cuda_available()
|
13
|
-
if is_cuda:
|
14
|
-
from sgl_kernel import sampling_scaling_penalties
|
15
|
-
|
16
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
11
|
+
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
12
|
+
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
|
13
|
+
apply_scaling_penalties,
|
14
|
+
)
|
17
15
|
|
18
16
|
logger = logging.getLogger(__name__)
|
19
17
|
|
@@ -36,6 +34,9 @@ class SamplingBatchInfo:
|
|
36
34
|
# Dispatch in CUDA graph
|
37
35
|
need_min_p_sampling: bool
|
38
36
|
|
37
|
+
# Whether any request has custom logit processor
|
38
|
+
has_custom_logit_processor: bool
|
39
|
+
|
39
40
|
# Bias Tensors
|
40
41
|
vocab_size: int
|
41
42
|
grammars: Optional[List] = None
|
@@ -52,6 +53,14 @@ class SamplingBatchInfo:
|
|
52
53
|
# Device
|
53
54
|
device: str = "cuda"
|
54
55
|
|
56
|
+
# Custom Parameters
|
57
|
+
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
58
|
+
|
59
|
+
# Custom Logit Processor
|
60
|
+
custom_logit_processor: Optional[
|
61
|
+
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
62
|
+
] = None
|
63
|
+
|
55
64
|
@classmethod
|
56
65
|
def from_schedule_batch(
|
57
66
|
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
@@ -76,6 +85,39 @@ class SamplingBatchInfo:
|
|
76
85
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
77
86
|
).to(device, non_blocking=True)
|
78
87
|
|
88
|
+
# Check if any request has custom logit processor
|
89
|
+
has_custom_logit_processor = (
|
90
|
+
batch.enable_custom_logit_processor # check the flag first.
|
91
|
+
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
92
|
+
)
|
93
|
+
|
94
|
+
if has_custom_logit_processor:
|
95
|
+
# Merge the same type of custom logit processors together
|
96
|
+
processor_dict = {}
|
97
|
+
for i, r in enumerate(reqs):
|
98
|
+
if r.custom_logit_processor is None:
|
99
|
+
continue
|
100
|
+
processor_str = r.custom_logit_processor
|
101
|
+
if processor_str not in processor_dict:
|
102
|
+
processor_dict[processor_str] = []
|
103
|
+
processor_dict[processor_str].append(i)
|
104
|
+
|
105
|
+
merged_custom_logit_processor = {
|
106
|
+
hash(processor_str): (
|
107
|
+
# The deserialized custom logit processor object
|
108
|
+
CustomLogitProcessor.from_str(processor_str),
|
109
|
+
# The mask tensor for the requests that use this custom logit processor
|
110
|
+
torch.zeros(len(reqs), dtype=torch.bool)
|
111
|
+
.scatter_(0, torch.tensor(true_indices), True)
|
112
|
+
.to(device, non_blocking=True),
|
113
|
+
)
|
114
|
+
for processor_str, true_indices in processor_dict.items()
|
115
|
+
}
|
116
|
+
custom_params = [r.sampling_params.custom_params for r in reqs]
|
117
|
+
else:
|
118
|
+
merged_custom_logit_processor = None
|
119
|
+
custom_params = None
|
120
|
+
|
79
121
|
ret = cls(
|
80
122
|
temperatures=temperatures,
|
81
123
|
top_ps=top_ps,
|
@@ -83,8 +125,11 @@ class SamplingBatchInfo:
|
|
83
125
|
min_ps=min_ps,
|
84
126
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
85
127
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
128
|
+
has_custom_logit_processor=has_custom_logit_processor,
|
86
129
|
vocab_size=vocab_size,
|
87
130
|
device=device,
|
131
|
+
custom_params=custom_params,
|
132
|
+
custom_logit_processor=merged_custom_logit_processor,
|
88
133
|
)
|
89
134
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
90
135
|
|
@@ -184,6 +229,8 @@ class SamplingBatchInfo:
|
|
184
229
|
|
185
230
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
186
231
|
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
232
|
+
if self.has_custom_logit_processor:
|
233
|
+
self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
|
187
234
|
|
188
235
|
for item in [
|
189
236
|
"temperatures",
|
@@ -196,6 +243,27 @@ class SamplingBatchInfo:
|
|
196
243
|
if value is not None: # logit_bias can be None
|
197
244
|
setattr(self, item, value[new_indices])
|
198
245
|
|
246
|
+
def _filter_batch_custom_logit_processor(
|
247
|
+
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
248
|
+
):
|
249
|
+
"""Filter the custom logit processor and custom params"""
|
250
|
+
|
251
|
+
self.custom_logit_processor = {
|
252
|
+
k: (p, mask[new_indices])
|
253
|
+
for k, (p, mask) in self.custom_logit_processor.items()
|
254
|
+
if any(
|
255
|
+
mask[new_indices]
|
256
|
+
) # ignore the custom logit processor whose mask is all False
|
257
|
+
}
|
258
|
+
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
259
|
+
|
260
|
+
# If the custom logit processor is an empty dict, set the flag to False,
|
261
|
+
# and set the custom logit processor and custom params to None.
|
262
|
+
if len(self.custom_logit_processor) == 0:
|
263
|
+
self.custom_logit_processor = None
|
264
|
+
self.custom_params = None
|
265
|
+
self.has_custom_logit_processor = False
|
266
|
+
|
199
267
|
@staticmethod
|
200
268
|
def merge_bias_tensor(
|
201
269
|
lhs: torch.Tensor,
|
@@ -221,9 +289,76 @@ class SamplingBatchInfo:
|
|
221
289
|
|
222
290
|
return None
|
223
291
|
|
292
|
+
@staticmethod
|
293
|
+
def merge_custom_logit_processor(
|
294
|
+
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
295
|
+
rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
296
|
+
bs1: int,
|
297
|
+
bs2: int,
|
298
|
+
device: str,
|
299
|
+
):
|
300
|
+
if lhs is None and rhs is None:
|
301
|
+
return None
|
302
|
+
lhs, rhs = lhs or {}, rhs or {}
|
303
|
+
|
304
|
+
keys = set(lhs.keys()).union(set(rhs.keys()))
|
305
|
+
merged_dict = {}
|
306
|
+
|
307
|
+
for k in keys:
|
308
|
+
# Get the logit processor object
|
309
|
+
processor = lhs[k][0] if k in lhs else rhs[k][0]
|
310
|
+
# Get and merge the mask tensors from the two dicts
|
311
|
+
left_mask = (
|
312
|
+
lhs[k][1]
|
313
|
+
if k in lhs
|
314
|
+
else torch.zeros(bs1, dtype=torch.bool, device=device)
|
315
|
+
)
|
316
|
+
right_mask = (
|
317
|
+
rhs[k][1]
|
318
|
+
if k in rhs
|
319
|
+
else torch.zeros(bs2, dtype=torch.bool, device=device)
|
320
|
+
)
|
321
|
+
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
|
322
|
+
|
323
|
+
assert merged_dict[k][1].shape[0] == bs1 + bs2, (
|
324
|
+
f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
|
325
|
+
f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
|
326
|
+
f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
|
327
|
+
f"\n{lhs=}\n{rhs=}"
|
328
|
+
)
|
329
|
+
|
330
|
+
return merged_dict
|
331
|
+
|
224
332
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
225
333
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
226
334
|
|
335
|
+
# Merge the logit bias tensor
|
336
|
+
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
337
|
+
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
338
|
+
)
|
339
|
+
# Merge the custom logit processors and custom params lists
|
340
|
+
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
341
|
+
# Merge the custom logit processors
|
342
|
+
self.custom_logit_processor = (
|
343
|
+
SamplingBatchInfo.merge_custom_logit_processor(
|
344
|
+
self.custom_logit_processor,
|
345
|
+
other.custom_logit_processor,
|
346
|
+
len(self),
|
347
|
+
len(other),
|
348
|
+
self.device,
|
349
|
+
)
|
350
|
+
)
|
351
|
+
# Merge the custom params lists
|
352
|
+
self.custom_params = self.custom_params or [None] * len(self)
|
353
|
+
other.custom_params = other.custom_params or [None] * len(other)
|
354
|
+
self.custom_params.extend(other.custom_params)
|
355
|
+
|
356
|
+
# Set the flag to True if any of the two has custom logit processor
|
357
|
+
self.has_custom_logit_processor = True
|
358
|
+
|
359
|
+
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
|
360
|
+
# please make sure any merge operation with len(self) or len(other) is done before
|
361
|
+
# the merge operation of the temperatures tensor below.
|
227
362
|
for item in [
|
228
363
|
"temperatures",
|
229
364
|
"top_ps",
|
@@ -235,9 +370,6 @@ class SamplingBatchInfo:
|
|
235
370
|
setattr(self, item, torch.concat([self_val, other_val]))
|
236
371
|
|
237
372
|
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
238
|
-
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
239
|
-
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
240
|
-
)
|
241
373
|
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
242
374
|
|
243
375
|
def apply_logits_bias(self, logits: torch.Tensor):
|
@@ -251,14 +383,7 @@ class SamplingBatchInfo:
|
|
251
383
|
|
252
384
|
# repetition
|
253
385
|
if self.scaling_penalties is not None:
|
254
|
-
|
255
|
-
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
|
256
|
-
else:
|
257
|
-
logits[:] = torch.where(
|
258
|
-
logits > 0,
|
259
|
-
logits / self.scaling_penalties,
|
260
|
-
logits * self.scaling_penalties,
|
261
|
-
)
|
386
|
+
apply_scaling_penalties(logits, self.scaling_penalties)
|
262
387
|
|
263
388
|
# Apply regex vocab_mask
|
264
389
|
if self.vocab_mask is not None:
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Sampling parameters for text generation."""
|
15
15
|
|
16
|
-
from typing import List, Optional, Union
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
17
17
|
|
18
18
|
_SAMPLING_EPS = 1e-6
|
19
19
|
|
@@ -48,6 +48,7 @@ class SamplingParams:
|
|
48
48
|
no_stop_trim: bool = False,
|
49
49
|
ignore_eos: bool = False,
|
50
50
|
skip_special_tokens: bool = True,
|
51
|
+
custom_params: Optional[Dict[str, Any]] = None,
|
51
52
|
) -> None:
|
52
53
|
self.temperature = temperature
|
53
54
|
self.top_p = top_p
|
@@ -71,6 +72,7 @@ class SamplingParams:
|
|
71
72
|
self.json_schema = json_schema
|
72
73
|
self.ebnf = ebnf
|
73
74
|
self.no_stop_trim = no_stop_trim
|
75
|
+
self.custom_params = custom_params
|
74
76
|
|
75
77
|
# Process some special cases
|
76
78
|
if self.temperature < _SAMPLING_EPS:
|