sglang 0.4.1.post7__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +17 -11
- sglang/bench_one_batch.py +14 -6
- sglang/bench_serving.py +47 -44
- sglang/lang/chat_template.py +31 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
- sglang/srt/entrypoints/engine.py +5 -2
- sglang/srt/entrypoints/http_server.py +24 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +5 -5
- sglang/srt/layers/dp_attention.py +3 -1
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +24 -9
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -12
- sglang/srt/layers/moe/fused_moe_native.py +17 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
- sglang/srt/layers/parameter.py +16 -7
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +4 -1
- sglang/srt/layers/rotary_embedding.py +6 -1
- sglang/srt/layers/sampler.py +28 -8
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +36 -5
- sglang/srt/managers/schedule_batch.py +31 -25
- sglang/srt/managers/scheduler.py +61 -35
- sglang/srt/managers/tokenizer_manager.py +4 -0
- sglang/srt/model_executor/cuda_graph_runner.py +23 -25
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +7 -4
- sglang/srt/model_loader/loader.py +75 -0
- sglang/srt/model_loader/weight_utils.py +91 -5
- sglang/srt/models/commandr.py +14 -2
- sglang/srt/models/dbrx.py +9 -1
- sglang/srt/models/deepseek_v2.py +3 -3
- sglang/srt/models/gemma2.py +9 -1
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/minicpm3.py +3 -3
- sglang/srt/models/torch_native_llama.py +17 -4
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +5 -4
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +4 -14
- sglang/srt/server.py +2 -2
- sglang/srt/server_args.py +20 -1
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +62 -65
- sglang/test/test_programs.py +1 -0
- sglang/test/test_utils.py +81 -22
- sglang/version.py +1 -1
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/METADATA +7 -7
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/RECORD +67 -56
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -20,7 +20,7 @@ import os
|
|
20
20
|
import time
|
21
21
|
import uuid
|
22
22
|
from http import HTTPStatus
|
23
|
-
from typing import Dict, List
|
23
|
+
from typing import Dict, List, Optional
|
24
24
|
|
25
25
|
from fastapi import HTTPException, Request, UploadFile
|
26
26
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
|
|
40
40
|
generate_chat_conv,
|
41
41
|
register_conv_template,
|
42
42
|
)
|
43
|
+
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
43
44
|
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
44
45
|
from sglang.srt.openai_api.protocol import (
|
45
46
|
BatchRequest,
|
@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
|
|
71
72
|
TopLogprob,
|
72
73
|
UsageInfo,
|
73
74
|
)
|
74
|
-
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
|
75
75
|
from sglang.utils import get_exception_traceback
|
76
76
|
|
77
77
|
logger = logging.getLogger(__name__)
|
@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
309
309
|
ret,
|
310
310
|
to_file=True,
|
311
311
|
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
312
|
+
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
312
313
|
)
|
313
314
|
else:
|
314
315
|
responses = v1_generate_response(
|
@@ -877,9 +878,6 @@ def v1_chat_generate_request(
|
|
877
878
|
tools = None
|
878
879
|
if request.tools and request.tool_choice != "none":
|
879
880
|
request.skip_special_tokens = False
|
880
|
-
if request.stream:
|
881
|
-
logger.warning("Streaming is not supported with tools.")
|
882
|
-
request.stream = False
|
883
881
|
if not isinstance(request.tool_choice, str):
|
884
882
|
tools = [
|
885
883
|
item.function.model_dump()
|
@@ -908,12 +906,26 @@ def v1_chat_generate_request(
|
|
908
906
|
openai_compatible_messages = openai_compatible_messages[:-1]
|
909
907
|
else:
|
910
908
|
assistant_prefix = None
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
909
|
+
|
910
|
+
try:
|
911
|
+
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
912
|
+
openai_compatible_messages,
|
913
|
+
tokenize=True,
|
914
|
+
add_generation_prompt=True,
|
915
|
+
tools=tools,
|
916
|
+
)
|
917
|
+
except:
|
918
|
+
# This except branch will be triggered when the chosen model
|
919
|
+
# has a different tools input format that is not compatiable
|
920
|
+
# with openAI's apply_chat_template tool_call format, like Mistral.
|
921
|
+
tools = [t if "function" in t else {"function": t} for t in tools]
|
922
|
+
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
923
|
+
openai_compatible_messages,
|
924
|
+
tokenize=True,
|
925
|
+
add_generation_prompt=True,
|
926
|
+
tools=tools,
|
927
|
+
)
|
928
|
+
|
917
929
|
if assistant_prefix:
|
918
930
|
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
919
931
|
stop = request.stop
|
@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
|
|
1005
1017
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
1006
1018
|
|
1007
1019
|
|
1008
|
-
def v1_chat_generate_response(
|
1020
|
+
def v1_chat_generate_response(
|
1021
|
+
request, ret, to_file=False, cache_report=False, tool_call_parser=None
|
1022
|
+
):
|
1009
1023
|
choices = []
|
1010
1024
|
|
1011
1025
|
for idx, ret_item in enumerate(ret):
|
@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
1066
1080
|
if finish_reason == "stop":
|
1067
1081
|
finish_reason = "tool_calls"
|
1068
1082
|
try:
|
1069
|
-
|
1083
|
+
parser = FunctionCallParser(tools, tool_call_parser)
|
1084
|
+
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
1070
1085
|
tool_calls = [
|
1071
1086
|
ToolCall(
|
1072
|
-
id=str(call_info
|
1087
|
+
id=str(call_info.tool_index),
|
1073
1088
|
function=FunctionResponse(
|
1074
|
-
name=call_info
|
1089
|
+
name=call_info.name, arguments=call_info.parameters
|
1075
1090
|
),
|
1076
1091
|
)
|
1077
1092
|
for call_info in call_info_list
|
@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1172
1187
|
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
1173
1188
|
|
1174
1189
|
if adapted_request.stream:
|
1190
|
+
parser_dict = {}
|
1175
1191
|
|
1176
1192
|
async def generate_stream_resp():
|
1177
1193
|
is_firsts = {}
|
@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1184
1200
|
adapted_request, raw_request
|
1185
1201
|
):
|
1186
1202
|
index = content.get("index", 0)
|
1203
|
+
text = content["text"]
|
1187
1204
|
|
1188
1205
|
is_first = is_firsts.get(index, True)
|
1189
1206
|
stream_buffer = stream_buffers.get(index, "")
|
@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1263
1280
|
|
1264
1281
|
text = content["text"]
|
1265
1282
|
delta = text[len(stream_buffer) :]
|
1266
|
-
|
1267
|
-
choice_data = ChatCompletionResponseStreamChoice(
|
1268
|
-
index=index,
|
1269
|
-
delta=DeltaMessage(content=delta),
|
1270
|
-
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
1271
|
-
matched_stop=(
|
1272
|
-
finish_reason["matched"]
|
1273
|
-
if finish_reason and "matched" in finish_reason
|
1274
|
-
else None
|
1275
|
-
),
|
1276
|
-
logprobs=choice_logprobs,
|
1277
|
-
)
|
1278
|
-
chunk = ChatCompletionStreamResponse(
|
1279
|
-
id=content["meta_info"]["id"],
|
1280
|
-
choices=[choice_data],
|
1281
|
-
model=request.model,
|
1282
|
-
)
|
1283
|
+
new_stream_buffer = stream_buffer + delta
|
1283
1284
|
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1285
|
+
if request.tool_choice != "none" and request.tools:
|
1286
|
+
if index not in parser_dict:
|
1287
|
+
parser_dict[index] = FunctionCallParser(
|
1288
|
+
tools=request.tools,
|
1289
|
+
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
1290
|
+
)
|
1291
|
+
parser = parser_dict[index]
|
1292
|
+
|
1293
|
+
# parse_increment => returns (normal_text, calls)
|
1294
|
+
normal_text, calls = parser.parse_stream_chunk(delta)
|
1295
|
+
|
1296
|
+
# 1) if there's normal_text, output it as normal content
|
1297
|
+
if normal_text:
|
1298
|
+
choice_data = ChatCompletionResponseStreamChoice(
|
1299
|
+
index=index,
|
1300
|
+
delta=DeltaMessage(content=normal_text),
|
1301
|
+
finish_reason=(
|
1302
|
+
finish_reason["type"] if finish_reason else ""
|
1303
|
+
),
|
1304
|
+
)
|
1305
|
+
chunk = ChatCompletionStreamResponse(
|
1306
|
+
id=content["meta_info"]["id"],
|
1307
|
+
choices=[choice_data],
|
1308
|
+
model=request.model,
|
1309
|
+
)
|
1310
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
1311
|
+
|
1312
|
+
# 2) if we found calls, we output them as separate chunk(s)
|
1313
|
+
for call_item in calls:
|
1314
|
+
# transform call_item -> FunctionResponse + ToolCall
|
1315
|
+
|
1316
|
+
if (
|
1317
|
+
content["meta_info"]["finish_reason"]
|
1318
|
+
and content["meta_info"]["finish_reason"]["type"]
|
1319
|
+
== "stop"
|
1320
|
+
):
|
1321
|
+
latest_delta_len = 0
|
1322
|
+
if isinstance(call_item.parameters, str):
|
1323
|
+
latest_delta_len = len(call_item.parameters)
|
1324
|
+
|
1325
|
+
expected_call = json.dumps(
|
1326
|
+
parser.multi_format_parser.detectors[0]
|
1327
|
+
.prev_tool_call_arr[index]
|
1328
|
+
.get("arguments", {}),
|
1329
|
+
ensure_ascii=False,
|
1330
|
+
)
|
1331
|
+
actual_call = parser.multi_format_parser.detectors[
|
1332
|
+
0
|
1333
|
+
].streamed_args_for_tool[index]
|
1334
|
+
if latest_delta_len > 0:
|
1335
|
+
actual_call = actual_call[:-latest_delta_len]
|
1336
|
+
remaining_call = expected_call.replace(
|
1337
|
+
actual_call, "", 1
|
1338
|
+
)
|
1339
|
+
call_item.parameters = remaining_call
|
1340
|
+
|
1341
|
+
tool_call = ToolCall(
|
1342
|
+
id=str(call_item.tool_index),
|
1343
|
+
function=FunctionResponse(
|
1344
|
+
name=call_item.name,
|
1345
|
+
arguments=call_item.parameters,
|
1346
|
+
),
|
1347
|
+
)
|
1348
|
+
choice_data = ChatCompletionResponseStreamChoice(
|
1349
|
+
index=index,
|
1350
|
+
delta=DeltaMessage(
|
1351
|
+
role="assistant", tool_calls=[tool_call]
|
1352
|
+
),
|
1353
|
+
finish_reason="tool_call",
|
1354
|
+
)
|
1355
|
+
chunk = ChatCompletionStreamResponse(
|
1356
|
+
id=content["meta_info"]["id"],
|
1357
|
+
choices=[choice_data],
|
1358
|
+
model=request.model,
|
1359
|
+
)
|
1360
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
1287
1361
|
|
1288
|
-
|
1362
|
+
stream_buffers[index] = new_stream_buffer
|
1363
|
+
is_firsts[index] = is_first
|
1364
|
+
|
1365
|
+
else:
|
1366
|
+
# No tool calls => just treat this as normal text
|
1367
|
+
choice_data = ChatCompletionResponseStreamChoice(
|
1368
|
+
index=index,
|
1369
|
+
delta=DeltaMessage(content=delta),
|
1370
|
+
finish_reason=(
|
1371
|
+
finish_reason["type"] if finish_reason else ""
|
1372
|
+
),
|
1373
|
+
matched_stop=(
|
1374
|
+
finish_reason["matched"]
|
1375
|
+
if finish_reason and "matched" in finish_reason
|
1376
|
+
else None
|
1377
|
+
),
|
1378
|
+
logprobs=choice_logprobs,
|
1379
|
+
)
|
1380
|
+
chunk = ChatCompletionStreamResponse(
|
1381
|
+
id=content["meta_info"]["id"],
|
1382
|
+
choices=[choice_data],
|
1383
|
+
model=request.model,
|
1384
|
+
)
|
1385
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
1386
|
+
stream_buffers[index] = new_stream_buffer
|
1387
|
+
is_firsts[index] = is_first
|
1289
1388
|
if request.stream_options and request.stream_options.include_usage:
|
1290
1389
|
total_prompt_tokens = sum(
|
1291
1390
|
tokens
|
@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1333
1432
|
ret = [ret]
|
1334
1433
|
|
1335
1434
|
response = v1_chat_generate_response(
|
1336
|
-
request,
|
1435
|
+
request,
|
1436
|
+
ret,
|
1437
|
+
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
1438
|
+
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
1337
1439
|
)
|
1338
1440
|
|
1339
1441
|
return response
|
@@ -262,7 +262,7 @@ class Function(BaseModel):
|
|
262
262
|
"""Function descriptions."""
|
263
263
|
|
264
264
|
description: Optional[str] = Field(default=None, examples=[None])
|
265
|
-
name: str
|
265
|
+
name: Optional[str] = None
|
266
266
|
parameters: Optional[object] = None
|
267
267
|
|
268
268
|
|
@@ -276,7 +276,7 @@ class Tool(BaseModel):
|
|
276
276
|
class ToolChoiceFuncName(BaseModel):
|
277
277
|
"""The name of tool choice function."""
|
278
278
|
|
279
|
-
name: str
|
279
|
+
name: Optional[str] = None
|
280
280
|
|
281
281
|
|
282
282
|
class ToolChoice(BaseModel):
|
@@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel):
|
|
329
329
|
class FunctionResponse(BaseModel):
|
330
330
|
"""Function response."""
|
331
331
|
|
332
|
-
name: str
|
333
|
-
arguments: str
|
332
|
+
name: Optional[str] = None
|
333
|
+
arguments: Optional[str] = None
|
334
334
|
|
335
335
|
|
336
336
|
class ToolCall(BaseModel):
|
@@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
|
|
367
367
|
class DeltaMessage(BaseModel):
|
368
368
|
role: Optional[str] = None
|
369
369
|
content: Optional[str] = None
|
370
|
+
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
370
371
|
|
371
372
|
|
372
373
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
@@ -3,11 +3,16 @@ from typing import List
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
-
from sglang.srt.utils import
|
6
|
+
from sglang.srt.utils import get_compiler_backend
|
7
7
|
|
8
|
-
|
9
|
-
|
10
|
-
|
8
|
+
|
9
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
10
|
+
def apply_scaling_penalties(logits, scaling_penalties):
|
11
|
+
logits[:] = torch.where(
|
12
|
+
logits > 0,
|
13
|
+
logits / scaling_penalties,
|
14
|
+
logits * scaling_penalties,
|
15
|
+
)
|
11
16
|
|
12
17
|
|
13
18
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
@@ -61,16 +66,8 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
61
66
|
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
62
67
|
|
63
68
|
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
64
|
-
|
65
|
-
|
66
|
-
logits, self.cumulated_repetition_penalties
|
67
|
-
)
|
68
|
-
else:
|
69
|
-
return torch.where(
|
70
|
-
logits > 0,
|
71
|
-
logits / self.cumulated_repetition_penalties,
|
72
|
-
logits * self.cumulated_repetition_penalties,
|
73
|
-
)
|
69
|
+
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
|
70
|
+
return logits
|
74
71
|
|
75
72
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
76
73
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
@@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
|
-
from sglang.srt.utils import is_cuda_available
|
11
|
-
|
12
|
-
is_cuda = is_cuda_available()
|
13
|
-
if is_cuda:
|
14
|
-
from sgl_kernel import sampling_scaling_penalties
|
15
|
-
|
16
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
17
11
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
12
|
+
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
|
13
|
+
apply_scaling_penalties,
|
14
|
+
)
|
18
15
|
|
19
16
|
logger = logging.getLogger(__name__)
|
20
17
|
|
@@ -386,14 +383,7 @@ class SamplingBatchInfo:
|
|
386
383
|
|
387
384
|
# repetition
|
388
385
|
if self.scaling_penalties is not None:
|
389
|
-
|
390
|
-
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
|
391
|
-
else:
|
392
|
-
logits[:] = torch.where(
|
393
|
-
logits > 0,
|
394
|
-
logits / self.scaling_penalties,
|
395
|
-
logits * self.scaling_penalties,
|
396
|
-
)
|
386
|
+
apply_scaling_penalties(logits, self.scaling_penalties)
|
397
387
|
|
398
388
|
# Apply regex vocab_mask
|
399
389
|
if self.vocab_mask is not None:
|
sglang/srt/server.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
-
# Some shortcuts for backward
|
15
|
+
# Some shortcuts for backward compatibility.
|
16
16
|
# They will be removed in new versions.
|
17
17
|
from sglang.srt.entrypoints.engine import Engine
|
18
|
-
from sglang.srt.entrypoints.http_server import launch_server
|
18
|
+
from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server
|
sglang/srt/server_args.py
CHANGED
@@ -75,6 +75,7 @@ class ServerArgs:
|
|
75
75
|
# Other runtime options
|
76
76
|
tp_size: int = 1
|
77
77
|
stream_interval: int = 1
|
78
|
+
stream_output: bool = False
|
78
79
|
random_seed: Optional[int] = None
|
79
80
|
constrained_json_whitespace_pattern: Optional[str] = None
|
80
81
|
watchdog_timeout: float = 300
|
@@ -161,6 +162,7 @@ class ServerArgs:
|
|
161
162
|
|
162
163
|
# Custom logit processor
|
163
164
|
enable_custom_logit_processor: bool = False
|
165
|
+
tool_call_parser: str = None
|
164
166
|
|
165
167
|
def __post_init__(self):
|
166
168
|
# Set missing default values
|
@@ -317,6 +319,7 @@ class ServerArgs:
|
|
317
319
|
"dummy",
|
318
320
|
"gguf",
|
319
321
|
"bitsandbytes",
|
322
|
+
"layered",
|
320
323
|
],
|
321
324
|
help="The format of the model weights to load. "
|
322
325
|
'"auto" will try to load the weights in the safetensors format '
|
@@ -330,7 +333,10 @@ class ServerArgs:
|
|
330
333
|
"which is mainly for profiling."
|
331
334
|
'"gguf" will load the weights in the gguf format. '
|
332
335
|
'"bitsandbytes" will load the weights using bitsandbytes '
|
333
|
-
"quantization."
|
336
|
+
"quantization."
|
337
|
+
'"layered" loads weights layer by layer so that one can quantize a '
|
338
|
+
"layer before loading another to make the peak memory envelope "
|
339
|
+
"smaller.",
|
334
340
|
)
|
335
341
|
parser.add_argument(
|
336
342
|
"--trust-remote-code",
|
@@ -495,6 +501,11 @@ class ServerArgs:
|
|
495
501
|
default=ServerArgs.stream_interval,
|
496
502
|
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
|
497
503
|
)
|
504
|
+
parser.add_argument(
|
505
|
+
"--stream-output",
|
506
|
+
action="store_true",
|
507
|
+
help="Whether to output as a sequence of disjoint segments.",
|
508
|
+
)
|
498
509
|
parser.add_argument(
|
499
510
|
"--random-seed",
|
500
511
|
type=int,
|
@@ -873,6 +884,14 @@ class ServerArgs:
|
|
873
884
|
action="store_true",
|
874
885
|
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
875
886
|
)
|
887
|
+
# Function Calling
|
888
|
+
parser.add_argument(
|
889
|
+
"--tool-call-parser",
|
890
|
+
type=str,
|
891
|
+
choices=["qwen25", "mistral", "llama3"],
|
892
|
+
default=ServerArgs.tool_call_parser,
|
893
|
+
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
|
894
|
+
)
|
876
895
|
|
877
896
|
@classmethod
|
878
897
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
|
|
180
180
|
class EAGLEDraftInput(SpecInfo):
|
181
181
|
def __init__(self):
|
182
182
|
self.prev_mode = ForwardMode.DECODE
|
183
|
-
self.sample_output = None
|
184
183
|
|
185
184
|
self.scores: torch.Tensor = None
|
186
185
|
self.score_list: List[torch.Tensor] = []
|
@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
|
|
190
189
|
self.cache_list: List[torch.Tenor] = []
|
191
190
|
self.iter = 0
|
192
191
|
|
192
|
+
# shape: (b, hidden_size)
|
193
193
|
self.hidden_states: torch.Tensor = None
|
194
|
+
# shape: (b,)
|
194
195
|
self.verified_id: torch.Tensor = None
|
196
|
+
# shape: (b, vocab_size)
|
197
|
+
self.sample_output: torch.Tensor = None
|
198
|
+
|
195
199
|
self.positions: torch.Tensor = None
|
196
200
|
self.accept_length: torch.Tensor = None
|
197
|
-
self.
|
198
|
-
self.unfinished_index: List[int] = None
|
201
|
+
self.accept_length_cpu: List[int] = None
|
199
202
|
|
200
203
|
def load_server_args(self, server_args: ServerArgs):
|
201
204
|
self.topk: int = server_args.speculative_eagle_topk
|
@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
|
|
218
221
|
:pre_len
|
219
222
|
] = req.prefix_indices
|
220
223
|
|
221
|
-
batch.req_to_token_pool.req_to_token[req.req_pool_idx
|
224
|
+
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
222
225
|
out_cache_loc[pt : pt + req.extend_input_len]
|
223
226
|
)
|
224
227
|
|
@@ -228,6 +231,14 @@ class EAGLEDraftInput(SpecInfo):
|
|
228
231
|
assert len(batch.extend_lens) == 1
|
229
232
|
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
230
233
|
|
234
|
+
def filter_batch(
|
235
|
+
self,
|
236
|
+
new_indices: torch.Tensor,
|
237
|
+
):
|
238
|
+
self.sample_output = self.sample_output[: len(new_indices)]
|
239
|
+
self.hidden_states = self.hidden_states[: len(new_indices)]
|
240
|
+
self.verified_id = self.verified_id[: len(new_indices)]
|
241
|
+
|
231
242
|
def prepare_for_decode(self, batch: ScheduleBatch):
|
232
243
|
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
|
233
244
|
top = torch.topk(prob, self.topk, dim=-1)
|
@@ -287,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
|
|
287
298
|
self.cache_list.append(batch.out_cache_loc)
|
288
299
|
self.positions = (
|
289
300
|
batch.seq_lens[:, None]
|
290
|
-
+ torch.
|
301
|
+
+ torch.full(
|
302
|
+
[1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
|
303
|
+
)
|
291
304
|
).flatten()
|
292
305
|
|
293
306
|
bs = len(batch.seq_lens)
|
@@ -304,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
|
|
304
317
|
|
305
318
|
def prepare_extend_after_decode(self, batch: ScheduleBatch):
|
306
319
|
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
|
307
|
-
|
320
|
+
accept_length_cpu = batch.spec_info.accept_length_cpu
|
321
|
+
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
322
|
+
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
323
|
+
seq_lens_cpu = batch.seq_lens.tolist()
|
308
324
|
|
309
325
|
pt = 0
|
310
|
-
seq_lens = batch.seq_lens.tolist()
|
311
|
-
|
312
326
|
i = 0
|
313
|
-
|
314
327
|
for req in batch.reqs:
|
315
328
|
if req.finished():
|
316
329
|
continue
|
317
330
|
# assert seq_len - pre_len == req.extend_input_len
|
318
|
-
input_len =
|
319
|
-
seq_len =
|
331
|
+
input_len = batch.extend_lens[i]
|
332
|
+
seq_len = seq_lens_cpu[i]
|
320
333
|
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
321
334
|
seq_len - input_len : seq_len
|
322
335
|
] = batch.out_cache_loc[pt : pt + input_len]
|
323
336
|
pt += input_len
|
324
337
|
i += 1
|
338
|
+
assert pt == batch.out_cache_loc.shape[0]
|
325
339
|
|
326
340
|
self.positions = torch.empty_like(self.verified_id)
|
327
341
|
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
|
@@ -337,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
|
|
337
351
|
triton.next_power_of_2(self.spec_steps + 1),
|
338
352
|
)
|
339
353
|
|
340
|
-
batch.seq_lens_sum = sum(
|
354
|
+
batch.seq_lens_sum = sum(seq_lens_cpu)
|
341
355
|
batch.input_ids = self.verified_id
|
342
356
|
self.verified_id = new_verified_id
|
343
357
|
|
@@ -565,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
|
|
565
579
|
finished_extend_len = {} # {rid:accept_length + 1}
|
566
580
|
accept_index_cpu = accept_index.tolist()
|
567
581
|
predict_cpu = predict.tolist()
|
582
|
+
has_finished = False
|
583
|
+
|
568
584
|
# iterate every accepted token and check if req has finished after append the token
|
569
585
|
# should be checked BEFORE free kv cache slots
|
570
586
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
@@ -578,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
|
|
578
594
|
finished_extend_len[req.rid] = j + 1
|
579
595
|
req.check_finished()
|
580
596
|
if req.finished():
|
581
|
-
|
597
|
+
has_finished = True
|
582
598
|
# set all tokens after finished token to -1 and break
|
583
599
|
accept_index[i, j + 1 :] = -1
|
584
600
|
break
|
@@ -587,12 +603,12 @@ class EagleVerifyInput(SpecInfo):
|
|
587
603
|
if not req.finished():
|
588
604
|
new_accept_index.extend(new_accept_index_)
|
589
605
|
unfinished_index.append(i)
|
606
|
+
req.spec_verify_ct += 1
|
590
607
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
591
608
|
|
592
609
|
accept_index = accept_index[accept_index != -1]
|
593
610
|
accept_length_cpu = accept_length.tolist()
|
594
611
|
verified_id = predict[accept_index]
|
595
|
-
verified_id_cpu = verified_id.tolist()
|
596
612
|
|
597
613
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
598
614
|
evict_mask[accept_index] = False
|
@@ -614,7 +630,13 @@ class EagleVerifyInput(SpecInfo):
|
|
614
630
|
draft_input.verified_id = predict[new_accept_index]
|
615
631
|
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
616
632
|
draft_input.accept_length = accept_length[unfinished_index]
|
617
|
-
draft_input.
|
633
|
+
draft_input.accept_length_cpu = [
|
634
|
+
accept_length_cpu[i] for i in unfinished_index
|
635
|
+
]
|
636
|
+
if has_finished:
|
637
|
+
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
638
|
+
else:
|
639
|
+
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
618
640
|
|
619
641
|
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
620
642
|
return (
|
@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
13
13
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
14
14
|
from sglang.srt.server_args import ServerArgs
|
15
15
|
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
|
16
|
+
from sglang.srt.utils import rank0_print
|
16
17
|
|
17
18
|
|
18
19
|
class EAGLEWorker(TpModelWorker):
|
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
|
|
50
51
|
|
51
52
|
def forward_draft_decode(self, batch: ScheduleBatch):
|
52
53
|
batch.spec_info.prepare_for_decode(batch)
|
54
|
+
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
53
55
|
model_worker_batch = batch.get_model_worker_batch()
|
54
56
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
55
|
-
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
56
57
|
logits_output = self.model_runner.forward(forward_batch)
|
57
58
|
self.capture_for_decode(logits_output, forward_batch)
|
58
59
|
|
59
60
|
def forward_draft_extend(self, batch: ScheduleBatch):
|
60
61
|
self._set_mem_pool(batch, self.model_runner)
|
61
62
|
batch.spec_info.prepare_for_extend(batch)
|
63
|
+
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
62
64
|
model_worker_batch = batch.get_model_worker_batch()
|
63
65
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
64
|
-
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
65
66
|
logits_output = self.model_runner.forward(forward_batch)
|
66
67
|
self.capture_for_decode(logits_output, forward_batch)
|
67
68
|
self._set_mem_pool(batch, self.target_worker.model_runner)
|
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
|
|
134
135
|
batch.req_to_token_pool = runner.req_to_token_pool
|
135
136
|
|
136
137
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
138
|
+
seq_lens_backup = batch.seq_lens
|
139
|
+
|
137
140
|
self._set_mem_pool(batch, self.model_runner)
|
138
141
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
139
|
-
if batch.spec_info.has_finished:
|
140
|
-
index = batch.spec_info.unfinished_index
|
141
|
-
seq_lens = batch.seq_lens
|
142
|
-
batch.seq_lens = batch.seq_lens[index]
|
143
|
-
|
144
142
|
batch.spec_info.prepare_extend_after_decode(batch)
|
143
|
+
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
145
144
|
model_worker_batch = batch.get_model_worker_batch()
|
146
145
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
147
|
-
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
148
146
|
logits_output = self.model_runner.forward(forward_batch)
|
149
|
-
|
150
|
-
batch.spec_info.hidden_states = logits_output.hidden_states
|
151
147
|
self.capture_for_decode(logits_output, forward_batch)
|
152
|
-
batch.forward_mode = ForwardMode.DECODE
|
153
|
-
if batch.spec_info.has_finished:
|
154
|
-
batch.seq_lens = seq_lens
|
155
148
|
self._set_mem_pool(batch, self.target_worker.model_runner)
|
156
149
|
|
150
|
+
# Restore backup.
|
151
|
+
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
152
|
+
batch.forward_mode = ForwardMode.DECODE
|
153
|
+
batch.seq_lens = seq_lens_backup
|
154
|
+
|
157
155
|
def capture_for_decode(
|
158
156
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
159
157
|
):
|