ipex-llm 2.2.0b20250121__py3-none-win_amd64.whl → 2.2.0b20250123__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. ipex_llm/libs/bloom-api.dll +0 -0
  2. ipex_llm/libs/bloom.dll +0 -0
  3. ipex_llm/libs/gptneox-api.dll +0 -0
  4. ipex_llm/libs/gptneox.dll +0 -0
  5. ipex_llm/libs/libbloom_avx.dll +0 -0
  6. ipex_llm/libs/libbloom_vnni.dll +0 -0
  7. ipex_llm/libs/libgptneox_avx.dll +0 -0
  8. ipex_llm/libs/libgptneox_vnni.dll +0 -0
  9. ipex_llm/libs/libllama_avx.dll +0 -0
  10. ipex_llm/libs/libllama_vnni.dll +0 -0
  11. ipex_llm/libs/libstarcoder_avx.dll +0 -0
  12. ipex_llm/libs/libstarcoder_vnni.dll +0 -0
  13. ipex_llm/libs/llama-api.dll +0 -0
  14. ipex_llm/libs/llama.dll +0 -0
  15. ipex_llm/libs/main-bloom.exe +0 -0
  16. ipex_llm/libs/main-gptneox.exe +0 -0
  17. ipex_llm/libs/main-llama.exe +0 -0
  18. ipex_llm/libs/main-starcoder.exe +0 -0
  19. ipex_llm/libs/pipeline.dll +0 -0
  20. ipex_llm/libs/quantize-bloom.exe +0 -0
  21. ipex_llm/libs/quantize-bloom_vnni.exe +0 -0
  22. ipex_llm/libs/quantize-gptneox.exe +0 -0
  23. ipex_llm/libs/quantize-gptneox_vnni.exe +0 -0
  24. ipex_llm/libs/quantize-llama.exe +0 -0
  25. ipex_llm/libs/quantize-llama_vnni.exe +0 -0
  26. ipex_llm/libs/quantize-starcoder.exe +0 -0
  27. ipex_llm/libs/quantize-starcoder_vnni.exe +0 -0
  28. ipex_llm/libs/starcoder-api.dll +0 -0
  29. ipex_llm/libs/starcoder.dll +0 -0
  30. ipex_llm/transformers/convert.py +0 -1
  31. ipex_llm/transformers/low_bit_linear.py +1 -1
  32. ipex_llm/transformers/model.py +1 -3
  33. ipex_llm/transformers/npu_models/mp_models_base.py +3 -1
  34. ipex_llm/transformers/patches.py +0 -11
  35. ipex_llm/vllm/cpu/engine/__init__.py +2 -1
  36. ipex_llm/vllm/cpu/engine/engine.py +159 -75
  37. ipex_llm/vllm/cpu/entrypoints/api_server.py +787 -0
  38. ipex_llm/vllm/cpu/entrypoints/openai/api_server.py +680 -95
  39. ipex_llm/vllm/cpu/entrypoints/openai/cli_args.py +277 -0
  40. ipex_llm/vllm/cpu/ipex_llm_v1_wrapper.py +23 -0
  41. ipex_llm/vllm/cpu/ipex_llm_wrapper.py +24 -0
  42. ipex_llm/vllm/cpu/model_convert.py +126 -233
  43. {ipex_llm-2.2.0b20250121.dist-info → ipex_llm-2.2.0b20250123.dist-info}/METADATA +20 -20
  44. {ipex_llm-2.2.0b20250121.dist-info → ipex_llm-2.2.0b20250123.dist-info}/RECORD +50 -46
  45. {ipex_llm-2.2.0b20250121.data → ipex_llm-2.2.0b20250123.data}/scripts/ipex-llm-init.bat +0 -0
  46. {ipex_llm-2.2.0b20250121.data → ipex_llm-2.2.0b20250123.data}/scripts/llm-chat.ps1 +0 -0
  47. {ipex_llm-2.2.0b20250121.data → ipex_llm-2.2.0b20250123.data}/scripts/llm-cli.ps1 +0 -0
  48. {ipex_llm-2.2.0b20250121.dist-info → ipex_llm-2.2.0b20250123.dist-info}/WHEEL +0 -0
  49. {ipex_llm-2.2.0b20250121.dist-info → ipex_llm-2.2.0b20250123.dist-info}/entry_points.txt +0 -0
  50. {ipex_llm-2.2.0b20250121.dist-info → ipex_llm-2.2.0b20250123.dist-info}/top_level.txt +0 -0
@@ -1,138 +1,559 @@
1
1
  import asyncio
2
+ import atexit
2
3
  import importlib
3
4
  import inspect
5
+ import multiprocessing
4
6
  import os
5
7
  import re
8
+ import signal
9
+ import socket
10
+ import tempfile
11
+ import uuid
12
+ from argparse import Namespace
6
13
  from contextlib import asynccontextmanager
14
+ from functools import partial
7
15
  from http import HTTPStatus
