sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -19
- sglang/bench_serving.py +976 -0
- sglang/check_env.py +171 -0
- sglang/global_config.py +3 -2
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/hf_transformers_utils.py +13 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/logits_processor.py +4 -5
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +39 -24
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
- sglang/srt/managers/controller/infer_batch.py +90 -63
- sglang/srt/managers/controller/manager_multi.py +107 -100
- sglang/srt/managers/controller/manager_single.py +76 -96
- sglang/srt/managers/controller/model_runner.py +41 -26
- sglang/srt/managers/controller/schedule_heuristic.py +8 -3
- sglang/srt/managers/controller/tp_worker.py +136 -149
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +32 -11
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +430 -0
- sglang/srt/models/gpt_bigcode.py +282 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +81 -23
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +132 -84
- sglang/srt/server_args.py +35 -21
- sglang/srt/utils.py +65 -117
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
- sglang-0.1.24.dist-info/RECORD +105 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
- sglang-0.1.21.dist-info/RECORD +0 -82
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
|
|
1
|
+
"""Pydantic models for OpenAI API protocol"""
|
2
|
+
|
3
|
+
import time
|
4
|
+
from typing import Dict, List, Optional, Union
|
5
|
+
|
6
|
+
from pydantic import BaseModel, Field
|
7
|
+
from typing_extensions import Literal
|
8
|
+
|
9
|
+
|
10
|
+
class ModelCard(BaseModel):
|
11
|
+
"""Model cards."""
|
12
|
+
|
13
|
+
id: str
|
14
|
+
object: str = "model"
|
15
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
16
|
+
owned_by: str = "sglang"
|
17
|
+
root: Optional[str] = None
|
18
|
+
|
19
|
+
|
20
|
+
class ModelList(BaseModel):
|
21
|
+
"""Model list consists of model cards."""
|
22
|
+
|
23
|
+
object: str = "list"
|
24
|
+
data: List[ModelCard] = []
|
25
|
+
|
26
|
+
|
27
|
+
class ErrorResponse(BaseModel):
|
28
|
+
object: str = "error"
|
29
|
+
message: str
|
30
|
+
type: str
|
31
|
+
param: Optional[str] = None
|
32
|
+
code: int
|
33
|
+
|
34
|
+
|
35
|
+
class LogProbs(BaseModel):
|
36
|
+
text_offset: List[int] = Field(default_factory=list)
|
37
|
+
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
38
|
+
tokens: List[str] = Field(default_factory=list)
|
39
|
+
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
40
|
+
|
41
|
+
|
42
|
+
class UsageInfo(BaseModel):
|
43
|
+
prompt_tokens: int = 0
|
44
|
+
total_tokens: int = 0
|
45
|
+
completion_tokens: Optional[int] = 0
|
46
|
+
|
47
|
+
|
48
|
+
class CompletionRequest(BaseModel):
|
49
|
+
# Ordered by official OpenAI API documentation
|
50
|
+
# https://platform.openai.com/docs/api-reference/completions/create
|
51
|
+
model: str
|
52
|
+
prompt: Union[List[int], List[List[int]], str, List[str]]
|
53
|
+
best_of: Optional[int] = None
|
54
|
+
echo: Optional[bool] = False
|
55
|
+
frequency_penalty: Optional[float] = 0.0
|
56
|
+
logit_bias: Optional[Dict[str, float]] = None
|
57
|
+
logprobs: Optional[int] = None
|
58
|
+
max_tokens: Optional[int] = 16
|
59
|
+
n: int = 1
|
60
|
+
presence_penalty: Optional[float] = 0.0
|
61
|
+
seed: Optional[int] = None
|
62
|
+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
63
|
+
stream: Optional[bool] = False
|
64
|
+
suffix: Optional[str] = None
|
65
|
+
temperature: Optional[float] = 1.0
|
66
|
+
top_p: Optional[float] = 1.0
|
67
|
+
user: Optional[str] = None
|
68
|
+
|
69
|
+
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
70
|
+
regex: Optional[str] = None
|
71
|
+
ignore_eos: Optional[bool] = False
|
72
|
+
|
73
|
+
|
74
|
+
class CompletionResponseChoice(BaseModel):
|
75
|
+
index: int
|
76
|
+
text: str
|
77
|
+
logprobs: Optional[LogProbs] = None
|
78
|
+
finish_reason: Optional[str] = None
|
79
|
+
|
80
|
+
|
81
|
+
class CompletionResponse(BaseModel):
|
82
|
+
id: str
|
83
|
+
object: str = "text_completion"
|
84
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
85
|
+
model: str
|
86
|
+
choices: List[CompletionResponseChoice]
|
87
|
+
usage: UsageInfo
|
88
|
+
|
89
|
+
|
90
|
+
class CompletionResponseStreamChoice(BaseModel):
|
91
|
+
index: int
|
92
|
+
text: str
|
93
|
+
logprobs: Optional[LogProbs] = None
|
94
|
+
finish_reason: Optional[str] = None
|
95
|
+
|
96
|
+
|
97
|
+
class CompletionStreamResponse(BaseModel):
|
98
|
+
id: str
|
99
|
+
object: str = "text_completion"
|
100
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
101
|
+
model: str
|
102
|
+
choices: List[CompletionResponseStreamChoice]
|
103
|
+
usage: UsageInfo
|
104
|
+
|
105
|
+
|
106
|
+
class ChatCompletionMessageGenericParam(BaseModel):
|
107
|
+
role: Literal["system", "assistant"]
|
108
|
+
content: str
|
109
|
+
|
110
|
+
|
111
|
+
class ChatCompletionMessageContentTextPart(BaseModel):
|
112
|
+
type: Literal["text"]
|
113
|
+
text: str
|
114
|
+
|
115
|
+
|
116
|
+
class ChatCompletionMessageContentImageURL(BaseModel):
|
117
|
+
url: str
|
118
|
+
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
119
|
+
|
120
|
+
|
121
|
+
class ChatCompletionMessageContentImagePart(BaseModel):
|
122
|
+
type: Literal["image_url"]
|
123
|
+
image_url: ChatCompletionMessageContentImageURL
|
124
|
+
|
125
|
+
|
126
|
+
ChatCompletionMessageContentPart = Union[
|
127
|
+
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
|
128
|
+
]
|
129
|
+
|
130
|
+
|
131
|
+
class ChatCompletionMessageUserParam(BaseModel):
|
132
|
+
role: Literal["user"]
|
133
|
+
content: Union[str, List[ChatCompletionMessageContentPart]]
|
134
|
+
|
135
|
+
|
136
|
+
ChatCompletionMessageParam = Union[
|
137
|
+
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
|
138
|
+
]
|
139
|
+
|
140
|
+
|
141
|
+
class ResponseFormat(BaseModel):
|
142
|
+
# type must be "json_object" or "text"
|
143
|
+
type: Literal["text", "json_object"]
|
144
|
+
|
145
|
+
|
146
|
+
class ChatCompletionRequest(BaseModel):
|
147
|
+
# Ordered by official OpenAI API documentation
|
148
|
+
# https://platform.openai.com/docs/api-reference/chat/create
|
149
|
+
messages: List[ChatCompletionMessageParam]
|
150
|
+
model: str
|
151
|
+
frequency_penalty: Optional[float] = 0.0
|
152
|
+
logit_bias: Optional[Dict[str, float]] = None
|
153
|
+
logprobs: Optional[bool] = False
|
154
|
+
top_logprobs: Optional[int] = None
|
155
|
+
max_tokens: Optional[int] = 16
|
156
|
+
n: Optional[int] = 1
|
157
|
+
presence_penalty: Optional[float] = 0.0
|
158
|
+
response_format: Optional[ResponseFormat] = None
|
159
|
+
seed: Optional[int] = None
|
160
|
+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
161
|
+
stream: Optional[bool] = False
|
162
|
+
temperature: Optional[float] = 0.7
|
163
|
+
top_p: Optional[float] = 1.0
|
164
|
+
user: Optional[str] = None
|
165
|
+
|
166
|
+
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
167
|
+
regex: Optional[str] = None
|
168
|
+
|
169
|
+
|
170
|
+
class ChatMessage(BaseModel):
|
171
|
+
role: Optional[str] = None
|
172
|
+
content: Optional[str] = None
|
173
|
+
|
174
|
+
|
175
|
+
class ChatCompletionResponseChoice(BaseModel):
|
176
|
+
index: int
|
177
|
+
message: ChatMessage
|
178
|
+
logprobs: Optional[LogProbs] = None
|
179
|
+
finish_reason: Optional[str] = None
|
180
|
+
|
181
|
+
|
182
|
+
class ChatCompletionResponse(BaseModel):
|
183
|
+
id: str
|
184
|
+
object: str = "chat.completion"
|
185
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
186
|
+
model: str
|
187
|
+
choices: List[ChatCompletionResponseChoice]
|
188
|
+
usage: UsageInfo
|
189
|
+
|
190
|
+
|
191
|
+
class DeltaMessage(BaseModel):
|
192
|
+
role: Optional[str] = None
|
193
|
+
content: Optional[str] = None
|
194
|
+
|
195
|
+
|
196
|
+
class ChatCompletionResponseStreamChoice(BaseModel):
|
197
|
+
index: int
|
198
|
+
delta: DeltaMessage
|
199
|
+
logprobs: Optional[LogProbs] = None
|
200
|
+
finish_reason: Optional[str] = None
|
201
|
+
|
202
|
+
|
203
|
+
class ChatCompletionStreamResponse(BaseModel):
|
204
|
+
id: str
|
205
|
+
object: str = "chat.completion.chunk"
|
206
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
207
|
+
model: str
|
208
|
+
choices: List[ChatCompletionResponseStreamChoice]
|
sglang/srt/openai_protocol.py
CHANGED
@@ -7,6 +7,23 @@ from pydantic import BaseModel, Field
|
|
7
7
|
from typing_extensions import Literal
|
8
8
|
|
9
9
|
|
10
|
+
class ModelCard(BaseModel):
|
11
|
+
"""Model cards."""
|
12
|
+
|
13
|
+
id: str
|
14
|
+
object: str = "model"
|
15
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
16
|
+
owned_by: str = "sglang"
|
17
|
+
root: Optional[str] = None
|
18
|
+
|
19
|
+
|
20
|
+
class ModelList(BaseModel):
|
21
|
+
"""Model list consists of model cards."""
|
22
|
+
|
23
|
+
object: str = "list"
|
24
|
+
data: List[ModelCard] = []
|
25
|
+
|
26
|
+
|
10
27
|
class ErrorResponse(BaseModel):
|
11
28
|
object: str = "error"
|
12
29
|
message: str
|
sglang/srt/sampling_params.py
CHANGED
@@ -20,6 +20,7 @@ class SamplingParams:
|
|
20
20
|
spaces_between_special_tokens: bool = True,
|
21
21
|
dtype: Optional[str] = None,
|
22
22
|
regex: Optional[str] = None,
|
23
|
+
n: int = 1,
|
23
24
|
) -> None:
|
24
25
|
self.temperature = temperature
|
25
26
|
self.top_p = top_p
|
@@ -33,6 +34,7 @@ class SamplingParams:
|
|
33
34
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
34
35
|
self.dtype = dtype
|
35
36
|
self.regex = regex
|
37
|
+
self.n = n
|
36
38
|
|
37
39
|
# Process some special cases
|
38
40
|
if self.temperature < _SAMPLING_EPS:
|
sglang/srt/server.py
CHANGED
@@ -26,33 +26,33 @@ import uvloop
|
|
26
26
|
from fastapi import FastAPI, Request
|
27
27
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
28
28
|
|
29
|
-
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
29
|
+
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
30
30
|
from sglang.srt.constrained import disable_cache
|
31
31
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
32
32
|
from sglang.srt.managers.controller.manager_multi import (
|
33
33
|
start_controller_process as start_controller_process_multi,
|
34
34
|
)
|
35
|
+
from sglang.srt.managers.controller.manager_single import launch_tp_servers
|
35
36
|
from sglang.srt.managers.controller.manager_single import (
|
36
|
-
launch_tp_servers,
|
37
37
|
start_controller_process as start_controller_process_single,
|
38
38
|
)
|
39
39
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
40
40
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
41
41
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
42
|
-
from sglang.srt.
|
42
|
+
from sglang.srt.openai_api.adapter import (
|
43
43
|
load_chat_template_for_openai_api,
|
44
44
|
v1_chat_completions,
|
45
45
|
v1_completions,
|
46
46
|
)
|
47
|
-
from sglang.srt.
|
47
|
+
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
48
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
48
49
|
from sglang.srt.utils import (
|
49
50
|
API_KEY_HEADER_NAME,
|
50
51
|
APIKeyValidatorMiddleware,
|
51
52
|
allocate_init_ports,
|
52
53
|
assert_pkg_version,
|
53
54
|
enable_show_time_cost,
|
54
|
-
|
55
|
-
send_addrs_to_rank_0,
|
55
|
+
set_ulimit,
|
56
56
|
)
|
57
57
|
from sglang.utils import get_exception_traceback
|
58
58
|
|
@@ -64,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
64
64
|
app = FastAPI()
|
65
65
|
tokenizer_manager = None
|
66
66
|
|
67
|
+
# Put some args for easily access
|
68
|
+
global_server_args_dict = {}
|
69
|
+
|
67
70
|
|
68
71
|
@app.get("/health")
|
69
72
|
async def health() -> Response:
|
@@ -95,6 +98,7 @@ async def flush_cache():
|
|
95
98
|
|
96
99
|
|
97
100
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
101
|
+
"""Handle a generate request."""
|
98
102
|
if obj.stream:
|
99
103
|
|
100
104
|
async def stream_results():
|
@@ -135,7 +139,43 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
135
139
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
136
140
|
|
137
141
|
|
138
|
-
|
142
|
+
@app.get("/v1/models")
|
143
|
+
def available_models():
|
144
|
+
"""Show available models."""
|
145
|
+
model_names = [tokenizer_manager.model_path]
|
146
|
+
model_cards = []
|
147
|
+
for model_name in model_names:
|
148
|
+
model_cards.append(ModelCard(id=model_name, root=model_name))
|
149
|
+
return ModelList(data=model_cards)
|
150
|
+
|
151
|
+
|
152
|
+
def _set_global_server_args(server_args: ServerArgs):
|
153
|
+
global global_server_args_dict
|
154
|
+
global_server_args_dict = {
|
155
|
+
"disable_flashinfer": server_args.disable_flashinfer,
|
156
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
157
|
+
}
|
158
|
+
|
159
|
+
|
160
|
+
def _set_torch_compile_config():
|
161
|
+
# The following configurations are for torch compile optimizations
|
162
|
+
import torch._dynamo.config
|
163
|
+
import torch._inductor.config
|
164
|
+
|
165
|
+
torch._inductor.config.coordinate_descent_tuning = True
|
166
|
+
torch._inductor.config.triton.unique_kernel_names = True
|
167
|
+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
168
|
+
|
169
|
+
# FIXME: tmp workaround
|
170
|
+
torch._dynamo.config.accumulated_cache_size_limit = 256
|
171
|
+
|
172
|
+
|
173
|
+
def launch_server(
|
174
|
+
server_args: ServerArgs,
|
175
|
+
model_overide_args: Optional[dict] = None,
|
176
|
+
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
177
|
+
):
|
178
|
+
"""Launch an HTTP server."""
|
139
179
|
global tokenizer_manager
|
140
180
|
|
141
181
|
logging.basicConfig(
|
@@ -146,6 +186,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
146
186
|
# Set global environments
|
147
187
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
148
188
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
189
|
+
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
190
|
+
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
191
|
+
set_ulimit()
|
149
192
|
if server_args.show_time_cost:
|
150
193
|
enable_show_time_cost()
|
151
194
|
if server_args.disable_disk_cache:
|
@@ -153,7 +196,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
153
196
|
if not server_args.disable_flashinfer:
|
154
197
|
assert_pkg_version(
|
155
198
|
"flashinfer",
|
156
|
-
"0.
|
199
|
+
"0.1.1",
|
157
200
|
"Please uninstall the old version and "
|
158
201
|
"reinstall the latest version by following the instructions "
|
159
202
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -162,63 +205,65 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
162
205
|
# TODO: replace this with huggingface transformers template
|
163
206
|
load_chat_template_for_openai_api(server_args.chat_template)
|
164
207
|
|
208
|
+
if server_args.enable_torch_compile:
|
209
|
+
_set_torch_compile_config()
|
210
|
+
|
211
|
+
_set_global_server_args(server_args)
|
212
|
+
|
165
213
|
# Allocate ports
|
166
|
-
assert server_args.tp_size % server_args.nnodes == 0
|
167
|
-
tp_size_local = server_args.tp_size // server_args.nnodes
|
168
214
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
169
215
|
server_args.port,
|
170
216
|
server_args.additional_ports,
|
171
|
-
tp_size_local,
|
172
217
|
server_args.dp_size,
|
173
218
|
)
|
174
|
-
|
175
219
|
ports = server_args.additional_ports
|
176
|
-
model_port_args = []
|
177
|
-
for i in range(server_args.dp_size):
|
178
|
-
model_port_args.append(
|
179
|
-
ModelPortArgs(
|
180
|
-
nccl_port=ports[3 + i * (tp_size_local + 1)],
|
181
|
-
model_tp_ips=[None] * tp_size_local,
|
182
|
-
model_tp_ports=ports[
|
183
|
-
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
|
184
|
-
],
|
185
|
-
)
|
186
|
-
)
|
187
220
|
port_args = PortArgs(
|
188
221
|
tokenizer_port=ports[0],
|
189
|
-
|
222
|
+
controller_port=ports[1],
|
190
223
|
detokenizer_port=ports[2],
|
191
|
-
|
224
|
+
nccl_ports=ports[3:],
|
192
225
|
)
|
226
|
+
logger.info(f"{server_args=}")
|
193
227
|
|
194
|
-
# Handle multi-node
|
228
|
+
# Handle multi-node tensor parallelism
|
195
229
|
if server_args.nnodes > 1:
|
196
230
|
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
197
231
|
|
198
232
|
if server_args.node_rank != 0:
|
199
233
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
200
|
-
gpu_ids = [
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
234
|
+
gpu_ids = [
|
235
|
+
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
|
236
|
+
]
|
237
|
+
tp_rank_range = list(
|
238
|
+
range(
|
239
|
+
server_args.node_rank * tp_size_local,
|
240
|
+
(server_args.node_rank + 1) * tp_size_local,
|
241
|
+
)
|
242
|
+
)
|
243
|
+
procs = launch_tp_servers(
|
244
|
+
gpu_ids,
|
245
|
+
tp_rank_range,
|
246
|
+
server_args,
|
247
|
+
ports[3],
|
248
|
+
model_overide_args,
|
249
|
+
)
|
205
250
|
while True:
|
206
251
|
pass
|
207
252
|
|
208
253
|
# Launch processes
|
209
254
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
210
|
-
|
255
|
+
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
211
256
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
212
257
|
|
213
258
|
if server_args.dp_size == 1:
|
214
259
|
start_process = start_controller_process_single
|
215
260
|
else:
|
216
261
|
start_process = start_controller_process_multi
|
217
|
-
|
262
|
+
proc_controller = mp.Process(
|
218
263
|
target=start_process,
|
219
|
-
args=(server_args, port_args,
|
264
|
+
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
220
265
|
)
|
221
|
-
|
266
|
+
proc_controller.start()
|
222
267
|
proc_detoken = mp.Process(
|
223
268
|
target=start_detokenizer_process,
|
224
269
|
args=(
|
@@ -230,68 +275,30 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
230
275
|
proc_detoken.start()
|
231
276
|
|
232
277
|
# Wait for the model to finish loading
|
233
|
-
|
278
|
+
controller_init_state = pipe_controller_reader.recv()
|
234
279
|
detoken_init_state = pipe_detoken_reader.recv()
|
235
280
|
|
236
|
-
if
|
237
|
-
|
281
|
+
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
282
|
+
proc_controller.kill()
|
238
283
|
proc_detoken.kill()
|
239
284
|
print(
|
240
|
-
f"Initialization failed.
|
285
|
+
f"Initialization failed. controller_init_state: {controller_init_state}",
|
286
|
+
flush=True,
|
241
287
|
)
|
242
288
|
print(
|
243
289
|
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
244
290
|
flush=True,
|
245
291
|
)
|
246
292
|
sys.exit(1)
|
247
|
-
assert
|
293
|
+
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
248
294
|
|
249
295
|
if server_args.api_key and server_args.api_key != "":
|
250
296
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
251
297
|
|
252
298
|
# Send a warmup request
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
if server_args.api_key:
|
257
|
-
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
258
|
-
|
259
|
-
# Wait until the server is launched
|
260
|
-
for _ in range(120):
|
261
|
-
time.sleep(0.5)
|
262
|
-
try:
|
263
|
-
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
264
|
-
break
|
265
|
-
except requests.exceptions.RequestException:
|
266
|
-
pass
|
267
|
-
|
268
|
-
# Send a warmup request
|
269
|
-
try:
|
270
|
-
for _ in range(server_args.dp_size):
|
271
|
-
res = requests.post(
|
272
|
-
url + "/generate",
|
273
|
-
json={
|
274
|
-
"text": "The capital city of France is",
|
275
|
-
"sampling_params": {
|
276
|
-
"temperature": 0,
|
277
|
-
"max_new_tokens": 8,
|
278
|
-
},
|
279
|
-
},
|
280
|
-
headers=headers,
|
281
|
-
timeout=600,
|
282
|
-
)
|
283
|
-
assert res.status_code == 200
|
284
|
-
except Exception as e:
|
285
|
-
if pipe_finish_writer is not None:
|
286
|
-
pipe_finish_writer.send(get_exception_traceback())
|
287
|
-
print(f"Initialization failed. warmup error: {e}", flush=True)
|
288
|
-
raise e
|
289
|
-
|
290
|
-
logger.info("The server is fired up and ready to roll!")
|
291
|
-
if pipe_finish_writer is not None:
|
292
|
-
pipe_finish_writer.send("init ok")
|
293
|
-
|
294
|
-
t = threading.Thread(target=_wait_and_warmup)
|
299
|
+
t = threading.Thread(
|
300
|
+
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
301
|
+
)
|
295
302
|
t.start()
|
296
303
|
|
297
304
|
# Listen for requests
|
@@ -308,6 +315,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
308
315
|
t.join()
|
309
316
|
|
310
317
|
|
318
|
+
def _wait_and_warmup(server_args, pipe_finish_writer):
|
319
|
+
headers = {}
|
320
|
+
url = server_args.url()
|
321
|
+
if server_args.api_key:
|
322
|
+
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
323
|
+
|
324
|
+
# Wait until the server is launched
|
325
|
+
for _ in range(120):
|
326
|
+
time.sleep(0.5)
|
327
|
+
try:
|
328
|
+
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
329
|
+
break
|
330
|
+
except requests.exceptions.RequestException:
|
331
|
+
pass
|
332
|
+
|
333
|
+
# Send a warmup request
|
334
|
+
try:
|
335
|
+
for _ in range(server_args.dp_size):
|
336
|
+
res = requests.post(
|
337
|
+
url + "/generate",
|
338
|
+
json={
|
339
|
+
"text": "The capital city of France is",
|
340
|
+
"sampling_params": {
|
341
|
+
"temperature": 0,
|
342
|
+
"max_new_tokens": 8,
|
343
|
+
},
|
344
|
+
},
|
345
|
+
headers=headers,
|
346
|
+
timeout=600,
|
347
|
+
)
|
348
|
+
assert res.status_code == 200
|
349
|
+
except Exception as e:
|
350
|
+
if pipe_finish_writer is not None:
|
351
|
+
pipe_finish_writer.send(get_exception_traceback())
|
352
|
+
print(f"Initialization failed. warmup error: {e}", flush=True)
|
353
|
+
raise e
|
354
|
+
|
355
|
+
logger.info("The server is fired up and ready to roll!")
|
356
|
+
if pipe_finish_writer is not None:
|
357
|
+
pipe_finish_writer.send("init ok")
|
358
|
+
|
359
|
+
|
311
360
|
class Runtime:
|
312
361
|
"""
|
313
362
|
A wrapper for the server.
|
@@ -329,7 +378,6 @@ class Runtime:
|
|
329
378
|
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
330
379
|
self.server_args.port,
|
331
380
|
self.server_args.additional_ports,
|
332
|
-
self.server_args.tp_size,
|
333
381
|
self.server_args.dp_size,
|
334
382
|
)
|
335
383
|
|
@@ -342,7 +390,7 @@ class Runtime:
|
|
342
390
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
343
391
|
proc = mp.Process(
|
344
392
|
target=launch_server,
|
345
|
-
args=(self.server_args,
|
393
|
+
args=(self.server_args, model_overide_args, pipe_writer),
|
346
394
|
)
|
347
395
|
proc.start()
|
348
396
|
pipe_writer.close()
|