sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/srt/server.py
CHANGED
@@ -1,97 +1,68 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
The entry point of inference server.
|
3
|
+
SRT = SGLang Runtime.
|
4
|
+
"""
|
2
5
|
|
3
6
|
import asyncio
|
4
7
|
import dataclasses
|
5
8
|
import json
|
9
|
+
import logging
|
6
10
|
import multiprocessing as mp
|
7
11
|
import os
|
8
12
|
import sys
|
9
13
|
import threading
|
10
14
|
import time
|
11
|
-
from
|
15
|
+
from http import HTTPStatus
|
16
|
+
from typing import Dict, Optional
|
12
17
|
|
13
|
-
# Fix a Python
|
18
|
+
# Fix a bug of Python threading
|
14
19
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
15
20
|
|
16
21
|
import aiohttp
|
17
22
|
import psutil
|
18
|
-
import pydantic
|
19
23
|
import requests
|
20
24
|
import uvicorn
|
21
25
|
import uvloop
|
22
|
-
from fastapi import FastAPI,
|
23
|
-
from fastapi.responses import Response, StreamingResponse
|
24
|
-
|
26
|
+
from fastapi import FastAPI, Request
|
27
|
+
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
28
|
+
|
25
29
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
26
30
|
from sglang.srt.constrained import disable_cache
|
27
|
-
from sglang.srt.conversation import (
|
28
|
-
Conversation,
|
29
|
-
SeparatorStyle,
|
30
|
-
chat_template_exists,
|
31
|
-
generate_chat_conv,
|
32
|
-
register_conv_template,
|
33
|
-
)
|
34
31
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
35
|
-
from sglang.srt.managers.
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
ChatCompletionResponseStreamChoice,
|
42
|
-
ChatCompletionStreamResponse,
|
43
|
-
ChatMessage,
|
44
|
-
CompletionRequest,
|
45
|
-
CompletionResponse,
|
46
|
-
CompletionResponseChoice,
|
47
|
-
CompletionResponseStreamChoice,
|
48
|
-
CompletionStreamResponse,
|
49
|
-
DeltaMessage,
|
50
|
-
LogProbs,
|
51
|
-
UsageInfo,
|
32
|
+
from sglang.srt.managers.controller.manager_multi import (
|
33
|
+
start_controller_process as start_controller_process_multi,
|
34
|
+
)
|
35
|
+
from sglang.srt.managers.controller.manager_single import (
|
36
|
+
launch_tp_servers,
|
37
|
+
start_controller_process as start_controller_process_single,
|
52
38
|
)
|
53
|
-
from sglang.srt.managers.
|
39
|
+
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
40
|
+
from sglang.srt.managers.io_struct import GenerateReqInput
|
54
41
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
55
|
-
from sglang.srt.
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
API_KEY_HEADER_NAME
|
63
|
-
|
42
|
+
from sglang.srt.openai_api_adapter import (
|
43
|
+
load_chat_template_for_openai_api,
|
44
|
+
v1_chat_completions,
|
45
|
+
v1_completions,
|
46
|
+
)
|
47
|
+
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
48
|
+
from sglang.srt.utils import (
|
49
|
+
API_KEY_HEADER_NAME,
|
50
|
+
APIKeyValidatorMiddleware,
|
51
|
+
allocate_init_ports,
|
52
|
+
assert_pkg_version,
|
53
|
+
enable_show_time_cost,
|
54
|
+
receive_addrs,
|
55
|
+
send_addrs_to_rank_0,
|
56
|
+
)
|
57
|
+
from sglang.utils import get_exception_traceback
|
64
58
|
|
65
|
-
|
66
|
-
def __init__(self, app, api_key: str):
|
67
|
-
super().__init__(app)
|
68
|
-
self.api_key = api_key
|
59
|
+
logger = logging.getLogger(__name__)
|
69
60
|
|
70
|
-
|
71
|
-
# extract API key from the request headers
|
72
|
-
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
73
|
-
if not api_key_header or api_key_header != self.api_key:
|
74
|
-
return JSONResponse(
|
75
|
-
status_code=403,
|
76
|
-
content={"detail": "Invalid API Key"},
|
77
|
-
)
|
78
|
-
response = await call_next(request)
|
79
|
-
return response
|
61
|
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
80
62
|
|
81
63
|
|
82
64
|
app = FastAPI()
|
83
65
|
tokenizer_manager = None
|
84
|
-
chat_template_name = None
|
85
|
-
|
86
|
-
|
87
|
-
# FIXME: Remove this once we drop support for pydantic 1.x
|
88
|
-
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
89
|
-
|
90
|
-
|
91
|
-
def jsonify_pydantic_model(obj: BaseModel):
|
92
|
-
if IS_PYDANTIC_1:
|
93
|
-
return obj.json(ensure_ascii=False)
|
94
|
-
return obj.model_dump_json()
|
95
66
|
|
96
67
|
|
97
68
|
@app.get("/health")
|
@@ -115,7 +86,7 @@ async def get_server_args():
|
|
115
86
|
|
116
87
|
@app.get("/flush_cache")
|
117
88
|
async def flush_cache():
|
118
|
-
|
89
|
+
tokenizer_manager.flush_cache()
|
119
90
|
return Response(
|
120
91
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
121
92
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
@@ -123,361 +94,129 @@ async def flush_cache():
|
|
123
94
|
)
|
124
95
|
|
125
96
|
|
126
|
-
async def
|
127
|
-
token_ids = [tid for tid, _ in token_logprobs]
|
128
|
-
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
129
|
-
return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
|
130
|
-
|
131
|
-
|
132
|
-
async def stream_generator(obj: GenerateReqInput):
|
133
|
-
async for out in tokenizer_manager.generate_request(obj):
|
134
|
-
if obj.return_logprob and obj.return_text_in_logprobs:
|
135
|
-
out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
|
136
|
-
out["meta_info"]["token_logprob"]
|
137
|
-
)
|
138
|
-
yield out
|
139
|
-
|
140
|
-
|
141
|
-
async def make_openai_style_logprobs(token_logprobs):
|
142
|
-
ret_logprobs = LogProbs()
|
143
|
-
|
144
|
-
for token_text, token_logprob in token_logprobs:
|
145
|
-
ret_logprobs.tokens.append(token_text)
|
146
|
-
ret_logprobs.token_logprobs.append(token_logprob)
|
147
|
-
|
148
|
-
# Not supported yet.
|
149
|
-
ret_logprobs.top_logprobs.append({})
|
150
|
-
ret_logprobs.text_offset.append(-1)
|
151
|
-
return ret_logprobs
|
152
|
-
|
153
|
-
|
154
|
-
@app.post("/generate")
|
155
|
-
async def generate_request(obj: GenerateReqInput):
|
156
|
-
obj.post_init()
|
157
|
-
|
97
|
+
async def generate_request(obj: GenerateReqInput, request: Request):
|
158
98
|
if obj.stream:
|
159
99
|
|
160
100
|
async def stream_results():
|
161
|
-
|
101
|
+
try:
|
102
|
+
async for out in tokenizer_manager.generate_request(obj, request):
|
103
|
+
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
104
|
+
except ValueError as e:
|
105
|
+
out = {"error": {"message": str(e)}}
|
162
106
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
163
107
|
yield "data: [DONE]\n\n"
|
164
108
|
|
165
|
-
return StreamingResponse(
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
|
170
|
-
ret["meta_info"]["token_logprob"]
|
109
|
+
return StreamingResponse(
|
110
|
+
stream_results(),
|
111
|
+
media_type="text/event-stream",
|
112
|
+
background=tokenizer_manager.create_abort_task(obj),
|
171
113
|
)
|
172
|
-
|
173
|
-
return ret
|
174
|
-
|
175
|
-
|
176
|
-
@app.post("/v1/completions")
|
177
|
-
async def v1_completions(raw_request: Request):
|
178
|
-
request_json = await raw_request.json()
|
179
|
-
request = CompletionRequest(**request_json)
|
180
|
-
|
181
|
-
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
182
|
-
assert request.n == 1
|
183
|
-
|
184
|
-
adapted_request = GenerateReqInput(
|
185
|
-
text=request.prompt,
|
186
|
-
sampling_params={
|
187
|
-
"temperature": request.temperature,
|
188
|
-
"max_new_tokens": request.max_tokens,
|
189
|
-
"stop": request.stop,
|
190
|
-
"top_p": request.top_p,
|
191
|
-
"presence_penalty": request.presence_penalty,
|
192
|
-
"frequency_penalty": request.frequency_penalty,
|
193
|
-
"regex": request.regex,
|
194
|
-
},
|
195
|
-
return_logprob=request.logprobs is not None,
|
196
|
-
return_text_in_logprobs=True,
|
197
|
-
stream=request.stream,
|
198
|
-
)
|
199
|
-
adapted_request.post_init()
|
200
|
-
|
201
|
-
if adapted_request.stream:
|
202
|
-
|
203
|
-
async def gnerate_stream_resp():
|
204
|
-
stream_buffer = ""
|
205
|
-
n_prev_token = 0
|
206
|
-
async for content in stream_generator(adapted_request):
|
207
|
-
text = content["text"]
|
208
|
-
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
209
|
-
completion_tokens = content["meta_info"]["completion_tokens"]
|
210
|
-
|
211
|
-
if not stream_buffer: # The first chunk
|
212
|
-
if request.echo:
|
213
|
-
# Prepend prompt in response text.
|
214
|
-
text = request.prompt + text
|
215
|
-
else:
|
216
|
-
# Skip prompt tokens if echo is disabled.
|
217
|
-
n_prev_token = prompt_tokens
|
218
|
-
|
219
|
-
if request.logprobs is not None:
|
220
|
-
logprobs = await make_openai_style_logprobs(
|
221
|
-
content["meta_info"]["token_logprob"][n_prev_token:]
|
222
|
-
)
|
223
|
-
n_prev_token = len(content["meta_info"]["token_logprob"])
|
224
|
-
else:
|
225
|
-
logprobs = None
|
226
|
-
|
227
|
-
delta = text[len(stream_buffer) :]
|
228
|
-
stream_buffer = content["text"]
|
229
|
-
choice_data = CompletionResponseStreamChoice(
|
230
|
-
index=0,
|
231
|
-
text=delta,
|
232
|
-
logprobs=logprobs,
|
233
|
-
finish_reason=None,
|
234
|
-
)
|
235
|
-
chunk = CompletionStreamResponse(
|
236
|
-
id=content["meta_info"]["id"],
|
237
|
-
object="text_completion",
|
238
|
-
choices=[choice_data],
|
239
|
-
model=request.model,
|
240
|
-
usage=UsageInfo(
|
241
|
-
prompt_tokens=prompt_tokens,
|
242
|
-
completion_tokens=completion_tokens,
|
243
|
-
total_tokens=prompt_tokens + completion_tokens,
|
244
|
-
),
|
245
|
-
)
|
246
|
-
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
247
|
-
yield "data: [DONE]\n\n"
|
248
|
-
|
249
|
-
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
|
250
|
-
|
251
|
-
# Non-streaming response.
|
252
|
-
ret = await generate_request(adapted_request)
|
253
|
-
ret = ret[0] if isinstance(ret, list) else ret
|
254
|
-
|
255
|
-
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
256
|
-
completion_tokens = ret["meta_info"]["completion_tokens"]
|
257
|
-
text = ret["text"]
|
258
|
-
token_logprob_pos = prompt_tokens
|
259
|
-
if request.echo:
|
260
|
-
token_logprob_pos = 0
|
261
|
-
text = request.prompt + text
|
262
114
|
else:
|
263
|
-
|
115
|
+
try:
|
116
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
117
|
+
return ret
|
118
|
+
except ValueError as e:
|
119
|
+
return JSONResponse(
|
120
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
121
|
+
)
|
264
122
|
|
265
|
-
logprobs = (
|
266
|
-
await make_openai_style_logprobs(
|
267
|
-
ret["meta_info"]["token_logprob"][token_logprob_pos:]
|
268
|
-
)
|
269
|
-
if request.logprobs is not None
|
270
|
-
else None
|
271
|
-
)
|
272
|
-
choice_data = CompletionResponseChoice(
|
273
|
-
index=0,
|
274
|
-
text=text,
|
275
|
-
logprobs=logprobs,
|
276
|
-
finish_reason=None, # TODO(comaniac): Add finish reason.
|
277
|
-
)
|
278
123
|
|
279
|
-
|
280
|
-
|
281
|
-
model=request.model,
|
282
|
-
choices=[choice_data],
|
283
|
-
usage=UsageInfo(
|
284
|
-
prompt_tokens=prompt_tokens,
|
285
|
-
completion_tokens=completion_tokens,
|
286
|
-
total_tokens=prompt_tokens + completion_tokens,
|
287
|
-
),
|
288
|
-
)
|
289
|
-
return response
|
124
|
+
app.post("/generate")(generate_request)
|
125
|
+
app.put("/generate")(generate_request)
|
290
126
|
|
291
127
|
|
292
|
-
@app.post("/v1/
|
293
|
-
async def
|
294
|
-
|
295
|
-
request = ChatCompletionRequest(**request_json)
|
296
|
-
|
297
|
-
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
298
|
-
assert request.n == 1
|
299
|
-
|
300
|
-
# Prep the data needed for the underlying GenerateReqInput:
|
301
|
-
# - prompt: The full prompt string.
|
302
|
-
# - stop: Custom stop tokens.
|
303
|
-
# - image_data: None or a list of image strings (URLs or base64 strings).
|
304
|
-
# None skips any image processing in GenerateReqInput.
|
305
|
-
if not isinstance(request.messages, str):
|
306
|
-
# Apply chat template and its stop strings.
|
307
|
-
if chat_template_name is None:
|
308
|
-
# This flow doesn't support the full OpenAI spec. Verify messages
|
309
|
-
# has the right type before proceeding:
|
310
|
-
for m in request.messages:
|
311
|
-
if not isinstance(m.content, str):
|
312
|
-
raise HTTPException(
|
313
|
-
status_code=503,
|
314
|
-
detail="Structured content requests not supported with "
|
315
|
-
"HuggingFace Chat Templates. "
|
316
|
-
"Make sure the server specifies a sglang chat template.",
|
317
|
-
)
|
318
|
-
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
319
|
-
request.messages, tokenize=False, add_generation_prompt=True
|
320
|
-
)
|
321
|
-
stop = request.stop
|
322
|
-
image_data = None
|
323
|
-
else:
|
324
|
-
conv = generate_chat_conv(request, chat_template_name)
|
325
|
-
prompt = conv.get_prompt()
|
326
|
-
image_data = conv.image_data
|
327
|
-
stop = conv.stop_str or []
|
328
|
-
if request.stop:
|
329
|
-
if isinstance(request.stop, str):
|
330
|
-
stop.append(request.stop)
|
331
|
-
else:
|
332
|
-
stop.extend(request.stop)
|
333
|
-
else:
|
334
|
-
# Use the raw prompt and stop strings if the messages is already a string.
|
335
|
-
prompt = request.messages
|
336
|
-
stop = request.stop
|
337
|
-
image_data = None
|
338
|
-
|
339
|
-
adapted_request = GenerateReqInput(
|
340
|
-
text=prompt,
|
341
|
-
image_data=image_data,
|
342
|
-
sampling_params={
|
343
|
-
"temperature": request.temperature,
|
344
|
-
"max_new_tokens": request.max_tokens,
|
345
|
-
"stop": stop,
|
346
|
-
"top_p": request.top_p,
|
347
|
-
"presence_penalty": request.presence_penalty,
|
348
|
-
"frequency_penalty": request.frequency_penalty,
|
349
|
-
"regex": request.regex,
|
350
|
-
},
|
351
|
-
stream=request.stream,
|
352
|
-
)
|
353
|
-
adapted_request.post_init()
|
354
|
-
|
355
|
-
if adapted_request.stream:
|
356
|
-
|
357
|
-
async def gnerate_stream_resp():
|
358
|
-
is_first = True
|
359
|
-
|
360
|
-
stream_buffer = ""
|
361
|
-
async for content in stream_generator(adapted_request):
|
362
|
-
if is_first:
|
363
|
-
# First chunk with role
|
364
|
-
is_first = False
|
365
|
-
choice_data = ChatCompletionResponseStreamChoice(
|
366
|
-
index=0,
|
367
|
-
delta=DeltaMessage(role="assistant"),
|
368
|
-
finish_reason=None,
|
369
|
-
)
|
370
|
-
chunk = ChatCompletionStreamResponse(
|
371
|
-
id=content["meta_info"]["id"],
|
372
|
-
choices=[choice_data],
|
373
|
-
model=request.model,
|
374
|
-
)
|
375
|
-
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
376
|
-
|
377
|
-
text = content["text"]
|
378
|
-
delta = text[len(stream_buffer) :]
|
379
|
-
stream_buffer = text
|
380
|
-
choice_data = ChatCompletionResponseStreamChoice(
|
381
|
-
index=0, delta=DeltaMessage(content=delta), finish_reason=None
|
382
|
-
)
|
383
|
-
chunk = ChatCompletionStreamResponse(
|
384
|
-
id=content["meta_info"]["id"],
|
385
|
-
choices=[choice_data],
|
386
|
-
model=request.model,
|
387
|
-
)
|
388
|
-
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
389
|
-
yield "data: [DONE]\n\n"
|
128
|
+
@app.post("/v1/completions")
|
129
|
+
async def openai_v1_completions(raw_request: Request):
|
130
|
+
return await v1_completions(tokenizer_manager, raw_request)
|
390
131
|
|
391
|
-
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
|
392
132
|
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
completion_tokens = ret["meta_info"]["completion_tokens"]
|
397
|
-
choice_data = ChatCompletionResponseChoice(
|
398
|
-
index=0,
|
399
|
-
message=ChatMessage(role="assistant", content=ret["text"]),
|
400
|
-
finish_reason=None, # TODO(comaniac): Add finish reason.
|
401
|
-
)
|
402
|
-
response = ChatCompletionResponse(
|
403
|
-
id=ret["meta_info"]["id"],
|
404
|
-
model=request.model,
|
405
|
-
choices=[choice_data],
|
406
|
-
usage=UsageInfo(
|
407
|
-
prompt_tokens=prompt_tokens,
|
408
|
-
completion_tokens=completion_tokens,
|
409
|
-
total_tokens=prompt_tokens + completion_tokens,
|
410
|
-
),
|
411
|
-
)
|
412
|
-
return response
|
133
|
+
@app.post("/v1/chat/completions")
|
134
|
+
async def openai_v1_chat_completions(raw_request: Request):
|
135
|
+
return await v1_chat_completions(tokenizer_manager, raw_request)
|
413
136
|
|
414
137
|
|
415
|
-
def launch_server(server_args, pipe_finish_writer):
|
138
|
+
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
|
416
139
|
global tokenizer_manager
|
417
|
-
global chat_template_name
|
418
140
|
|
419
|
-
|
141
|
+
logging.basicConfig(
|
142
|
+
level=getattr(logging, server_args.log_level.upper()),
|
143
|
+
format="%(message)s",
|
144
|
+
)
|
145
|
+
|
146
|
+
# Set global environments
|
147
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
148
|
+
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
149
|
+
if server_args.show_time_cost:
|
150
|
+
enable_show_time_cost()
|
420
151
|
if server_args.disable_disk_cache:
|
421
152
|
disable_cache()
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
153
|
+
if not server_args.disable_flashinfer:
|
154
|
+
assert_pkg_version(
|
155
|
+
"flashinfer",
|
156
|
+
"0.0.8",
|
157
|
+
"Please uninstall the old version and "
|
158
|
+
"reinstall the latest version by following the instructions "
|
159
|
+
"at https://docs.flashinfer.ai/installation.html.",
|
160
|
+
)
|
161
|
+
if server_args.chat_template:
|
162
|
+
# TODO: replace this with huggingface transformers template
|
163
|
+
load_chat_template_for_openai_api(server_args.chat_template)
|
164
|
+
|
165
|
+
# Allocate ports
|
166
|
+
assert server_args.tp_size % server_args.nnodes == 0
|
167
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
168
|
+
server_args.port, server_args.additional_ports = allocate_init_ports(
|
169
|
+
server_args.port,
|
170
|
+
server_args.additional_ports,
|
171
|
+
tp_size_local,
|
172
|
+
server_args.dp_size,
|
426
173
|
)
|
427
174
|
|
175
|
+
ports = server_args.additional_ports
|
176
|
+
model_port_args = []
|
177
|
+
for i in range(server_args.dp_size):
|
178
|
+
model_port_args.append(
|
179
|
+
ModelPortArgs(
|
180
|
+
nccl_port=ports[3 + i * (tp_size_local + 1)],
|
181
|
+
model_tp_ips=[None] * tp_size_local,
|
182
|
+
model_tp_ports=ports[
|
183
|
+
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
|
184
|
+
],
|
185
|
+
)
|
186
|
+
)
|
428
187
|
port_args = PortArgs(
|
429
|
-
tokenizer_port=
|
430
|
-
router_port=
|
431
|
-
detokenizer_port=
|
432
|
-
|
433
|
-
model_rpc_ports=server_args.additional_ports[4:],
|
188
|
+
tokenizer_port=ports[0],
|
189
|
+
router_port=ports[1],
|
190
|
+
detokenizer_port=ports[2],
|
191
|
+
model_port_args=model_port_args,
|
434
192
|
)
|
435
193
|
|
436
|
-
#
|
437
|
-
if server_args.
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
except KeyError:
|
450
|
-
raise ValueError(
|
451
|
-
f"Unknown separator style: {template['sep_style']}"
|
452
|
-
) from None
|
453
|
-
register_conv_template(
|
454
|
-
Conversation(
|
455
|
-
name=template["name"],
|
456
|
-
system_template=template["system"] + "\n{system_message}",
|
457
|
-
system_message=template.get("system_message", ""),
|
458
|
-
roles=(template["user"], template["assistant"]),
|
459
|
-
sep_style=sep_style,
|
460
|
-
sep=template.get("sep", "\n"),
|
461
|
-
stop_str=template["stop_str"],
|
462
|
-
),
|
463
|
-
override=True,
|
464
|
-
)
|
465
|
-
chat_template_name = template["name"]
|
466
|
-
else:
|
467
|
-
chat_template_name = server_args.chat_template
|
194
|
+
# Handle multi-node tp
|
195
|
+
if server_args.nnodes > 1:
|
196
|
+
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
197
|
+
|
198
|
+
if server_args.node_rank != 0:
|
199
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
200
|
+
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
201
|
+
tp_rank_range = list(range(server_args.node_rank * tp_size_local,
|
202
|
+
(server_args.node_rank + 1) * tp_size_local))
|
203
|
+
procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args,
|
204
|
+
port_args.model_port_args[0], model_overide_args)
|
205
|
+
while True:
|
206
|
+
pass
|
468
207
|
|
469
208
|
# Launch processes
|
470
|
-
tokenizer_manager = TokenizerManager(server_args, port_args)
|
209
|
+
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
471
210
|
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
472
211
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
473
212
|
|
213
|
+
if server_args.dp_size == 1:
|
214
|
+
start_process = start_controller_process_single
|
215
|
+
else:
|
216
|
+
start_process = start_controller_process_multi
|
474
217
|
proc_router = mp.Process(
|
475
|
-
target=
|
476
|
-
args=(
|
477
|
-
server_args,
|
478
|
-
port_args,
|
479
|
-
pipe_router_writer,
|
480
|
-
),
|
218
|
+
target=start_process,
|
219
|
+
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
481
220
|
)
|
482
221
|
proc_router.start()
|
483
222
|
proc_detoken = mp.Process(
|
@@ -497,128 +236,101 @@ def launch_server(server_args, pipe_finish_writer):
|
|
497
236
|
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
498
237
|
proc_router.kill()
|
499
238
|
proc_detoken.kill()
|
500
|
-
print(
|
501
|
-
|
239
|
+
print(
|
240
|
+
f"Initialization failed. router_init_state: {router_init_state}", flush=True
|
241
|
+
)
|
242
|
+
print(
|
243
|
+
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
244
|
+
flush=True,
|
245
|
+
)
|
502
246
|
sys.exit(1)
|
503
|
-
|
504
247
|
assert proc_router.is_alive() and proc_detoken.is_alive()
|
505
248
|
|
506
249
|
if server_args.api_key and server_args.api_key != "":
|
507
250
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
508
251
|
|
509
|
-
|
510
|
-
uvicorn.run(
|
511
|
-
app,
|
512
|
-
host=server_args.host,
|
513
|
-
port=server_args.port,
|
514
|
-
log_level=server_args.log_level,
|
515
|
-
timeout_keep_alive=5,
|
516
|
-
loop="uvloop",
|
517
|
-
)
|
518
|
-
|
252
|
+
# Send a warmup request
|
519
253
|
def _wait_and_warmup():
|
520
254
|
headers = {}
|
521
255
|
url = server_args.url()
|
522
|
-
if server_args.api_key
|
256
|
+
if server_args.api_key:
|
523
257
|
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
524
258
|
|
259
|
+
# Wait until the server is launched
|
525
260
|
for _ in range(120):
|
526
261
|
time.sleep(0.5)
|
527
262
|
try:
|
528
263
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
529
264
|
break
|
530
|
-
except requests.exceptions.RequestException
|
265
|
+
except requests.exceptions.RequestException:
|
531
266
|
pass
|
532
|
-
else:
|
533
|
-
if pipe_finish_writer is not None:
|
534
|
-
pipe_finish_writer.send(str(e))
|
535
|
-
else:
|
536
|
-
print(e, flush=True)
|
537
|
-
return
|
538
267
|
|
539
|
-
#
|
268
|
+
# Send a warmup request
|
540
269
|
try:
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
270
|
+
for _ in range(server_args.dp_size):
|
271
|
+
res = requests.post(
|
272
|
+
url + "/generate",
|
273
|
+
json={
|
274
|
+
"text": "The capital city of France is",
|
275
|
+
"sampling_params": {
|
276
|
+
"temperature": 0,
|
277
|
+
"max_new_tokens": 8,
|
278
|
+
},
|
549
279
|
},
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
# print("=" * 20, "Server is ready", "=" * 20, flush=True)
|
556
|
-
except requests.exceptions.RequestException as e:
|
280
|
+
headers=headers,
|
281
|
+
timeout=600,
|
282
|
+
)
|
283
|
+
assert res.status_code == 200
|
284
|
+
except Exception as e:
|
557
285
|
if pipe_finish_writer is not None:
|
558
|
-
pipe_finish_writer.send(
|
559
|
-
|
560
|
-
|
561
|
-
return
|
286
|
+
pipe_finish_writer.send(get_exception_traceback())
|
287
|
+
print(f"Initialization failed. warmup error: {e}", flush=True)
|
288
|
+
raise e
|
562
289
|
|
290
|
+
logger.info("The server is fired up and ready to roll!")
|
563
291
|
if pipe_finish_writer is not None:
|
564
292
|
pipe_finish_writer.send("init ok")
|
565
293
|
|
566
294
|
t = threading.Thread(target=_wait_and_warmup)
|
567
295
|
t.start()
|
296
|
+
|
297
|
+
# Listen for requests
|
568
298
|
try:
|
569
|
-
|
299
|
+
uvicorn.run(
|
300
|
+
app,
|
301
|
+
host=server_args.host,
|
302
|
+
port=server_args.port,
|
303
|
+
log_level=server_args.log_level_http or server_args.log_level,
|
304
|
+
timeout_keep_alive=5,
|
305
|
+
loop="uvloop",
|
306
|
+
)
|
570
307
|
finally:
|
571
308
|
t.join()
|
572
309
|
|
573
310
|
|
574
311
|
class Runtime:
|
312
|
+
"""
|
313
|
+
A wrapper for the server.
|
314
|
+
This is used for launching the server in a python program without
|
315
|
+
using the commond line interface.
|
316
|
+
"""
|
317
|
+
|
575
318
|
def __init__(
|
576
319
|
self,
|
577
|
-
model_path: str,
|
578
|
-
tokenizer_path: Optional[str] = None,
|
579
|
-
load_format: str = "auto",
|
580
|
-
tokenizer_mode: str = "auto",
|
581
|
-
trust_remote_code: bool = True,
|
582
|
-
mem_fraction_static: float = ServerArgs.mem_fraction_static,
|
583
|
-
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
|
584
|
-
context_length: int = ServerArgs.context_length,
|
585
|
-
tp_size: int = 1,
|
586
|
-
schedule_heuristic: str = "lpm",
|
587
|
-
attention_reduce_in_fp32: bool = False,
|
588
|
-
random_seed: int = 42,
|
589
320
|
log_level: str = "error",
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
disable_disk_cache: bool = False,
|
594
|
-
api_key: str = "",
|
595
|
-
port: Optional[int] = None,
|
596
|
-
additional_ports: Optional[Union[List[int], int]] = None,
|
321
|
+
model_overide_args: Optional[dict] = None,
|
322
|
+
*args,
|
323
|
+
**kwargs,
|
597
324
|
):
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
tokenizer_mode=tokenizer_mode,
|
608
|
-
trust_remote_code=trust_remote_code,
|
609
|
-
mem_fraction_static=mem_fraction_static,
|
610
|
-
max_prefill_num_token=max_prefill_num_token,
|
611
|
-
context_length=context_length,
|
612
|
-
tp_size=tp_size,
|
613
|
-
schedule_heuristic=schedule_heuristic,
|
614
|
-
attention_reduce_in_fp32=attention_reduce_in_fp32,
|
615
|
-
random_seed=random_seed,
|
616
|
-
log_level=log_level,
|
617
|
-
disable_radix_cache=disable_radix_cache,
|
618
|
-
enable_flashinfer=enable_flashinfer,
|
619
|
-
disable_regex_jump_forward=disable_regex_jump_forward,
|
620
|
-
disable_disk_cache=disable_disk_cache,
|
621
|
-
api_key=api_key,
|
325
|
+
"""See the arguments in server_args.py::ServerArgs"""
|
326
|
+
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
327
|
+
|
328
|
+
# Pre-allocate ports
|
329
|
+
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
330
|
+
self.server_args.port,
|
331
|
+
self.server_args.additional_ports,
|
332
|
+
self.server_args.tp_size,
|
333
|
+
self.server_args.dp_size,
|
622
334
|
)
|
623
335
|
|
624
336
|
self.url = self.server_args.url()
|
@@ -628,7 +340,10 @@ class Runtime:
|
|
628
340
|
|
629
341
|
self.pid = None
|
630
342
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
631
|
-
proc = mp.Process(
|
343
|
+
proc = mp.Process(
|
344
|
+
target=launch_server,
|
345
|
+
args=(self.server_args, pipe_writer, model_overide_args),
|
346
|
+
)
|
632
347
|
proc.start()
|
633
348
|
pipe_writer.close()
|
634
349
|
self.pid = proc.pid
|
@@ -640,7 +355,9 @@ class Runtime:
|
|
640
355
|
|
641
356
|
if init_state != "init ok":
|
642
357
|
self.shutdown()
|
643
|
-
raise RuntimeError(
|
358
|
+
raise RuntimeError(
|
359
|
+
"Initialization failed. Please see the error messages above."
|
360
|
+
)
|
644
361
|
|
645
362
|
self.endpoint = RuntimeEndpoint(self.url)
|
646
363
|
|
@@ -668,14 +385,13 @@ class Runtime:
|
|
668
385
|
async def add_request(
|
669
386
|
self,
|
670
387
|
prompt: str,
|
671
|
-
sampling_params,
|
672
|
-
)
|
388
|
+
sampling_params: Dict,
|
389
|
+
):
|
673
390
|
json_data = {
|
674
391
|
"text": prompt,
|
675
392
|
"sampling_params": sampling_params,
|
676
393
|
"stream": True,
|
677
394
|
}
|
678
|
-
|
679
395
|
pos = 0
|
680
396
|
|
681
397
|
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|