8
- from typing import Any, Set
16
+ from typing import AsyncIterator, Optional, Set, Tuple
9
17
 
10
- import fastapi
11
- import uvicorn
12
- from fastapi import Request
18
+ import uvloop
19
+ from fastapi import APIRouter, FastAPI, Request
13
20
  from fastapi.exceptions import RequestValidationError
14
21
  from fastapi.middleware.cors import CORSMiddleware
15
22
  from fastapi.responses import JSONResponse, Response, StreamingResponse
16
- from prometheus_client import make_asgi_app
23
+ from starlette.datastructures import State
17
24
  from starlette.routing import Mount
25
+ from typing_extensions import assert_never
18
26
 
19
- import vllm
20
27
  import vllm.envs as envs
28
+ from vllm.config import ModelConfig
21
29
  from vllm.engine.arg_utils import AsyncEngineArgs
22
- from vllm.engine.async_llm_engine import AsyncLLMEngine
23
- from vllm.entrypoints.openai.cli_args import make_arg_parser
30
+ from ipex_llm.vllm.cpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
31
+ from vllm.engine.multiprocessing.client import MQLLMEngineClient
32
+ from ipex_llm.vllm.cpu.engine import run_mp_engine
33
+ from vllm.engine.protocol import EngineClient
34
+ from vllm.entrypoints.chat_utils import load_chat_template
35
+ from vllm.entrypoints.launcher import serve_http
36
+ from vllm.entrypoints.logger import RequestLogger
37
+ from vllm.entrypoints.openai.cli_args import (make_arg_parser,
38
+ validate_parsed_serve_args)
39
+ # yapf conflicts with isort for this block
40
+ # yapf: disable
24
41
  from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
25
42
  ChatCompletionResponse,
26
- CompletionRequest, ErrorResponse)
43
+ CompletionRequest,
44
+ CompletionResponse,
45
+ DetokenizeRequest,
46
+ DetokenizeResponse,
47
+ EmbeddingRequest,
48
+ EmbeddingResponse,
49
+ EmbeddingResponseData,
50
+ ErrorResponse,
51
+ LoadLoraAdapterRequest,
52
+ PoolingRequest, PoolingResponse,
53
+ ScoreRequest, ScoreResponse,
54
+ TokenizeRequest,
55
+ TokenizeResponse,
56
+ UnloadLoraAdapterRequest)
57
+ # yapf: enable
27
58
  from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
28
59
  from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
60
+ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
61
+ from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
62
+ from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
63
+ from vllm.entrypoints.openai.serving_score import OpenAIServingScores
64
+ from vllm.entrypoints.openai.serving_tokenization import (
65
+ OpenAIServingTokenization)
66
+ from vllm.entrypoints.openai.tool_parsers import ToolParserManager
67
+ from vllm.entrypoints.utils import with_cancellation
29
68
  from vllm.logger import init_logger
30
69
  from vllm.usage.usage_lib import UsageContext
31
-
32
- from ipex_llm.vllm.cpu.engine import IPEXLLMAsyncLLMEngine
33
- from ipex_llm.utils.common import invalidInputError
70
+ from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
71
+ is_valid_ipv6_address, set_ulimit)
72
+ from vllm.version import __version__ as VLLM_VERSION
34
73
 
35
74
  TIMEOUT_KEEP_ALIVE = 5 # seconds
36
75
 
37
- openai_serving_chat: OpenAIServingChat
38
- openai_serving_completion: OpenAIServingCompletion
39
- logger = init_logger(__name__)
76
+ prometheus_multiproc_dir: tempfile.TemporaryDirectory
77
+
78
+ # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
79
+ logger = init_logger('vllm.entrypoints.openai.api_server')
40
80
 
41
- _running_tasks: Set[asyncio.Task[Any]] = set()
81
+ _running_tasks: Set[asyncio.Task] = set()
42
82
 
43
83
 
44
84
  @asynccontextmanager
45
- async def lifespan(app: fastapi.FastAPI):
85
+ async def lifespan(app: FastAPI):
86
+ try:
87
+ if app.state.log_stats:
88
+ engine_client: EngineClient = app.state.engine_client
89
+
90
+ async def _force_log():
91
+ while True:
92
+ await asyncio.sleep(10.)
93
+ await engine_client.do_log_stats()
94
+
95
+ task = asyncio.create_task(_force_log())
96
+ _running_tasks.add(task)
97
+ task.add_done_callback(_running_tasks.remove)
98
+ else:
99
+ task = None
100
+ try:
101
+ yield
102
+ finally:
103
+ if task is not None:
104
+ task.cancel()
105
+ finally:
106
+ # Ensure app state including engine ref is gc'd
107
+ del app.state
46
108
 
47
- async def _force_log():
48
- while True:
49
- await asyncio.sleep(10)
50
- await engine.do_log_stats()
51
109
 
