ipex-llm 2.3.0b20250428__py3-none-win_amd64.whl → 2.3.0b20250501__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. ipex_llm/libs/bloom-api.dll +0 -0
  2. ipex_llm/libs/bloom.dll +0 -0
  3. ipex_llm/libs/gptneox-api.dll +0 -0
  4. ipex_llm/libs/gptneox.dll +0 -0
  5. ipex_llm/libs/libbloom_avx.dll +0 -0
  6. ipex_llm/libs/libbloom_vnni.dll +0 -0
  7. ipex_llm/libs/libgptneox_avx.dll +0 -0
  8. ipex_llm/libs/libgptneox_vnni.dll +0 -0
  9. ipex_llm/libs/libllama_avx.dll +0 -0
  10. ipex_llm/libs/libllama_vnni.dll +0 -0
  11. ipex_llm/libs/libstarcoder_avx.dll +0 -0
  12. ipex_llm/libs/libstarcoder_vnni.dll +0 -0
  13. ipex_llm/libs/llama-api.dll +0 -0
  14. ipex_llm/libs/llama.dll +0 -0
  15. ipex_llm/libs/main-bloom.exe +0 -0
  16. ipex_llm/libs/main-gptneox.exe +0 -0
  17. ipex_llm/libs/main-llama.exe +0 -0
  18. ipex_llm/libs/main-starcoder.exe +0 -0
  19. ipex_llm/libs/pipeline.dll +0 -0
  20. ipex_llm/libs/quantize-bloom.exe +0 -0
  21. ipex_llm/libs/quantize-bloom_vnni.exe +0 -0
  22. ipex_llm/libs/quantize-gptneox.exe +0 -0
  23. ipex_llm/libs/quantize-gptneox_vnni.exe +0 -0
  24. ipex_llm/libs/quantize-llama.exe +0 -0
  25. ipex_llm/libs/quantize-llama_vnni.exe +0 -0
  26. ipex_llm/libs/quantize-starcoder.exe +0 -0
  27. ipex_llm/libs/quantize-starcoder_vnni.exe +0 -0
  28. ipex_llm/libs/starcoder-api.dll +0 -0
  29. ipex_llm/libs/starcoder.dll +0 -0
  30. ipex_llm/transformers/convert.py +3 -2
  31. ipex_llm/vllm/xpu/engine/__init__.py +3 -1
  32. ipex_llm/vllm/xpu/engine/engine.py +163 -19
  33. ipex_llm/vllm/xpu/entrypoints/openai/api_server.py +448 -180
  34. ipex_llm/vllm/xpu/model_convert.py +5 -2
  35. {ipex_llm-2.3.0b20250428.dist-info → ipex_llm-2.3.0b20250501.dist-info}/METADATA +11 -11
  36. {ipex_llm-2.3.0b20250428.dist-info → ipex_llm-2.3.0b20250501.dist-info}/RECORD +42 -42
  37. {ipex_llm-2.3.0b20250428.data → ipex_llm-2.3.0b20250501.data}/scripts/ipex-llm-init.bat +0 -0
  38. {ipex_llm-2.3.0b20250428.data → ipex_llm-2.3.0b20250501.data}/scripts/llm-chat.ps1 +0 -0
  39. {ipex_llm-2.3.0b20250428.data → ipex_llm-2.3.0b20250501.data}/scripts/llm-cli.ps1 +0 -0
  40. {ipex_llm-2.3.0b20250428.dist-info → ipex_llm-2.3.0b20250501.dist-info}/WHEEL +0 -0
  41. {ipex_llm-2.3.0b20250428.dist-info → ipex_llm-2.3.0b20250501.dist-info}/entry_points.txt +0 -0
  42. {ipex_llm-2.3.0b20250428.dist-info → ipex_llm-2.3.0b20250501.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,8 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
1
3
  import asyncio
2
4
  import atexit
5
+ import gc
3
6
  import importlib
4
7
  import inspect
5
8
  import multiprocessing
@@ -10,16 +13,18 @@ import socket
10
13
  import tempfile
11
14
  import uuid
12
15
  from argparse import Namespace
16
+ from collections.abc import AsyncIterator
13
17
  from contextlib import asynccontextmanager
14
18
  from functools import partial
15
19
  from http import HTTPStatus
16
- from typing import AsyncIterator, Optional, Set, Tuple
20
+ from typing import Annotated, Optional, Union
17
21
 
18
22
  import uvloop
19
- from fastapi import APIRouter, FastAPI, Request
23
+ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
20
24
  from fastapi.exceptions import RequestValidationError
21
25
  from fastapi.middleware.cors import CORSMiddleware
22
26
  from fastapi.responses import JSONResponse, Response, StreamingResponse
27
+ from starlette.concurrency import iterate_in_threadpool
23
28
  from starlette.datastructures import State
24
29
  from starlette.routing import Mount
25
30
  from typing_extensions import assert_never
@@ -27,17 +32,17 @@ from typing_extensions import assert_never
27
32
  import vllm.envs as envs
28
33
  from vllm.config import ModelConfig
29
34
  from vllm.engine.arg_utils import AsyncEngineArgs
30
- from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
35
+ from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine # type: ignore
31
36
  from vllm.engine.multiprocessing.client import MQLLMEngineClient
32
37
  from ipex_llm.vllm.xpu.engine import run_mp_engine
33
38
  from vllm.engine.protocol import EngineClient
34
- from vllm.entrypoints.chat_utils import load_chat_template
39
+ from vllm.entrypoints.chat_utils import (load_chat_template,
40
+ resolve_hf_chat_template,
41
+ resolve_mistral_chat_template)
35
42
  from vllm.entrypoints.launcher import serve_http
36
43
  from vllm.entrypoints.logger import RequestLogger
37
44
  from vllm.entrypoints.openai.cli_args import (make_arg_parser,
38
45
  validate_parsed_serve_args)
39
-
40
- # from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
41
46
  # yapf conflicts with isort for this block
42
47
  # yapf: disable
43
48
  from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@@ -46,33 +51,46 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
46
51
  CompletionResponse,
47
52
  DetokenizeRequest,
48
53
  DetokenizeResponse,
54
+ EmbeddingChatRequest,
55
+ EmbeddingCompletionRequest,
49
56
  EmbeddingRequest,
50
57
  EmbeddingResponse,
51
58
  EmbeddingResponseData,
52
59
  ErrorResponse,
53
- LoadLoraAdapterRequest,
60
+ LoadLoRAAdapterRequest,
61
+ PoolingChatRequest,
62
+ PoolingCompletionRequest,
54
63
  PoolingRequest, PoolingResponse,
64
+ RerankRequest, RerankResponse,
55
65
  ScoreRequest, ScoreResponse,
56
66
  TokenizeRequest,
57
67
  TokenizeResponse,
58
- UnloadLoraAdapterRequest)
68
+ TranscriptionRequest,
69
+ TranscriptionResponse,
70
+ UnloadLoRAAdapterRequest)
59
71
  # yapf: enable
