sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api_adapter.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1
1
|
"""Conversion between OpenAI APIs and native SRT APIs"""
|
2
|
+
|
3
|
+
import asyncio
|
2
4
|
import json
|
3
5
|
import os
|
6
|
+
from http import HTTPStatus
|
4
7
|
|
5
|
-
from fastapi import
|
6
|
-
from fastapi.responses import StreamingResponse
|
8
|
+
from fastapi import Request
|
9
|
+
from fastapi.responses import JSONResponse, StreamingResponse
|
7
10
|
|
8
11
|
from sglang.srt.conversation import (
|
9
12
|
Conversation,
|
@@ -26,14 +29,33 @@ from sglang.srt.openai_protocol import (
|
|
26
29
|
CompletionResponseStreamChoice,
|
27
30
|
CompletionStreamResponse,
|
28
31
|
DeltaMessage,
|
32
|
+
ErrorResponse,
|
29
33
|
LogProbs,
|
30
34
|
UsageInfo,
|
31
35
|
)
|
32
|
-
from sglang.srt.utils import jsonify_pydantic_model
|
33
|
-
|
34
36
|
|
35
37
|
chat_template_name = None
|
36
38
|
|
39
|
+
|
40
|
+
def create_error_response(
|
41
|
+
message: str,
|
42
|
+
err_type: str = "BadRequestError",
|
43
|
+
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
44
|
+
):
|
45
|
+
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
46
|
+
return JSONResponse(content=error.model_dump(), status_code=error.code)
|
47
|
+
|
48
|
+
|
49
|
+
def create_streaming_error_response(
|
50
|
+
message: str,
|
51
|
+
err_type: str = "BadRequestError",
|
52
|
+
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
53
|
+
) -> str:
|
54
|
+
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
55
|
+
json_str = json.dumps({"error": error.model_dump()})
|
56
|
+
return json_str
|
57
|
+
|
58
|
+
|
37
59
|
def load_chat_template_for_openai_api(chat_template_arg):
|
38
60
|
global chat_template_name
|
39
61
|
|
@@ -73,8 +95,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
73
95
|
request_json = await raw_request.json()
|
74
96
|
request = CompletionRequest(**request_json)
|
75
97
|
|
76
|
-
|
77
|
-
|
98
|
+
if request.n != 1:
|
99
|
+
return create_error_response("n != 1 is not supported")
|
78
100
|
|
79
101
|
adapted_request = GenerateReqInput(
|
80
102
|
text=request.prompt,
|
@@ -92,79 +114,95 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
92
114
|
return_text_in_logprobs=True,
|
93
115
|
stream=request.stream,
|
94
116
|
)
|
95
|
-
adapted_request.post_init()
|
96
117
|
|
97
118
|
if adapted_request.stream:
|
98
119
|
|
99
120
|
async def generate_stream_resp():
|
100
121
|
stream_buffer = ""
|
101
122
|
n_prev_token = 0
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
123
|
+
try:
|
124
|
+
async for content in tokenizer_manager.generate_request(
|
125
|
+
adapted_request, raw_request
|
126
|
+
):
|
127
|
+
text = content["text"]
|
128
|
+
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
129
|
+
completion_tokens = content["meta_info"]["completion_tokens"]
|
130
|
+
|
131
|
+
if not stream_buffer: # The first chunk
|
132
|
+
if request.echo:
|
133
|
+
# Prepend prompt in response text.
|
134
|
+
text = request.prompt + text
|
135
|
+
|
136
|
+
if request.logprobs:
|
137
|
+
# The first chunk and echo is enabled.
|
138
|
+
if not stream_buffer and request.echo:
|
139
|
+
prefill_token_logprobs = content["meta_info"][
|
140
|
+
"prefill_token_logprobs"
|
141
|
+
]
|
142
|
+
prefill_top_logprobs = content["meta_info"][
|
143
|
+
"prefill_top_logprobs"
|
144
|
+
]
|
145
|
+
else:
|
146
|
+
prefill_token_logprobs = None
|
147
|
+
prefill_top_logprobs = None
|
148
|
+
|
149
|
+
logprobs = to_openai_style_logprobs(
|
150
|
+
prefill_token_logprobs=prefill_token_logprobs,
|
151
|
+
prefill_top_logprobs=prefill_top_logprobs,
|
152
|
+
decode_token_logprobs=content["meta_info"][
|
153
|
+
"decode_token_logprobs"
|
154
|
+
][n_prev_token:],
|
155
|
+
decode_top_logprobs=content["meta_info"][
|
156
|
+
"decode_top_logprobs"
|
157
|
+
][n_prev_token:],
|
158
|
+
)
|
159
|
+
|
160
|
+
n_prev_token = len(
|
161
|
+
content["meta_info"]["decode_token_logprobs"]
|
162
|
+
)
|
121
163
|
else:
|
122
|
-
|
123
|
-
prefill_top_logprobs = None
|
124
|
-
|
125
|
-
logprobs = to_openai_style_logprobs(
|
126
|
-
prefill_token_logprobs=prefill_token_logprobs,
|
127
|
-
prefill_top_logprobs=prefill_top_logprobs,
|
128
|
-
decode_token_logprobs=content["meta_info"][
|
129
|
-
"decode_token_logprobs"
|
130
|
-
][n_prev_token:],
|
131
|
-
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
|
132
|
-
n_prev_token:
|
133
|
-
],
|
134
|
-
)
|
164
|
+
logprobs = None
|
135
165
|
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
166
|
+
delta = text[len(stream_buffer) :]
|
167
|
+
stream_buffer = content["text"]
|
168
|
+
choice_data = CompletionResponseStreamChoice(
|
169
|
+
index=0,
|
170
|
+
text=delta,
|
171
|
+
logprobs=logprobs,
|
172
|
+
finish_reason=content["meta_info"]["finish_reason"],
|
173
|
+
)
|
174
|
+
chunk = CompletionStreamResponse(
|
175
|
+
id=content["meta_info"]["id"],
|
176
|
+
object="text_completion",
|
177
|
+
choices=[choice_data],
|
178
|
+
model=request.model,
|
179
|
+
usage=UsageInfo(
|
180
|
+
prompt_tokens=prompt_tokens,
|
181
|
+
completion_tokens=completion_tokens,
|
182
|
+
total_tokens=prompt_tokens + completion_tokens,
|
183
|
+
),
|
184
|
+
)
|
185
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
186
|
+
except ValueError as e:
|
187
|
+
error = create_streaming_error_response(str(e))
|
188
|
+
yield f"data: {error}\n\n"
|
160
189
|
yield "data: [DONE]\n\n"
|
161
190
|
|
162
|
-
return StreamingResponse(
|
191
|
+
return StreamingResponse(
|
192
|
+
generate_stream_resp(),
|
193
|
+
media_type="text/event-stream",
|
194
|
+
background=tokenizer_manager.create_abort_task(adapted_request),
|
195
|
+
)
|
163
196
|
|
164
197
|
# Non-streaming response.
|
165
|
-
|
166
|
-
|
198
|
+
try:
|
199
|
+
ret = await tokenizer_manager.generate_request(
|
200
|
+
adapted_request, raw_request
|
201
|
+
).__anext__()
|
202
|
+
except ValueError as e:
|
203
|
+
return create_error_response(str(e))
|
167
204
|
|
205
|
+
ret = ret[0] if isinstance(ret, list) else ret
|
168
206
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
169
207
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
170
208
|
text = ret["text"]
|
@@ -192,7 +230,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
192
230
|
index=0,
|
193
231
|
text=text,
|
194
232
|
logprobs=logprobs,
|
195
|
-
finish_reason=
|
233
|
+
finish_reason=ret["meta_info"]["finish_reason"],
|
196
234
|
)
|
197
235
|
response = CompletionResponse(
|
198
236
|
id=ret["meta_info"]["id"],
|
@@ -211,8 +249,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
211
249
|
request_json = await raw_request.json()
|
212
250
|
request = ChatCompletionRequest(**request_json)
|
213
251
|
|
214
|
-
|
215
|
-
|
252
|
+
if request.n != 1:
|
253
|
+
return create_error_response("n != 1 is not supported")
|
216
254
|
|
217
255
|
# Prep the data needed for the underlying GenerateReqInput:
|
218
256
|
# - prompt: The full prompt string.
|
@@ -257,7 +295,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
257
295
|
},
|
258
296
|
stream=request.stream,
|
259
297
|
)
|
260
|
-
adapted_request.post_init()
|
261
298
|
|
262
299
|
if adapted_request.stream:
|
263
300
|
|
@@ -265,46 +302,64 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
265
302
|
is_first = True
|
266
303
|
|
267
304
|
stream_buffer = ""
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
305
|
+
try:
|
306
|
+
async for content in tokenizer_manager.generate_request(
|
307
|
+
adapted_request, raw_request
|
308
|
+
):
|
309
|
+
if is_first:
|
310
|
+
# First chunk with role
|
311
|
+
is_first = False
|
312
|
+
choice_data = ChatCompletionResponseStreamChoice(
|
313
|
+
index=0,
|
314
|
+
delta=DeltaMessage(role="assistant"),
|
315
|
+
finish_reason=content["meta_info"]["finish_reason"],
|
316
|
+
)
|
317
|
+
chunk = ChatCompletionStreamResponse(
|
318
|
+
id=content["meta_info"]["id"],
|
319
|
+
choices=[choice_data],
|
320
|
+
model=request.model,
|
321
|
+
)
|
322
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
323
|
+
|
324
|
+
text = content["text"]
|
325
|
+
delta = text[len(stream_buffer) :]
|
326
|
+
stream_buffer = text
|
272
327
|
choice_data = ChatCompletionResponseStreamChoice(
|
273
328
|
index=0,
|
274
|
-
delta=DeltaMessage(
|
275
|
-
finish_reason=
|
329
|
+
delta=DeltaMessage(content=delta),
|
330
|
+
finish_reason=content["meta_info"]["finish_reason"],
|
276
331
|
)
|
277
332
|
chunk = ChatCompletionStreamResponse(
|
278
333
|
id=content["meta_info"]["id"],
|
279
334
|
choices=[choice_data],
|
280
335
|
model=request.model,
|
281
336
|
)
|
282
|
-
yield f"data: {
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
stream_buffer = text
|
287
|
-
choice_data = ChatCompletionResponseStreamChoice(
|
288
|
-
index=0, delta=DeltaMessage(content=delta), finish_reason=None
|
289
|
-
)
|
290
|
-
chunk = ChatCompletionStreamResponse(
|
291
|
-
id=content["meta_info"]["id"],
|
292
|
-
choices=[choice_data],
|
293
|
-
model=request.model,
|
294
|
-
)
|
295
|
-
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
337
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
338
|
+
except ValueError as e:
|
339
|
+
error = create_streaming_error_response(str(e))
|
340
|
+
yield f"data: {error}\n\n"
|
296
341
|
yield "data: [DONE]\n\n"
|
297
342
|
|
298
|
-
return StreamingResponse(
|
343
|
+
return StreamingResponse(
|
344
|
+
generate_stream_resp(),
|
345
|
+
media_type="text/event-stream",
|
346
|
+
background=tokenizer_manager.create_abort_task(adapted_request),
|
347
|
+
)
|
299
348
|
|
300
349
|
# Non-streaming response.
|
301
|
-
|
350
|
+
try:
|
351
|
+
ret = await tokenizer_manager.generate_request(
|
352
|
+
adapted_request, raw_request
|
353
|
+
).__anext__()
|
354
|
+
except ValueError as e:
|
355
|
+
return create_error_response(str(e))
|
356
|
+
|
302
357
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
303
358
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
304
359
|
choice_data = ChatCompletionResponseChoice(
|
305
360
|
index=0,
|
306
361
|
message=ChatMessage(role="assistant", content=ret["text"]),
|
307
|
-
finish_reason=
|
362
|
+
finish_reason=ret["meta_info"]["finish_reason"],
|
308
363
|
)
|
309
364
|
response = ChatCompletionResponse(
|
310
365
|
id=ret["meta_info"]["id"],
|
@@ -332,7 +387,7 @@ def to_openai_style_logprobs(
|
|
332
387
|
ret_logprobs.tokens.append(token_text)
|
333
388
|
ret_logprobs.token_logprobs.append(logprob)
|
334
389
|
|
335
|
-
# Not
|
390
|
+
# Not supported yet
|
336
391
|
ret_logprobs.text_offset.append(-1)
|
337
392
|
|
338
393
|
def append_top_logprobs(top_logprobs):
|
@@ -353,4 +408,4 @@ def to_openai_style_logprobs(
|
|
353
408
|
if decode_top_logprobs is not None:
|
354
409
|
append_top_logprobs(decode_top_logprobs)
|
355
410
|
|
356
|
-
return ret_logprobs
|
411
|
+
return ret_logprobs
|
sglang/srt/openai_protocol.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
"""
|
1
|
+
"""Pydantic models for OpenAI API protocol"""
|
2
|
+
|
2
3
|
import time
|
3
4
|
from typing import Dict, List, Optional, Union
|
4
5
|
|
@@ -6,6 +7,14 @@ from pydantic import BaseModel, Field
|
|
6
7
|
from typing_extensions import Literal
|
7
8
|
|
8
9
|
|
10
|
+
class ErrorResponse(BaseModel):
|
11
|
+
object: str = "error"
|
12
|
+
message: str
|
13
|
+
type: str
|
14
|
+
param: Optional[str] = None
|
15
|
+
code: int
|
16
|
+
|
17
|
+
|
9
18
|
class LogProbs(BaseModel):
|
10
19
|
text_offset: List[int] = Field(default_factory=list)
|
11
20
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
@@ -178,4 +187,4 @@ class ChatCompletionStreamResponse(BaseModel):
|
|
178
187
|
object: str = "chat.completion.chunk"
|
179
188
|
created: int = Field(default_factory=lambda: int(time.time()))
|
180
189
|
model: str
|
181
|
-
choices: List[ChatCompletionResponseStreamChoice]
|
190
|
+
choices: List[ChatCompletionResponseStreamChoice]
|
sglang/srt/server.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
The entry point of inference server.
|
3
|
+
SRT = SGLang Runtime.
|
4
|
+
"""
|
2
5
|
|
3
6
|
import asyncio
|
4
7
|
import dataclasses
|
@@ -9,7 +12,8 @@ import os
|
|
9
12
|
import sys
|
10
13
|
import threading
|
11
14
|
import time
|
12
|
-
from
|
15
|
+
from http import HTTPStatus
|
16
|
+
from typing import Dict, Optional
|
13
17
|
|
14
18
|
# Fix a bug of Python threading
|
15
19
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -25,21 +29,36 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
25
29
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
26
30
|
from sglang.srt.constrained import disable_cache
|
27
31
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
32
|
+
from sglang.srt.managers.controller.manager_multi import (
|
33
|
+
start_controller_process as start_controller_process_multi,
|
34
|
+
)
|
35
|
+
from sglang.srt.managers.controller.manager_single import (
|
36
|
+
start_controller_process as start_controller_process_single,
|
37
|
+
)
|
38
|
+
from sglang.srt.managers.controller.tp_worker import ModelTpService
|
28
39
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
29
40
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
30
|
-
from sglang.srt.managers.router.manager import start_router_process
|
31
41
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
32
42
|
from sglang.srt.openai_api_adapter import (
|
33
|
-
|
34
|
-
|
43
|
+
load_chat_template_for_openai_api,
|
44
|
+
v1_chat_completions,
|
45
|
+
v1_completions,
|
46
|
+
)
|
47
|
+
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
35
48
|
from sglang.srt.utils import (
|
49
|
+
API_KEY_HEADER_NAME,
|
50
|
+
APIKeyValidatorMiddleware,
|
36
51
|
allocate_init_ports,
|
37
52
|
assert_pkg_version,
|
38
53
|
enable_show_time_cost,
|
39
|
-
|
40
|
-
|
41
|
-
|
54
|
+
send_addrs_to_rank_0,
|
55
|
+
receive_addrs,
|
56
|
+
start_rpyc_service_process,
|
42
57
|
)
|
58
|
+
from sglang.utils import get_exception_traceback
|
59
|
+
|
60
|
+
|
61
|
+
logger = logging.getLogger(__name__)
|
43
62
|
|
44
63
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
45
64
|
|
@@ -69,7 +88,7 @@ async def get_server_args():
|
|
69
88
|
|
70
89
|
@app.get("/flush_cache")
|
71
90
|
async def flush_cache():
|
72
|
-
|
91
|
+
tokenizer_manager.flush_cache()
|
73
92
|
return Response(
|
74
93
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
75
94
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
@@ -77,24 +96,35 @@ async def flush_cache():
|
|
77
96
|
)
|
78
97
|
|
79
98
|
|
80
|
-
|
81
|
-
async def generate_request(obj: GenerateReqInput):
|
82
|
-
obj.post_init()
|
83
|
-
|
99
|
+
async def generate_request(obj: GenerateReqInput, request: Request):
|
84
100
|
if obj.stream:
|
85
101
|
|
86
102
|
async def stream_results():
|
87
|
-
|
103
|
+
try:
|
104
|
+
async for out in tokenizer_manager.generate_request(obj, request):
|
105
|
+
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
106
|
+
except ValueError as e:
|
107
|
+
out = {"error": {"message": str(e)}}
|
88
108
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
89
109
|
yield "data: [DONE]\n\n"
|
90
110
|
|
91
|
-
return StreamingResponse(
|
111
|
+
return StreamingResponse(
|
112
|
+
stream_results(),
|
113
|
+
media_type="text/event-stream",
|
114
|
+
background=tokenizer_manager.create_abort_task(obj),
|
115
|
+
)
|
116
|
+
else:
|
117
|
+
try:
|
118
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
119
|
+
return ret
|
120
|
+
except ValueError as e:
|
121
|
+
return JSONResponse(
|
122
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
123
|
+
)
|
124
|
+
|
92
125
|
|
93
|
-
|
94
|
-
|
95
|
-
return ret
|
96
|
-
except ValueError as e:
|
97
|
-
return JSONResponse({"error": str(e)}, status_code=400)
|
126
|
+
app.post("/generate")(generate_request)
|
127
|
+
app.put("/generate")(generate_request)
|
98
128
|
|
99
129
|
|
100
130
|
@app.post("/v1/completions")
|
@@ -121,31 +151,66 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
121
151
|
enable_show_time_cost()
|
122
152
|
if server_args.disable_disk_cache:
|
123
153
|
disable_cache()
|
124
|
-
if server_args.
|
125
|
-
assert_pkg_version("flashinfer", "0.0.
|
154
|
+
if not server_args.disable_flashinfer:
|
155
|
+
assert_pkg_version("flashinfer", "0.0.8", "Please uninstall the old version and "
|
156
|
+
"reinstall the latest version by following the instructions "
|
157
|
+
"at https://docs.flashinfer.ai/installation.html.")
|
126
158
|
if server_args.chat_template:
|
127
159
|
# TODO: replace this with huggingface transformers template
|
128
160
|
load_chat_template_for_openai_api(server_args.chat_template)
|
129
161
|
|
130
162
|
# Allocate ports
|
163
|
+
assert server_args.tp_size % server_args.nnodes == 0
|
164
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
131
165
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
132
|
-
server_args.port,
|
166
|
+
server_args.port,
|
167
|
+
server_args.additional_ports,
|
168
|
+
tp_size_local,
|
169
|
+
server_args.dp_size,
|
133
170
|
)
|
171
|
+
|
172
|
+
ports = server_args.additional_ports
|
173
|
+
model_port_args = []
|
174
|
+
for i in range(server_args.dp_size):
|
175
|
+
model_port_args.append(
|
176
|
+
ModelPortArgs(
|
177
|
+
nccl_port=ports[3 + i * (tp_size_local + 1)],
|
178
|
+
model_tp_ips=[None] * tp_size_local,
|
179
|
+
model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)],
|
180
|
+
)
|
181
|
+
)
|
134
182
|
port_args = PortArgs(
|
135
|
-
tokenizer_port=
|
136
|
-
router_port=
|
137
|
-
detokenizer_port=
|
138
|
-
|
139
|
-
model_rpc_ports=server_args.additional_ports[4:],
|
183
|
+
tokenizer_port=ports[0],
|
184
|
+
router_port=ports[1],
|
185
|
+
detokenizer_port=ports[2],
|
186
|
+
model_port_args=model_port_args,
|
140
187
|
)
|
141
188
|
|
189
|
+
# TODO multi-node dp is not supported
|
190
|
+
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
|
191
|
+
if server_args.nnodes > 1:
|
192
|
+
if server_args.node_rank != 0:
|
193
|
+
send_addrs_to_rank_0(model_port_args[0], server_args)
|
194
|
+
else:
|
195
|
+
receive_addrs(model_port_args[0], server_args)
|
196
|
+
for i in range(tp_size_local):
|
197
|
+
start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i])
|
198
|
+
if server_args.node_rank != 0:
|
199
|
+
logger.info(f"[node_rank={server_args.node_rank}]: Listen for connections...")
|
200
|
+
while True:
|
201
|
+
pass
|
202
|
+
|
142
203
|
# Launch processes
|
143
204
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
144
205
|
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
145
206
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
146
207
|
|
208
|
+
if server_args.dp_size == 1:
|
209
|
+
start_process = start_controller_process_single
|
210
|
+
else:
|
211
|
+
start_process = start_controller_process_multi
|
147
212
|
proc_router = mp.Process(
|
148
|
-
target=
|
213
|
+
target=start_process,
|
149
214
|
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
150
215
|
)
|
151
216
|
proc_router.start()
|
@@ -179,6 +244,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
179
244
|
if server_args.api_key and server_args.api_key != "":
|
180
245
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
181
246
|
|
247
|
+
# Send a warmup request
|
182
248
|
def _wait_and_warmup():
|
183
249
|
headers = {}
|
184
250
|
url = server_args.url()
|
@@ -190,43 +256,46 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
190
256
|
time.sleep(0.5)
|
191
257
|
try:
|
192
258
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
193
|
-
success = True # Set flag to True if request succeeds
|
194
259
|
break
|
195
|
-
except requests.exceptions.RequestException
|
260
|
+
except requests.exceptions.RequestException:
|
196
261
|
pass
|
197
262
|
|
198
263
|
# Send a warmup request
|
199
264
|
try:
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
"
|
206
|
-
|
265
|
+
for _ in range(server_args.dp_size):
|
266
|
+
res = requests.post(
|
267
|
+
url + "/generate",
|
268
|
+
json={
|
269
|
+
"text": "The capital city of France is",
|
270
|
+
"sampling_params": {
|
271
|
+
"temperature": 0,
|
272
|
+
"max_new_tokens": 8,
|
273
|
+
},
|
207
274
|
},
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
assert res.status_code == 200
|
275
|
+
headers=headers,
|
276
|
+
timeout=600,
|
277
|
+
)
|
278
|
+
assert res.status_code == 200
|
213
279
|
except Exception as e:
|
214
280
|
if pipe_finish_writer is not None:
|
215
281
|
pipe_finish_writer.send(get_exception_traceback())
|
216
|
-
print(f"Initialization failed. warmup error: {e}")
|
282
|
+
print(f"Initialization failed. warmup error: {e}", flush=True)
|
217
283
|
raise e
|
218
284
|
|
285
|
+
logger.info("The server is fired up and ready to roll!")
|
219
286
|
if pipe_finish_writer is not None:
|
220
287
|
pipe_finish_writer.send("init ok")
|
221
288
|
|
222
289
|
t = threading.Thread(target=_wait_and_warmup)
|
223
290
|
t.start()
|
291
|
+
|
292
|
+
# Listen for requests
|
224
293
|
try:
|
225
294
|
uvicorn.run(
|
226
295
|
app,
|
227
296
|
host=server_args.host,
|
228
297
|
port=server_args.port,
|
229
|
-
log_level=server_args.log_level,
|
298
|
+
log_level=server_args.log_level_http or server_args.log_level,
|
230
299
|
timeout_keep_alive=5,
|
231
300
|
loop="uvloop",
|
232
301
|
)
|
@@ -235,21 +304,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
235
304
|
|
236
305
|
|
237
306
|
class Runtime:
|
307
|
+
"""
|
308
|
+
A wrapper for the server.
|
309
|
+
This is used for launching the server in a python program without
|
310
|
+
using the commond line interface.
|
311
|
+
"""
|
312
|
+
|
238
313
|
def __init__(
|
239
314
|
self,
|
240
|
-
|
315
|
+
log_level: str = "error",
|
241
316
|
model_overide_args: Optional[dict] = None,
|
242
317
|
*args,
|
243
318
|
**kwargs,
|
244
319
|
):
|
245
320
|
"""See the arguments in server_args.py::ServerArgs"""
|
246
|
-
self.server_args = ServerArgs(*args, log_level=
|
321
|
+
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
247
322
|
|
248
323
|
# Pre-allocate ports
|
249
324
|
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
250
325
|
self.server_args.port,
|
251
326
|
self.server_args.additional_ports,
|
252
327
|
self.server_args.tp_size,
|
328
|
+
self.server_args.dp_size,
|
253
329
|
)
|
254
330
|
|
255
331
|
self.url = self.server_args.url()
|
@@ -304,7 +380,7 @@ class Runtime:
|
|
304
380
|
async def add_request(
|
305
381
|
self,
|
306
382
|
prompt: str,
|
307
|
-
sampling_params,
|
383
|
+
sampling_params: Dict,
|
308
384
|
):
|
309
385
|
json_data = {
|
310
386
|
"text": prompt,
|