52
- if not engine_args.disable_log_stats:
53
- task = asyncio.create_task(_force_log())
54
- _running_tasks.add(task)
55
- task.add_done_callback(_running_tasks.remove)
110
+ @asynccontextmanager
111
+ async def build_async_engine_client(
112
+ args: Namespace) -> AsyncIterator[EngineClient]:
56
113
 
57
- yield
114
+ # Context manager to handle engine_client lifecycle
115
+ # Ensures everything is shutdown and cleaned up on error/exit
116
+ engine_args = AsyncEngineArgs.from_cli_args(args)
58
117
 
118
+ async with build_async_engine_client_from_engine_args(
119
+ engine_args, args.disable_frontend_multiprocessing, args.load_in_low_bit) as engine:
120
+ yield engine
59
121
 
60
- app = fastapi.FastAPI(lifespan=lifespan)
61
122
 
123
+ @asynccontextmanager
124
+ async def build_async_engine_client_from_engine_args(
125
+ engine_args: AsyncEngineArgs,
126
+ disable_frontend_multiprocessing: bool = False,
127
+ load_in_low_bit: str = "sym_int4",
128
+ ) -> AsyncIterator[EngineClient]:
129
+ """
130
+ Create EngineClient, either:
131
+ - in-process using the AsyncLLMEngine Directly
132
+ - multiprocess using AsyncLLMEngine RPC
133
+
134
+ Returns the Client or None if the creation failed.
135
+ """
136
+
137
+ # Fall back
138
+ # TODO: fill out feature matrix.
139
+ if (MQLLMEngineClient.is_unsupported_config(engine_args)
140
+ or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
141
+ engine_config = engine_args.create_engine_config(
142
+ UsageContext.OPENAI_API_SERVER)
143
+ uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
144
+ "uses_ray", False)
145
+
146
+ build_engine = partial(AsyncLLMEngine.from_engine_args,
147
+ engine_args=engine_args,
148
+ engine_config=engine_config,
149
+ load_in_low_bit=load_in_low_bit,
150
+ usage_context=UsageContext.OPENAI_API_SERVER)
151
+ if uses_ray:
152
+ # Must run in main thread with ray for its signal handlers to work
153
+ engine_client = build_engine()
154
+ else:
155
+ engine_client = await asyncio.get_running_loop().run_in_executor(
156
+ None, build_engine)
62
157
 
63
- def parse_args():
64
- parser = make_arg_parser()
65
- parser.add_argument(
66
- "--load-in-low-bit",
67
- type=str,
68
- default=None,
69
- help="Low-bit quantization for IPEX-LLM models")
70
- return parser.parse_args()
158
+ yield engine_client
159
+ if hasattr(engine_client, "shutdown"):
160
+ engine_client.shutdown()
161
+ return
162
+
163
+ # Otherwise, use the multiprocessing AsyncLLMEngine.
164
+ else:
165
+ if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
166
+ # Make TemporaryDirectory for prometheus multiprocessing
167
+ # Note: global TemporaryDirectory will be automatically
168
+ # cleaned up upon exit.
169
+ global prometheus_multiproc_dir
170
+ prometheus_multiproc_dir = tempfile.TemporaryDirectory()
171
+ os.environ[
172
+ "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
173
+ else:
174
+ logger.warning(
175
+ "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
176
+ "This directory must be wiped between vLLM runs or "
177
+ "you will find inaccurate metrics. Unset the variable "
178
+ "and vLLM will properly handle cleanup.")
179
+
180
+ # Select random path for IPC.
181
+ ipc_path = get_open_zmq_ipc_path()
182
+ logger.debug("Multiprocessing frontend to use %s for IPC Path.",
183
+ ipc_path)
184
+
185
+ # Start RPCServer in separate process (holds the LLMEngine).
186
+ # the current process might have CUDA context,
187
+ # so we need to spawn a new process
188
+ context = multiprocessing.get_context("spawn")
189
+
190
+ # The Process can raise an exception during startup, which may
191
+ # not actually result in an exitcode being reported. As a result
192
+ # we use a shared variable to communicate the information.
193
+ engine_alive = multiprocessing.Value('b', True, lock=False)
194
+ engine_process = context.Process(target=run_mp_engine,
195
+ args=(engine_args,
196
+ UsageContext.OPENAI_API_SERVER,
197
+ ipc_path, load_in_low_bit, engine_alive))
198
+ engine_process.start()
199
+ engine_pid = engine_process.pid
200
+ assert engine_pid is not None, "Engine process failed to start."
201
+ logger.info("Started engine process with PID %d", engine_pid)
202
+
203
+ def _cleanup_ipc_path():
204
+ socket_path = ipc_path.replace("ipc://", "")
205
+ if os.path.exists(socket_path):
206
+ os.remove(socket_path)
207
+
208
+ # Ensure we clean up the local IPC socket file on exit.
209
+ atexit.register(_cleanup_ipc_path)
210
+
211
+ # Build RPCClient, which conforms to EngineClient Protocol.
212
+ engine_config = engine_args.create_engine_config()
213
+ build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
214
+ engine_pid)
215
+ mq_engine_client = await asyncio.get_running_loop().run_in_executor(
216
+ None, build_client)
217
+ try:
218
+ while True:
219
+ try:
220
+ await mq_engine_client.setup()
221
+ break
222
+ except TimeoutError:
223
+ if (not engine_process.is_alive()
224
+ or not engine_alive.value):
225
+ raise RuntimeError(
226
+ "Engine process failed to start. See stack "
227
+ "trace for the root cause.") from None
228
+
229
+ yield mq_engine_client # type: ignore[misc]
230
+ finally:
231
+ # Ensure rpc server process was terminated
232
+ engine_process.terminate()
233
+
234
+ # Close all open connections to the backend
235
+ mq_engine_client.close()
236
+
237
+ # Wait for engine process to join
238
+ engine_process.join(4)
239
+ if engine_process.exitcode is None:
240
+ # Kill if taking longer than 5 seconds to stop
241
+ engine_process.kill()
242
+
243
+ # Lazy import for prometheus multiprocessing.
244
+ # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
245
+ # before prometheus_client is imported.
246
+ # See https://prometheus.github.io/client_python/multiprocess/
247
+ from prometheus_client import multiprocess
248
+ multiprocess.mark_process_dead(engine_process.pid)
249
+
250
+
251
+ router = APIRouter()
252
+
253
+
254
+ def mount_metrics(app: FastAPI):
255
+ # Lazy import for prometheus multiprocessing.
256
+ # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
257
+ # before prometheus_client is imported.
258
+ # See https://prometheus.github.io/client_python/multiprocess/
259
+ from prometheus_client import (CollectorRegistry, make_asgi_app,
260
+ multiprocess)
261
+
262
+ prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
263
+ if prometheus_multiproc_dir_path is not None:
264
+ logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
265
+ prometheus_multiproc_dir_path)
266
+ registry = CollectorRegistry()
267
+ multiprocess.MultiProcessCollector(registry)
268
+
269
+ # Add prometheus asgi middleware to route /metrics requests
270
+ metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
271
+ else:
272
+ # Add prometheus asgi middleware to route /metrics requests
273
+ metrics_route = Mount("/metrics", make_asgi_app())
71
274
 