60
72
  from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
61
73
  from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
62
74
  from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
75
+ from vllm.entrypoints.openai.serving_engine import OpenAIServing
63
76
  from vllm.entrypoints.openai.serving_models import (BaseModelPath,
64
77
  OpenAIServingModels)
65
-
66
- from vllm.entrypoints.openai.serving_engine import OpenAIServing
67
78
  from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
68
- from vllm.entrypoints.openai.serving_score import OpenAIServingScores
79
+ from vllm.entrypoints.openai.serving_score import ServingScores
69
80
  from vllm.entrypoints.openai.serving_tokenization import (
70
81
  OpenAIServingTokenization)
82
+ from vllm.entrypoints.openai.serving_transcription import (
83
+ OpenAIServingTranscription)
71
84
  from vllm.entrypoints.openai.tool_parsers import ToolParserManager
72
- from vllm.entrypoints.utils import with_cancellation
85
+ from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
86
+ with_cancellation)
73
87
  from vllm.logger import init_logger
88
+ from vllm.reasoning import ReasoningParserManager
89
+ from vllm.transformers_utils.config import (
90
+ maybe_register_config_serialize_by_value)
91
+ from vllm.transformers_utils.tokenizer import MistralTokenizer
74
92
  from vllm.usage.usage_lib import UsageContext
75
- from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
93
+ from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
76
94
  is_valid_ipv6_address, set_ulimit)
77
95
  from vllm.version import __version__ as VLLM_VERSION
78
96
 
@@ -83,7 +101,7 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory
83
101
  # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
84
102
  logger = init_logger('vllm.entrypoints.openai.api_server')
85
103
 
86
- _running_tasks: Set[asyncio.Task] = set()
104
+ _running_tasks: set[asyncio.Task] = set()
87
105
 
88
106
 
89
107
  @asynccontextmanager
@@ -102,6 +120,11 @@ async def lifespan(app: FastAPI):
102
120
  task.add_done_callback(_running_tasks.remove)
103
121
  else:
104
122
  task = None
123
+
124
+ # Mark the startup heap as static so that it's ignored by GC.
125
+ # Reduces pause times of oldest generation collections.
126
+ gc.collect()
127
+ gc.freeze()
105
128
  try:
106
129
  yield
107
130
  finally:
