sglang 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/runtime_endpoint.py +14 -4
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -20
- sglang/bench_serving.py +758 -0
- sglang/check_env.py +171 -0
- sglang/global_config.py +3 -1
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/chat_template.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +31 -5
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +44 -18
- sglang/srt/managers/controller/infer_batch.py +76 -72
- sglang/srt/managers/controller/manager_multi.py +109 -98
- sglang/srt/managers/controller/manager_single.py +105 -50
- sglang/srt/managers/controller/model_runner.py +42 -18
- sglang/srt/managers/controller/radix_cache.py +4 -3
- sglang/srt/managers/controller/schedule_heuristic.py +4 -0
- sglang/srt/managers/controller/tp_worker.py +143 -156
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +46 -58
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +65 -16
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +2 -8
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +130 -108
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +114 -90
- sglang/srt/server_args.py +27 -17
- sglang/srt/utils.py +17 -118
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -159
- sglang-0.1.22.dist-info/RECORD +103 -0
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
- sglang-0.1.20.dist-info/RECORD +0 -82
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
|
|
1
|
+
"""Pydantic models for OpenAI API protocol"""
|
2
|
+
|
3
|
+
import time
|
4
|
+
from typing import Dict, List, Optional, Union
|
5
|
+
|
6
|
+
from pydantic import BaseModel, Field
|
7
|
+
from typing_extensions import Literal
|
8
|
+
|
9
|
+
|
10
|
+
class ModelCard(BaseModel):
|
11
|
+
"""Model cards."""
|
12
|
+
|
13
|
+
id: str
|
14
|
+
object: str = "model"
|
15
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
16
|
+
owned_by: str = "sglang"
|
17
|
+
root: Optional[str] = None
|
18
|
+
|
19
|
+
|
20
|
+
class ModelList(BaseModel):
|
21
|
+
"""Model list consists of model cards."""
|
22
|
+
|
23
|
+
object: str = "list"
|
24
|
+
data: List[ModelCard] = []
|
25
|
+
|
26
|
+
|
27
|
+
class ErrorResponse(BaseModel):
|
28
|
+
object: str = "error"
|
29
|
+
message: str
|
30
|
+
type: str
|
31
|
+
param: Optional[str] = None
|
32
|
+
code: int
|
33
|
+
|
34
|
+
|
35
|
+
class LogProbs(BaseModel):
|
36
|
+
text_offset: List[int] = Field(default_factory=list)
|
37
|
+
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
38
|
+
tokens: List[str] = Field(default_factory=list)
|
39
|
+
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
40
|
+
|
41
|
+
|
42
|
+
class UsageInfo(BaseModel):
|
43
|
+
prompt_tokens: int = 0
|
44
|
+
total_tokens: int = 0
|
45
|
+
completion_tokens: Optional[int] = 0
|
46
|
+
|
47
|
+
|
48
|
+
class CompletionRequest(BaseModel):
|
49
|
+
# Ordered by official OpenAI API documentation
|
50
|
+
# https://platform.openai.com/docs/api-reference/completions/create
|
51
|
+
model: str
|
52
|
+
prompt: Union[List[int], List[List[int]], str, List[str]]
|
53
|
+
best_of: Optional[int] = None
|
54
|
+
echo: Optional[bool] = False
|
55
|
+
frequency_penalty: Optional[float] = 0.0
|
56
|
+
logit_bias: Optional[Dict[str, float]] = None
|
57
|
+
logprobs: Optional[int] = None
|
58
|
+
max_tokens: Optional[int] = 16
|
59
|
+
n: int = 1
|
60
|
+
presence_penalty: Optional[float] = 0.0
|
61
|
+
seed: Optional[int] = None
|
62
|
+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
63
|
+
stream: Optional[bool] = False
|
64
|
+
suffix: Optional[str] = None
|
65
|
+
temperature: Optional[float] = 1.0
|
66
|
+
top_p: Optional[float] = 1.0
|
67
|
+
user: Optional[str] = None
|
68
|
+
|
69
|
+
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
70
|
+
regex: Optional[str] = None
|
71
|
+
ignore_eos: Optional[bool] = False
|
72
|
+
|
73
|
+
|
74
|
+
class CompletionResponseChoice(BaseModel):
|
75
|
+
index: int
|
76
|
+
text: str
|
77
|
+
logprobs: Optional[LogProbs] = None
|
78
|
+
finish_reason: Optional[str] = None
|
79
|
+
|
80
|
+
|
81
|
+
class CompletionResponse(BaseModel):
|
82
|
+
id: str
|
83
|
+
object: str = "text_completion"
|
84
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
85
|
+
model: str
|
86
|
+
choices: List[CompletionResponseChoice]
|
87
|
+
usage: UsageInfo
|
88
|
+
|
89
|
+
|
90
|
+
class CompletionResponseStreamChoice(BaseModel):
|
91
|
+
index: int
|
92
|
+
text: str
|
93
|
+
logprobs: Optional[LogProbs] = None
|
94
|
+
finish_reason: Optional[str] = None
|
95
|
+
|
96
|
+
|
97
|
+
class CompletionStreamResponse(BaseModel):
|
98
|
+
id: str
|
99
|
+
object: str = "text_completion"
|
100
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
101
|
+
model: str
|
102
|
+
choices: List[CompletionResponseStreamChoice]
|
103
|
+
usage: UsageInfo
|
104
|
+
|
105
|
+
|
106
|
+
class ChatCompletionMessageGenericParam(BaseModel):
|
107
|
+
role: Literal["system", "assistant"]
|
108
|
+
content: str
|
109
|
+
|
110
|
+
|
111
|
+
class ChatCompletionMessageContentTextPart(BaseModel):
|
112
|
+
type: Literal["text"]
|
113
|
+
text: str
|
114
|
+
|
115
|
+
|
116
|
+
class ChatCompletionMessageContentImageURL(BaseModel):
|
117
|
+
url: str
|
118
|
+
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
119
|
+
|
120
|
+
|
121
|
+
class ChatCompletionMessageContentImagePart(BaseModel):
|
122
|
+
type: Literal["image_url"]
|
123
|
+
image_url: ChatCompletionMessageContentImageURL
|
124
|
+
|
125
|
+
|
126
|
+
ChatCompletionMessageContentPart = Union[
|
127
|
+
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
|
128
|
+
]
|
129
|
+
|
130
|
+
|
131
|
+
class ChatCompletionMessageUserParam(BaseModel):
|
132
|
+
role: Literal["user"]
|
133
|
+
content: Union[str, List[ChatCompletionMessageContentPart]]
|
134
|
+
|
135
|
+
|
136
|
+
ChatCompletionMessageParam = Union[
|
137
|
+
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
|
138
|
+
]
|
139
|
+
|
140
|
+
|
141
|
+
class ResponseFormat(BaseModel):
|
142
|
+
# type must be "json_object" or "text"
|
143
|
+
type: Literal["text", "json_object"]
|
144
|
+
|
145
|
+
|
146
|
+
class ChatCompletionRequest(BaseModel):
|
147
|
+
# Ordered by official OpenAI API documentation
|
148
|
+
# https://platform.openai.com/docs/api-reference/chat/create
|
149
|
+
messages: List[ChatCompletionMessageParam]
|
150
|
+
model: str
|
151
|
+
frequency_penalty: Optional[float] = 0.0
|
152
|
+
logit_bias: Optional[Dict[str, float]] = None
|
153
|
+
logprobs: Optional[bool] = False
|
154
|
+
top_logprobs: Optional[int] = None
|
155
|
+
max_tokens: Optional[int] = 16
|
156
|
+
n: Optional[int] = 1
|
157
|
+
presence_penalty: Optional[float] = 0.0
|
158
|
+
response_format: Optional[ResponseFormat] = None
|
159
|
+
seed: Optional[int] = None
|
160
|
+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
161
|
+
stream: Optional[bool] = False
|
162
|
+
temperature: Optional[float] = 0.7
|
163
|
+
top_p: Optional[float] = 1.0
|
164
|
+
user: Optional[str] = None
|
165
|
+
|
166
|
+
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
167
|
+
regex: Optional[str] = None
|
168
|
+
|
169
|
+
|
170
|
+
class ChatMessage(BaseModel):
|
171
|
+
role: Optional[str] = None
|
172
|
+
content: Optional[str] = None
|
173
|
+
|
174
|
+
|
175
|
+
class ChatCompletionResponseChoice(BaseModel):
|
176
|
+
index: int
|
177
|
+
message: ChatMessage
|
178
|
+
logprobs: Optional[LogProbs] = None
|
179
|
+
finish_reason: Optional[str] = None
|
180
|
+
|
181
|
+
|
182
|
+
class ChatCompletionResponse(BaseModel):
|
183
|
+
id: str
|
184
|
+
object: str = "chat.completion"
|
185
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
186
|
+
model: str
|
187
|
+
choices: List[ChatCompletionResponseChoice]
|
188
|
+
usage: UsageInfo
|
189
|
+
|
190
|
+
|
191
|
+
class DeltaMessage(BaseModel):
|
192
|
+
role: Optional[str] = None
|
193
|
+
content: Optional[str] = None
|
194
|
+
|
195
|
+
|
196
|
+
class ChatCompletionResponseStreamChoice(BaseModel):
|
197
|
+
index: int
|
198
|
+
delta: DeltaMessage
|
199
|
+
logprobs: Optional[LogProbs] = None
|
200
|
+
finish_reason: Optional[str] = None
|
201
|
+
|
202
|
+
|
203
|
+
class ChatCompletionStreamResponse(BaseModel):
|
204
|
+
id: str
|
205
|
+
object: str = "chat.completion.chunk"
|
206
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
207
|
+
model: str
|
208
|
+
choices: List[ChatCompletionResponseStreamChoice]
|
sglang/srt/openai_protocol.py
CHANGED
@@ -7,6 +7,23 @@ from pydantic import BaseModel, Field
|
|
7
7
|
from typing_extensions import Literal
|
8
8
|
|
9
9
|
|
10
|
+
class ModelCard(BaseModel):
|
11
|
+
"""Model cards."""
|
12
|
+
|
13
|
+
id: str
|
14
|
+
object: str = "model"
|
15
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
16
|
+
owned_by: str = "sglang"
|
17
|
+
root: Optional[str] = None
|
18
|
+
|
19
|
+
|
20
|
+
class ModelList(BaseModel):
|
21
|
+
"""Model list consists of model cards."""
|
22
|
+
|
23
|
+
object: str = "list"
|
24
|
+
data: List[ModelCard] = []
|
25
|
+
|
26
|
+
|
10
27
|
class ErrorResponse(BaseModel):
|
11
28
|
object: str = "error"
|
12
29
|
message: str
|
sglang/srt/sampling_params.py
CHANGED
@@ -20,6 +20,7 @@ class SamplingParams:
|
|
20
20
|
spaces_between_special_tokens: bool = True,
|
21
21
|
dtype: Optional[str] = None,
|
22
22
|
regex: Optional[str] = None,
|
23
|
+
n: int = 1,
|
23
24
|
) -> None:
|
24
25
|
self.temperature = temperature
|
25
26
|
self.top_p = top_p
|
@@ -33,6 +34,7 @@ class SamplingParams:
|
|
33
34
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
34
35
|
self.dtype = dtype
|
35
36
|
self.regex = regex
|
37
|
+
self.n = n
|
36
38
|
|
37
39
|
# Process some special cases
|
38
40
|
if self.temperature < _SAMPLING_EPS:
|
sglang/srt/server.py
CHANGED
@@ -26,34 +26,33 @@ import uvloop
|
|
26
26
|
from fastapi import FastAPI, Request
|
27
27
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
28
28
|
|
29
|
-
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
29
|
+
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
30
30
|
from sglang.srt.constrained import disable_cache
|
31
31
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
32
32
|
from sglang.srt.managers.controller.manager_multi import (
|
33
33
|
start_controller_process as start_controller_process_multi,
|
34
34
|
)
|
35
|
+
from sglang.srt.managers.controller.manager_single import launch_tp_servers
|
35
36
|
from sglang.srt.managers.controller.manager_single import (
|
36
37
|
start_controller_process as start_controller_process_single,
|
37
38
|
)
|
38
|
-
from sglang.srt.managers.controller.tp_worker import ModelTpService
|
39
39
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
40
40
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
41
41
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
42
|
-
from sglang.srt.
|
42
|
+
from sglang.srt.openai_api.adapter import (
|
43
43
|
load_chat_template_for_openai_api,
|
44
44
|
v1_chat_completions,
|
45
45
|
v1_completions,
|
46
46
|
)
|
47
|
-
from sglang.srt.
|
47
|
+
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
48
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
48
49
|
from sglang.srt.utils import (
|
49
50
|
API_KEY_HEADER_NAME,
|
50
51
|
APIKeyValidatorMiddleware,
|
51
52
|
allocate_init_ports,
|
52
53
|
assert_pkg_version,
|
53
54
|
enable_show_time_cost,
|
54
|
-
|
55
|
-
send_addrs_to_rank_0,
|
56
|
-
start_rpyc_service_process,
|
55
|
+
set_ulimit,
|
57
56
|
)
|
58
57
|
from sglang.utils import get_exception_traceback
|
59
58
|
|
@@ -65,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
65
64
|
app = FastAPI()
|
66
65
|
tokenizer_manager = None
|
67
66
|
|
67
|
+
# Put some args for easily access
|
68
|
+
global_server_args_dict = {}
|
69
|
+
|
68
70
|
|
69
71
|
@app.get("/health")
|
70
72
|
async def health() -> Response:
|
@@ -96,6 +98,7 @@ async def flush_cache():
|
|
96
98
|
|
97
99
|
|
98
100
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
101
|
+
"""Handle a generate request."""
|
99
102
|
if obj.stream:
|
100
103
|
|
101
104
|
async def stream_results():
|
@@ -136,7 +139,30 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
136
139
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
137
140
|
|
138
141
|
|
139
|
-
|
142
|
+
@app.get("/v1/models")
|
143
|
+
def available_models():
|
144
|
+
"""Show available models."""
|
145
|
+
model_names = [tokenizer_manager.model_path]
|
146
|
+
model_cards = []
|
147
|
+
for model_name in model_names:
|
148
|
+
model_cards.append(ModelCard(id=model_name, root=model_name))
|
149
|
+
return ModelList(data=model_cards)
|
150
|
+
|
151
|
+
|
152
|
+
def _set_global_server_args(server_args: ServerArgs):
|
153
|
+
global global_server_args_dict
|
154
|
+
global_server_args_dict = {
|
155
|
+
"disable_flashinfer": server_args.disable_flashinfer,
|
156
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
157
|
+
}
|
158
|
+
|
159
|
+
|
160
|
+
def launch_server(
|
161
|
+
server_args: ServerArgs,
|
162
|
+
model_overide_args: Optional[dict] = None,
|
163
|
+
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
164
|
+
):
|
165
|
+
"""Launch an HTTP server."""
|
140
166
|
global tokenizer_manager
|
141
167
|
|
142
168
|
logging.basicConfig(
|
@@ -147,6 +173,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
147
173
|
# Set global environments
|
148
174
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
149
175
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
176
|
+
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
177
|
+
set_ulimit()
|
150
178
|
if server_args.show_time_cost:
|
151
179
|
enable_show_time_cost()
|
152
180
|
if server_args.disable_disk_cache:
|
@@ -154,7 +182,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
154
182
|
if not server_args.disable_flashinfer:
|
155
183
|
assert_pkg_version(
|
156
184
|
"flashinfer",
|
157
|
-
"0.0
|
185
|
+
"0.1.0",
|
158
186
|
"Please uninstall the old version and "
|
159
187
|
"reinstall the latest version by following the instructions "
|
160
188
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -162,68 +190,61 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
162
190
|
if server_args.chat_template:
|
163
191
|
# TODO: replace this with huggingface transformers template
|
164
192
|
load_chat_template_for_openai_api(server_args.chat_template)
|
193
|
+
_set_global_server_args(server_args)
|
165
194
|
|
166
195
|
# Allocate ports
|
167
|
-
assert server_args.tp_size % server_args.nnodes == 0
|
168
|
-
tp_size_local = server_args.tp_size // server_args.nnodes
|
169
196
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
170
197
|
server_args.port,
|
171
198
|
server_args.additional_ports,
|
172
|
-
tp_size_local,
|
173
199
|
server_args.dp_size,
|
174
200
|
)
|
175
|
-
|
176
201
|
ports = server_args.additional_ports
|
177
|
-
model_port_args = []
|
178
|
-
for i in range(server_args.dp_size):
|
179
|
-
model_port_args.append(
|
180
|
-
ModelPortArgs(
|
181
|
-
nccl_port=ports[3 + i * (tp_size_local + 1)],
|
182
|
-
model_tp_ips=[None] * tp_size_local,
|
183
|
-
model_tp_ports=ports[
|
184
|
-
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
|
185
|
-
],
|
186
|
-
)
|
187
|
-
)
|
188
202
|
port_args = PortArgs(
|
189
203
|
tokenizer_port=ports[0],
|
190
|
-
|
204
|
+
controller_port=ports[1],
|
191
205
|
detokenizer_port=ports[2],
|
192
|
-
|
206
|
+
nccl_ports=ports[3:],
|
193
207
|
)
|
194
208
|
|
195
|
-
#
|
196
|
-
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
|
209
|
+
# Handle multi-node tensor parallelism
|
197
210
|
if server_args.nnodes > 1:
|
211
|
+
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
212
|
+
|
198
213
|
if server_args.node_rank != 0:
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
214
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
215
|
+
gpu_ids = [
|
216
|
+
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
|
217
|
+
]
|
218
|
+
tp_rank_range = list(
|
219
|
+
range(
|
220
|
+
server_args.node_rank * tp_size_local,
|
221
|
+
(server_args.node_rank + 1) * tp_size_local,
|
222
|
+
)
|
205
223
|
)
|
206
|
-
|
207
|
-
|
208
|
-
|
224
|
+
procs = launch_tp_servers(
|
225
|
+
gpu_ids,
|
226
|
+
tp_rank_range,
|
227
|
+
server_args,
|
228
|
+
ports[3],
|
229
|
+
model_overide_args,
|
209
230
|
)
|
210
231
|
while True:
|
211
232
|
pass
|
212
233
|
|
213
234
|
# Launch processes
|
214
235
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
215
|
-
|
236
|
+
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
216
237
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
217
238
|
|
218
239
|
if server_args.dp_size == 1:
|
219
240
|
start_process = start_controller_process_single
|
220
241
|
else:
|
221
242
|
start_process = start_controller_process_multi
|
222
|
-
|
243
|
+
proc_controller = mp.Process(
|
223
244
|
target=start_process,
|
224
|
-
args=(server_args, port_args,
|
245
|
+
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
225
246
|
)
|
226
|
-
|
247
|
+
proc_controller.start()
|
227
248
|
proc_detoken = mp.Process(
|
228
249
|
target=start_detokenizer_process,
|
229
250
|
args=(
|
@@ -235,68 +256,30 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
235
256
|
proc_detoken.start()
|
236
257
|
|
237
258
|
# Wait for the model to finish loading
|
238
|
-
|
259
|
+
controller_init_state = pipe_controller_reader.recv()
|
239
260
|
detoken_init_state = pipe_detoken_reader.recv()
|
240
261
|
|
241
|
-
if
|
242
|
-
|
262
|
+
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
263
|
+
proc_controller.kill()
|
243
264
|
proc_detoken.kill()
|
244
265
|
print(
|
245
|
-
f"Initialization failed.
|
266
|
+
f"Initialization failed. controller_init_state: {controller_init_state}",
|
267
|
+
flush=True,
|
246
268
|
)
|
247
269
|
print(
|
248
270
|
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
249
271
|
flush=True,
|
250
272
|
)
|
251
273
|
sys.exit(1)
|
252
|
-
assert
|
274
|
+
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
253
275
|
|
254
276
|
if server_args.api_key and server_args.api_key != "":
|
255
277
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
256
278
|
|
257
279
|
# Send a warmup request
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
if server_args.api_key:
|
262
|
-
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
263
|
-
|
264
|
-
# Wait until the server is launched
|
265
|
-
for _ in range(120):
|
266
|
-
time.sleep(0.5)
|
267
|
-
try:
|
268
|
-
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
269
|
-
break
|
270
|
-
except requests.exceptions.RequestException:
|
271
|
-
pass
|
272
|
-
|
273
|
-
# Send a warmup request
|
274
|
-
try:
|
275
|
-
for _ in range(server_args.dp_size):
|
276
|
-
res = requests.post(
|
277
|
-
url + "/generate",
|
278
|
-
json={
|
279
|
-
"text": "The capital city of France is",
|
280
|
-
"sampling_params": {
|
281
|
-
"temperature": 0,
|
282
|
-
"max_new_tokens": 8,
|
283
|
-
},
|
284
|
-
},
|
285
|
-
headers=headers,
|
286
|
-
timeout=600,
|
287
|
-
)
|
288
|
-
assert res.status_code == 200
|
289
|
-
except Exception as e:
|
290
|
-
if pipe_finish_writer is not None:
|
291
|
-
pipe_finish_writer.send(get_exception_traceback())
|
292
|
-
print(f"Initialization failed. warmup error: {e}", flush=True)
|
293
|
-
raise e
|
294
|
-
|
295
|
-
logger.info("The server is fired up and ready to roll!")
|
296
|
-
if pipe_finish_writer is not None:
|
297
|
-
pipe_finish_writer.send("init ok")
|
298
|
-
|
299
|
-
t = threading.Thread(target=_wait_and_warmup)
|
280
|
+
t = threading.Thread(
|
281
|
+
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
282
|
+
)
|
300
283
|
t.start()
|
301
284
|
|
302
285
|
# Listen for requests
|
@@ -313,6 +296,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
313
296
|
t.join()
|
314
297
|
|
315
298
|
|
299
|
+
def _wait_and_warmup(server_args, pipe_finish_writer):
|
300
|
+
headers = {}
|
301
|
+
url = server_args.url()
|
302
|
+
if server_args.api_key:
|
303
|
+
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
304
|
+
|
305
|
+
# Wait until the server is launched
|
306
|
+
for _ in range(120):
|
307
|
+
time.sleep(0.5)
|
308
|
+
try:
|
309
|
+
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
310
|
+
break
|
311
|
+
except requests.exceptions.RequestException:
|
312
|
+
pass
|
313
|
+
|
314
|
+
# Send a warmup request
|
315
|
+
try:
|
316
|
+
for _ in range(server_args.dp_size):
|
317
|
+
res = requests.post(
|
318
|
+
url + "/generate",
|
319
|
+
json={
|
320
|
+
"text": "The capital city of France is",
|
321
|
+
"sampling_params": {
|
322
|
+
"temperature": 0,
|
323
|
+
"max_new_tokens": 8,
|
324
|
+
},
|
325
|
+
},
|
326
|
+
headers=headers,
|
327
|
+
timeout=600,
|
328
|
+
)
|
329
|
+
assert res.status_code == 200
|
330
|
+
except Exception as e:
|
331
|
+
if pipe_finish_writer is not None:
|
332
|
+
pipe_finish_writer.send(get_exception_traceback())
|
333
|
+
print(f"Initialization failed. warmup error: {e}", flush=True)
|
334
|
+
raise e
|
335
|
+
|
336
|
+
logger.info("The server is fired up and ready to roll!")
|
337
|
+
if pipe_finish_writer is not None:
|
338
|
+
pipe_finish_writer.send("init ok")
|
339
|
+
|
340
|
+
|
316
341
|
class Runtime:
|
317
342
|
"""
|
318
343
|
A wrapper for the server.
|
@@ -334,7 +359,6 @@ class Runtime:
|
|
334
359
|
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
335
360
|
self.server_args.port,
|
336
361
|
self.server_args.additional_ports,
|
337
|
-
self.server_args.tp_size,
|
338
362
|
self.server_args.dp_size,
|
339
363
|
)
|
340
364
|
|
@@ -347,7 +371,7 @@ class Runtime:
|
|
347
371
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
348
372
|
proc = mp.Process(
|
349
373
|
target=launch_server,
|
350
|
-
args=(self.server_args,
|
374
|
+
args=(self.server_args, model_overide_args, pipe_writer),
|
351
375
|
)
|
352
376
|
proc.start()
|
353
377
|
pipe_writer.close()
|