275
+ # Workaround for 307 Redirect for /metrics
276
+ metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
277
+ app.routes.append(metrics_route)
72
278
 
73
- # Add prometheus asgi middleware to route /metrics requests
74
- route = Mount("/metrics", make_asgi_app())
75
- # Workaround for 307 Redirect for /metrics
76
- route.path_regex = re.compile('^/metrics(?P<path>.*)$')
77
- app.routes.append(route)
78
279
 
280
+ def base(request: Request) -> OpenAIServing:
281
+ # Reuse the existing instance
282
+ return tokenization(request)
79
283
 
80
- @app.exception_handler(RequestValidationError)
81
- async def validation_exception_handler(_, exc):
82
- err = openai_serving_chat.create_error_response(message=str(exc))
83
- return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
84
284
 
285
+ def chat(request: Request) -> Optional[OpenAIServingChat]:
286
+ return request.app.state.openai_serving_chat
85
287
 
86
- @app.get("/health")
87
- async def health() -> Response:
288
+
289
+ def completion(request: Request) -> Optional[OpenAIServingCompletion]:
290
+ return request.app.state.openai_serving_completion
291
+
292
+
293
+ def pooling(request: Request) -> Optional[OpenAIServingPooling]:
294
+ return request.app.state.openai_serving_pooling
295
+
296
+
297
+ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
298
+ return request.app.state.openai_serving_embedding
299
+
300
+
301
+ def score(request: Request) -> Optional[OpenAIServingScores]:
302
+ return request.app.state.openai_serving_scores
303
+
304
+
305
+ def tokenization(request: Request) -> OpenAIServingTokenization:
306
+ return request.app.state.openai_serving_tokenization
307
+
308
+
309
+ def engine_client(request: Request) -> EngineClient:
310
+ return request.app.state.engine_client
311
+
312
+
313
+ @router.get("/health")
314
+ async def health(raw_request: Request) -> Response:
88
315
  """Health check."""
89
- await openai_serving_chat.engine.check_health()
316
+ await engine_client(raw_request).check_health()
90
317
  return Response(status_code=200)
91
318
 
92
319
 
