sglang 0.4.1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +1 -0
- sglang/bench_serving.py +11 -3
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/__init__.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +54 -41
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/logits_processor.py +30 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -30
- sglang/srt/layers/moe/topk.py +14 -0
- sglang/srt/layers/quantization/fp8.py +42 -2
- sglang/srt/layers/quantization/fp8_kernel.py +91 -18
- sglang/srt/layers/quantization/fp8_utils.py +8 -2
- sglang/srt/managers/io_struct.py +29 -8
- sglang/srt/managers/schedule_batch.py +22 -15
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +71 -34
- sglang/srt/managers/session_controller.py +102 -27
- sglang/srt/managers/tokenizer_manager.py +95 -55
- sglang/srt/managers/tp_worker.py +7 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
- sglang/srt/model_executor/forward_batch_info.py +42 -3
- sglang/srt/model_executor/model_runner.py +4 -6
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/llama.py +13 -2
- sglang/srt/models/llama_eagle.py +132 -0
- sglang/srt/openai_api/adapter.py +79 -2
- sglang/srt/openai_api/protocol.py +50 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +45 -39
- sglang/srt/server_args.py +17 -30
- sglang/srt/speculative/spec_info.py +19 -0
- sglang/srt/utils.py +62 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +5 -5
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +41 -39
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,132 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
# Adapted from
|
17
|
+
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
|
18
|
+
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
|
19
|
+
|
20
|
+
from typing import Iterable, Optional, Tuple
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
from transformers import LlamaConfig
|
25
|
+
|
26
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
29
|
+
ParallelLMHead,
|
30
|
+
VocabParallelEmbedding,
|
31
|
+
)
|
32
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
33
|
+
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
34
|
+
|
35
|
+
|
36
|
+
class LlamaDecoderLayer(LlamaDecoderLayer):
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
config: LlamaConfig,
|
40
|
+
layer_id: int = 0,
|
41
|
+
quant_config: Optional[QuantizationConfig] = None,
|
42
|
+
prefix: str = "",
|
43
|
+
) -> None:
|
44
|
+
super().__init__(config, layer_id, quant_config, prefix)
|
45
|
+
|
46
|
+
# Skip the input_layernorm
|
47
|
+
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
|
48
|
+
if layer_id == 0:
|
49
|
+
del self.input_layernorm
|
50
|
+
setattr(self, "input_layernorm", lambda x: x)
|
51
|
+
|
52
|
+
|
53
|
+
class LlamaModel(nn.Module):
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
config: LlamaConfig,
|
57
|
+
quant_config: Optional[QuantizationConfig] = None,
|
58
|
+
) -> None:
|
59
|
+
super().__init__()
|
60
|
+
self.config = config
|
61
|
+
self.vocab_size = config.vocab_size
|
62
|
+
self.embed_tokens = VocabParallelEmbedding(
|
63
|
+
config.vocab_size,
|
64
|
+
config.hidden_size,
|
65
|
+
)
|
66
|
+
self.layers = nn.ModuleList(
|
67
|
+
[
|
68
|
+
LlamaDecoderLayer(
|
69
|
+
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
70
|
+
)
|
71
|
+
for i in range(config.num_hidden_layers)
|
72
|
+
]
|
73
|
+
)
|
74
|
+
self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)
|
75
|
+
|
76
|
+
def forward(
|
77
|
+
self,
|
78
|
+
input_ids: torch.Tensor,
|
79
|
+
positions: torch.Tensor,
|
80
|
+
forward_batch: ForwardBatch,
|
81
|
+
input_embeds: torch.Tensor = None,
|
82
|
+
) -> torch.Tensor:
|
83
|
+
if input_embeds is None:
|
84
|
+
hidden_states = self.embed_tokens(input_ids)
|
85
|
+
else:
|
86
|
+
hidden_states = input_embeds
|
87
|
+
|
88
|
+
hidden_states = self.fc(
|
89
|
+
torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
|
90
|
+
)
|
91
|
+
|
92
|
+
residual = None
|
93
|
+
for i in range(len(self.layers)):
|
94
|
+
layer = self.layers[i]
|
95
|
+
hidden_states, residual = layer(
|
96
|
+
positions,
|
97
|
+
hidden_states,
|
98
|
+
forward_batch,
|
99
|
+
residual,
|
100
|
+
)
|
101
|
+
return hidden_states + residual
|
102
|
+
|
103
|
+
|
104
|
+
class LlamaForCausalLMEagle(LlamaForCausalLM):
|
105
|
+
def __init__(
|
106
|
+
self,
|
107
|
+
config: LlamaConfig,
|
108
|
+
quant_config: Optional[QuantizationConfig] = None,
|
109
|
+
cache_config=None,
|
110
|
+
) -> None:
|
111
|
+
nn.Module.__init__(self)
|
112
|
+
self.config = config
|
113
|
+
self.quant_config = quant_config
|
114
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
115
|
+
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
116
|
+
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
117
|
+
if self.config.tie_word_embeddings:
|
118
|
+
self.lm_head = self.model.embed_tokens
|
119
|
+
else:
|
120
|
+
self.lm_head = ParallelLMHead(
|
121
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
122
|
+
)
|
123
|
+
self.logits_processor = LogitsProcessor(config)
|
124
|
+
|
125
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
126
|
+
for name, loaded_weight in weights:
|
127
|
+
if "lm_head" not in name:
|
128
|
+
name = "model." + name
|
129
|
+
super().load_weights([(name, loaded_weight)])
|
130
|
+
|
131
|
+
|
132
|
+
EntryClass = [LlamaForCausalLMEagle]
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -65,10 +65,13 @@ from sglang.srt.openai_api.protocol import (
|
|
65
65
|
FileDeleteResponse,
|
66
66
|
FileRequest,
|
67
67
|
FileResponse,
|
68
|
+
FunctionResponse,
|
68
69
|
LogProbs,
|
70
|
+
ToolCall,
|
69
71
|
TopLogprob,
|
70
72
|
UsageInfo,
|
71
73
|
)
|
74
|
+
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
|
72
75
|
from sglang.utils import get_exception_traceback
|
73
76
|
|
74
77
|
logger = logging.getLogger(__name__)
|
@@ -517,6 +520,7 @@ def v1_generate_request(
|
|
517
520
|
"repetition_penalty": request.repetition_penalty,
|
518
521
|
"regex": request.regex,
|
519
522
|
"json_schema": request.json_schema,
|
523
|
+
"ebnf": request.ebnf,
|
520
524
|
"n": request.n,
|
521
525
|
"no_stop_trim": request.no_stop_trim,
|
522
526
|
"ignore_eos": request.ignore_eos,
|
@@ -692,6 +696,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
692
696
|
|
693
697
|
async def v1_completions(tokenizer_manager, raw_request: Request):
|
694
698
|
request_json = await raw_request.json()
|
699
|
+
if "extra_body" in request_json:
|
700
|
+
extra = request_json["extra_body"]
|
701
|
+
if "ebnf" in extra:
|
702
|
+
request_json["ebnf"] = extra["ebnf"]
|
703
|
+
if "regex" in extra:
|
704
|
+
request_json["regex"] = extra["regex"]
|
705
|
+
# remove extra_body to avoid pydantic conflict
|
706
|
+
del request_json["extra_body"]
|
695
707
|
all_requests = [CompletionRequest(**request_json)]
|
696
708
|
adapted_request, request = v1_generate_request(all_requests)
|
697
709
|
|
@@ -870,6 +882,21 @@ def v1_chat_generate_request(
|
|
870
882
|
# None skips any image processing in GenerateReqInput.
|
871
883
|
if not isinstance(request.messages, str):
|
872
884
|
# Apply chat template and its stop strings.
|
885
|
+
tools = None
|
886
|
+
if request.tools and request.tool_choice != "none":
|
887
|
+
request.skip_special_tokens = False
|
888
|
+
if request.stream:
|
889
|
+
logger.warning("Streaming is not supported with tools.")
|
890
|
+
request.stream = False
|
891
|
+
if not isinstance(request.tool_choice, str):
|
892
|
+
tools = [
|
893
|
+
item.function.model_dump()
|
894
|
+
for item in request.tools
|
895
|
+
if item.function.name == request.tool_choice.function.name
|
896
|
+
]
|
897
|
+
else:
|
898
|
+
tools = [item.function.model_dump() for item in request.tools]
|
899
|
+
|
873
900
|
if chat_template_name is None:
|
874
901
|
openai_compatible_messages = []
|
875
902
|
for message in request.messages:
|
@@ -893,6 +920,7 @@ def v1_chat_generate_request(
|
|
893
920
|
openai_compatible_messages,
|
894
921
|
tokenize=True,
|
895
922
|
add_generation_prompt=True,
|
923
|
+
tools=tools,
|
896
924
|
)
|
897
925
|
if assistant_prefix:
|
898
926
|
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
@@ -936,6 +964,7 @@ def v1_chat_generate_request(
|
|
936
964
|
"frequency_penalty": request.frequency_penalty,
|
937
965
|
"repetition_penalty": request.repetition_penalty,
|
938
966
|
"regex": request.regex,
|
967
|
+
"ebnf": request.ebnf,
|
939
968
|
"n": request.n,
|
940
969
|
"no_stop_trim": request.no_stop_trim,
|
941
970
|
"ignore_eos": request.ignore_eos,
|
@@ -1031,11 +1060,46 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
1031
1060
|
|
1032
1061
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
1033
1062
|
|
1063
|
+
tool_calls = None
|
1064
|
+
text = ret_item["text"]
|
1065
|
+
|
1066
|
+
if isinstance(request, list):
|
1067
|
+
tool_choice = request[idx].tool_choice
|
1068
|
+
tools = request[idx].tools
|
1069
|
+
else:
|
1070
|
+
tool_choice = request.tool_choice
|
1071
|
+
tools = request.tools
|
1072
|
+
|
1073
|
+
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
|
1074
|
+
if finish_reason == "stop":
|
1075
|
+
finish_reason = "tool_calls"
|
1076
|
+
try:
|
1077
|
+
text, call_info_list = parse_tool_response(text, tools) # noqa
|
1078
|
+
tool_calls = [
|
1079
|
+
ToolCall(
|
1080
|
+
id=str(call_info[0]),
|
1081
|
+
function=FunctionResponse(
|
1082
|
+
name=call_info[1], arguments=call_info[2]
|
1083
|
+
),
|
1084
|
+
)
|
1085
|
+
for call_info in call_info_list
|
1086
|
+
]
|
1087
|
+
except Exception as e:
|
1088
|
+
logger.error(f"Exception: {e}")
|
1089
|
+
return create_error_response(
|
1090
|
+
HTTPStatus.BAD_REQUEST,
|
1091
|
+
"Failed to parse fc related info to json format!",
|
1092
|
+
)
|
1093
|
+
|
1034
1094
|
if to_file:
|
1035
1095
|
# to make the choice data json serializable
|
1036
1096
|
choice_data = {
|
1037
1097
|
"index": 0,
|
1038
|
-
"message": {
|
1098
|
+
"message": {
|
1099
|
+
"role": "assistant",
|
1100
|
+
"content": ret_item["text"] if tool_calls is None else None,
|
1101
|
+
"tool_calls": tool_calls,
|
1102
|
+
},
|
1039
1103
|
"logprobs": choice_logprobs,
|
1040
1104
|
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
1041
1105
|
"matched_stop": (
|
@@ -1047,7 +1111,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
1047
1111
|
else:
|
1048
1112
|
choice_data = ChatCompletionResponseChoice(
|
1049
1113
|
index=idx,
|
1050
|
-
message=ChatMessage(
|
1114
|
+
message=ChatMessage(
|
1115
|
+
role="assistant",
|
1116
|
+
content=ret_item["text"] if tool_calls is None else None,
|
1117
|
+
tool_calls=tool_calls,
|
1118
|
+
),
|
1051
1119
|
logprobs=choice_logprobs,
|
1052
1120
|
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
1053
1121
|
matched_stop=(
|
@@ -1108,6 +1176,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
1108
1176
|
|
1109
1177
|
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
1110
1178
|
request_json = await raw_request.json()
|
1179
|
+
if "extra_body" in request_json:
|
1180
|
+
extra = request_json["extra_body"]
|
1181
|
+
# For example, if 'ebnf' is given:
|
1182
|
+
if "ebnf" in extra:
|
1183
|
+
request_json["ebnf"] = extra["ebnf"]
|
1184
|
+
if "regex" in extra:
|
1185
|
+
request_json["regex"] = extra["regex"]
|
1186
|
+
# remove extra_body to avoid pydantic conflict
|
1187
|
+
del request_json["extra_body"]
|
1111
1188
|
all_requests = [ChatCompletionRequest(**request_json)]
|
1112
1189
|
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
1113
1190
|
|
@@ -179,6 +179,7 @@ class CompletionRequest(BaseModel):
|
|
179
179
|
ignore_eos: bool = False
|
180
180
|
skip_special_tokens: bool = True
|
181
181
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
182
|
+
ebnf: Optional[str] = None
|
182
183
|
|
183
184
|
|
184
185
|
class CompletionResponseChoice(BaseModel):
|
@@ -256,6 +257,34 @@ class ResponseFormat(BaseModel):
|
|
256
257
|
json_schema: Optional[JsonSchemaResponseFormat] = None
|
257
258
|
|
258
259
|
|
260
|
+
class Function(BaseModel):
|
261
|
+
"""Function descriptions."""
|
262
|
+
|
263
|
+
description: Optional[str] = Field(default=None, examples=[None])
|
264
|
+
name: str
|
265
|
+
parameters: Optional[object] = None
|
266
|
+
|
267
|
+
|
268
|
+
class Tool(BaseModel):
|
269
|
+
"""Function wrapper."""
|
270
|
+
|
271
|
+
type: str = Field(default="function", examples=["function"])
|
272
|
+
function: Function
|
273
|
+
|
274
|
+
|
275
|
+
class ToolChoiceFuncName(BaseModel):
|
276
|
+
"""The name of tool choice function."""
|
277
|
+
|
278
|
+
name: str
|
279
|
+
|
280
|
+
|
281
|
+
class ToolChoice(BaseModel):
|
282
|
+
"""The tool choice definition."""
|
283
|
+
|
284
|
+
function: ToolChoiceFuncName
|
285
|
+
type: Literal["function"] = Field(default="function", examples=["function"])
|
286
|
+
|
287
|
+
|
259
288
|
class ChatCompletionRequest(BaseModel):
|
260
289
|
# Ordered by official OpenAI API documentation
|
261
290
|
# https://platform.openai.com/docs/api-reference/chat/create
|
@@ -276,6 +305,10 @@ class ChatCompletionRequest(BaseModel):
|
|
276
305
|
temperature: float = 0.7
|
277
306
|
top_p: float = 1.0
|
278
307
|
user: Optional[str] = None
|
308
|
+
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
|
309
|
+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
310
|
+
default="auto", examples=["none"]
|
311
|
+
) # noqa
|
279
312
|
|
280
313
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
281
314
|
top_k: int = -1
|
@@ -288,11 +321,28 @@ class ChatCompletionRequest(BaseModel):
|
|
288
321
|
ignore_eos: bool = False
|
289
322
|
skip_special_tokens: bool = True
|
290
323
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
324
|
+
ebnf: Optional[str] = None
|
325
|
+
|
326
|
+
|
327
|
+
class FunctionResponse(BaseModel):
|
328
|
+
"""Function response."""
|
329
|
+
|
330
|
+
name: str
|
331
|
+
arguments: str
|
332
|
+
|
333
|
+
|
334
|
+
class ToolCall(BaseModel):
|
335
|
+
"""Tool call response."""
|
336
|
+
|
337
|
+
id: str
|
338
|
+
type: Literal["function"] = "function"
|
339
|
+
function: FunctionResponse
|
291
340
|
|
292
341
|
|
293
342
|
class ChatMessage(BaseModel):
|
294
343
|
role: Optional[str] = None
|
295
344
|
content: Optional[str] = None
|
345
|
+
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
296
346
|
|
297
347
|
|
298
348
|
class ChatCompletionResponseChoice(BaseModel):
|
@@ -36,6 +36,7 @@ class SamplingParams:
|
|
36
36
|
regex: Optional[str] = None,
|
37
37
|
n: int = 1,
|
38
38
|
json_schema: Optional[str] = None,
|
39
|
+
ebnf: Optional[str] = None,
|
39
40
|
no_stop_trim: bool = False,
|
40
41
|
ignore_eos: bool = False,
|
41
42
|
skip_special_tokens: bool = True,
|
@@ -60,6 +61,7 @@ class SamplingParams:
|
|
60
61
|
self.regex = regex
|
61
62
|
self.n = n
|
62
63
|
self.json_schema = json_schema
|
64
|
+
self.ebnf = ebnf
|
63
65
|
self.no_stop_trim = no_stop_trim
|
64
66
|
|
65
67
|
# Process some special cases
|
@@ -111,8 +113,13 @@ class SamplingParams:
|
|
111
113
|
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
112
114
|
f"{self.min_new_tokens}."
|
113
115
|
)
|
114
|
-
|
115
|
-
|
116
|
+
grammars = [
|
117
|
+
self.json_schema,
|
118
|
+
self.regex,
|
119
|
+
self.ebnf,
|
120
|
+
] # since mutually exclusive, only one can be set
|
121
|
+
if sum(x is not None for x in grammars) > 1:
|
122
|
+
raise ValueError("Only one of regex, json_schema, or ebnf can be set.")
|
116
123
|
|
117
124
|
def normalize(self, tokenizer):
|
118
125
|
# Process stop strings
|
sglang/srt/server.py
CHANGED
@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
|
|
57
57
|
OpenSessionReqInput,
|
58
58
|
UpdateWeightFromDiskReqInput,
|
59
59
|
UpdateWeightsFromDistributedReqInput,
|
60
|
+
UpdateWeightsFromTensorReqInput,
|
60
61
|
)
|
61
62
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
62
63
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
@@ -109,6 +110,7 @@ app.add_middleware(
|
|
109
110
|
tokenizer_manager: TokenizerManager = None
|
110
111
|
scheduler_info: Dict = None
|
111
112
|
|
113
|
+
|
112
114
|
##### Native API endpoints #####
|
113
115
|
|
114
116
|
|
@@ -245,16 +247,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
|
245
247
|
try:
|
246
248
|
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
247
249
|
if ret is None:
|
248
|
-
return
|
249
|
-
{"error": {"message": "Get parameter by name failed"}},
|
250
|
-
status_code=HTTPStatus.BAD_REQUEST,
|
251
|
-
)
|
250
|
+
return _create_error_response("Get parameter by name failed")
|
252
251
|
else:
|
253
252
|
return ORJSONResponse(ret, status_code=200)
|
254
253
|
except Exception as e:
|
255
|
-
return
|
256
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
257
|
-
)
|
254
|
+
return _create_error_response(e)
|
258
255
|
|
259
256
|
|
260
257
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
@@ -262,11 +259,13 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
|
|
262
259
|
"""Open a session, and return its unique session id."""
|
263
260
|
try:
|
264
261
|
session_id = await tokenizer_manager.open_session(obj, request)
|
262
|
+
if session_id is None:
|
263
|
+
raise Exception(
|
264
|
+
"Failed to open the session. Check if a session with the same id is still open."
|
265
|
+
)
|
265
266
|
return session_id
|
266
267
|
except Exception as e:
|
267
|
-
return
|
268
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
269
|
-
)
|
268
|
+
return _create_error_response(e)
|
270
269
|
|
271
270
|
|
272
271
|
@app.api_route("/close_session", methods=["GET", "POST"])
|
@@ -276,9 +275,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|
276
275
|
await tokenizer_manager.close_session(obj, request)
|
277
276
|
return Response(status_code=200)
|
278
277
|
except Exception as e:
|
279
|
-
return
|
280
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
281
|
-
)
|
278
|
+
return _create_error_response(e)
|
282
279
|
|
283
280
|
|
284
281
|
# fastapi implicitly converts json in the request to obj (dataclass)
|
@@ -312,9 +309,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|
312
309
|
return ret
|
313
310
|
except ValueError as e:
|
314
311
|
logger.error(f"Error: {e}")
|
315
|
-
return
|
316
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
317
|
-
)
|
312
|
+
return _create_error_response(e)
|
318
313
|
|
319
314
|
|
320
315
|
@app.api_route("/encode", methods=["POST", "PUT"])
|
@@ -325,9 +320,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|
325
320
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
326
321
|
return ret
|
327
322
|
except ValueError as e:
|
328
|
-
return
|
329
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
330
|
-
)
|
323
|
+
return _create_error_response(e)
|
331
324
|
|
332
325
|
|
333
326
|
@app.api_route("/classify", methods=["POST", "PUT"])
|
@@ -338,9 +331,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
338
331
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
339
332
|
return ret
|
340
333
|
except ValueError as e:
|
341
|
-
return
|
342
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
343
|
-
)
|
334
|
+
return _create_error_response(e)
|
344
335
|
|
345
336
|
|
346
337
|
##### OpenAI-compatible API endpoints #####
|
@@ -416,6 +407,12 @@ async def retrieve_file_content(file_id: str):
|
|
416
407
|
return await v1_retrieve_file_content(file_id)
|
417
408
|
|
418
409
|
|
410
|
+
def _create_error_response(e):
|
411
|
+
return ORJSONResponse(
|
412
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
413
|
+
)
|
414
|
+
|
415
|
+
|
419
416
|
def launch_engine(
|
420
417
|
server_args: ServerArgs,
|
421
418
|
):
|
@@ -493,7 +490,16 @@ def launch_engine(
|
|
493
490
|
# Wait for model to finish loading
|
494
491
|
scheduler_infos = []
|
495
492
|
for i in range(len(scheduler_pipe_readers)):
|
496
|
-
|
493
|
+
try:
|
494
|
+
data = scheduler_pipe_readers[i].recv()
|
495
|
+
except EOFError as e:
|
496
|
+
logger.exception(e)
|
497
|
+
logger.error(
|
498
|
+
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
|
499
|
+
)
|
500
|
+
scheduler_procs[i].join()
|
501
|
+
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
|
502
|
+
raise
|
497
503
|
|
498
504
|
if data["status"] != "ready":
|
499
505
|
raise RuntimeError(
|
@@ -501,7 +507,7 @@ def launch_engine(
|
|
501
507
|
)
|
502
508
|
scheduler_infos.append(data)
|
503
509
|
|
504
|
-
# Assume all schedulers have same
|
510
|
+
# Assume all schedulers have same scheduler_info
|
505
511
|
scheduler_info = scheduler_infos[0]
|
506
512
|
|
507
513
|
|
@@ -849,12 +855,10 @@ class Engine:
|
|
849
855
|
group_name=group_name,
|
850
856
|
backend=backend,
|
851
857
|
)
|
852
|
-
|
853
|
-
async def _init_group():
|
854
|
-
return await tokenizer_manager.init_weights_update_group(obj, None)
|
855
|
-
|
856
858
|
loop = asyncio.get_event_loop()
|
857
|
-
return loop.run_until_complete(
|
859
|
+
return loop.run_until_complete(
|
860
|
+
tokenizer_manager.init_weights_update_group(obj, None)
|
861
|
+
)
|
858
862
|
|
859
863
|
def update_weights_from_distributed(self, name, dtype, shape):
|
860
864
|
"""Update weights from distributed source."""
|
@@ -863,22 +867,24 @@ class Engine:
|
|
863
867
|
dtype=dtype,
|
864
868
|
shape=shape,
|
865
869
|
)
|
870
|
+
loop = asyncio.get_event_loop()
|
871
|
+
return loop.run_until_complete(
|
872
|
+
tokenizer_manager.update_weights_from_distributed(obj, None)
|
873
|
+
)
|
866
874
|
|
867
|
-
|
868
|
-
|
869
|
-
|
875
|
+
def update_weights_from_tensor(self, name, tensor):
|
876
|
+
"""Update weights from distributed source."""
|
877
|
+
obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
|
870
878
|
loop = asyncio.get_event_loop()
|
871
|
-
return loop.run_until_complete(
|
879
|
+
return loop.run_until_complete(
|
880
|
+
tokenizer_manager.update_weights_from_tensor(obj, None)
|
881
|
+
)
|
872
882
|
|
873
883
|
def get_weights_by_name(self, name, truncate_size=100):
|
874
884
|
"""Get weights by parameter name."""
|
875
885
|
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
876
|
-
|
877
|
-
async def _get_weights():
|
878
|
-
return await tokenizer_manager.get_weights_by_name(obj, None)
|
879
|
-
|
880
886
|
loop = asyncio.get_event_loop()
|
881
|
-
return loop.run_until_complete(
|
887
|
+
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
|
882
888
|
|
883
889
|
|
884
890
|
class Runtime:
|
@@ -888,7 +894,7 @@ class Runtime:
|
|
888
894
|
using the commond line interface.
|
889
895
|
|
890
896
|
It is mainly used for the frontend language.
|
891
|
-
You should use the Engine class if you want to do normal offline processing.
|
897
|
+
You should use the Engine class above if you want to do normal offline processing.
|
892
898
|
"""
|
893
899
|
|
894
900
|
def __init__(
|
sglang/srt/server_args.py
CHANGED
@@ -55,7 +55,7 @@ class ServerArgs:
|
|
55
55
|
is_embedding: bool = False
|
56
56
|
revision: Optional[str] = None
|
57
57
|
|
58
|
-
# Port
|
58
|
+
# Port for the HTTP server
|
59
59
|
host: str = "127.0.0.1"
|
60
60
|
port: int = 30000
|
61
61
|
|
@@ -68,6 +68,7 @@ class ServerArgs:
|
|
68
68
|
schedule_policy: str = "lpm"
|
69
69
|
schedule_conservativeness: float = 1.0
|
70
70
|
cpu_offload_gb: int = 0
|
71
|
+
prefill_only_one_req: bool = False
|
71
72
|
|
72
73
|
# Other runtime options
|
73
74
|
tp_size: int = 1
|
@@ -94,6 +95,7 @@ class ServerArgs:
|
|
94
95
|
# Data parallelism
|
95
96
|
dp_size: int = 1
|
96
97
|
load_balance_method: str = "round_robin"
|
98
|
+
|
97
99
|
# Expert parallelism
|
98
100
|
ep_size: int = 1
|
99
101
|
|
@@ -217,6 +219,13 @@ class ServerArgs:
|
|
217
219
|
)
|
218
220
|
self.disable_cuda_graph = True
|
219
221
|
|
222
|
+
# Expert parallelism
|
223
|
+
if self.enable_ep_moe:
|
224
|
+
self.ep_size = self.tp_size
|
225
|
+
logger.info(
|
226
|
+
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
227
|
+
)
|
228
|
+
|
220
229
|
# Others
|
221
230
|
if self.enable_dp_attention:
|
222
231
|
self.dp_size = self.tp_size
|
@@ -229,12 +238,6 @@ class ServerArgs:
|
|
229
238
|
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
230
239
|
"Overlap scheduler is disabled."
|
231
240
|
)
|
232
|
-
# Expert parallelism
|
233
|
-
if self.enable_ep_moe:
|
234
|
-
self.ep_size = self.tp_size
|
235
|
-
logger.info(
|
236
|
-
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
237
|
-
)
|
238
241
|
|
239
242
|
# GGUF
|
240
243
|
if (
|
@@ -430,13 +433,18 @@ class ServerArgs:
|
|
430
433
|
default=ServerArgs.schedule_conservativeness,
|
431
434
|
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
432
435
|
)
|
433
|
-
|
434
436
|
parser.add_argument(
|
435
437
|
"--cpu-offload-gb",
|
436
438
|
type=int,
|
437
439
|
default=ServerArgs.cpu_offload_gb,
|
438
440
|
help="How many GBs of RAM to reserve for CPU offloading",
|
439
441
|
)
|
442
|
+
parser.add_argument(
|
443
|
+
"--prefill-only-one-req",
|
444
|
+
type=bool,
|
445
|
+
help="If true, we only prefill one request at one prefill batch",
|
446
|
+
default=ServerArgs.prefill_only_one_req,
|
447
|
+
)
|
440
448
|
|
441
449
|
# Other runtime options
|
442
450
|
parser.add_argument(
|
@@ -555,6 +563,7 @@ class ServerArgs:
|
|
555
563
|
"shortest_queue",
|
556
564
|
],
|
557
565
|
)
|
566
|
+
|
558
567
|
# Expert parallelism
|
559
568
|
parser.add_argument(
|
560
569
|
"--expert-parallel-size",
|
@@ -777,28 +786,6 @@ class ServerArgs:
|
|
777
786
|
help="Delete the model checkpoint after loading the model.",
|
778
787
|
)
|
779
788
|
|
780
|
-
# Deprecated arguments
|
781
|
-
parser.add_argument(
|
782
|
-
"--enable-overlap-schedule",
|
783
|
-
action=DeprecatedAction,
|
784
|
-
help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
|
785
|
-
)
|
786
|
-
parser.add_argument(
|
787
|
-
"--disable-flashinfer",
|
788
|
-
action=DeprecatedAction,
|
789
|
-
help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
|
790
|
-
)
|
791
|
-
parser.add_argument(
|
792
|
-
"--disable-flashinfer-sampling",
|
793
|
-
action=DeprecatedAction,
|
794
|
-
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
|
795
|
-
)
|
796
|
-
parser.add_argument(
|
797
|
-
"--disable-disk-cache",
|
798
|
-
action=DeprecatedAction,
|
799
|
-
help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.",
|
800
|
-
)
|
801
|
-
|
802
789
|
@classmethod
|
803
790
|
def from_cli_args(cls, args: argparse.Namespace):
|
804
791
|
args.tp_size = args.tensor_parallel_size
|