@@ -139,24 +162,49 @@ async def build_async_engine_client_from_engine_args(
139
162
  Returns the Client or None if the creation failed.
140
163
  """
141
164
 
142
- # Fall back
143
- # TODO: fill out feature matrix.
144
- if (MQLLMEngineClient.is_unsupported_config(engine_args)
145
- or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
165
+ # Create the EngineConfig (determines if we can use V1).
166
+ usage_context = UsageContext.OPENAI_API_SERVER
167
+ vllm_config = engine_args.create_engine_config(usage_context=usage_context)
168
+
169
+ # V1 AsyncLLM.
170
+ if envs.VLLM_USE_V1:
171
+ if disable_frontend_multiprocessing:
172
+ logger.warning(
173
+ "V1 is enabled, but got --disable-frontend-multiprocessing. "
174
+ "To disable frontend multiprocessing, set VLLM_USE_V1=0.")
175
+
176
+ from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncV1Engine as AsyncLLM
177
+ async_llm: Optional[AsyncLLM] = None
178
+ try:
179
+ async_llm = AsyncLLM.from_vllm_config(
180
+ vllm_config=vllm_config,
181
+ usage_context=usage_context,
182
+ disable_log_requests=engine_args.disable_log_requests,
183
+ disable_log_stats=engine_args.disable_log_stats,
184
+ load_in_low_bit=load_in_low_bit)
185
+ yield async_llm
186
+ finally:
187
+ if async_llm:
188
+ async_llm.shutdown()
189
+
190
+ # V0 AsyncLLM.
191
+ elif (MQLLMEngineClient.is_unsupported_config(vllm_config)
192
+ or disable_frontend_multiprocessing):
193
+
146
194
  engine_client: Optional[EngineClient] = None
147
195
  try:
148
- # When starting this, we are actually starting with the V1Engine
149
- # Here we are doing a classification, we will need to do this in IPEX-LLM
150
- engine_client = AsyncLLMEngine.from_engine_args(
151
- engine_args=engine_args,
152
- usage_context=UsageContext.OPENAI_API_SERVER,
196
+ engine_client = AsyncLLMEngine.from_vllm_config(
197
+ vllm_config=vllm_config,
198
+ usage_context=usage_context,
199
+ disable_log_requests=engine_args.disable_log_requests,
200
+ disable_log_stats=engine_args.disable_log_stats,
153
201
  load_in_low_bit=load_in_low_bit)
154
202
  yield engine_client
155
203
  finally:
156
204
  if engine_client and hasattr(engine_client, "shutdown"):
157
205
  engine_client.shutdown()
158
206
 
159
- # Otherwise, use the multiprocessing AsyncLLMEngine.
207
+ # V0MQLLMEngine.
160
208
  else:
161
209
  if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
162
210
  # Make TemporaryDirectory for prometheus multiprocessing
@@ -183,14 +231,18 @@ async def build_async_engine_client_from_engine_args(
183
231
  # so we need to spawn a new process
184
232
  context = multiprocessing.get_context("spawn")
185
233
 
234
+ # Ensure we can serialize transformer config before spawning
235
+ maybe_register_config_serialize_by_value()
236
+
186
237
  # The Process can raise an exception during startup, which may
187
238
  # not actually result in an exitcode being reported. As a result
188
239
  # we use a shared variable to communicate the information.
189
240
  engine_alive = multiprocessing.Value('b', True, lock=False)
190
- engine_process = context.Process(target=run_mp_engine,
191
- args=(engine_args,
192
- UsageContext.OPENAI_API_SERVER,
193
- ipc_path, load_in_low_bit, engine_alive))
241
+ engine_process = context.Process(
242
+ target=run_mp_engine,
243
+ args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path,
244
+ engine_args.disable_log_stats,
245
+ engine_args.disable_log_requests, load_in_low_bit, engine_alive))
194
246
  engine_process.start()
195
247
  engine_pid = engine_process.pid
196
248
  assert engine_pid is not None, "Engine process failed to start."
@@ -205,8 +257,7 @@ async def build_async_engine_client_from_engine_args(
205
257
  atexit.register(_cleanup_ipc_path)
206
258
 
207
259
  # Build RPCClient, which conforms to EngineClient Protocol.
208
- engine_config = engine_args.create_engine_config()
209
- build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
260
+ build_client = partial(MQLLMEngineClient, ipc_path, vllm_config,
210
261
  engine_pid)
211
262
  mq_engine_client = await asyncio.get_running_loop().run_in_executor(
212
263
  None, build_client)
@@ -244,6 +295,43 @@ async def build_async_engine_client_from_engine_args(
244
295
  multiprocess.mark_process_dead(engine_process.pid)
245
296
 
246
297
 
298
+ async def validate_json_request(raw_request: Request):
299
+ content_type = raw_request.headers.get("content-type", "").lower()
300
+ media_type = content_type.split(";", maxsplit=1)[0]
301
+ if media_type != "application/json":
302
+ raise HTTPException(
303
+ status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
304
+ detail="Unsupported Media Type: Only 'application/json' is allowed"
305
+ )
306
+
307
+
308
+ save_dict = {}
309
+ import os
310
+ flag = os.getenv("VLLM_LOG_OUTPUT", None)
311
+ async def stream_generator(generator, request, request_id):
312
+ async for chunk in generator:
313
+ if request_id not in save_dict:
314
+ save_dict[request_id] = ""
315
+ import json
316
+ try:
317
+ data = chunk.strip()
318
+ if data.startswith('data: '):
319
+ data = data[len('data: '):]
320
+ else:
321
+ yield chunk
322
+ json_data = json.loads(data)
323
+ if 'choices' in json_data and len(json_data['choices']) > 0:
324
+ choice = json_data['choices'][0]
325
+ if 'delta' in choice:
326
+ save_dict[request_id] += choice["delta"]["content"]
327
+ elif 'text' in choice:
328
+ save_dict[request_id] += choice["text"]
329
+ except json.JSONDecodeError:
330
+ print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
331
+ pass # Done
332
+ yield chunk
333
+
334
+
247
335
  router = APIRouter()
248
336
 
249
337
 
@@ -254,6 +342,7 @@ def mount_metrics(app: FastAPI):
254
342
  # See https://prometheus.github.io/client_python/multiprocess/
255
343
  from prometheus_client import (CollectorRegistry, make_asgi_app,
256
344
  multiprocess)
345
+ from prometheus_fastapi_instrumentator import Instrumentator
257
346
 
258
347
  prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
259
348
  if prometheus_multiproc_dir_path is not None:
@@ -261,6 +350,16 @@ def mount_metrics(app: FastAPI):
261
350
  prometheus_multiproc_dir_path)
262
351
  registry = CollectorRegistry()
263
352
  multiprocess.MultiProcessCollector(registry)
353
+ Instrumentator(
354
+ excluded_handlers=[
355
+ "/metrics",
356
+ "/health",
357
+ "/load",
358
+ "/ping",
359
+ "/version",
360
+ ],
361
+ registry=registry,
362
+ ).add().instrument(app).expose(app)
264
363
 
265
364
  # Add prometheus asgi middleware to route /metrics requests
266
365
  metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
@@ -298,7 +397,11 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
298
397
  return request.app.state.openai_serving_embedding
299
398
 
300
399
 
301
- def score(request: Request) -> Optional[OpenAIServingScores]:
400
+ def score(request: Request) -> Optional[ServingScores]:
401
+ return request.app.state.openai_serving_scores
402
+
403
+
404
+ def rerank(request: Request) -> Optional[ServingScores]:
302
405
  return request.app.state.openai_serving_scores
303
406
 
304
407
 
@@ -306,6 +409,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
306
409
  return request.app.state.openai_serving_tokenization
307
410
 
308
411
 
412
+ def transcription(request: Request) -> OpenAIServingTranscription:
413
+ return request.app.state.openai_serving_transcription
414
+
415
+
309
416
  def engine_client(request: Request) -> EngineClient:
310
417
  return request.app.state.engine_client
311
418
 
@@ -317,7 +424,31 @@ async def health(raw_request: Request) -> Response:
317
424
  return Response(status_code=200)
318
425
 
319
426
 
320
- @router.post("/tokenize")
427
+ @router.get("/load")
428
+ async def get_server_load_metrics(request: Request):
429
+ # This endpoint returns the current server load metrics.
430
+ # It tracks requests utilizing the GPU from the following routes:
431
+ # - /v1/chat/completions
432
+ # - /v1/completions
433
+ # - /v1/audio/transcriptions
434
+ # - /v1/embeddings
435
+ # - /pooling
436
+ # - /score
437
+ # - /v1/score
438
+ # - /rerank
439
+ # - /v1/rerank
440
+ # - /v2/rerank
441
+ return JSONResponse(
442
+ content={'server_load': request.app.state.server_load_metrics})
443
+
444
+
445
+ @router.api_route("/ping", methods=["GET", "POST"])
446
+ async def ping(raw_request: Request) -> Response:
447
+ """Ping check. Endpoint required for SageMaker"""
448
+ return await health(raw_request)
449
+
450
+
451
+ @router.post("/tokenize", dependencies=[Depends(validate_json_request)])
321
452
  @with_cancellation
322
453
  async def tokenize(request: TokenizeRequest, raw_request: Request):
323
454
  handler = tokenization(raw_request)
@@ -332,7 +463,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
332
463
  assert_never(generator)
333
464
 
334
465
 
335
- @router.post("/detokenize")
466
+ @router.post("/detokenize", dependencies=[Depends(validate_json_request)])
336
467
  @with_cancellation
337
468
  async def detokenize(request: DetokenizeRequest, raw_request: Request):
338
469
  handler = tokenization(raw_request)
@@ -361,35 +492,10 @@ async def show_version():
361
492
  return JSONResponse(content=ver)
362
493
 
363
494
 
364
- save_dict = {}
365
- import os
366
- flag = os.getenv("VLLM_LOG_OUTPUT", None)
367
- async def stream_generator(generator, request, request_id):
368
- async for chunk in generator:
369
- if request_id not in save_dict:
370
- save_dict[request_id] = ""
371
- import json
372
- try:
373
- data = chunk.strip()
374
- if data.startswith('data: '):
375
- data = data[len('data: '):]
376
- else:
377
- yield chunk
378
- json_data = json.loads(data)
379
- if 'choices' in json_data and len(json_data['choices']) > 0:
380
- choice = json_data['choices'][0]
381
- if 'delta' in choice:
382
- save_dict[request_id] += choice["delta"]["content"]
383
- elif 'text' in choice:
384
- save_dict[request_id] += choice["text"]
385
- except json.JSONDecodeError:
386
- print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
387
- pass # Done
388
- yield chunk
389
-
390
-
391
- @router.post("/v1/chat/completions")
495
+ @router.post("/v1/chat/completions",
496
+ dependencies=[Depends(validate_json_request)])
392
497
  @with_cancellation
498
+ @load_aware_call
393
499
  async def create_chat_completion(request: ChatCompletionRequest,
394
500
  raw_request: Request):
395
501
  handler = chat(raw_request)
@@ -401,7 +507,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
401
507
  request_id = "chatcmpl-" \
402
508
  f"{handler._base_request_id(raw_request, request.request_id)}"
403
509
  print(f"First received request_id: {request_id}, request: {request}")
404
-
405
510
  generator = await handler.create_chat_completion(request, raw_request)
406
511
 
407
512
  if isinstance(generator, ErrorResponse):
@@ -418,8 +523,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
418
523
  return StreamingResponse(content=generator, media_type="text/event-stream")
419
524
 
420
525
 
421
- @router.post("/v1/completions")
526
+ @router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
422
527
  @with_cancellation
528
+ @load_aware_call
423
529
  async def create_completion(request: CompletionRequest, raw_request: Request):
424
530
  handler = completion(raw_request)
425
531
  if handler is None:
@@ -438,14 +544,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
438
544
  if flag is not None:
439
545
  print(f"Received request-id:{request_id}, request:{request}, Output:{generator.model_dump()}")
440
546
  return JSONResponse(content=generator.model_dump())
441
-
547
+
442
548
  if flag is not None:
443
549
  return StreamingResponse(content=stream_generator(generator, request, request_id), media_type="text/event-stream")
444
550
  return StreamingResponse(content=generator, media_type="text/event-stream")
445
551
 
446
552
 
447
- @router.post("/v1/embeddings")
553
+ @router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
448
554
  @with_cancellation
555
+ @load_aware_call
449
556
  async def create_embedding(request: EmbeddingRequest, raw_request: Request):
450
557
  handler = embedding(raw_request)
451
558
  if handler is None:
@@ -460,6 +567,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
460
567
  "use the Pooling API (`/pooling`) instead.")
461
568
 
462
569
  res = await fallback_handler.create_pooling(request, raw_request)
570
+
571
+ generator: Union[ErrorResponse, EmbeddingResponse]
463
572
  if isinstance(res, PoolingResponse):
464
573
  generator = EmbeddingResponse(
465
574
  id=res.id,
@@ -488,8 +597,9 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
488
597
  assert_never(generator)
489
598
 
490
599
 
491
- @router.post("/pooling")
600
+ @router.post("/pooling", dependencies=[Depends(validate_json_request)])
492
601
  @with_cancellation
602
+ @load_aware_call
493
603
  async def create_pooling(request: PoolingRequest, raw_request: Request):
494
604
  handler = pooling(raw_request)
495
605
  if handler is None:
@@ -506,8 +616,9 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
506
616
  assert_never(generator)
507
617
 
508
618
 
509
- @router.post("/score")
619
+ @router.post("/score", dependencies=[Depends(validate_json_request)])
510
620
  @with_cancellation
621
+ @load_aware_call
511
622
  async def create_score(request: ScoreRequest, raw_request: Request):
512
623
  handler = score(raw_request)
513
624
  if handler is None:
@@ -524,8 +635,9 @@ async def create_score(request: ScoreRequest, raw_request: Request):
524
635
  assert_never(generator)
525
636
 
526
637
 
527
- @router.post("/v1/score")
638
+ @router.post("/v1/score", dependencies=[Depends(validate_json_request)])
528
639
  @with_cancellation
640
+ @load_aware_call
529
641
  async def create_score_v1(request: ScoreRequest, raw_request: Request):
530
642
  logger.warning(
531
643
  "To indicate that Score API is not part of standard OpenAI API, we "
@@ -534,6 +646,160 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
534
646
  return await create_score(request, raw_request)
535
647
 
536
648
 
649
+ @router.post("/v1/audio/transcriptions")
650
+ @with_cancellation
651
+ @load_aware_call
652
+ async def create_transcriptions(request: Annotated[TranscriptionRequest,
653
+ Form()],
654
+ raw_request: Request):
655
+ handler = transcription(raw_request)
656
+ if handler is None:
657
+ return base(raw_request).create_error_response(
658
+ message="The model does not support Transcriptions API")
659
+
660
+ audio_data = await request.file.read()
661
+ generator = await handler.create_transcription(audio_data, request,
662
+ raw_request)
663
+
664
+ if isinstance(generator, ErrorResponse):
665
+ return JSONResponse(content=generator.model_dump(),
666
+ status_code=generator.code)
667
+
668
+ elif isinstance(generator, TranscriptionResponse):
669
+ return JSONResponse(content=generator.model_dump())
670
+
671
+ return StreamingResponse(content=generator, media_type="text/event-stream")
672
+
673
+
674
+ @router.post("/rerank", dependencies=[Depends(validate_json_request)])
675
+ @with_cancellation
676
+ @load_aware_call
677
+ async def do_rerank(request: RerankRequest, raw_request: Request):
678
+ handler = rerank(raw_request)
679
+ if handler is None:
680
+ return base(raw_request).create_error_response(
681
+ message="The model does not support Rerank (Score) API")
682
+ generator = await handler.do_rerank(request, raw_request)
683
+ if isinstance(generator, ErrorResponse):
684
+ return JSONResponse(content=generator.model_dump(),
685
+ status_code=generator.code)
686
+ elif isinstance(generator, RerankResponse):
687
+ return JSONResponse(content=generator.model_dump())
688
+
689
+ assert_never(generator)
690
+
691
+
692
+ @router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
693
+ @with_cancellation
694
+ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
695
+ logger.warning_once(
696
+ "To indicate that the rerank API is not part of the standard OpenAI"
697
+ " API, we have located it at `/rerank`. Please update your client "
698
+ "accordingly. (Note: Conforms to JinaAI rerank API)")
699
+
700
+ return await do_rerank(request, raw_request)
701
+
702
+
703
+ @router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
704
+ @with_cancellation
705
+ async def do_rerank_v2(request: RerankRequest, raw_request: Request):
706
+ return await do_rerank(request, raw_request)
707
+
708
+
709
+ TASK_HANDLERS: dict[str, dict[str, tuple]] = {
710
+ "generate": {
711
+ "messages": (ChatCompletionRequest, create_chat_completion),
712
+ "default": (CompletionRequest, create_completion),
713
+ },
714
+ "embed": {
715
+ "messages": (EmbeddingChatRequest, create_embedding),
716
+ "default": (EmbeddingCompletionRequest, create_embedding),
717
+ },
718
+ "score": {
719
+ "default": (RerankRequest, do_rerank)
720
+ },
721
+ "rerank": {
722
+ "default": (RerankRequest, do_rerank)
723
+ },
724
+ "reward": {
725
+ "messages": (PoolingChatRequest, create_pooling),
726
+ "default": (PoolingCompletionRequest, create_pooling),
727
+ },
728
+ "classify": {
729
+ "messages": (PoolingChatRequest, create_pooling),
730
+ "default": (PoolingCompletionRequest, create_pooling),
731
+ },
732
+ }
733
+
734
+ if envs.VLLM_SERVER_DEV_MODE:
735
+
736
+ @router.post("/reset_prefix_cache")
737
+ async def reset_prefix_cache(raw_request: Request):
738
+ """
739
+ Reset the prefix cache. Note that we currently do not check if the
740
+ prefix cache is successfully reset in the API server.
741
+ """
742
+ device = None
743
+ device_str = raw_request.query_params.get("device")
744
+ if device_str is not None:
745
+ device = Device[device_str.upper()]
746
+ logger.info("Resetting prefix cache with specific %s...", str(device))
747
+ await engine_client(raw_request).reset_prefix_cache(device)
748
+ return Response(status_code=200)
749
+
750
+ @router.post("/sleep")
751
+ async def sleep(raw_request: Request):
752
+ # get POST params
753
+ level = raw_request.query_params.get("level", "1")
754
+ await engine_client(raw_request).sleep(int(level))
755
+ # FIXME: in v0 with frontend multiprocessing, the sleep command
756
+ # is sent but does not finish yet when we return a response.
757
+ return Response(status_code=200)
758
+
759
+ @router.post("/wake_up")
760
+ async def wake_up(raw_request: Request):
761
+ tags = raw_request.query_params.getlist("tags")
762
+ if tags == []:
763
+ # set to None to wake up all tags if no tags are provided
764
+ tags = None
765
+ logger.info("wake up the engine with tags: %s", tags)
766
+ await engine_client(raw_request).wake_up(tags)
767
+ # FIXME: in v0 with frontend multiprocessing, the wake-up command
768
+ # is sent but does not finish yet when we return a response.
769
+ return Response(status_code=200)
770
+
771
+ @router.get("/is_sleeping")
772
+ async def is_sleeping(raw_request: Request):
773
+ logger.info("check whether the engine is sleeping")
774
+ is_sleeping = await engine_client(raw_request).is_sleeping()
775
+ return JSONResponse(content={"is_sleeping": is_sleeping})
776
+
777
+
778
+ @router.post("/invocations", dependencies=[Depends(validate_json_request)])
779
+ async def invocations(raw_request: Request):
780
+ """
781
+ For SageMaker, routes requests to other handlers based on model `task`.
782
+ """
783
+ body = await raw_request.json()
784
+ task = raw_request.app.state.task
785
+
786
+ if task not in TASK_HANDLERS:
787
+ raise HTTPException(
788
+ status_code=400,
789
+ detail=f"Unsupported task: '{task}' for '/invocations'. "
790
+ f"Expected one of {set(TASK_HANDLERS.keys())}")
791
+
792
+ handler_config = TASK_HANDLERS[task]
793
+ if "messages" in body:
794
+ request_model, handler = handler_config["messages"]
795
+ else:
796
+ request_model, handler = handler_config["default"]
797
+
798
+ # this is required since we lose the FastAPI automatic casting
799
+ request = request_model.model_validate(body)
800
+ return await handler(request, raw_request)
801
+
802
+
537
803
  if envs.VLLM_TORCH_PROFILER_DIR:
538
804
  logger.warning(
539
805
  "Torch Profiler is enabled in the API server. This should ONLY be "
@@ -556,32 +822,30 @@ if envs.VLLM_TORCH_PROFILER_DIR:
556
822
 
557
823
  if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
558
824
  logger.warning(
559
- "Lora dynamic loading & unloading is enabled in the API server. "
825
+ "LoRA dynamic loading & unloading is enabled in the API server. "
560
826
  "This should ONLY be used for local development!")
561
827
 
562
- @router.post("/v1/load_lora_adapter")
563
- async def load_lora_adapter(request: LoadLoraAdapterRequest,
828
+ @router.post("/v1/load_lora_adapter",
829
+ dependencies=[Depends(validate_json_request)])
830
+ async def load_lora_adapter(request: LoadLoRAAdapterRequest,
564
831
  raw_request: Request):
565
- for route in [chat, completion, embedding]:
566
- handler = route(raw_request)
567
- if handler is not None:
568
- response = await handler.load_lora_adapter(request)
569
- if isinstance(response, ErrorResponse):
570
- return JSONResponse(content=response.model_dump(),
571
- status_code=response.code)
832
+ handler = models(raw_request)
833
+ response = await handler.load_lora_adapter(request)
834
+ if isinstance(response, ErrorResponse):
835
+ return JSONResponse(content=response.model_dump(),
836
+ status_code=response.code)
572
837
 
573
838
  return Response(status_code=200, content=response)
574
839
 
575
- @router.post("/v1/unload_lora_adapter")
576
- async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
840
+ @router.post("/v1/unload_lora_adapter",
841
+ dependencies=[Depends(validate_json_request)])
842
+ async def unload_lora_adapter(request: UnloadLoRAAdapterRequest,
577
843
  raw_request: Request):
578
- for route in [chat, completion, embedding]:
579
- handler = route(raw_request)
580
- if handler is not None:
581
- response = await handler.unload_lora_adapter(request)
582
- if isinstance(response, ErrorResponse):
583
- return JSONResponse(content=response.model_dump(),
584
- status_code=response.code)
844
+ handler = models(raw_request)
845
+ response = await handler.unload_lora_adapter(request)
846
+ if isinstance(response, ErrorResponse):
847
+ return JSONResponse(content=response.model_dump(),
848
+ status_code=response.code)
585
849
 
586
850
  return Response(status_code=200, content=response)
587
851
 
@@ -615,7 +879,8 @@ def build_app(args: Namespace) -> FastAPI:
615
879
  return JSONResponse(err.model_dump(),
616
880
  status_code=HTTPStatus.BAD_REQUEST)
617
881
 
618
- if token := envs.VLLM_API_KEY or args.api_key:
882
+ # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
883
+ if token := args.api_key or envs.VLLM_API_KEY:
619
884
 
620
885
  @app.middleware("http")
621
886
  async def authentication(request: Request, call_next):
@@ -644,11 +909,26 @@ def build_app(args: Namespace) -> FastAPI:
644
909
  response.headers["X-Request-Id"] = request_id
645
910
  return response
646
911
 
912
+ if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
913
+ logger.warning("CAUTION: Enabling log response in the API Server. "
914
+ "This can include sensitive information and should be "
915
+ "avoided in production.")
916
+
917
+ @app.middleware("http")
918
+ async def log_response(request: Request, call_next):
919
+ response = await call_next(request)
920
+ response_body = [
921
+ section async for section in response.body_iterator
922
+ ]
923
+ response.body_iterator = iterate_in_threadpool(iter(response_body))
924
+ logger.info("response_body={%s}", response_body[0].decode())
925
+ return response
926
+
647
927
  for middleware in args.middleware:
648
928
  module_path, object_name = middleware.rsplit(".", 1)
649
929
  imported = getattr(importlib.import_module(module_path), object_name)
650
930
  if inspect.isclass(imported):
651
- app.add_middleware(imported)
931
+ app.add_middleware(imported) # type: ignore[arg-type]
652
932
  elif inspect.iscoroutinefunction(imported):
653
933
  app.middleware("http")(imported)
654
934
  else:
@@ -658,7 +938,7 @@ def build_app(args: Namespace) -> FastAPI:
658
938
  return app
659
939
 
660
940
 
661
- def init_app_state(
941
+ async def init_app_state(
662
942
  engine_client: EngineClient,
663
943
  model_config: ModelConfig,
664
944
  state: State,
@@ -683,15 +963,36 @@ def init_app_state(
683
963
  state.log_stats = not args.disable_log_stats
684
964
 
685
965
  resolved_chat_template = load_chat_template(args.chat_template)
686
- logger.info("Using supplied chat template:\n%s", resolved_chat_template)
966
+ if resolved_chat_template is not None:
967
+ # Get the tokenizer to check official template
968
+ tokenizer = await engine_client.get_tokenizer()
969
+
970
+ if isinstance(tokenizer, MistralTokenizer):
971
+ # The warning is logged in resolve_mistral_chat_template.
972
+ resolved_chat_template = resolve_mistral_chat_template(
973
+ chat_template=resolved_chat_template)
974
+ else:
975
+ hf_chat_template = resolve_hf_chat_template(
976
+ tokenizer,
977
+ chat_template=None,
978
+ tools=None,
979
+ trust_remote_code=model_config.trust_remote_code)
980
+
981
+ if hf_chat_template != resolved_chat_template:
982
+ logger.warning(
983
+ "Using supplied chat template: %s\n"
984
+ "It is different from official chat template '%s'. "
985
+ "This discrepancy may lead to performance degradation.",
986
+ resolved_chat_template, args.model)
687
987
 
688
988
  state.openai_serving_models = OpenAIServingModels(
989
+ engine_client=engine_client,
689
990
  model_config=model_config,
690
991
  base_model_paths=base_model_paths,
691
992
  lora_modules=args.lora_modules,
692
993
  prompt_adapters=args.prompt_adapters,
693
994
  )
694
- # TODO: The chat template is now broken for lora adapters :(
995
+ await state.openai_serving_models.init_static_loras()
695
996
  state.openai_serving_chat = OpenAIServingChat(
696
997
  engine_client,
697
998
  model_config,
@@ -703,6 +1004,8 @@ def init_app_state(
703
1004
  return_tokens_as_token_ids=args.return_tokens_as_token_ids,
704
1005
  enable_auto_tools=args.enable_auto_tool_choice,
705
1006
  tool_parser=args.tool_call_parser,
1007
+ enable_reasoning=args.enable_reasoning,
1008
+ reasoning_parser=args.reasoning_parser,
706
1009
  enable_prompt_tokens_details=args.enable_prompt_tokens_details,
707
1010
  ) if model_config.runner_type == "generate" else None
708
1011
  state.openai_serving_completion = OpenAIServingCompletion(
@@ -728,7 +1031,13 @@ def init_app_state(
728
1031
  chat_template=resolved_chat_template,
729
1032
  chat_template_content_format=args.chat_template_content_format,
730
1033
  ) if model_config.task == "embed" else None
731
- state.openai_serving_scores = OpenAIServingScores(
1034
+ state.openai_serving_scores = ServingScores(
1035
+ engine_client,
1036
+ model_config,
1037
+ state.openai_serving_models,
1038
+ request_logger=request_logger) if model_config.task in (
1039
+ "score", "embed", "pooling") else None
1040
+ state.jinaai_serving_reranking = ServingScores(
732
1041
  engine_client,
733
1042
  model_config,
734
1043
  state.openai_serving_models,
@@ -742,92 +1051,26 @@ def init_app_state(
742
1051
  chat_template=resolved_chat_template,
743
1052
  chat_template_content_format=args.chat_template_content_format,
744
1053
  )
1054
+ state.openai_serving_transcription = OpenAIServingTranscription(
1055
+ engine_client,
1056
+ model_config,
1057
+ state.openai_serving_models,
1058
+ request_logger=request_logger,
1059
+ ) if model_config.runner_type == "transcription" else None
745
1060
  state.task = model_config.task
746
- # if args.served_model_name is not None:
747
- # served_model_names = args.served_model_name
748
- # else:
749
- # served_model_names = [args.model]
750
-
751
- # if args.disable_log_requests:
752
- # request_logger = None
753
- # else:
754
- # request_logger = RequestLogger(max_log_len=args.max_log_len)
755
-
756
- # base_model_paths = [
757
- # BaseModelPath(name=name, model_path=args.model)
758
- # for name in served_model_names
759
- # ]
760
-
761
- # state.engine_client = engine_client
762
- # state.log_stats = not args.disable_log_stats
763
-
764
- # resolved_chat_template = load_chat_template(args.chat_template)
765
- # logger.info("Using supplied chat template:\n%s", resolved_chat_template)
766
-
767
- # state.openai_serving_chat = OpenAIServingChat(
768
- # engine_client,
769
- # model_config,
770
- # base_model_paths,
771
- # args.response_role,
772
- # lora_modules=args.lora_modules,
773
- # prompt_adapters=args.prompt_adapters,
774
- # request_logger=request_logger,
775
- # chat_template=resolved_chat_template,
776
- # chat_template_content_format=args.chat_template_content_format,
777
- # return_tokens_as_token_ids=args.return_tokens_as_token_ids,
778
- # enable_auto_tools=args.enable_auto_tool_choice,
779
- # tool_parser=args.tool_call_parser,
780
- # enable_prompt_tokens_details=args.enable_prompt_tokens_details,
781
- # ) if model_config.runner_type == "generate" else None
782
- # state.openai_serving_completion = OpenAIServingCompletion(
783
- # engine_client,
784
- # model_config,
785
- # base_model_paths,
786
- # lora_modules=args.lora_modules,
787
- # prompt_adapters=args.prompt_adapters,
788
- # request_logger=request_logger,
789
- # return_tokens_as_token_ids=args.return_tokens_as_token_ids,
790
- # ) if model_config.runner_type == "generate" else None
791
- # state.openai_serving_pooling = OpenAIServingPooling(
792
- # engine_client,
793
- # model_config,
794
- # base_model_paths,
795
- # request_logger=request_logger,
796
- # chat_template=resolved_chat_template,
797
- # chat_template_content_format=args.chat_template_content_format,
798
- # ) if model_config.runner_type == "pooling" else None
799
- # state.openai_serving_embedding = OpenAIServingEmbedding(
800
- # engine_client,
801
- # model_config,
802
- # base_model_paths,
803
- # request_logger=request_logger,
804
- # chat_template=resolved_chat_template,
805
- # chat_template_content_format=args.chat_template_content_format,
806
- # ) if model_config.task == "embed" else None
807
- # state.openai_serving_scores = OpenAIServingScores(
808
- # engine_client,
809
- # model_config,
810
- # base_model_paths,
811
- # request_logger=request_logger
812
- # ) if model_config.task == "score" else None
813
- # state.openai_serving_tokenization = OpenAIServingTokenization(
814
- # engine_client,
815
- # model_config,
816
- # base_model_paths,
817
- # lora_modules=args.lora_modules,
818
- # request_logger=request_logger,
819
- # chat_template=resolved_chat_template,
820
- # chat_template_content_format=args.chat_template_content_format,
821
- # )
822
-
823
-
824
- def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
1061
+
1062
+ state.enable_server_load_tracking = args.enable_server_load_tracking
1063
+ state.server_load_metrics = 0
1064
+
1065
+
1066
+ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
825
1067
  family = socket.AF_INET
826
1068
  if is_valid_ipv6_address(addr[0]):
827
1069
  family = socket.AF_INET6
828
1070
 
829
1071
  sock = socket.socket(family=family, type=socket.SOCK_STREAM)
830
1072
  sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1073
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
831
1074
  sock.bind(addr)
832
1075
 
833
1076
  return sock
@@ -840,11 +1083,18 @@ async def run_server(args, **uvicorn_kwargs) -> None:
840
1083
  if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
841
1084
  ToolParserManager.import_tool_parser(args.tool_parser_plugin)
842
1085
 
843
- valide_tool_parses = ToolParserManager.tool_parsers.keys()
1086
+ valid_tool_parses = ToolParserManager.tool_parsers.keys()
844
1087
  if args.enable_auto_tool_choice \
845
- and args.tool_call_parser not in valide_tool_parses:
1088
+ and args.tool_call_parser not in valid_tool_parses:
846
1089
  raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
847
- f"(chose from {{ {','.join(valide_tool_parses)} }})")
1090
+ f"(chose from {{ {','.join(valid_tool_parses)} }})")
1091
+
1092
+ valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
1093
+ if args.enable_reasoning \
1094
+ and args.reasoning_parser not in valid_reasoning_parses:
1095
+ raise KeyError(
1096
+ f"invalid reasoning parser: {args.reasoning_parser} "
1097
+ f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
848
1098
 
849
1099
  # workaround to make sure that we bind the port before the engine is set up.
850
1100
  # This avoids race conditions with ray.
@@ -866,13 +1116,28 @@ async def run_server(args, **uvicorn_kwargs) -> None:
866
1116
  app = build_app(args)
867
1117
 
868
1118
  model_config = await engine_client.get_model_config()
869
- init_app_state(engine_client, model_config, app.state, args)
1119
+ await init_app_state(engine_client, model_config, app.state, args)
1120
+
1121
+ def _listen_addr(a: str) -> str:
1122
+ if is_valid_ipv6_address(a):
1123
+ return '[' + a + ']'
1124
+ return a or "0.0.0.0"
1125
+
1126
+ is_ssl = args.ssl_keyfile and args.ssl_certfile
1127
+ logger.info("Starting vLLM API server on http%s://%s:%d",
1128
+ "s" if is_ssl else "", _listen_addr(sock_addr[0]),
1129
+ sock_addr[1])
870
1130
 
871
1131
  shutdown_task = await serve_http(
872
1132
  app,
1133
+ sock=sock,
1134
+ enable_ssl_refresh=args.enable_ssl_refresh,
873
1135
  host=args.host,
874
1136
  port=args.port,
875
1137
  log_level=args.uvicorn_log_level,
1138
+ # NOTE: When the 'disable_uvicorn_access_log' value is True,
1139
+ # no access log will be output.
1140
+ access_log=not args.disable_uvicorn_access_log,
876
1141
  timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
877
1142
  ssl_keyfile=args.ssl_keyfile,
878
1143
  ssl_certfile=args.ssl_certfile,
@@ -882,16 +1147,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
882
1147
  )
883
1148
 
884
1149
  # NB: Await server shutdown only after the backend context is exited
885
- await shutdown_task
886
-
887
- sock.close()
1150
+ try:
1151
+ await shutdown_task
1152
+ finally:
1153
+ sock.close()
888
1154
 
889
1155
 
890
1156
  if __name__ == "__main__":
891
1157
  # NOTE(simon):
892
- # This section should be in sync with vllm/scripts.py for CLI entrypoints.
1158
+ # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
1159
+ # entrypoints.
893
1160
  logger.warning("Warning: Please use `ipex_llm.vllm.xpu.entrypoints.openai.api_server` "
894
1161
  "instead of `vllm.entrypoints.openai.api_server` to start the API server")
1162
+ cli_env_setup()
895
1163
  parser = FlexibleArgumentParser(
896
1164
  description="vLLM OpenAI-Compatible RESTful API server.")
897
1165
  parser = make_arg_parser(parser)