93
- @app.get("/v1/models")
94
- async def show_available_models():
95
- models = await openai_serving_chat.show_available_models()
320
+ @router.post("/tokenize")
321
+ @with_cancellation
322
+ async def tokenize(request: TokenizeRequest, raw_request: Request):
323
+ handler = tokenization(raw_request)
324
+
325
+ generator = await handler.create_tokenize(request, raw_request)
326
+ if isinstance(generator, ErrorResponse):
327
+ return JSONResponse(content=generator.model_dump(),
328
+ status_code=generator.code)
329
+ elif isinstance(generator, TokenizeResponse):
330
+ return JSONResponse(content=generator.model_dump())
331
+
332
+ assert_never(generator)
333
+
334
+
335
+ @router.post("/detokenize")
336
+ @with_cancellation
337
+ async def detokenize(request: DetokenizeRequest, raw_request: Request):
338
+ handler = tokenization(raw_request)
339
+
340
+ generator = await handler.create_detokenize(request, raw_request)
341
+ if isinstance(generator, ErrorResponse):
342
+ return JSONResponse(content=generator.model_dump(),
343
+ status_code=generator.code)
344
+ elif isinstance(generator, DetokenizeResponse):
345
+ return JSONResponse(content=generator.model_dump())
346
+
347
+ assert_never(generator)
348
+
349
+
350
+ @router.get("/v1/models")
351
+ async def show_available_models(raw_request: Request):
352
+ handler = base(raw_request)
353
+
354
+ models = await handler.show_available_models()
96
355
  return JSONResponse(content=models.model_dump())
97
356
 
98
357
 
99
- @app.get("/version")
358
+ @router.get("/version")
100
359
  async def show_version():
101
- ver = {"version": vllm.__version__}
360
+ ver = {"version": VLLM_VERSION}
102
361
  return JSONResponse(content=ver)
103
362
 
104
363
 
105
- @app.post("/v1/chat/completions")
364
+ @router.post("/v1/chat/completions")
365
+ @with_cancellation
106
366
  async def create_chat_completion(request: ChatCompletionRequest,
107
367
  raw_request: Request):
108
- generator = await openai_serving_chat.create_chat_completion(
109
- request, raw_request)
368
+ handler = chat(raw_request)
369
+ if handler is None:
370
+ return base(raw_request).create_error_response(
371
+ message="The model does not support Chat Completions API")
372
+
373
+ generator = await handler.create_chat_completion(request, raw_request)
374
+
110
375
  if isinstance(generator, ErrorResponse):
111
376
  return JSONResponse(content=generator.model_dump(),
112
377
  status_code=generator.code)
113
- if request.stream:
114
- return StreamingResponse(content=generator,
115
- media_type="text/event-stream")
116
- else:
378
+
379
+ elif isinstance(generator, ChatCompletionResponse):
117
380
  return JSONResponse(content=generator.model_dump())
118
381
 
382
+ return StreamingResponse(content=generator, media_type="text/event-stream")
383
+
119
384
 
120
- @app.post("/v1/completions")
385
+ @router.post("/v1/completions")
386
+ @with_cancellation
121
387
  async def create_completion(request: CompletionRequest, raw_request: Request):
122
- generator = await openai_serving_completion.create_completion(
123
- request, raw_request)
388
+ handler = completion(raw_request)
389
+ if handler is None:
390
+ return base(raw_request).create_error_response(
391
+ message="The model does not support Completions API")
392
+
393
+ generator = await handler.create_completion(request, raw_request)
124
394
  if isinstance(generator, ErrorResponse):
125
395
  return JSONResponse(content=generator.model_dump(),
126
396
  status_code=generator.code)
127
- if request.stream:
128
- return StreamingResponse(content=generator,
129
- media_type="text/event-stream")
397
+ elif isinstance(generator, CompletionResponse):
398
+ return JSONResponse(content=generator.model_dump())
399
+
400
+ return StreamingResponse(content=generator, media_type="text/event-stream")
401
+
402
+
403
+ @router.post("/v1/embeddings")
404
+ @with_cancellation
405
+ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
406
+ handler = embedding(raw_request)
407
+ if handler is None:
408
+ fallback_handler = pooling(raw_request)
409
+ if fallback_handler is None:
410
+ return base(raw_request).create_error_response(
411
+ message="The model does not support Embeddings API")
412
+
413
+ logger.warning(
414
+ "Embeddings API will become exclusive to embedding models "
415
+ "in a future release. To return the hidden states directly, "
416
+ "use the Pooling API (`/pooling`) instead.")
417
+
418
+ res = await fallback_handler.create_pooling(request, raw_request)
419
+ if isinstance(res, PoolingResponse):
420
+ generator = EmbeddingResponse(
421
+ id=res.id,
422
+ object=res.object,
423
+ created=res.created,
424
+ model=res.model,
425
+ data=[
426
+ EmbeddingResponseData(
427
+ index=d.index,
428
+ embedding=d.data, # type: ignore
429
+ ) for d in res.data
430
+ ],
431
+ usage=res.usage,
432
+ )
433
+ else:
434
+ generator = res
130
435
  else:
436
+ generator = await handler.create_embedding(request, raw_request)
437
+
438
+ if isinstance(generator, ErrorResponse):
439
+ return JSONResponse(content=generator.model_dump(),
440
+ status_code=generator.code)
441
+ elif isinstance(generator, EmbeddingResponse):
131
442
  return JSONResponse(content=generator.model_dump())
