sglang 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -19
- sglang/bench_serving.py +758 -0
- sglang/check_env.py +171 -0
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +31 -5
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +14 -12
- sglang/srt/managers/controller/infer_batch.py +47 -49
- sglang/srt/managers/controller/manager_multi.py +107 -100
- sglang/srt/managers/controller/manager_single.py +76 -96
- sglang/srt/managers/controller/model_runner.py +35 -23
- sglang/srt/managers/controller/tp_worker.py +127 -138
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +19 -6
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +65 -16
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +113 -84
- sglang/srt/server_args.py +23 -15
- sglang/srt/utils.py +16 -117
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -167
- sglang-0.1.22.dist-info/RECORD +103 -0
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
- sglang-0.1.21.dist-info/RECORD +0 -82
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
|
|
1
|
+
"""Pydantic models for OpenAI API protocol"""
|
2
|
+
|
3
|
+
import time
|
4
|
+
from typing import Dict, List, Optional, Union
|
5
|
+
|
6
|
+
from pydantic import BaseModel, Field
|
7
|
+
from typing_extensions import Literal
|
8
|
+
|
9
|
+
|
10
|
+
class ModelCard(BaseModel):
|
11
|
+
"""Model cards."""
|
12
|
+
|
13
|
+
id: str
|
14
|
+
object: str = "model"
|
15
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
16
|
+
owned_by: str = "sglang"
|
17
|
+
root: Optional[str] = None
|
18
|
+
|
19
|
+
|
20
|
+
class ModelList(BaseModel):
|
21
|
+
"""Model list consists of model cards."""
|
22
|
+
|
23
|
+
object: str = "list"
|
24
|
+
data: List[ModelCard] = []
|
25
|
+
|
26
|
+
|
27
|
+
class ErrorResponse(BaseModel):
|
28
|
+
object: str = "error"
|
29
|
+
message: str
|
30
|
+
type: str
|
31
|
+
param: Optional[str] = None
|
32
|
+
code: int
|
33
|
+
|
34
|
+
|
35
|
+
class LogProbs(BaseModel):
|
36
|
+
text_offset: List[int] = Field(default_factory=list)
|
37
|
+
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
38
|
+
tokens: List[str] = Field(default_factory=list)
|
39
|
+
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
40
|
+
|
41
|
+
|
42
|
+
class UsageInfo(BaseModel):
|
43
|
+
prompt_tokens: int = 0
|
44
|
+
total_tokens: int = 0
|
45
|
+
completion_tokens: Optional[int] = 0
|
46
|
+
|
47
|
+
|
48
|
+
class CompletionRequest(BaseModel):
|
49
|
+
# Ordered by official OpenAI API documentation
|
50
|
+
# https://platform.openai.com/docs/api-reference/completions/create
|
51
|
+
model: str
|
52
|
+
prompt: Union[List[int], List[List[int]], str, List[str]]
|
53
|
+
best_of: Optional[int] = None
|
54
|
+
echo: Optional[bool] = False
|
55
|
+
frequency_penalty: Optional[float] = 0.0
|
56
|
+
logit_bias: Optional[Dict[str, float]] = None
|
57
|
+
logprobs: Optional[int] = None
|
58
|
+
max_tokens: Optional[int] = 16
|
59
|
+
n: int = 1
|
60
|
+
presence_penalty: Optional[float] = 0.0
|
61
|
+
seed: Optional[int] = None
|
62
|
+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
63
|
+
stream: Optional[bool] = False
|
64
|
+
suffix: Optional[str] = None
|
65
|
+
temperature: Optional[float] = 1.0
|
66
|
+
top_p: Optional[float] = 1.0
|
67
|
+
user: Optional[str] = None
|
68
|
+
|
69
|
+
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
70
|
+
regex: Optional[str] = None
|
71
|
+
ignore_eos: Optional[bool] = False
|
72
|
+
|
73
|
+
|
74
|
+
class CompletionResponseChoice(BaseModel):
|
75
|
+
index: int
|
76
|
+
text: str
|
77
|
+
logprobs: Optional[LogProbs] = None
|
78
|
+
finish_reason: Optional[str] = None
|
79
|
+
|
80
|
+
|
81
|
+
class CompletionResponse(BaseModel):
|
82
|
+
id: str
|
83
|
+
object: str = "text_completion"
|
84
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
85
|
+
model: str
|
86
|
+
choices: List[CompletionResponseChoice]
|
87
|
+
usage: UsageInfo
|
88
|
+
|
89
|
+
|
90
|
+
class CompletionResponseStreamChoice(BaseModel):
|
91
|
+
index: int
|
92
|
+
text: str
|
93
|
+
logprobs: Optional[LogProbs] = None
|
94
|
+
finish_reason: Optional[str] = None
|
95
|
+
|
96
|
+
|
97
|
+
class CompletionStreamResponse(BaseModel):
|
98
|
+
id: str
|
99
|
+
object: str = "text_completion"
|
100
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
101
|
+
model: str
|
102
|
+
choices: List[CompletionResponseStreamChoice]
|
103
|
+
usage: UsageInfo
|
104
|
+
|
105
|
+
|
106
|
+
class ChatCompletionMessageGenericParam(BaseModel):
|
107
|
+
role: Literal["system", "assistant"]
|
108
|
+
content: str
|
109
|
+
|
110
|
+
|
111
|
+
class ChatCompletionMessageContentTextPart(BaseModel):
|
112
|
+
type: Literal["text"]
|
113
|
+
text: str
|
114
|
+
|
115
|
+
|
116
|
+
class ChatCompletionMessageContentImageURL(BaseModel):
|
117
|
+
url: str
|
118
|
+
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
119
|
+
|
120
|
+
|
121
|
+
class ChatCompletionMessageContentImagePart(BaseModel):
|
122
|
+
type: Literal["image_url"]
|
123
|
+
image_url: ChatCompletionMessageContentImageURL
|
124
|
+
|
125
|
+
|
126
|
+
ChatCompletionMessageContentPart = Union[
|
127
|
+
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
|
128
|
+
]
|
129
|
+
|
130
|
+
|
131
|
+
class ChatCompletionMessageUserParam(BaseModel):
|
132
|
+
role: Literal["user"]
|
133
|
+
content: Union[str, List[ChatCompletionMessageContentPart]]
|
134
|
+
|
135
|
+
|
136
|
+
ChatCompletionMessageParam = Union[
|
137
|
+
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
|
138
|
+
]
|
139
|
+
|
140
|
+
|
141
|
+
class ResponseFormat(BaseModel):
|
142
|
+
# type must be "json_object" or "text"
|
143
|
+
type: Literal["text", "json_object"]
|
144
|
+
|
145
|
+
|
146
|
+
class ChatCompletionRequest(BaseModel):
|
147
|
+
# Ordered by official OpenAI API documentation
|
148
|
+
# https://platform.openai.com/docs/api-reference/chat/create
|
149
|
+
messages: List[ChatCompletionMessageParam]
|
150
|
+
model: str
|
151
|
+
frequency_penalty: Optional[float] = 0.0
|
152
|
+
logit_bias: Optional[Dict[str, float]] = None
|
153
|
+
logprobs: Optional[bool] = False
|
154
|
+
top_logprobs: Optional[int] = None
|
155
|
+
max_tokens: Optional[int] = 16
|
156
|
+
n: Optional[int] = 1
|
157
|
+
presence_penalty: Optional[float] = 0.0
|
158
|
+
response_format: Optional[ResponseFormat] = None
|
159
|
+
seed: Optional[int] = None
|
160
|
+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
161
|
+
stream: Optional[bool] = False
|
162
|
+
temperature: Optional[float] = 0.7
|
163
|
+
top_p: Optional[float] = 1.0
|
164
|
+
user: Optional[str] = None
|
165
|
+
|
166
|
+
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
167
|
+
regex: Optional[str] = None
|
168
|
+
|
169
|
+
|
170
|
+
class ChatMessage(BaseModel):
|
171
|
+
role: Optional[str] = None
|
172
|
+
content: Optional[str] = None
|
173
|
+
|
174
|
+
|
175
|
+
class ChatCompletionResponseChoice(BaseModel):
|
176
|
+
index: int
|
177
|
+
message: ChatMessage
|
178
|
+
logprobs: Optional[LogProbs] = None
|
179
|
+
finish_reason: Optional[str] = None
|
180
|
+
|
181
|
+
|
182
|
+
class ChatCompletionResponse(BaseModel):
|
183
|
+
id: str
|
184
|
+
object: str = "chat.completion"
|
185
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
186
|
+
model: str
|
187
|
+
choices: List[ChatCompletionResponseChoice]
|
188
|
+
usage: UsageInfo
|
189
|
+
|
190
|
+
|
191
|
+
class DeltaMessage(BaseModel):
|
192
|
+
role: Optional[str] = None
|
193
|
+
content: Optional[str] = None
|
194
|
+
|
195
|
+
|
196
|
+
class ChatCompletionResponseStreamChoice(BaseModel):
|
197
|
+
index: int
|
198
|
+
delta: DeltaMessage
|
199
|
+
logprobs: Optional[LogProbs] = None
|
200
|
+
finish_reason: Optional[str] = None
|
201
|
+
|
202
|
+
|
203
|
+
class ChatCompletionStreamResponse(BaseModel):
|
204
|
+
id: str
|
205
|
+
object: str = "chat.completion.chunk"
|
206
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
207
|
+
model: str
|
208
|
+
choices: List[ChatCompletionResponseStreamChoice]
|
sglang/srt/openai_protocol.py
CHANGED
@@ -7,6 +7,23 @@ from pydantic import BaseModel, Field
|
|
7
7
|
from typing_extensions import Literal
|
8
8
|
|
9
9
|
|
10
|
+
class ModelCard(BaseModel):
|
11
|
+
"""Model cards."""
|
12
|
+
|
13
|
+
id: str
|
14
|
+
object: str = "model"
|
15
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
16
|
+
owned_by: str = "sglang"
|
17
|
+
root: Optional[str] = None
|
18
|
+
|
19
|
+
|
20
|
+
class ModelList(BaseModel):
|
21
|
+
"""Model list consists of model cards."""
|
22
|
+
|
23
|
+
object: str = "list"
|
24
|
+
data: List[ModelCard] = []
|
25
|
+
|
26
|
+
|
10
27
|
class ErrorResponse(BaseModel):
|
11
28
|
object: str = "error"
|
12
29
|
message: str
|
sglang/srt/sampling_params.py
CHANGED
@@ -20,6 +20,7 @@ class SamplingParams:
|
|
20
20
|
spaces_between_special_tokens: bool = True,
|
21
21
|
dtype: Optional[str] = None,
|
22
22
|
regex: Optional[str] = None,
|
23
|
+
n: int = 1,
|
23
24
|
) -> None:
|
24
25
|
self.temperature = temperature
|
25
26
|
self.top_p = top_p
|
@@ -33,6 +34,7 @@ class SamplingParams:
|
|
33
34
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
34
35
|
self.dtype = dtype
|
35
36
|
self.regex = regex
|
37
|
+
self.n = n
|
36
38
|
|
37
39
|
# Process some special cases
|
38
40
|
if self.temperature < _SAMPLING_EPS:
|
sglang/srt/server.py
CHANGED
@@ -26,33 +26,33 @@ import uvloop
|
|
26
26
|
from fastapi import FastAPI, Request
|
27
27
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
28
28
|
|
29
|
-
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
29
|
+
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
30
30
|
from sglang.srt.constrained import disable_cache
|
31
31
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
32
32
|
from sglang.srt.managers.controller.manager_multi import (
|
33
33
|
start_controller_process as start_controller_process_multi,
|
34
34
|
)
|
35
|
+
from sglang.srt.managers.controller.manager_single import launch_tp_servers
|
35
36
|
from sglang.srt.managers.controller.manager_single import (
|
36
|
-
launch_tp_servers,
|
37
37
|
start_controller_process as start_controller_process_single,
|
38
38
|
)
|
39
39
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
40
40
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
41
41
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
42
|
-
from sglang.srt.
|
42
|
+
from sglang.srt.openai_api.adapter import (
|
43
43
|
load_chat_template_for_openai_api,
|
44
44
|
v1_chat_completions,
|
45
45
|
v1_completions,
|
46
46
|
)
|
47
|
-
from sglang.srt.
|
47
|
+
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
48
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
48
49
|
from sglang.srt.utils import (
|
49
50
|
API_KEY_HEADER_NAME,
|
50
51
|
APIKeyValidatorMiddleware,
|
51
52
|
allocate_init_ports,
|
52
53
|
assert_pkg_version,
|
53
54
|
enable_show_time_cost,
|
54
|
-
|
55
|
-
send_addrs_to_rank_0,
|
55
|
+
set_ulimit,
|
56
56
|
)
|
57
57
|
from sglang.utils import get_exception_traceback
|
58
58
|
|
@@ -64,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
64
64
|
app = FastAPI()
|
65
65
|
tokenizer_manager = None
|
66
66
|
|
67
|
+
# Put some args for easily access
|
68
|
+
global_server_args_dict = {}
|
69
|
+
|
67
70
|
|
68
71
|
@app.get("/health")
|
69
72
|
async def health() -> Response:
|
@@ -95,6 +98,7 @@ async def flush_cache():
|
|
95
98
|
|
96
99
|
|
97
100
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
101
|
+
"""Handle a generate request."""
|
98
102
|
if obj.stream:
|
99
103
|
|
100
104
|
async def stream_results():
|
@@ -135,7 +139,30 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
135
139
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
136
140
|
|
137
141
|
|
138
|
-
|
142
|
+
@app.get("/v1/models")
|
143
|
+
def available_models():
|
144
|
+
"""Show available models."""
|
145
|
+
model_names = [tokenizer_manager.model_path]
|
146
|
+
model_cards = []
|
147
|
+
for model_name in model_names:
|
148
|
+
model_cards.append(ModelCard(id=model_name, root=model_name))
|
149
|
+
return ModelList(data=model_cards)
|
150
|
+
|
151
|
+
|
152
|
+
def _set_global_server_args(server_args: ServerArgs):
|
153
|
+
global global_server_args_dict
|
154
|
+
global_server_args_dict = {
|
155
|
+
"disable_flashinfer": server_args.disable_flashinfer,
|
156
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
157
|
+
}
|
158
|
+
|
159
|
+
|
160
|
+
def launch_server(
|
161
|
+
server_args: ServerArgs,
|
162
|
+
model_overide_args: Optional[dict] = None,
|
163
|
+
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
164
|
+
):
|
165
|
+
"""Launch an HTTP server."""
|
139
166
|
global tokenizer_manager
|
140
167
|
|
141
168
|
logging.basicConfig(
|
@@ -146,6 +173,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
146
173
|
# Set global environments
|
147
174
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
148
175
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
176
|
+
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
177
|
+
set_ulimit()
|
149
178
|
if server_args.show_time_cost:
|
150
179
|
enable_show_time_cost()
|
151
180
|
if server_args.disable_disk_cache:
|
@@ -153,7 +182,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
153
182
|
if not server_args.disable_flashinfer:
|
154
183
|
assert_pkg_version(
|
155
184
|
"flashinfer",
|
156
|
-
"0.0
|
185
|
+
"0.1.0",
|
157
186
|
"Please uninstall the old version and "
|
158
187
|
"reinstall the latest version by following the instructions "
|
159
188
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -161,64 +190,61 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
161
190
|
if server_args.chat_template:
|
162
191
|
# TODO: replace this with huggingface transformers template
|
163
192
|
load_chat_template_for_openai_api(server_args.chat_template)
|
193
|
+
_set_global_server_args(server_args)
|
164
194
|
|
165
195
|
# Allocate ports
|
166
|
-
assert server_args.tp_size % server_args.nnodes == 0
|
167
|
-
tp_size_local = server_args.tp_size // server_args.nnodes
|
168
196
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
169
197
|
server_args.port,
|
170
198
|
server_args.additional_ports,
|
171
|
-
tp_size_local,
|
172
199
|
server_args.dp_size,
|
173
200
|
)
|
174
|
-
|
175
201
|
ports = server_args.additional_ports
|
176
|
-
model_port_args = []
|
177
|
-
for i in range(server_args.dp_size):
|
178
|
-
model_port_args.append(
|
179
|
-
ModelPortArgs(
|
180
|
-
nccl_port=ports[3 + i * (tp_size_local + 1)],
|
181
|
-
model_tp_ips=[None] * tp_size_local,
|
182
|
-
model_tp_ports=ports[
|
183
|
-
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
|
184
|
-
],
|
185
|
-
)
|
186
|
-
)
|
187
202
|
port_args = PortArgs(
|
188
203
|
tokenizer_port=ports[0],
|
189
|
-
|
204
|
+
controller_port=ports[1],
|
190
205
|
detokenizer_port=ports[2],
|
191
|
-
|
206
|
+
nccl_ports=ports[3:],
|
192
207
|
)
|
193
208
|
|
194
|
-
# Handle multi-node
|
209
|
+
# Handle multi-node tensor parallelism
|
195
210
|
if server_args.nnodes > 1:
|
196
211
|
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
197
212
|
|
198
213
|
if server_args.node_rank != 0:
|
199
214
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
200
|
-
gpu_ids = [
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
215
|
+
gpu_ids = [
|
216
|
+
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
|
217
|
+
]
|
218
|
+
tp_rank_range = list(
|
219
|
+
range(
|
220
|
+
server_args.node_rank * tp_size_local,
|
221
|
+
(server_args.node_rank + 1) * tp_size_local,
|
222
|
+
)
|
223
|
+
)
|
224
|
+
procs = launch_tp_servers(
|
225
|
+
gpu_ids,
|
226
|
+
tp_rank_range,
|
227
|
+
server_args,
|
228
|
+
ports[3],
|
229
|
+
model_overide_args,
|
230
|
+
)
|
205
231
|
while True:
|
206
232
|
pass
|
207
233
|
|
208
234
|
# Launch processes
|
209
235
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
210
|
-
|
236
|
+
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
211
237
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
212
238
|
|
213
239
|
if server_args.dp_size == 1:
|
214
240
|
start_process = start_controller_process_single
|
215
241
|
else:
|
216
242
|
start_process = start_controller_process_multi
|
217
|
-
|
243
|
+
proc_controller = mp.Process(
|
218
244
|
target=start_process,
|
219
|
-
args=(server_args, port_args,
|
245
|
+
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
220
246
|
)
|
221
|
-
|
247
|
+
proc_controller.start()
|
222
248
|
proc_detoken = mp.Process(
|
223
249
|
target=start_detokenizer_process,
|
224
250
|
args=(
|
@@ -230,68 +256,30 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
230
256
|
proc_detoken.start()
|
231
257
|
|
232
258
|
# Wait for the model to finish loading
|
233
|
-
|
259
|
+
controller_init_state = pipe_controller_reader.recv()
|
234
260
|
detoken_init_state = pipe_detoken_reader.recv()
|
235
261
|
|
236
|
-
if
|
237
|
-
|
262
|
+
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
263
|
+
proc_controller.kill()
|
238
264
|
proc_detoken.kill()
|
239
265
|
print(
|
240
|
-
f"Initialization failed.
|
266
|
+
f"Initialization failed. controller_init_state: {controller_init_state}",
|
267
|
+
flush=True,
|
241
268
|
)
|
242
269
|
print(
|
243
270
|
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
244
271
|
flush=True,
|
245
272
|
)
|
246
273
|
sys.exit(1)
|
247
|
-
assert
|
274
|
+
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
248
275
|
|
249
276
|
if server_args.api_key and server_args.api_key != "":
|
250
277
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
251
278
|
|
252
279
|
# Send a warmup request
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
if server_args.api_key:
|
257
|
-
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
258
|
-
|
259
|
-
# Wait until the server is launched
|
260
|
-
for _ in range(120):
|
261
|
-
time.sleep(0.5)
|
262
|
-
try:
|
263
|
-
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
264
|
-
break
|
265
|
-
except requests.exceptions.RequestException:
|
266
|
-
pass
|
267
|
-
|
268
|
-
# Send a warmup request
|
269
|
-
try:
|
270
|
-
for _ in range(server_args.dp_size):
|
271
|
-
res = requests.post(
|
272
|
-
url + "/generate",
|
273
|
-
json={
|
274
|
-
"text": "The capital city of France is",
|
275
|
-
"sampling_params": {
|
276
|
-
"temperature": 0,
|
277
|
-
"max_new_tokens": 8,
|
278
|
-
},
|
279
|
-
},
|
280
|
-
headers=headers,
|
281
|
-
timeout=600,
|
282
|
-
)
|
283
|
-
assert res.status_code == 200
|
284
|
-
except Exception as e:
|
285
|
-
if pipe_finish_writer is not None:
|
286
|
-
pipe_finish_writer.send(get_exception_traceback())
|
287
|
-
print(f"Initialization failed. warmup error: {e}", flush=True)
|
288
|
-
raise e
|
289
|
-
|
290
|
-
logger.info("The server is fired up and ready to roll!")
|
291
|
-
if pipe_finish_writer is not None:
|
292
|
-
pipe_finish_writer.send("init ok")
|
293
|
-
|
294
|
-
t = threading.Thread(target=_wait_and_warmup)
|
280
|
+
t = threading.Thread(
|
281
|
+
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
282
|
+
)
|
295
283
|
t.start()
|
296
284
|
|
297
285
|
# Listen for requests
|
@@ -308,6 +296,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
308
296
|
t.join()
|
309
297
|
|
310
298
|
|
299
|
+
def _wait_and_warmup(server_args, pipe_finish_writer):
|
300
|
+
headers = {}
|
301
|
+
url = server_args.url()
|
302
|
+
if server_args.api_key:
|
303
|
+
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
304
|
+
|
305
|
+
# Wait until the server is launched
|
306
|
+
for _ in range(120):
|
307
|
+
time.sleep(0.5)
|
308
|
+
try:
|
309
|
+
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
310
|
+
break
|
311
|
+
except requests.exceptions.RequestException:
|
312
|
+
pass
|
313
|
+
|
314
|
+
# Send a warmup request
|
315
|
+
try:
|
316
|
+
for _ in range(server_args.dp_size):
|
317
|
+
res = requests.post(
|
318
|
+
url + "/generate",
|
319
|
+
json={
|
320
|
+
"text": "The capital city of France is",
|
321
|
+
"sampling_params": {
|
322
|
+
"temperature": 0,
|
323
|
+
"max_new_tokens": 8,
|
324
|
+
},
|
325
|
+
},
|
326
|
+
headers=headers,
|
327
|
+
timeout=600,
|
328
|
+
)
|
329
|
+
assert res.status_code == 200
|
330
|
+
except Exception as e:
|
331
|
+
if pipe_finish_writer is not None:
|
332
|
+
pipe_finish_writer.send(get_exception_traceback())
|
333
|
+
print(f"Initialization failed. warmup error: {e}", flush=True)
|
334
|
+
raise e
|
335
|
+
|
336
|
+
logger.info("The server is fired up and ready to roll!")
|
337
|
+
if pipe_finish_writer is not None:
|
338
|
+
pipe_finish_writer.send("init ok")
|
339
|
+
|
340
|
+
|
311
341
|
class Runtime:
|
312
342
|
"""
|
313
343
|
A wrapper for the server.
|
@@ -329,7 +359,6 @@ class Runtime:
|
|
329
359
|
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
330
360
|
self.server_args.port,
|
331
361
|
self.server_args.additional_ports,
|
332
|
-
self.server_args.tp_size,
|
333
362
|
self.server_args.dp_size,
|
334
363
|
)
|
335
364
|
|
@@ -342,7 +371,7 @@ class Runtime:
|
|
342
371
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
343
372
|
proc = mp.Process(
|
344
373
|
target=launch_server,
|
345
|
-
args=(self.server_args,
|
374
|
+
args=(self.server_args, model_overide_args, pipe_writer),
|
346
375
|
)
|
347
376
|
proc.start()
|
348
377
|
pipe_writer.close()
|
sglang/srt/server_args.py
CHANGED
@@ -33,7 +33,7 @@ class ServerArgs:
|
|
33
33
|
|
34
34
|
# Other runtime options
|
35
35
|
tp_size: int = 1
|
36
|
-
stream_interval: int =
|
36
|
+
stream_interval: int = 1
|
37
37
|
random_seed: Optional[int] = None
|
38
38
|
|
39
39
|
# Logging
|
@@ -57,6 +57,7 @@ class ServerArgs:
|
|
57
57
|
disable_disk_cache: bool = False
|
58
58
|
attention_reduce_in_fp32: bool = False
|
59
59
|
enable_p2p_check: bool = False
|
60
|
+
efficient_weight_load: bool = False
|
60
61
|
|
61
62
|
# Distributed args
|
62
63
|
nccl_init_addr: Optional[str] = None
|
@@ -166,6 +167,15 @@ class ServerArgs:
|
|
166
167
|
"--quantization",
|
167
168
|
type=str,
|
168
169
|
default=ServerArgs.quantization,
|
170
|
+
choices=[
|
171
|
+
"awq",
|
172
|
+
"fp8",
|
173
|
+
"gptq",
|
174
|
+
"marlin",
|
175
|
+
"gptq_marlin",
|
176
|
+
"squeezellm",
|
177
|
+
"bitsandbytes",
|
178
|
+
],
|
169
179
|
help="The quantization method.",
|
170
180
|
)
|
171
181
|
parser.add_argument(
|
@@ -243,13 +253,13 @@ class ServerArgs:
|
|
243
253
|
parser.add_argument(
|
244
254
|
"--show-time-cost",
|
245
255
|
action="store_true",
|
246
|
-
help="Show time cost of custom marks",
|
256
|
+
help="Show time cost of custom marks.",
|
247
257
|
)
|
248
258
|
parser.add_argument(
|
249
259
|
"--api-key",
|
250
260
|
type=str,
|
251
261
|
default=ServerArgs.api_key,
|
252
|
-
help="Set API key of the server",
|
262
|
+
help="Set API key of the server.",
|
253
263
|
)
|
254
264
|
|
255
265
|
# Data parallelism
|
@@ -285,17 +295,17 @@ class ServerArgs:
|
|
285
295
|
parser.add_argument(
|
286
296
|
"--disable-flashinfer",
|
287
297
|
action="store_true",
|
288
|
-
help="Disable flashinfer inference kernels",
|
298
|
+
help="Disable flashinfer inference kernels.",
|
289
299
|
)
|
290
300
|
parser.add_argument(
|
291
301
|
"--disable-radix-cache",
|
292
302
|
action="store_true",
|
293
|
-
help="Disable RadixAttention",
|
303
|
+
help="Disable RadixAttention for prefix caching.",
|
294
304
|
)
|
295
305
|
parser.add_argument(
|
296
306
|
"--disable-regex-jump-forward",
|
297
307
|
action="store_true",
|
298
|
-
help="Disable regex jump-forward",
|
308
|
+
help="Disable regex jump-forward.",
|
299
309
|
)
|
300
310
|
parser.add_argument(
|
301
311
|
"--disable-cuda-graph",
|
@@ -318,6 +328,11 @@ class ServerArgs:
|
|
318
328
|
action="store_true",
|
319
329
|
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
320
330
|
)
|
331
|
+
parser.add_argument(
|
332
|
+
"--efficient-weight-load",
|
333
|
+
action="store_true",
|
334
|
+
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
335
|
+
)
|
321
336
|
|
322
337
|
@classmethod
|
323
338
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -337,16 +352,9 @@ class ServerArgs:
|
|
337
352
|
)
|
338
353
|
|
339
354
|
|
340
|
-
@dataclasses.dataclass
|
341
|
-
class ModelPortArgs:
|
342
|
-
nccl_port: int
|
343
|
-
model_tp_ips: List[str]
|
344
|
-
model_tp_ports: List[int]
|
345
|
-
|
346
|
-
|
347
355
|
@dataclasses.dataclass
|
348
356
|
class PortArgs:
|
349
357
|
tokenizer_port: int
|
350
|
-
|
358
|
+
controller_port: int
|
351
359
|
detokenizer_port: int
|
352
|
-
|
360
|
+
nccl_ports: List[int]
|