sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +3 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +8 -1
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +17 -2
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +1 -1
- sglang/srt/hf_transformers_utils.py +75 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +15 -11
- sglang/srt/managers/router/infer_batch.py +103 -59
- sglang/srt/managers/router/manager.py +1 -1
- sglang/srt/managers/router/model_rpc.py +175 -122
- sglang/srt/managers/router/model_runner.py +91 -104
- sglang/srt/managers/router/radix_cache.py +7 -1
- sglang/srt/managers/router/scheduler.py +6 -6
- sglang/srt/managers/tokenizer_manager.py +152 -89
- sglang/srt/model_config.py +4 -5
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +8 -15
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +19 -15
- sglang/srt/models/llava.py +84 -20
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +248 -118
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +77 -42
- sglang/srt/server_args.py +51 -6
- sglang/srt/utils.py +124 -66
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +22 -4
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api_adapter.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1
1
|
"""Conversion between OpenAI APIs and native SRT APIs"""
|
2
|
+
|
3
|
+
import asyncio
|
2
4
|
import json
|
3
5
|
import os
|
6
|
+
from http import HTTPStatus
|
4
7
|
|
5
|
-
from fastapi import
|
6
|
-
from fastapi.responses import StreamingResponse
|
8
|
+
from fastapi import Request
|
9
|
+
from fastapi.responses import StreamingResponse, JSONResponse
|
7
10
|
|
8
11
|
from sglang.srt.conversation import (
|
9
12
|
Conversation,
|
@@ -26,14 +29,36 @@ from sglang.srt.openai_protocol import (
|
|
26
29
|
CompletionResponseStreamChoice,
|
27
30
|
CompletionStreamResponse,
|
28
31
|
DeltaMessage,
|
32
|
+
ErrorResponse,
|
29
33
|
LogProbs,
|
30
34
|
UsageInfo,
|
31
35
|
)
|
32
|
-
from sglang.srt.utils import jsonify_pydantic_model
|
33
|
-
|
34
36
|
|
35
37
|
chat_template_name = None
|
36
38
|
|
39
|
+
|
40
|
+
def create_error_response(
|
41
|
+
message: str,
|
42
|
+
err_type: str = "BadRequestError",
|
43
|
+
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
|
44
|
+
error = ErrorResponse(message=message,
|
45
|
+
type=err_type,
|
46
|
+
code=status_code.value)
|
47
|
+
return JSONResponse(content=error.model_dump(),
|
48
|
+
status_code=error.code)
|
49
|
+
|
50
|
+
|
51
|
+
def create_streaming_error_response(
|
52
|
+
message: str,
|
53
|
+
err_type: str = "BadRequestError",
|
54
|
+
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
55
|
+
error = ErrorResponse(message=message,
|
56
|
+
type=err_type,
|
57
|
+
code=status_code.value)
|
58
|
+
json_str = json.dumps({"error": error.model_dump()})
|
59
|
+
return json_str
|
60
|
+
|
61
|
+
|
37
62
|
def load_chat_template_for_openai_api(chat_template_arg):
|
38
63
|
global chat_template_name
|
39
64
|
|
@@ -73,8 +98,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
73
98
|
request_json = await raw_request.json()
|
74
99
|
request = CompletionRequest(**request_json)
|
75
100
|
|
76
|
-
|
77
|
-
|
101
|
+
if request.n != 1:
|
102
|
+
return create_error_response("n != 1 is not supported")
|
78
103
|
|
79
104
|
adapted_request = GenerateReqInput(
|
80
105
|
text=request.prompt,
|
@@ -92,79 +117,88 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
92
117
|
return_text_in_logprobs=True,
|
93
118
|
stream=request.stream,
|
94
119
|
)
|
95
|
-
adapted_request.post_init()
|
96
120
|
|
97
121
|
if adapted_request.stream:
|
98
122
|
|
99
123
|
async def generate_stream_resp():
|
100
124
|
stream_buffer = ""
|
101
125
|
n_prev_token = 0
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
if
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
126
|
+
try:
|
127
|
+
async for content in tokenizer_manager.generate_request(
|
128
|
+
adapted_request, raw_request):
|
129
|
+
text = content["text"]
|
130
|
+
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
131
|
+
completion_tokens = content["meta_info"]["completion_tokens"]
|
132
|
+
|
133
|
+
if not stream_buffer: # The first chunk
|
134
|
+
if request.echo:
|
135
|
+
# Prepend prompt in response text.
|
136
|
+
text = request.prompt + text
|
137
|
+
|
138
|
+
if request.logprobs:
|
139
|
+
# The first chunk and echo is enabled.
|
140
|
+
if not stream_buffer and request.echo:
|
141
|
+
prefill_token_logprobs = content["meta_info"][
|
142
|
+
"prefill_token_logprobs"
|
143
|
+
]
|
144
|
+
prefill_top_logprobs = content["meta_info"][
|
145
|
+
"prefill_top_logprobs"
|
146
|
+
]
|
147
|
+
else:
|
148
|
+
prefill_token_logprobs = None
|
149
|
+
prefill_top_logprobs = None
|
150
|
+
|
151
|
+
logprobs = to_openai_style_logprobs(
|
152
|
+
prefill_token_logprobs=prefill_token_logprobs,
|
153
|
+
prefill_top_logprobs=prefill_top_logprobs,
|
154
|
+
decode_token_logprobs=content["meta_info"][
|
155
|
+
"decode_token_logprobs"
|
156
|
+
][n_prev_token:],
|
157
|
+
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
|
158
|
+
n_prev_token:
|
159
|
+
],
|
160
|
+
)
|
161
|
+
|
162
|
+
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
|
121
163
|
else:
|
122
|
-
|
123
|
-
prefill_top_logprobs = None
|
124
|
-
|
125
|
-
logprobs = to_openai_style_logprobs(
|
126
|
-
prefill_token_logprobs=prefill_token_logprobs,
|
127
|
-
prefill_top_logprobs=prefill_top_logprobs,
|
128
|
-
decode_token_logprobs=content["meta_info"][
|
129
|
-
"decode_token_logprobs"
|
130
|
-
][n_prev_token:],
|
131
|
-
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
|
132
|
-
n_prev_token:
|
133
|
-
],
|
134
|
-
)
|
164
|
+
logprobs = None
|
135
165
|
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
166
|
+
delta = text[len(stream_buffer) :]
|
167
|
+
stream_buffer = content["text"]
|
168
|
+
choice_data = CompletionResponseStreamChoice(
|
169
|
+
index=0,
|
170
|
+
text=delta,
|
171
|
+
logprobs=logprobs,
|
172
|
+
finish_reason=content["meta_info"]["finish_reason"],
|
173
|
+
)
|
174
|
+
chunk = CompletionStreamResponse(
|
175
|
+
id=content["meta_info"]["id"],
|
176
|
+
object="text_completion",
|
177
|
+
choices=[choice_data],
|
178
|
+
model=request.model,
|
179
|
+
usage=UsageInfo(
|
180
|
+
prompt_tokens=prompt_tokens,
|
181
|
+
completion_tokens=completion_tokens,
|
182
|
+
total_tokens=prompt_tokens + completion_tokens,
|
183
|
+
),
|
184
|
+
)
|
185
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
186
|
+
except ValueError as e:
|
187
|
+
error = create_streaming_error_response(str(e))
|
188
|
+
yield f"data: {error}\n\n"
|
160
189
|
yield "data: [DONE]\n\n"
|
161
190
|
|
162
|
-
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream"
|
191
|
+
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
|
192
|
+
background=tokenizer_manager.create_abort_task(adapted_request))
|
163
193
|
|
164
194
|
# Non-streaming response.
|
165
|
-
|
166
|
-
|
195
|
+
try:
|
196
|
+
ret = await tokenizer_manager.generate_request(
|
197
|
+
adapted_request, raw_request).__anext__()
|
198
|
+
except ValueError as e:
|
199
|
+
return create_error_response(str(e))
|
167
200
|
|
201
|
+
ret = ret[0] if isinstance(ret, list) else ret
|
168
202
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
169
203
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
170
204
|
text = ret["text"]
|
@@ -192,7 +226,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
192
226
|
index=0,
|
193
227
|
text=text,
|
194
228
|
logprobs=logprobs,
|
195
|
-
finish_reason=
|
229
|
+
finish_reason=ret["meta_info"]["finish_reason"],
|
196
230
|
)
|
197
231
|
response = CompletionResponse(
|
198
232
|
id=ret["meta_info"]["id"],
|
@@ -211,8 +245,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
211
245
|
request_json = await raw_request.json()
|
212
246
|
request = ChatCompletionRequest(**request_json)
|
213
247
|
|
214
|
-
|
215
|
-
|
248
|
+
if request.n != 1:
|
249
|
+
return create_error_response("n != 1 is not supported")
|
216
250
|
|
217
251
|
# Prep the data needed for the underlying GenerateReqInput:
|
218
252
|
# - prompt: The full prompt string.
|
@@ -257,7 +291,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
257
291
|
},
|
258
292
|
stream=request.stream,
|
259
293
|
)
|
260
|
-
adapted_request.post_init()
|
261
294
|
|
262
295
|
if adapted_request.stream:
|
263
296
|
|
@@ -265,46 +298,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
265
298
|
is_first = True
|
266
299
|
|
267
300
|
stream_buffer = ""
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
301
|
+
try:
|
302
|
+
async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
|
303
|
+
if is_first:
|
304
|
+
# First chunk with role
|
305
|
+
is_first = False
|
306
|
+
choice_data = ChatCompletionResponseStreamChoice(
|
307
|
+
index=0,
|
308
|
+
delta=DeltaMessage(role="assistant"),
|
309
|
+
finish_reason=content["meta_info"]["finish_reason"],
|
310
|
+
)
|
311
|
+
chunk = ChatCompletionStreamResponse(
|
312
|
+
id=content["meta_info"]["id"],
|
313
|
+
choices=[choice_data],
|
314
|
+
model=request.model,
|
315
|
+
)
|
316
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
317
|
+
|
318
|
+
text = content["text"]
|
319
|
+
delta = text[len(stream_buffer) :]
|
320
|
+
stream_buffer = text
|
272
321
|
choice_data = ChatCompletionResponseStreamChoice(
|
273
322
|
index=0,
|
274
|
-
delta=DeltaMessage(
|
275
|
-
finish_reason=
|
323
|
+
delta=DeltaMessage(content=delta),
|
324
|
+
finish_reason=content["meta_info"]["finish_reason"],
|
276
325
|
)
|
277
326
|
chunk = ChatCompletionStreamResponse(
|
278
327
|
id=content["meta_info"]["id"],
|
279
328
|
choices=[choice_data],
|
280
329
|
model=request.model,
|
281
330
|
)
|
282
|
-
yield f"data: {
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
stream_buffer = text
|
287
|
-
choice_data = ChatCompletionResponseStreamChoice(
|
288
|
-
index=0, delta=DeltaMessage(content=delta), finish_reason=None
|
289
|
-
)
|
290
|
-
chunk = ChatCompletionStreamResponse(
|
291
|
-
id=content["meta_info"]["id"],
|
292
|
-
choices=[choice_data],
|
293
|
-
model=request.model,
|
294
|
-
)
|
295
|
-
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
331
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
332
|
+
except ValueError as e:
|
333
|
+
error = create_streaming_error_response(str(e))
|
334
|
+
yield f"data: {error}\n\n"
|
296
335
|
yield "data: [DONE]\n\n"
|
297
336
|
|
298
|
-
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream"
|
337
|
+
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
|
338
|
+
background=tokenizer_manager.create_abort_task(adapted_request))
|
299
339
|
|
300
340
|
# Non-streaming response.
|
301
|
-
|
341
|
+
try:
|
342
|
+
ret = await tokenizer_manager.generate_request(
|
343
|
+
adapted_request, raw_request).__anext__()
|
344
|
+
except ValueError as e:
|
345
|
+
return create_error_response(str(e))
|
346
|
+
|
302
347
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
303
348
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
304
349
|
choice_data = ChatCompletionResponseChoice(
|
305
350
|
index=0,
|
306
351
|
message=ChatMessage(role="assistant", content=ret["text"]),
|
307
|
-
finish_reason=
|
352
|
+
finish_reason=ret["meta_info"]["finish_reason"],
|
308
353
|
)
|
309
354
|
response = ChatCompletionResponse(
|
310
355
|
id=ret["meta_info"]["id"],
|
@@ -332,7 +377,7 @@ def to_openai_style_logprobs(
|
|
332
377
|
ret_logprobs.tokens.append(token_text)
|
333
378
|
ret_logprobs.token_logprobs.append(logprob)
|
334
379
|
|
335
|
-
# Not
|
380
|
+
# Not supported yet
|
336
381
|
ret_logprobs.text_offset.append(-1)
|
337
382
|
|
338
383
|
def append_top_logprobs(top_logprobs):
|
@@ -353,4 +398,4 @@ def to_openai_style_logprobs(
|
|
353
398
|
if decode_top_logprobs is not None:
|
354
399
|
append_top_logprobs(decode_top_logprobs)
|
355
400
|
|
356
|
-
return ret_logprobs
|
401
|
+
return ret_logprobs
|
sglang/srt/openai_protocol.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
"""pydantic models for OpenAI API protocol"""
|
2
|
+
|
2
3
|
import time
|
3
4
|
from typing import Dict, List, Optional, Union
|
4
5
|
|
@@ -6,6 +7,14 @@ from pydantic import BaseModel, Field
|
|
6
7
|
from typing_extensions import Literal
|
7
8
|
|
8
9
|
|
10
|
+
class ErrorResponse(BaseModel):
|
11
|
+
object: str = "error"
|
12
|
+
message: str
|
13
|
+
type: str
|
14
|
+
param: Optional[str] = None
|
15
|
+
code: int
|
16
|
+
|
17
|
+
|
9
18
|
class LogProbs(BaseModel):
|
10
19
|
text_offset: List[int] = Field(default_factory=list)
|
11
20
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
@@ -178,4 +187,4 @@ class ChatCompletionStreamResponse(BaseModel):
|
|
178
187
|
object: str = "chat.completion.chunk"
|
179
188
|
created: int = Field(default_factory=lambda: int(time.time()))
|
180
189
|
model: str
|
181
|
-
choices: List[ChatCompletionResponseStreamChoice]
|
190
|
+
choices: List[ChatCompletionResponseStreamChoice]
|
sglang/srt/server.py
CHANGED
@@ -9,7 +9,8 @@ import os
|
|
9
9
|
import sys
|
10
10
|
import threading
|
11
11
|
import time
|
12
|
-
from
|
12
|
+
from http import HTTPStatus
|
13
|
+
from typing import Optional
|
13
14
|
|
14
15
|
# Fix a bug of Python threading
|
15
16
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -27,19 +28,23 @@ from sglang.srt.constrained import disable_cache
|
|
27
28
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
28
29
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
29
30
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
30
|
-
from sglang.srt.managers.
|
31
|
+
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
|
32
|
+
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
|
31
33
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
32
34
|
from sglang.srt.openai_api_adapter import (
|
33
|
-
|
34
|
-
|
35
|
+
load_chat_template_for_openai_api,
|
36
|
+
v1_chat_completions,
|
37
|
+
v1_completions,
|
38
|
+
)
|
39
|
+
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
35
40
|
from sglang.srt.utils import (
|
41
|
+
API_KEY_HEADER_NAME,
|
42
|
+
APIKeyValidatorMiddleware,
|
36
43
|
allocate_init_ports,
|
37
44
|
assert_pkg_version,
|
38
45
|
enable_show_time_cost,
|
39
|
-
get_exception_traceback,
|
40
|
-
API_KEY_HEADER_NAME,
|
41
|
-
APIKeyValidatorMiddleware
|
42
46
|
)
|
47
|
+
from sglang.utils import get_exception_traceback
|
43
48
|
|
44
49
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
45
50
|
|
@@ -69,7 +74,7 @@ async def get_server_args():
|
|
69
74
|
|
70
75
|
@app.get("/flush_cache")
|
71
76
|
async def flush_cache():
|
72
|
-
|
77
|
+
tokenizer_manager.flush_cache()
|
73
78
|
return Response(
|
74
79
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
75
80
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
@@ -77,24 +82,32 @@ async def flush_cache():
|
|
77
82
|
)
|
78
83
|
|
79
84
|
|
80
|
-
|
81
|
-
async def generate_request(obj: GenerateReqInput):
|
82
|
-
obj.post_init()
|
83
|
-
|
85
|
+
async def generate_request(obj: GenerateReqInput, request: Request):
|
84
86
|
if obj.stream:
|
85
87
|
|
86
88
|
async def stream_results():
|
87
|
-
|
89
|
+
try:
|
90
|
+
async for out in tokenizer_manager.generate_request(obj, request):
|
91
|
+
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
92
|
+
except ValueError as e:
|
93
|
+
out = {"error": {"message": str(e)}}
|
88
94
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
89
95
|
yield "data: [DONE]\n\n"
|
90
96
|
|
91
|
-
return StreamingResponse(stream_results(), media_type="text/event-stream"
|
97
|
+
return StreamingResponse(stream_results(), media_type="text/event-stream",
|
98
|
+
background=tokenizer_manager.create_abort_task(obj))
|
99
|
+
else:
|
100
|
+
try:
|
101
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
102
|
+
return ret
|
103
|
+
except ValueError as e:
|
104
|
+
return JSONResponse(
|
105
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
106
|
+
)
|
107
|
+
|
92
108
|
|
93
|
-
|
94
|
-
|
95
|
-
return ret
|
96
|
-
except ValueError as e:
|
97
|
-
return JSONResponse({"error": str(e)}, status_code=400)
|
109
|
+
app.post("/generate")(generate_request)
|
110
|
+
app.put("/generate")(generate_request)
|
98
111
|
|
99
112
|
|
100
113
|
@app.post("/v1/completions")
|
@@ -129,14 +142,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
129
142
|
|
130
143
|
# Allocate ports
|
131
144
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
132
|
-
server_args.port,
|
145
|
+
server_args.port,
|
146
|
+
server_args.additional_ports,
|
147
|
+
server_args.tp_size,
|
148
|
+
server_args.dp_size,
|
133
149
|
)
|
150
|
+
|
151
|
+
# Init local models port args
|
152
|
+
ports = server_args.additional_ports
|
153
|
+
tp = server_args.tp_size
|
154
|
+
model_port_args = []
|
155
|
+
for i in range(server_args.dp_size):
|
156
|
+
model_port_args.append(
|
157
|
+
ModelPortArgs(
|
158
|
+
nccl_port=ports[3 + i * (tp + 1)],
|
159
|
+
model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
|
160
|
+
)
|
161
|
+
)
|
134
162
|
port_args = PortArgs(
|
135
|
-
tokenizer_port=
|
136
|
-
router_port=
|
137
|
-
detokenizer_port=
|
138
|
-
|
139
|
-
model_rpc_ports=server_args.additional_ports[4:],
|
163
|
+
tokenizer_port=ports[0],
|
164
|
+
router_port=ports[1],
|
165
|
+
detokenizer_port=ports[2],
|
166
|
+
model_port_args=model_port_args,
|
140
167
|
)
|
141
168
|
|
142
169
|
# Launch processes
|
@@ -144,8 +171,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
144
171
|
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
145
172
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
146
173
|
|
174
|
+
if server_args.dp_size == 1:
|
175
|
+
start_process = start_controller_process_single
|
176
|
+
else:
|
177
|
+
start_process = start_controller_process_multi
|
147
178
|
proc_router = mp.Process(
|
148
|
-
target=
|
179
|
+
target=start_process,
|
149
180
|
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
150
181
|
)
|
151
182
|
proc_router.start()
|
@@ -179,6 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
179
210
|
if server_args.api_key and server_args.api_key != "":
|
180
211
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
181
212
|
|
213
|
+
# Send a warmup request
|
182
214
|
def _wait_and_warmup():
|
183
215
|
headers = {}
|
184
216
|
url = server_args.url()
|
@@ -190,27 +222,27 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
190
222
|
time.sleep(0.5)
|
191
223
|
try:
|
192
224
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
193
|
-
success = True # Set flag to True if request succeeds
|
194
225
|
break
|
195
226
|
except requests.exceptions.RequestException as e:
|
196
227
|
pass
|
197
228
|
|
198
229
|
# Send a warmup request
|
199
230
|
try:
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
"
|
206
|
-
|
231
|
+
for _ in range(server_args.dp_size):
|
232
|
+
res = requests.post(
|
233
|
+
url + "/generate",
|
234
|
+
json={
|
235
|
+
"text": "The capital city of France is",
|
236
|
+
"sampling_params": {
|
237
|
+
"temperature": 0,
|
238
|
+
"max_new_tokens": 16,
|
239
|
+
},
|
207
240
|
},
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
except Exception as e:
|
241
|
+
headers=headers,
|
242
|
+
timeout=600,
|
243
|
+
)
|
244
|
+
assert res.status_code == 200
|
245
|
+
except Exception:
|
214
246
|
if pipe_finish_writer is not None:
|
215
247
|
pipe_finish_writer.send(get_exception_traceback())
|
216
248
|
print(f"Initialization failed. warmup error: {e}")
|
@@ -221,6 +253,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
221
253
|
|
222
254
|
t = threading.Thread(target=_wait_and_warmup)
|
223
255
|
t.start()
|
256
|
+
|
257
|
+
# Listen for requests
|
224
258
|
try:
|
225
259
|
uvicorn.run(
|
226
260
|
app,
|
@@ -237,19 +271,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
237
271
|
class Runtime:
|
238
272
|
def __init__(
|
239
273
|
self,
|
240
|
-
|
274
|
+
log_level: str = "error",
|
241
275
|
model_overide_args: Optional[dict] = None,
|
242
276
|
*args,
|
243
277
|
**kwargs,
|
244
278
|
):
|
245
279
|
"""See the arguments in server_args.py::ServerArgs"""
|
246
|
-
self.server_args = ServerArgs(*args, log_level=
|
280
|
+
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
247
281
|
|
248
282
|
# Pre-allocate ports
|
249
283
|
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
250
284
|
self.server_args.port,
|
251
285
|
self.server_args.additional_ports,
|
252
286
|
self.server_args.tp_size,
|
287
|
+
self.server_args.dp_size,
|
253
288
|
)
|
254
289
|
|
255
290
|
self.url = self.server_args.url()
|