132
443
 
444
+ assert_never(generator)
133
445
 
134
- if __name__ == "__main__":
135
- args = parse_args()
446
+
447
+ @router.post("/pooling")
448
+ @with_cancellation
449
+ async def create_pooling(request: PoolingRequest, raw_request: Request):
450
+ handler = pooling(raw_request)
451
+ if handler is None:
452
+ return base(raw_request).create_error_response(
453
+ message="The model does not support Pooling API")
454
+
455
+ generator = await handler.create_pooling(request, raw_request)
456
+ if isinstance(generator, ErrorResponse):
457
+ return JSONResponse(content=generator.model_dump(),
458
+ status_code=generator.code)
459
+ elif isinstance(generator, PoolingResponse):
460
+ return JSONResponse(content=generator.model_dump())
461
+
462
+ assert_never(generator)
463
+
464
+
465
+ @router.post("/score")
466
+ @with_cancellation
467
+ async def create_score(request: ScoreRequest, raw_request: Request):
468
+ handler = score(raw_request)
469
+ if handler is None:
470
+ return base(raw_request).create_error_response(
471
+ message="The model does not support Score API")
472
+
473
+ generator = await handler.create_score(request, raw_request)
474
+ if isinstance(generator, ErrorResponse):
475
+ return JSONResponse(content=generator.model_dump(),
476
+ status_code=generator.code)
477
+ elif isinstance(generator, ScoreResponse):
478
+ return JSONResponse(content=generator.model_dump())
479
+
480
+ assert_never(generator)
481
+
482
+
483
+ @router.post("/v1/score")
484
+ @with_cancellation
485
+ async def create_score_v1(request: ScoreRequest, raw_request: Request):
486
+ logger.warning(
487
+ "To indicate that Score API is not part of standard OpenAI API, we "
488
+ "have moved it to `/score`. Please update your client accordingly.")
489
+
490
+ return await create_score(request, raw_request)
491
+
492
+
493
+ if envs.VLLM_TORCH_PROFILER_DIR:
494
+ logger.warning(
495
+ "Torch Profiler is enabled in the API server. This should ONLY be "
496
+ "used for local development!")
497
+
498
+ @router.post("/start_profile")
499
+ async def start_profile(raw_request: Request):
500
+ logger.info("Starting profiler...")
501
+ await engine_client(raw_request).start_profile()
502
+ logger.info("Profiler started.")
503
+ return Response(status_code=200)
504
+
505
+ @router.post("/stop_profile")
506
+ async def stop_profile(raw_request: Request):
507
+ logger.info("Stopping profiler...")
508
+ await engine_client(raw_request).stop_profile()
509
+ logger.info("Profiler stopped.")
510
+ return Response(status_code=200)
511
+
512
+
513
+ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
514
+ logger.warning(
515
+ "Lora dynamic loading & unloading is enabled in the API server. "
516
+ "This should ONLY be used for local development!")
517
+
518
+ @router.post("/v1/load_lora_adapter")
519
+ async def load_lora_adapter(request: LoadLoraAdapterRequest,
520
+ raw_request: Request):
521
+ for route in [chat, completion, embedding]:
522
+ handler = route(raw_request)
523
+ if handler is not None:
524
+ response = await handler.load_lora_adapter(request)
525
+ if isinstance(response, ErrorResponse):
526
+ return JSONResponse(content=response.model_dump(),
527
+ status_code=response.code)
528
+
529
+ return Response(status_code=200, content=response)
530
+
531
+ @router.post("/v1/unload_lora_adapter")
532
+ async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
533
+ raw_request: Request):
534
+ for route in [chat, completion, embedding]:
535
+ handler = route(raw_request)
536
+ if handler is not None:
537
+ response = await handler.unload_lora_adapter(request)
538
+ if isinstance(response, ErrorResponse):
539
+ return JSONResponse(content=response.model_dump(),
540
+ status_code=response.code)
541
+
542
+ return Response(status_code=200, content=response)
543
+
544
+
545
+ def build_app(args: Namespace) -> FastAPI:
546
+ if args.disable_fastapi_docs:
547
+ app = FastAPI(openapi_url=None,
548
+ docs_url=None,
549
+ redoc_url=None,
550
+ lifespan=lifespan)
551
+ else:
552
+ app = FastAPI(lifespan=lifespan)
553
+ app.include_router(router)
554
+ app.root_path = args.root_path
555
+
556
+ mount_metrics(app)
136
557
 
137
558
  app.add_middleware(
138
559
  CORSMiddleware,
@@ -142,18 +563,43 @@ if __name__ == "__main__":
142
563
  allow_headers=args.allowed_headers,
143
564
  )
144
565
 
