ipex-llm 2.3.0b20250427__py3-none-win_amd64.whl → 2.3.0b20250501__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ipex_llm/libs/bloom-api.dll +0 -0
- ipex_llm/libs/bloom.dll +0 -0
- ipex_llm/libs/gptneox-api.dll +0 -0
- ipex_llm/libs/gptneox.dll +0 -0
- ipex_llm/libs/libbloom_avx.dll +0 -0
- ipex_llm/libs/libbloom_vnni.dll +0 -0
- ipex_llm/libs/libgptneox_avx.dll +0 -0
- ipex_llm/libs/libgptneox_vnni.dll +0 -0
- ipex_llm/libs/libllama_avx.dll +0 -0
- ipex_llm/libs/libllama_vnni.dll +0 -0
- ipex_llm/libs/libstarcoder_avx.dll +0 -0
- ipex_llm/libs/libstarcoder_vnni.dll +0 -0
- ipex_llm/libs/llama-api.dll +0 -0
- ipex_llm/libs/llama.dll +0 -0
- ipex_llm/libs/main-bloom.exe +0 -0
- ipex_llm/libs/main-gptneox.exe +0 -0
- ipex_llm/libs/main-llama.exe +0 -0
- ipex_llm/libs/main-starcoder.exe +0 -0
- ipex_llm/libs/pipeline.dll +0 -0
- ipex_llm/libs/quantize-bloom.exe +0 -0
- ipex_llm/libs/quantize-bloom_vnni.exe +0 -0
- ipex_llm/libs/quantize-gptneox.exe +0 -0
- ipex_llm/libs/quantize-gptneox_vnni.exe +0 -0
- ipex_llm/libs/quantize-llama.exe +0 -0
- ipex_llm/libs/quantize-llama_vnni.exe +0 -0
- ipex_llm/libs/quantize-starcoder.exe +0 -0
- ipex_llm/libs/quantize-starcoder_vnni.exe +0 -0
- ipex_llm/libs/starcoder-api.dll +0 -0
- ipex_llm/libs/starcoder.dll +0 -0
- ipex_llm/transformers/convert.py +3 -2
- ipex_llm/vllm/xpu/engine/__init__.py +3 -1
- ipex_llm/vllm/xpu/engine/engine.py +163 -19
- ipex_llm/vllm/xpu/entrypoints/openai/api_server.py +448 -180
- ipex_llm/vllm/xpu/model_convert.py +5 -2
- {ipex_llm-2.3.0b20250427.dist-info → ipex_llm-2.3.0b20250501.dist-info}/METADATA +11 -11
- {ipex_llm-2.3.0b20250427.dist-info → ipex_llm-2.3.0b20250501.dist-info}/RECORD +42 -42
- {ipex_llm-2.3.0b20250427.data → ipex_llm-2.3.0b20250501.data}/scripts/ipex-llm-init.bat +0 -0
- {ipex_llm-2.3.0b20250427.data → ipex_llm-2.3.0b20250501.data}/scripts/llm-chat.ps1 +0 -0
- {ipex_llm-2.3.0b20250427.data → ipex_llm-2.3.0b20250501.data}/scripts/llm-cli.ps1 +0 -0
- {ipex_llm-2.3.0b20250427.dist-info → ipex_llm-2.3.0b20250501.dist-info}/WHEEL +0 -0
- {ipex_llm-2.3.0b20250427.dist-info → ipex_llm-2.3.0b20250501.dist-info}/entry_points.txt +0 -0
- {ipex_llm-2.3.0b20250427.dist-info → ipex_llm-2.3.0b20250501.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,8 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import atexit
|
5
|
+
import gc
|
3
6
|
import importlib
|
4
7
|
import inspect
|
5
8
|
import multiprocessing
|
@@ -10,16 +13,18 @@ import socket
|
|
10
13
|
import tempfile
|
11
14
|
import uuid
|
12
15
|
from argparse import Namespace
|
16
|
+
from collections.abc import AsyncIterator
|
13
17
|
from contextlib import asynccontextmanager
|
14
18
|
from functools import partial
|
15
19
|
from http import HTTPStatus
|
16
|
-
from typing import
|
20
|
+
from typing import Annotated, Optional, Union
|
17
21
|
|
18
22
|
import uvloop
|
19
|
-
from fastapi import APIRouter, FastAPI, Request
|
23
|
+
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
20
24
|
from fastapi.exceptions import RequestValidationError
|
21
25
|
from fastapi.middleware.cors import CORSMiddleware
|
22
26
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
27
|
+
from starlette.concurrency import iterate_in_threadpool
|
23
28
|
from starlette.datastructures import State
|
24
29
|
from starlette.routing import Mount
|
25
30
|
from typing_extensions import assert_never
|
@@ -27,17 +32,17 @@ from typing_extensions import assert_never
|
|
27
32
|
import vllm.envs as envs
|
28
33
|
from vllm.config import ModelConfig
|
29
34
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
30
|
-
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
|
35
|
+
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine # type: ignore
|
31
36
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
32
37
|
from ipex_llm.vllm.xpu.engine import run_mp_engine
|
33
38
|
from vllm.engine.protocol import EngineClient
|
34
|
-
from vllm.entrypoints.chat_utils import load_chat_template
|
39
|
+
from vllm.entrypoints.chat_utils import (load_chat_template,
|
40
|
+
resolve_hf_chat_template,
|
41
|
+
resolve_mistral_chat_template)
|
35
42
|
from vllm.entrypoints.launcher import serve_http
|
36
43
|
from vllm.entrypoints.logger import RequestLogger
|
37
44
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
38
45
|
validate_parsed_serve_args)
|
39
|
-
|
40
|
-
# from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
|
41
46
|
# yapf conflicts with isort for this block
|
42
47
|
# yapf: disable
|
43
48
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
@@ -46,33 +51,46 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|
46
51
|
CompletionResponse,
|
47
52
|
DetokenizeRequest,
|
48
53
|
DetokenizeResponse,
|
54
|
+
EmbeddingChatRequest,
|
55
|
+
EmbeddingCompletionRequest,
|
49
56
|
EmbeddingRequest,
|
50
57
|
EmbeddingResponse,
|
51
58
|
EmbeddingResponseData,
|
52
59
|
ErrorResponse,
|
53
|
-
|
60
|
+
LoadLoRAAdapterRequest,
|
61
|
+
PoolingChatRequest,
|
62
|
+
PoolingCompletionRequest,
|
54
63
|
PoolingRequest, PoolingResponse,
|
64
|
+
RerankRequest, RerankResponse,
|
55
65
|
ScoreRequest, ScoreResponse,
|
56
66
|
TokenizeRequest,
|
57
67
|
TokenizeResponse,
|
58
|
-
|
68
|
+
TranscriptionRequest,
|
69
|
+
TranscriptionResponse,
|
70
|
+
UnloadLoRAAdapterRequest)
|
59
71
|
# yapf: enable
|
60
72
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
61
73
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
62
74
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
75
|
+
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
63
76
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
64
77
|
OpenAIServingModels)
|
65
|
-
|
66
|
-
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
67
78
|
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
68
|
-
from vllm.entrypoints.openai.serving_score import
|
79
|
+
from vllm.entrypoints.openai.serving_score import ServingScores
|
69
80
|
from vllm.entrypoints.openai.serving_tokenization import (
|
70
81
|
OpenAIServingTokenization)
|
82
|
+
from vllm.entrypoints.openai.serving_transcription import (
|
83
|
+
OpenAIServingTranscription)
|
71
84
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
72
|
-
from vllm.entrypoints.utils import
|
85
|
+
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
86
|
+
with_cancellation)
|
73
87
|
from vllm.logger import init_logger
|
88
|
+
from vllm.reasoning import ReasoningParserManager
|
89
|
+
from vllm.transformers_utils.config import (
|
90
|
+
maybe_register_config_serialize_by_value)
|
91
|
+
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
74
92
|
from vllm.usage.usage_lib import UsageContext
|
75
|
-
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
93
|
+
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
76
94
|
is_valid_ipv6_address, set_ulimit)
|
77
95
|
from vllm.version import __version__ as VLLM_VERSION
|
78
96
|
|
@@ -83,7 +101,7 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
|
83
101
|
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
84
102
|
logger = init_logger('vllm.entrypoints.openai.api_server')
|
85
103
|
|
86
|
-
_running_tasks:
|
104
|
+
_running_tasks: set[asyncio.Task] = set()
|
87
105
|
|
88
106
|
|
89
107
|
@asynccontextmanager
|
@@ -102,6 +120,11 @@ async def lifespan(app: FastAPI):
|
|
102
120
|
task.add_done_callback(_running_tasks.remove)
|
103
121
|
else:
|
104
122
|
task = None
|
123
|
+
|
124
|
+
# Mark the startup heap as static so that it's ignored by GC.
|
125
|
+
# Reduces pause times of oldest generation collections.
|
126
|
+
gc.collect()
|
127
|
+
gc.freeze()
|
105
128
|
try:
|
106
129
|
yield
|
107
130
|
finally:
|
@@ -139,24 +162,49 @@ async def build_async_engine_client_from_engine_args(
|
|
139
162
|
Returns the Client or None if the creation failed.
|
140
163
|
"""
|
141
164
|
|
142
|
-
#
|
143
|
-
|
144
|
-
|
145
|
-
|
165
|
+
# Create the EngineConfig (determines if we can use V1).
|
166
|
+
usage_context = UsageContext.OPENAI_API_SERVER
|
167
|
+
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
168
|
+
|
169
|
+
# V1 AsyncLLM.
|
170
|
+
if envs.VLLM_USE_V1:
|
171
|
+
if disable_frontend_multiprocessing:
|
172
|
+
logger.warning(
|
173
|
+
"V1 is enabled, but got --disable-frontend-multiprocessing. "
|
174
|
+
"To disable frontend multiprocessing, set VLLM_USE_V1=0.")
|
175
|
+
|
176
|
+
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncV1Engine as AsyncLLM
|
177
|
+
async_llm: Optional[AsyncLLM] = None
|
178
|
+
try:
|
179
|
+
async_llm = AsyncLLM.from_vllm_config(
|
180
|
+
vllm_config=vllm_config,
|
181
|
+
usage_context=usage_context,
|
182
|
+
disable_log_requests=engine_args.disable_log_requests,
|
183
|
+
disable_log_stats=engine_args.disable_log_stats,
|
184
|
+
load_in_low_bit=load_in_low_bit)
|
185
|
+
yield async_llm
|
186
|
+
finally:
|
187
|
+
if async_llm:
|
188
|
+
async_llm.shutdown()
|
189
|
+
|
190
|
+
# V0 AsyncLLM.
|
191
|
+
elif (MQLLMEngineClient.is_unsupported_config(vllm_config)
|
192
|
+
or disable_frontend_multiprocessing):
|
193
|
+
|
146
194
|
engine_client: Optional[EngineClient] = None
|
147
195
|
try:
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
196
|
+
engine_client = AsyncLLMEngine.from_vllm_config(
|
197
|
+
vllm_config=vllm_config,
|
198
|
+
usage_context=usage_context,
|
199
|
+
disable_log_requests=engine_args.disable_log_requests,
|
200
|
+
disable_log_stats=engine_args.disable_log_stats,
|
153
201
|
load_in_low_bit=load_in_low_bit)
|
154
202
|
yield engine_client
|
155
203
|
finally:
|
156
204
|
if engine_client and hasattr(engine_client, "shutdown"):
|
157
205
|
engine_client.shutdown()
|
158
206
|
|
159
|
-
#
|
207
|
+
# V0MQLLMEngine.
|
160
208
|
else:
|
161
209
|
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
162
210
|
# Make TemporaryDirectory for prometheus multiprocessing
|
@@ -183,14 +231,18 @@ async def build_async_engine_client_from_engine_args(
|
|
183
231
|
# so we need to spawn a new process
|
184
232
|
context = multiprocessing.get_context("spawn")
|
185
233
|
|
234
|
+
# Ensure we can serialize transformer config before spawning
|
235
|
+
maybe_register_config_serialize_by_value()
|
236
|
+
|
186
237
|
# The Process can raise an exception during startup, which may
|
187
238
|
# not actually result in an exitcode being reported. As a result
|
188
239
|
# we use a shared variable to communicate the information.
|
189
240
|
engine_alive = multiprocessing.Value('b', True, lock=False)
|
190
|
-
engine_process = context.Process(
|
191
|
-
|
192
|
-
|
193
|
-
|
241
|
+
engine_process = context.Process(
|
242
|
+
target=run_mp_engine,
|
243
|
+
args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path,
|
244
|
+
engine_args.disable_log_stats,
|
245
|
+
engine_args.disable_log_requests, load_in_low_bit, engine_alive))
|
194
246
|
engine_process.start()
|
195
247
|
engine_pid = engine_process.pid
|
196
248
|
assert engine_pid is not None, "Engine process failed to start."
|
@@ -205,8 +257,7 @@ async def build_async_engine_client_from_engine_args(
|
|
205
257
|
atexit.register(_cleanup_ipc_path)
|
206
258
|
|
207
259
|
# Build RPCClient, which conforms to EngineClient Protocol.
|
208
|
-
|
209
|
-
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
|
260
|
+
build_client = partial(MQLLMEngineClient, ipc_path, vllm_config,
|
210
261
|
engine_pid)
|
211
262
|
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
|
212
263
|
None, build_client)
|
@@ -244,6 +295,43 @@ async def build_async_engine_client_from_engine_args(
|
|
244
295
|
multiprocess.mark_process_dead(engine_process.pid)
|
245
296
|
|
246
297
|
|
298
|
+
async def validate_json_request(raw_request: Request):
|
299
|
+
content_type = raw_request.headers.get("content-type", "").lower()
|
300
|
+
media_type = content_type.split(";", maxsplit=1)[0]
|
301
|
+
if media_type != "application/json":
|
302
|
+
raise HTTPException(
|
303
|
+
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
304
|
+
detail="Unsupported Media Type: Only 'application/json' is allowed"
|
305
|
+
)
|
306
|
+
|
307
|
+
|
308
|
+
save_dict = {}
|
309
|
+
import os
|
310
|
+
flag = os.getenv("VLLM_LOG_OUTPUT", None)
|
311
|
+
async def stream_generator(generator, request, request_id):
|
312
|
+
async for chunk in generator:
|
313
|
+
if request_id not in save_dict:
|
314
|
+
save_dict[request_id] = ""
|
315
|
+
import json
|
316
|
+
try:
|
317
|
+
data = chunk.strip()
|
318
|
+
if data.startswith('data: '):
|
319
|
+
data = data[len('data: '):]
|
320
|
+
else:
|
321
|
+
yield chunk
|
322
|
+
json_data = json.loads(data)
|
323
|
+
if 'choices' in json_data and len(json_data['choices']) > 0:
|
324
|
+
choice = json_data['choices'][0]
|
325
|
+
if 'delta' in choice:
|
326
|
+
save_dict[request_id] += choice["delta"]["content"]
|
327
|
+
elif 'text' in choice:
|
328
|
+
save_dict[request_id] += choice["text"]
|
329
|
+
except json.JSONDecodeError:
|
330
|
+
print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
|
331
|
+
pass # Done
|
332
|
+
yield chunk
|
333
|
+
|
334
|
+
|
247
335
|
router = APIRouter()
|
248
336
|
|
249
337
|
|
@@ -254,6 +342,7 @@ def mount_metrics(app: FastAPI):
|
|
254
342
|
# See https://prometheus.github.io/client_python/multiprocess/
|
255
343
|
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
256
344
|
multiprocess)
|
345
|
+
from prometheus_fastapi_instrumentator import Instrumentator
|
257
346
|
|
258
347
|
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
259
348
|
if prometheus_multiproc_dir_path is not None:
|
@@ -261,6 +350,16 @@ def mount_metrics(app: FastAPI):
|
|
261
350
|
prometheus_multiproc_dir_path)
|
262
351
|
registry = CollectorRegistry()
|
263
352
|
multiprocess.MultiProcessCollector(registry)
|
353
|
+
Instrumentator(
|
354
|
+
excluded_handlers=[
|
355
|
+
"/metrics",
|
356
|
+
"/health",
|
357
|
+
"/load",
|
358
|
+
"/ping",
|
359
|
+
"/version",
|
360
|
+
],
|
361
|
+
registry=registry,
|
362
|
+
).add().instrument(app).expose(app)
|
264
363
|
|
265
364
|
# Add prometheus asgi middleware to route /metrics requests
|
266
365
|
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
@@ -298,7 +397,11 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
|
298
397
|
return request.app.state.openai_serving_embedding
|
299
398
|
|
300
399
|
|
301
|
-
def score(request: Request) -> Optional[
|
400
|
+
def score(request: Request) -> Optional[ServingScores]:
|
401
|
+
return request.app.state.openai_serving_scores
|
402
|
+
|
403
|
+
|
404
|
+
def rerank(request: Request) -> Optional[ServingScores]:
|
302
405
|
return request.app.state.openai_serving_scores
|
303
406
|
|
304
407
|
|
@@ -306,6 +409,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
|
|
306
409
|
return request.app.state.openai_serving_tokenization
|
307
410
|
|
308
411
|
|
412
|
+
def transcription(request: Request) -> OpenAIServingTranscription:
|
413
|
+
return request.app.state.openai_serving_transcription
|
414
|
+
|
415
|
+
|
309
416
|
def engine_client(request: Request) -> EngineClient:
|
310
417
|
return request.app.state.engine_client
|
311
418
|
|
@@ -317,7 +424,31 @@ async def health(raw_request: Request) -> Response:
|
|
317
424
|
return Response(status_code=200)
|
318
425
|
|
319
426
|
|
320
|
-
@router.
|
427
|
+
@router.get("/load")
|
428
|
+
async def get_server_load_metrics(request: Request):
|
429
|
+
# This endpoint returns the current server load metrics.
|
430
|
+
# It tracks requests utilizing the GPU from the following routes:
|
431
|
+
# - /v1/chat/completions
|
432
|
+
# - /v1/completions
|
433
|
+
# - /v1/audio/transcriptions
|
434
|
+
# - /v1/embeddings
|
435
|
+
# - /pooling
|
436
|
+
# - /score
|
437
|
+
# - /v1/score
|
438
|
+
# - /rerank
|
439
|
+
# - /v1/rerank
|
440
|
+
# - /v2/rerank
|
441
|
+
return JSONResponse(
|
442
|
+
content={'server_load': request.app.state.server_load_metrics})
|
443
|
+
|
444
|
+
|
445
|
+
@router.api_route("/ping", methods=["GET", "POST"])
|
446
|
+
async def ping(raw_request: Request) -> Response:
|
447
|
+
"""Ping check. Endpoint required for SageMaker"""
|
448
|
+
return await health(raw_request)
|
449
|
+
|
450
|
+
|
451
|
+
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
|
321
452
|
@with_cancellation
|
322
453
|
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
323
454
|
handler = tokenization(raw_request)
|
@@ -332,7 +463,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
|
|
332
463
|
assert_never(generator)
|
333
464
|
|
334
465
|
|
335
|
-
@router.post("/detokenize")
|
466
|
+
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
|
336
467
|
@with_cancellation
|
337
468
|
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
338
469
|
handler = tokenization(raw_request)
|
@@ -361,35 +492,10 @@ async def show_version():
|
|
361
492
|
return JSONResponse(content=ver)
|
362
493
|
|
363
494
|
|
364
|
-
|
365
|
-
|
366
|
-
flag = os.getenv("VLLM_LOG_OUTPUT", None)
|
367
|
-
async def stream_generator(generator, request, request_id):
|
368
|
-
async for chunk in generator:
|
369
|
-
if request_id not in save_dict:
|
370
|
-
save_dict[request_id] = ""
|
371
|
-
import json
|
372
|
-
try:
|
373
|
-
data = chunk.strip()
|
374
|
-
if data.startswith('data: '):
|
375
|
-
data = data[len('data: '):]
|
376
|
-
else:
|
377
|
-
yield chunk
|
378
|
-
json_data = json.loads(data)
|
379
|
-
if 'choices' in json_data and len(json_data['choices']) > 0:
|
380
|
-
choice = json_data['choices'][0]
|
381
|
-
if 'delta' in choice:
|
382
|
-
save_dict[request_id] += choice["delta"]["content"]
|
383
|
-
elif 'text' in choice:
|
384
|
-
save_dict[request_id] += choice["text"]
|
385
|
-
except json.JSONDecodeError:
|
386
|
-
print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
|
387
|
-
pass # Done
|
388
|
-
yield chunk
|
389
|
-
|
390
|
-
|
391
|
-
@router.post("/v1/chat/completions")
|
495
|
+
@router.post("/v1/chat/completions",
|
496
|
+
dependencies=[Depends(validate_json_request)])
|
392
497
|
@with_cancellation
|
498
|
+
@load_aware_call
|
393
499
|
async def create_chat_completion(request: ChatCompletionRequest,
|
394
500
|
raw_request: Request):
|
395
501
|
handler = chat(raw_request)
|
@@ -401,7 +507,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|
401
507
|
request_id = "chatcmpl-" \
|
402
508
|
f"{handler._base_request_id(raw_request, request.request_id)}"
|
403
509
|
print(f"First received request_id: {request_id}, request: {request}")
|
404
|
-
|
405
510
|
generator = await handler.create_chat_completion(request, raw_request)
|
406
511
|
|
407
512
|
if isinstance(generator, ErrorResponse):
|
@@ -418,8 +523,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|
418
523
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
419
524
|
|
420
525
|
|
421
|
-
@router.post("/v1/completions")
|
526
|
+
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
|
422
527
|
@with_cancellation
|
528
|
+
@load_aware_call
|
423
529
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
424
530
|
handler = completion(raw_request)
|
425
531
|
if handler is None:
|
@@ -438,14 +544,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|
438
544
|
if flag is not None:
|
439
545
|
print(f"Received request-id:{request_id}, request:{request}, Output:{generator.model_dump()}")
|
440
546
|
return JSONResponse(content=generator.model_dump())
|
441
|
-
|
547
|
+
|
442
548
|
if flag is not None:
|
443
549
|
return StreamingResponse(content=stream_generator(generator, request, request_id), media_type="text/event-stream")
|
444
550
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
445
551
|
|
446
552
|
|
447
|
-
@router.post("/v1/embeddings")
|
553
|
+
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
|
448
554
|
@with_cancellation
|
555
|
+
@load_aware_call
|
449
556
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
450
557
|
handler = embedding(raw_request)
|
451
558
|
if handler is None:
|
@@ -460,6 +567,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|
460
567
|
"use the Pooling API (`/pooling`) instead.")
|
461
568
|
|
462
569
|
res = await fallback_handler.create_pooling(request, raw_request)
|
570
|
+
|
571
|
+
generator: Union[ErrorResponse, EmbeddingResponse]
|
463
572
|
if isinstance(res, PoolingResponse):
|
464
573
|
generator = EmbeddingResponse(
|
465
574
|
id=res.id,
|
@@ -488,8 +597,9 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|
488
597
|
assert_never(generator)
|
489
598
|
|
490
599
|
|
491
|
-
@router.post("/pooling")
|
600
|
+
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
|
492
601
|
@with_cancellation
|
602
|
+
@load_aware_call
|
493
603
|
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
494
604
|
handler = pooling(raw_request)
|
495
605
|
if handler is None:
|
@@ -506,8 +616,9 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
|
|
506
616
|
assert_never(generator)
|
507
617
|
|
508
618
|
|
509
|
-
@router.post("/score")
|
619
|
+
@router.post("/score", dependencies=[Depends(validate_json_request)])
|
510
620
|
@with_cancellation
|
621
|
+
@load_aware_call
|
511
622
|
async def create_score(request: ScoreRequest, raw_request: Request):
|
512
623
|
handler = score(raw_request)
|
513
624
|
if handler is None:
|
@@ -524,8 +635,9 @@ async def create_score(request: ScoreRequest, raw_request: Request):
|
|
524
635
|
assert_never(generator)
|
525
636
|
|
526
637
|
|
527
|
-
@router.post("/v1/score")
|
638
|
+
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
528
639
|
@with_cancellation
|
640
|
+
@load_aware_call
|
529
641
|
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
530
642
|
logger.warning(
|
531
643
|
"To indicate that Score API is not part of standard OpenAI API, we "
|
@@ -534,6 +646,160 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
|
534
646
|
return await create_score(request, raw_request)
|
535
647
|
|
536
648
|
|
649
|
+
@router.post("/v1/audio/transcriptions")
|
650
|
+
@with_cancellation
|
651
|
+
@load_aware_call
|
652
|
+
async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
653
|
+
Form()],
|
654
|
+
raw_request: Request):
|
655
|
+
handler = transcription(raw_request)
|
656
|
+
if handler is None:
|
657
|
+
return base(raw_request).create_error_response(
|
658
|
+
message="The model does not support Transcriptions API")
|
659
|
+
|
660
|
+
audio_data = await request.file.read()
|
661
|
+
generator = await handler.create_transcription(audio_data, request,
|
662
|
+
raw_request)
|
663
|
+
|
664
|
+
if isinstance(generator, ErrorResponse):
|
665
|
+
return JSONResponse(content=generator.model_dump(),
|
666
|
+
status_code=generator.code)
|
667
|
+
|
668
|
+
elif isinstance(generator, TranscriptionResponse):
|
669
|
+
return JSONResponse(content=generator.model_dump())
|
670
|
+
|
671
|
+
return StreamingResponse(content=generator, media_type="text/event-stream")
|
672
|
+
|
673
|
+
|
674
|
+
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
675
|
+
@with_cancellation
|
676
|
+
@load_aware_call
|
677
|
+
async def do_rerank(request: RerankRequest, raw_request: Request):
|
678
|
+
handler = rerank(raw_request)
|
679
|
+
if handler is None:
|
680
|
+
return base(raw_request).create_error_response(
|
681
|
+
message="The model does not support Rerank (Score) API")
|
682
|
+
generator = await handler.do_rerank(request, raw_request)
|
683
|
+
if isinstance(generator, ErrorResponse):
|
684
|
+
return JSONResponse(content=generator.model_dump(),
|
685
|
+
status_code=generator.code)
|
686
|
+
elif isinstance(generator, RerankResponse):
|
687
|
+
return JSONResponse(content=generator.model_dump())
|
688
|
+
|
689
|
+
assert_never(generator)
|
690
|
+
|
691
|
+
|
692
|
+
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
|
693
|
+
@with_cancellation
|
694
|
+
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
695
|
+
logger.warning_once(
|
696
|
+
"To indicate that the rerank API is not part of the standard OpenAI"
|
697
|
+
" API, we have located it at `/rerank`. Please update your client "
|
698
|
+
"accordingly. (Note: Conforms to JinaAI rerank API)")
|
699
|
+
|
700
|
+
return await do_rerank(request, raw_request)
|
701
|
+
|
702
|
+
|
703
|
+
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
|
704
|
+
@with_cancellation
|
705
|
+
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
706
|
+
return await do_rerank(request, raw_request)
|
707
|
+
|
708
|
+
|
709
|
+
TASK_HANDLERS: dict[str, dict[str, tuple]] = {
|
710
|
+
"generate": {
|
711
|
+
"messages": (ChatCompletionRequest, create_chat_completion),
|
712
|
+
"default": (CompletionRequest, create_completion),
|
713
|
+
},
|
714
|
+
"embed": {
|
715
|
+
"messages": (EmbeddingChatRequest, create_embedding),
|
716
|
+
"default": (EmbeddingCompletionRequest, create_embedding),
|
717
|
+
},
|
718
|
+
"score": {
|
719
|
+
"default": (RerankRequest, do_rerank)
|
720
|
+
},
|
721
|
+
"rerank": {
|
722
|
+
"default": (RerankRequest, do_rerank)
|
723
|
+
},
|
724
|
+
"reward": {
|
725
|
+
"messages": (PoolingChatRequest, create_pooling),
|
726
|
+
"default": (PoolingCompletionRequest, create_pooling),
|
727
|
+
},
|
728
|
+
"classify": {
|
729
|
+
"messages": (PoolingChatRequest, create_pooling),
|
730
|
+
"default": (PoolingCompletionRequest, create_pooling),
|
731
|
+
},
|
732
|
+
}
|
733
|
+
|
734
|
+
if envs.VLLM_SERVER_DEV_MODE:
|
735
|
+
|
736
|
+
@router.post("/reset_prefix_cache")
|
737
|
+
async def reset_prefix_cache(raw_request: Request):
|
738
|
+
"""
|
739
|
+
Reset the prefix cache. Note that we currently do not check if the
|
740
|
+
prefix cache is successfully reset in the API server.
|
741
|
+
"""
|
742
|
+
device = None
|
743
|
+
device_str = raw_request.query_params.get("device")
|
744
|
+
if device_str is not None:
|
745
|
+
device = Device[device_str.upper()]
|
746
|
+
logger.info("Resetting prefix cache with specific %s...", str(device))
|
747
|
+
await engine_client(raw_request).reset_prefix_cache(device)
|
748
|
+
return Response(status_code=200)
|
749
|
+
|
750
|
+
@router.post("/sleep")
|
751
|
+
async def sleep(raw_request: Request):
|
752
|
+
# get POST params
|
753
|
+
level = raw_request.query_params.get("level", "1")
|
754
|
+
await engine_client(raw_request).sleep(int(level))
|
755
|
+
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
756
|
+
# is sent but does not finish yet when we return a response.
|
757
|
+
return Response(status_code=200)
|
758
|
+
|
759
|
+
@router.post("/wake_up")
|
760
|
+
async def wake_up(raw_request: Request):
|
761
|
+
tags = raw_request.query_params.getlist("tags")
|
762
|
+
if tags == []:
|
763
|
+
# set to None to wake up all tags if no tags are provided
|
764
|
+
tags = None
|
765
|
+
logger.info("wake up the engine with tags: %s", tags)
|
766
|
+
await engine_client(raw_request).wake_up(tags)
|
767
|
+
# FIXME: in v0 with frontend multiprocessing, the wake-up command
|
768
|
+
# is sent but does not finish yet when we return a response.
|
769
|
+
return Response(status_code=200)
|
770
|
+
|
771
|
+
@router.get("/is_sleeping")
|
772
|
+
async def is_sleeping(raw_request: Request):
|
773
|
+
logger.info("check whether the engine is sleeping")
|
774
|
+
is_sleeping = await engine_client(raw_request).is_sleeping()
|
775
|
+
return JSONResponse(content={"is_sleeping": is_sleeping})
|
776
|
+
|
777
|
+
|
778
|
+
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
|
779
|
+
async def invocations(raw_request: Request):
|
780
|
+
"""
|
781
|
+
For SageMaker, routes requests to other handlers based on model `task`.
|
782
|
+
"""
|
783
|
+
body = await raw_request.json()
|
784
|
+
task = raw_request.app.state.task
|
785
|
+
|
786
|
+
if task not in TASK_HANDLERS:
|
787
|
+
raise HTTPException(
|
788
|
+
status_code=400,
|
789
|
+
detail=f"Unsupported task: '{task}' for '/invocations'. "
|
790
|
+
f"Expected one of {set(TASK_HANDLERS.keys())}")
|
791
|
+
|
792
|
+
handler_config = TASK_HANDLERS[task]
|
793
|
+
if "messages" in body:
|
794
|
+
request_model, handler = handler_config["messages"]
|
795
|
+
else:
|
796
|
+
request_model, handler = handler_config["default"]
|
797
|
+
|
798
|
+
# this is required since we lose the FastAPI automatic casting
|
799
|
+
request = request_model.model_validate(body)
|
800
|
+
return await handler(request, raw_request)
|
801
|
+
|
802
|
+
|
537
803
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
538
804
|
logger.warning(
|
539
805
|
"Torch Profiler is enabled in the API server. This should ONLY be "
|
@@ -556,32 +822,30 @@ if envs.VLLM_TORCH_PROFILER_DIR:
|
|
556
822
|
|
557
823
|
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
558
824
|
logger.warning(
|
559
|
-
"
|
825
|
+
"LoRA dynamic loading & unloading is enabled in the API server. "
|
560
826
|
"This should ONLY be used for local development!")
|
561
827
|
|
562
|
-
@router.post("/v1/load_lora_adapter"
|
563
|
-
|
828
|
+
@router.post("/v1/load_lora_adapter",
|
829
|
+
dependencies=[Depends(validate_json_request)])
|
830
|
+
async def load_lora_adapter(request: LoadLoRAAdapterRequest,
|
564
831
|
raw_request: Request):
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
return JSONResponse(content=response.model_dump(),
|
571
|
-
status_code=response.code)
|
832
|
+
handler = models(raw_request)
|
833
|
+
response = await handler.load_lora_adapter(request)
|
834
|
+
if isinstance(response, ErrorResponse):
|
835
|
+
return JSONResponse(content=response.model_dump(),
|
836
|
+
status_code=response.code)
|
572
837
|
|
573
838
|
return Response(status_code=200, content=response)
|
574
839
|
|
575
|
-
@router.post("/v1/unload_lora_adapter"
|
576
|
-
|
840
|
+
@router.post("/v1/unload_lora_adapter",
|
841
|
+
dependencies=[Depends(validate_json_request)])
|
842
|
+
async def unload_lora_adapter(request: UnloadLoRAAdapterRequest,
|
577
843
|
raw_request: Request):
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
return JSONResponse(content=response.model_dump(),
|
584
|
-
status_code=response.code)
|
844
|
+
handler = models(raw_request)
|
845
|
+
response = await handler.unload_lora_adapter(request)
|
846
|
+
if isinstance(response, ErrorResponse):
|
847
|
+
return JSONResponse(content=response.model_dump(),
|
848
|
+
status_code=response.code)
|
585
849
|
|
586
850
|
return Response(status_code=200, content=response)
|
587
851
|
|
@@ -615,7 +879,8 @@ def build_app(args: Namespace) -> FastAPI:
|
|
615
879
|
return JSONResponse(err.model_dump(),
|
616
880
|
status_code=HTTPStatus.BAD_REQUEST)
|
617
881
|
|
618
|
-
|
882
|
+
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
|
883
|
+
if token := args.api_key or envs.VLLM_API_KEY:
|
619
884
|
|
620
885
|
@app.middleware("http")
|
621
886
|
async def authentication(request: Request, call_next):
|
@@ -644,11 +909,26 @@ def build_app(args: Namespace) -> FastAPI:
|
|
644
909
|
response.headers["X-Request-Id"] = request_id
|
645
910
|
return response
|
646
911
|
|
912
|
+
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
|
913
|
+
logger.warning("CAUTION: Enabling log response in the API Server. "
|
914
|
+
"This can include sensitive information and should be "
|
915
|
+
"avoided in production.")
|
916
|
+
|
917
|
+
@app.middleware("http")
|
918
|
+
async def log_response(request: Request, call_next):
|
919
|
+
response = await call_next(request)
|
920
|
+
response_body = [
|
921
|
+
section async for section in response.body_iterator
|
922
|
+
]
|
923
|
+
response.body_iterator = iterate_in_threadpool(iter(response_body))
|
924
|
+
logger.info("response_body={%s}", response_body[0].decode())
|
925
|
+
return response
|
926
|
+
|
647
927
|
for middleware in args.middleware:
|
648
928
|
module_path, object_name = middleware.rsplit(".", 1)
|
649
929
|
imported = getattr(importlib.import_module(module_path), object_name)
|
650
930
|
if inspect.isclass(imported):
|
651
|
-
app.add_middleware(imported)
|
931
|
+
app.add_middleware(imported) # type: ignore[arg-type]
|
652
932
|
elif inspect.iscoroutinefunction(imported):
|
653
933
|
app.middleware("http")(imported)
|
654
934
|
else:
|
@@ -658,7 +938,7 @@ def build_app(args: Namespace) -> FastAPI:
|
|
658
938
|
return app
|
659
939
|
|
660
940
|
|
661
|
-
def init_app_state(
|
941
|
+
async def init_app_state(
|
662
942
|
engine_client: EngineClient,
|
663
943
|
model_config: ModelConfig,
|
664
944
|
state: State,
|
@@ -683,15 +963,36 @@ def init_app_state(
|
|
683
963
|
state.log_stats = not args.disable_log_stats
|
684
964
|
|
685
965
|
resolved_chat_template = load_chat_template(args.chat_template)
|
686
|
-
|
966
|
+
if resolved_chat_template is not None:
|
967
|
+
# Get the tokenizer to check official template
|
968
|
+
tokenizer = await engine_client.get_tokenizer()
|
969
|
+
|
970
|
+
if isinstance(tokenizer, MistralTokenizer):
|
971
|
+
# The warning is logged in resolve_mistral_chat_template.
|
972
|
+
resolved_chat_template = resolve_mistral_chat_template(
|
973
|
+
chat_template=resolved_chat_template)
|
974
|
+
else:
|
975
|
+
hf_chat_template = resolve_hf_chat_template(
|
976
|
+
tokenizer,
|
977
|
+
chat_template=None,
|
978
|
+
tools=None,
|
979
|
+
trust_remote_code=model_config.trust_remote_code)
|
980
|
+
|
981
|
+
if hf_chat_template != resolved_chat_template:
|
982
|
+
logger.warning(
|
983
|
+
"Using supplied chat template: %s\n"
|
984
|
+
"It is different from official chat template '%s'. "
|
985
|
+
"This discrepancy may lead to performance degradation.",
|
986
|
+
resolved_chat_template, args.model)
|
687
987
|
|
688
988
|
state.openai_serving_models = OpenAIServingModels(
|
989
|
+
engine_client=engine_client,
|
689
990
|
model_config=model_config,
|
690
991
|
base_model_paths=base_model_paths,
|
691
992
|
lora_modules=args.lora_modules,
|
692
993
|
prompt_adapters=args.prompt_adapters,
|
693
994
|
)
|
694
|
-
|
995
|
+
await state.openai_serving_models.init_static_loras()
|
695
996
|
state.openai_serving_chat = OpenAIServingChat(
|
696
997
|
engine_client,
|
697
998
|
model_config,
|
@@ -703,6 +1004,8 @@ def init_app_state(
|
|
703
1004
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
704
1005
|
enable_auto_tools=args.enable_auto_tool_choice,
|
705
1006
|
tool_parser=args.tool_call_parser,
|
1007
|
+
enable_reasoning=args.enable_reasoning,
|
1008
|
+
reasoning_parser=args.reasoning_parser,
|
706
1009
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
707
1010
|
) if model_config.runner_type == "generate" else None
|
708
1011
|
state.openai_serving_completion = OpenAIServingCompletion(
|
@@ -728,7 +1031,13 @@ def init_app_state(
|
|
728
1031
|
chat_template=resolved_chat_template,
|
729
1032
|
chat_template_content_format=args.chat_template_content_format,
|
730
1033
|
) if model_config.task == "embed" else None
|
731
|
-
state.openai_serving_scores =
|
1034
|
+
state.openai_serving_scores = ServingScores(
|
1035
|
+
engine_client,
|
1036
|
+
model_config,
|
1037
|
+
state.openai_serving_models,
|
1038
|
+
request_logger=request_logger) if model_config.task in (
|
1039
|
+
"score", "embed", "pooling") else None
|
1040
|
+
state.jinaai_serving_reranking = ServingScores(
|
732
1041
|
engine_client,
|
733
1042
|
model_config,
|
734
1043
|
state.openai_serving_models,
|
@@ -742,92 +1051,26 @@ def init_app_state(
|
|
742
1051
|
chat_template=resolved_chat_template,
|
743
1052
|
chat_template_content_format=args.chat_template_content_format,
|
744
1053
|
)
|
1054
|
+
state.openai_serving_transcription = OpenAIServingTranscription(
|
1055
|
+
engine_client,
|
1056
|
+
model_config,
|
1057
|
+
state.openai_serving_models,
|
1058
|
+
request_logger=request_logger,
|
1059
|
+
) if model_config.runner_type == "transcription" else None
|
745
1060
|
state.task = model_config.task
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
# request_logger = None
|
753
|
-
# else:
|
754
|
-
# request_logger = RequestLogger(max_log_len=args.max_log_len)
|
755
|
-
|
756
|
-
# base_model_paths = [
|
757
|
-
# BaseModelPath(name=name, model_path=args.model)
|
758
|
-
# for name in served_model_names
|
759
|
-
# ]
|
760
|
-
|
761
|
-
# state.engine_client = engine_client
|
762
|
-
# state.log_stats = not args.disable_log_stats
|
763
|
-
|
764
|
-
# resolved_chat_template = load_chat_template(args.chat_template)
|
765
|
-
# logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
766
|
-
|
767
|
-
# state.openai_serving_chat = OpenAIServingChat(
|
768
|
-
# engine_client,
|
769
|
-
# model_config,
|
770
|
-
# base_model_paths,
|
771
|
-
# args.response_role,
|
772
|
-
# lora_modules=args.lora_modules,
|
773
|
-
# prompt_adapters=args.prompt_adapters,
|
774
|
-
# request_logger=request_logger,
|
775
|
-
# chat_template=resolved_chat_template,
|
776
|
-
# chat_template_content_format=args.chat_template_content_format,
|
777
|
-
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
778
|
-
# enable_auto_tools=args.enable_auto_tool_choice,
|
779
|
-
# tool_parser=args.tool_call_parser,
|
780
|
-
# enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
781
|
-
# ) if model_config.runner_type == "generate" else None
|
782
|
-
# state.openai_serving_completion = OpenAIServingCompletion(
|
783
|
-
# engine_client,
|
784
|
-
# model_config,
|
785
|
-
# base_model_paths,
|
786
|
-
# lora_modules=args.lora_modules,
|
787
|
-
# prompt_adapters=args.prompt_adapters,
|
788
|
-
# request_logger=request_logger,
|
789
|
-
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
790
|
-
# ) if model_config.runner_type == "generate" else None
|
791
|
-
# state.openai_serving_pooling = OpenAIServingPooling(
|
792
|
-
# engine_client,
|
793
|
-
# model_config,
|
794
|
-
# base_model_paths,
|
795
|
-
# request_logger=request_logger,
|
796
|
-
# chat_template=resolved_chat_template,
|
797
|
-
# chat_template_content_format=args.chat_template_content_format,
|
798
|
-
# ) if model_config.runner_type == "pooling" else None
|
799
|
-
# state.openai_serving_embedding = OpenAIServingEmbedding(
|
800
|
-
# engine_client,
|
801
|
-
# model_config,
|
802
|
-
# base_model_paths,
|
803
|
-
# request_logger=request_logger,
|
804
|
-
# chat_template=resolved_chat_template,
|
805
|
-
# chat_template_content_format=args.chat_template_content_format,
|
806
|
-
# ) if model_config.task == "embed" else None
|
807
|
-
# state.openai_serving_scores = OpenAIServingScores(
|
808
|
-
# engine_client,
|
809
|
-
# model_config,
|
810
|
-
# base_model_paths,
|
811
|
-
# request_logger=request_logger
|
812
|
-
# ) if model_config.task == "score" else None
|
813
|
-
# state.openai_serving_tokenization = OpenAIServingTokenization(
|
814
|
-
# engine_client,
|
815
|
-
# model_config,
|
816
|
-
# base_model_paths,
|
817
|
-
# lora_modules=args.lora_modules,
|
818
|
-
# request_logger=request_logger,
|
819
|
-
# chat_template=resolved_chat_template,
|
820
|
-
# chat_template_content_format=args.chat_template_content_format,
|
821
|
-
# )
|
822
|
-
|
823
|
-
|
824
|
-
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
|
1061
|
+
|
1062
|
+
state.enable_server_load_tracking = args.enable_server_load_tracking
|
1063
|
+
state.server_load_metrics = 0
|
1064
|
+
|
1065
|
+
|
1066
|
+
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
|
825
1067
|
family = socket.AF_INET
|
826
1068
|
if is_valid_ipv6_address(addr[0]):
|
827
1069
|
family = socket.AF_INET6
|
828
1070
|
|
829
1071
|
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
|
830
1072
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
1073
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
831
1074
|
sock.bind(addr)
|
832
1075
|
|
833
1076
|
return sock
|
@@ -840,11 +1083,18 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|
840
1083
|
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
841
1084
|
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
842
1085
|
|
843
|
-
|
1086
|
+
valid_tool_parses = ToolParserManager.tool_parsers.keys()
|
844
1087
|
if args.enable_auto_tool_choice \
|
845
|
-
and args.tool_call_parser not in
|
1088
|
+
and args.tool_call_parser not in valid_tool_parses:
|
846
1089
|
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
847
|
-
f"(chose from {{ {','.join(
|
1090
|
+
f"(chose from {{ {','.join(valid_tool_parses)} }})")
|
1091
|
+
|
1092
|
+
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
|
1093
|
+
if args.enable_reasoning \
|
1094
|
+
and args.reasoning_parser not in valid_reasoning_parses:
|
1095
|
+
raise KeyError(
|
1096
|
+
f"invalid reasoning parser: {args.reasoning_parser} "
|
1097
|
+
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
|
848
1098
|
|
849
1099
|
# workaround to make sure that we bind the port before the engine is set up.
|
850
1100
|
# This avoids race conditions with ray.
|
@@ -866,13 +1116,28 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|
866
1116
|
app = build_app(args)
|
867
1117
|
|
868
1118
|
model_config = await engine_client.get_model_config()
|
869
|
-
init_app_state(engine_client, model_config, app.state, args)
|
1119
|
+
await init_app_state(engine_client, model_config, app.state, args)
|
1120
|
+
|
1121
|
+
def _listen_addr(a: str) -> str:
|
1122
|
+
if is_valid_ipv6_address(a):
|
1123
|
+
return '[' + a + ']'
|
1124
|
+
return a or "0.0.0.0"
|
1125
|
+
|
1126
|
+
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
1127
|
+
logger.info("Starting vLLM API server on http%s://%s:%d",
|
1128
|
+
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
|
1129
|
+
sock_addr[1])
|
870
1130
|
|
871
1131
|
shutdown_task = await serve_http(
|
872
1132
|
app,
|
1133
|
+
sock=sock,
|
1134
|
+
enable_ssl_refresh=args.enable_ssl_refresh,
|
873
1135
|
host=args.host,
|
874
1136
|
port=args.port,
|
875
1137
|
log_level=args.uvicorn_log_level,
|
1138
|
+
# NOTE: When the 'disable_uvicorn_access_log' value is True,
|
1139
|
+
# no access log will be output.
|
1140
|
+
access_log=not args.disable_uvicorn_access_log,
|
876
1141
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
877
1142
|
ssl_keyfile=args.ssl_keyfile,
|
878
1143
|
ssl_certfile=args.ssl_certfile,
|
@@ -882,16 +1147,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|
882
1147
|
)
|
883
1148
|
|
884
1149
|
# NB: Await server shutdown only after the backend context is exited
|
885
|
-
|
886
|
-
|
887
|
-
|
1150
|
+
try:
|
1151
|
+
await shutdown_task
|
1152
|
+
finally:
|
1153
|
+
sock.close()
|
888
1154
|
|
889
1155
|
|
890
1156
|
if __name__ == "__main__":
|
891
1157
|
# NOTE(simon):
|
892
|
-
# This section should be in sync with vllm/
|
1158
|
+
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
|
1159
|
+
# entrypoints.
|
893
1160
|
logger.warning("Warning: Please use `ipex_llm.vllm.xpu.entrypoints.openai.api_server` "
|
894
1161
|
"instead of `vllm.entrypoints.openai.api_server` to start the API server")
|
1162
|
+
cli_env_setup()
|
895
1163
|
parser = FlexibleArgumentParser(
|
896
1164
|
description="vLLM OpenAI-Compatible RESTful API server.")
|
897
1165
|
parser = make_arg_parser(parser)
|