145
- token = os.environ.get("VLLM_API_KEY") or args.api_key
146
- if token:
566
+ @app.exception_handler(RequestValidationError)
567
+ async def validation_exception_handler(_, exc):
568
+ err = ErrorResponse(message=str(exc),
569
+ type="BadRequestError",
570
+ code=HTTPStatus.BAD_REQUEST)
571
+ return JSONResponse(err.model_dump(),
572
+ status_code=HTTPStatus.BAD_REQUEST)
573
+
574
+ if token := envs.VLLM_API_KEY or args.api_key:
575
+
147
576
  @app.middleware("http")
148
577
  async def authentication(request: Request, call_next):
149
- root_path = "" if args.root_path is None else args.root_path
150
- if not request.url.path.startswith(f"{root_path}/v1"):
578
+ if request.method == "OPTIONS":
579
+ return await call_next(request)
580
+ url_path = request.url.path
581
+ if app.root_path and url_path.startswith(app.root_path):
582
+ url_path = url_path[len(app.root_path):]
583
+ if not url_path.startswith("/v1"):
151
584
  return await call_next(request)
152
585
  if request.headers.get("Authorization") != "Bearer " + token:
153
586
  return JSONResponse(content={"error": "Unauthorized"},
154
587
  status_code=401)
155
588
  return await call_next(request)
156
589
 
590
+ if args.enable_request_id_headers:
591
+ logger.warning(
592
+ "CAUTION: Enabling X-Request-Id headers in the API Server. "
593
+ "This can harm performance at high QPS.")
594
+
595
+ @app.middleware("http")
596
+ async def add_request_id(request: Request, call_next):
597
+ request_id = request.headers.get(
598
+ "X-Request-Id") or uuid.uuid4().hex
599
+ response = await call_next(request)
600
+ response.headers["X-Request-Id"] = request_id
601
+ return response
602
+
157
603
  for middleware in args.middleware:
158
604
  module_path, object_name = middleware.rsplit(".", 1)
159
605
  imported = getattr(importlib.import_module(module_path), object_name)
@@ -162,35 +608,174 @@ if __name__ == "__main__":
162
608
  elif inspect.iscoroutinefunction(imported):
163
609
  app.middleware("http")(imported)
164
610
  else:
165
- invalidInputError(False, (f"Invalid middleware {middleware}. "
166
- f"Must be a function or a class."))
611
+ raise ValueError(f"Invalid middleware {middleware}. "
612
+ f"Must be a function or a class.")
613
+
614
+ return app
167
615
 
168
- logger.info("vLLM API server version %s", vllm.__version__)
169
- logger.info("args: %s", args)
170
616
 
617
+ def init_app_state(
618
+ engine_client: EngineClient,
619
+ model_config: ModelConfig,
620
+ state: State,
621
+ args: Namespace,
622
+ ) -> None:
171
623
  if args.served_model_name is not None:
172
624
  served_model_names = args.served_model_name
173
625
  else:
174
626
  served_model_names = [args.model]
175
- engine_args = AsyncEngineArgs.from_cli_args(args)
176
- engine = IPEXLLMAsyncLLMEngine.from_engine_args(
177
- engine_args, usage_context=UsageContext.OPENAI_API_SERVER,
178
- load_in_low_bit=args.load_in_low_bit,
627
+
628
+ if args.disable_log_requests:
629
+ request_logger = None
630
+ else:
631
+ request_logger = RequestLogger(max_log_len=args.max_log_len)
632
+
633
+ base_model_paths = [
634
+ BaseModelPath(name=name, model_path=args.model)
635
+ for name in served_model_names
636
+ ]
637
+
638
+ state.engine_client = engine_client
639
+ state.log_stats = not args.disable_log_stats
640
+
641
+ resolved_chat_template = load_chat_template(args.chat_template)
642
+ logger.info("Using supplied chat template:\n%s", resolved_chat_template)
643
+
644
+ state.openai_serving_chat = OpenAIServingChat(
645
+ engine_client,
646
+ model_config,
647
+ base_model_paths,
648
+ args.response_role,
649
+ lora_modules=args.lora_modules,
650
+ prompt_adapters=args.prompt_adapters,
651
+ request_logger=request_logger,
652
+ chat_template=resolved_chat_template,
653
+ chat_template_content_format=args.chat_template_content_format,
654
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
655
+ enable_auto_tools=args.enable_auto_tool_choice,
656
+ tool_parser=args.tool_call_parser,
657
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
658
+ ) if model_config.runner_type == "generate" else None
659
+ state.openai_serving_completion = OpenAIServingCompletion(
660
+ engine_client,
661
+ model_config,
662
+ base_model_paths,
663
+ lora_modules=args.lora_modules,
664
+ prompt_adapters=args.prompt_adapters,
665
+ request_logger=request_logger,
666
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
667
+ ) if model_config.runner_type == "generate" else None
668
+ state.openai_serving_pooling = OpenAIServingPooling(
669
+ engine_client,
670
+ model_config,
671
+ base_model_paths,
672
+ request_logger=request_logger,
673
+ chat_template=resolved_chat_template,
674
+ chat_template_content_format=args.chat_template_content_format,
675
+ ) if model_config.runner_type == "pooling" else None
676
+ state.openai_serving_embedding = OpenAIServingEmbedding(
677
+ engine_client,
678
+ model_config,
679
+ base_model_paths,
680
+ request_logger=request_logger,
681
+ chat_template=resolved_chat_template,
682
+ chat_template_content_format=args.chat_template_content_format,
683
+ ) if model_config.task == "embed" else None
684
+ state.openai_serving_scores = OpenAIServingScores(
685
+ engine_client,
686
+ model_config,
687
+ base_model_paths,
688
+ request_logger=request_logger
689
+ ) if model_config.task == "score" else None
690
+ state.openai_serving_tokenization = OpenAIServingTokenization(
691
+ engine_client,
692
+ model_config,
693
+ base_model_paths,
694
+ lora_modules=args.lora_modules,
695
+ request_logger=request_logger,
696
+ chat_template=resolved_chat_template,
697
+ chat_template_content_format=args.chat_template_content_format,
179
698
  )
180
- openai_serving_chat = OpenAIServingChat(engine, served_model_names,
181
- args.response_role,
182
- args.lora_modules,
183
- args.chat_template)
184
- openai_serving_completion = OpenAIServingCompletion(
185
- engine, served_model_names, args.lora_modules)
186
699
 
187
- app.root_path = args.root_path
188
- uvicorn.run(app,
189
- host=args.host,
190
- port=args.port,
191
- log_level=args.uvicorn_log_level,
192
- timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
193
- ssl_keyfile=args.ssl_keyfile,
194
- ssl_certfile=args.ssl_certfile,
195
- ssl_ca_certs=args.ssl_ca_certs,
196
- ssl_cert_reqs=args.ssl_cert_reqs)
700
+
701
+ def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
702
+ family = socket.AF_INET
703
+ if is_valid_ipv6_address(addr[0]):
704
+ family = socket.AF_INET6
705
+
706
+ sock = socket.socket(family=family, type=socket.SOCK_STREAM)
707
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
708
+ sock.bind(addr)
709
+
710
+ return sock
711
+
712
+
713
+ async def run_server(args, **uvicorn_kwargs) -> None:
714
+ logger.info("vLLM API server version %s", VLLM_VERSION)
715
+ logger.info("args: %s", args)
716
+
717
+ if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
718
+ ToolParserManager.import_tool_parser(args.tool_parser_plugin)
719
+
720
+ valide_tool_parses = ToolParserManager.tool_parsers.keys()
721
+ if args.enable_auto_tool_choice \
722
+ and args.tool_call_parser not in valide_tool_parses:
723
+ raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
724
+ f"(chose from {{ {','.join(valide_tool_parses)} }})")
725
+
726
+ # workaround to make sure that we bind the port before the engine is set up.
727
+ # This avoids race conditions with ray.
728
+ # see https://github.com/vllm-project/vllm/issues/8204
729
+ sock_addr = (args.host or "", args.port)
730
+ sock = create_server_socket(sock_addr)
731
+
732
+ # workaround to avoid footguns where uvicorn drops requests with too
733
+ # many concurrent requests active
734
+ set_ulimit()
735
+
736
+ def signal_handler(*_) -> None:
737
+ # Interrupt server on sigterm while initializing
738
+ raise KeyboardInterrupt("terminated")
739
+
740
+ signal.signal(signal.SIGTERM, signal_handler)
741
+
742
+ async with build_async_engine_client(args) as engine_client:
743
+ app = build_app(args)
744
+
745
+ model_config = await engine_client.get_model_config()
746
+ init_app_state(engine_client, model_config, app.state, args)
747
+
748
+ shutdown_task = await serve_http(
749
+ app,
750
+ host=args.host,
751
+ port=args.port,
752
+ log_level=args.uvicorn_log_level,
753
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
754
+ ssl_keyfile=args.ssl_keyfile,
755
+ ssl_certfile=args.ssl_certfile,
756
+ ssl_ca_certs=args.ssl_ca_certs,
757
+ ssl_cert_reqs=args.ssl_cert_reqs,
758
+ **uvicorn_kwargs,
759
+ )
760
+
761
+ # NB: Await server shutdown only after the backend context is exited
762
+ await shutdown_task
763
+
764
+ sock.close()
765
+
766
+
767
+ if __name__ == "__main__":
768
+ # NOTE(simon):
769
+ # This section should be in sync with vllm/scripts.py for CLI entrypoints.
770
+ parser = FlexibleArgumentParser(
771
+ description="vLLM OpenAI-Compatible RESTful API server.")
772
+ parser = make_arg_parser(parser)
773
+ parser.add_argument(
774
+ "--load-in-low-bit",
775
+ type=str,
776
+ default="sym_int4",
777
+ help="Low-bit quantization for IPEX-LLM models")
778
+ args = parser.parse_args()
779
+ validate_parsed_serve_args(args)
780
+
781
+ uvloop.run(run_server(args))