arbor-ai 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +12 -0
- arbor/server/api/routes/grpo.py +4 -1
- arbor/server/api/routes/inference.py +11 -16
- arbor/server/services/grpo_manager.py +179 -98
- arbor/server/services/inference/vllm_client.py +445 -0
- arbor/server/services/inference/vllm_serve.py +2335 -0
- arbor/server/services/inference_manager.py +145 -272
- arbor/server/services/scripts/dpo_training.py +0 -0
- arbor/server/services/scripts/grpo_training.py +157 -53
- arbor/server/services/scripts/sft_training.py +109 -0
- arbor/server/services/scripts/utils/__init__.py +0 -0
- arbor/server/services/scripts/utils/arg_parser.py +31 -0
- arbor/server/services/scripts/utils/dataset.py +0 -0
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.1.15.dist-info}/METADATA +4 -5
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.1.15.dist-info}/RECORD +19 -13
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.1.15.dist-info}/WHEEL +1 -1
- arbor/server/services/inference/sgl_router_launch_server.py +0 -226
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.1.15.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.1.15.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.1.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2335 @@
|
|
1
|
+
"""
|
2
|
+
OpenAI-compatible vLLM server with weight synchronization.
|
3
|
+
|
4
|
+
Usage:
|
5
|
+
|
6
|
+
```bash
|
7
|
+
uv run python vllm_server.py --model <model_name> --port <port>
|
8
|
+
```
|
9
|
+
|
10
|
+
Supports:
|
11
|
+
- /v1/chat/completions
|
12
|
+
- /v1/completions
|
13
|
+
"""
|
14
|
+
|
15
|
+
import argparse
|
16
|
+
import asyncio
|
17
|
+
import inspect
|
18
|
+
import json
|
19
|
+
import logging
|
20
|
+
import os
|
21
|
+
import threading
|
22
|
+
import time
|
23
|
+
import traceback
|
24
|
+
from collections import defaultdict
|
25
|
+
from collections.abc import Sequence
|
26
|
+
from contextlib import asynccontextmanager
|
27
|
+
from dataclasses import dataclass, field
|
28
|
+
from datetime import datetime, timezone
|
29
|
+
from multiprocessing import Pipe, Process
|
30
|
+
from multiprocessing.connection import Connection as MPConnection
|
31
|
+
from typing import Any
|
32
|
+
from typing import Any as AnyType
|
33
|
+
from typing import Literal, Optional
|
34
|
+
from uuid import uuid4
|
35
|
+
|
36
|
+
import torch
|
37
|
+
import uvicorn
|
38
|
+
from fastapi import FastAPI
|
39
|
+
from fastapi.responses import JSONResponse, StreamingResponse
|
40
|
+
from pydantic import BaseModel, ConfigDict
|
41
|
+
from trl import TrlParser
|
42
|
+
from vllm import LLM, SamplingParams
|
43
|
+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
44
|
+
from vllm.distributed.parallel_state import get_world_group
|
45
|
+
from vllm.distributed.utils import StatelessProcessGroup
|
46
|
+
from vllm.sampling_params import GuidedDecodingParams
|
47
|
+
from vllm.utils import get_open_port
|
48
|
+
|
49
|
+
logging.basicConfig(
|
50
|
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
51
|
+
)
|
52
|
+
logger = logging.getLogger(__name__) # Ensure logger is defined
|
53
|
+
|
54
|
+
# We use CUDA with multiprocessing, so we must use the 'spawn' start method. Otherwise, we will get the following
|
55
|
+
# error: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use
|
56
|
+
# the 'spawn' start method
|
57
|
+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
58
|
+
|
59
|
+
# At the global level, after imports and logger setup:
|
60
|
+
pipe_lock = threading.Lock() # Global lock for pipe operations
|
61
|
+
request_queue: Optional[asyncio.Queue] = None
|
62
|
+
batch_processor_task: Optional[asyncio.Task] = None
|
63
|
+
|
64
|
+
# Generation tracking
|
65
|
+
active_generation_count = 0
|
66
|
+
generation_count_lock = asyncio.Lock()
|
67
|
+
|
68
|
+
# Weight update throttling
|
69
|
+
MAX_CONCURRENT_WEIGHT_UPDATES = 5
|
70
|
+
weight_update_semaphore = asyncio.Semaphore(MAX_CONCURRENT_WEIGHT_UPDATES)
|
71
|
+
|
72
|
+
# Worker rotation for load balancing
|
73
|
+
worker_rotation_index = 0
|
74
|
+
worker_rotation_lock = asyncio.Lock()
|
75
|
+
|
76
|
+
|
77
|
+
async def get_next_worker_connection(connections: list[AnyType]) -> tuple[int, AnyType]:
|
78
|
+
"""Get the next worker connection using round-robin rotation."""
|
79
|
+
global worker_rotation_index
|
80
|
+
async with worker_rotation_lock:
|
81
|
+
if not connections:
|
82
|
+
raise RuntimeError("No worker connections available")
|
83
|
+
worker_idx = worker_rotation_index % len(connections)
|
84
|
+
worker_rotation_index += 1
|
85
|
+
return worker_idx, connections[worker_idx]
|
86
|
+
|
87
|
+
|
88
|
+
# -------- OpenAI /v1/chat/completions Pydantic Models ---------- #
|
89
|
+
class OAChatMessage(BaseModel):
|
90
|
+
role: str
|
91
|
+
content: str
|
92
|
+
|
93
|
+
|
94
|
+
class OAChatResponseFormat(BaseModel):
|
95
|
+
type: Literal["json_schema", "json_object"]
|
96
|
+
json_schema: Optional[dict] = None
|
97
|
+
|
98
|
+
model_config = ConfigDict(frozen=True)
|
99
|
+
|
100
|
+
def __hash__(self):
|
101
|
+
return hash(
|
102
|
+
(
|
103
|
+
self.type,
|
104
|
+
(
|
105
|
+
json.dumps(self.json_schema, sort_keys=True)
|
106
|
+
if self.json_schema
|
107
|
+
else None
|
108
|
+
),
|
109
|
+
)
|
110
|
+
)
|
111
|
+
|
112
|
+
|
113
|
+
class OAChatCompletionRequest(BaseModel):
|
114
|
+
model: str
|
115
|
+
messages: list[OAChatMessage]
|
116
|
+
temperature: float | None = 0.7
|
117
|
+
top_p: float | None = 1.0
|
118
|
+
presence_penalty: float | None = 0.0
|
119
|
+
frequency_penalty: float | None = 0.0
|
120
|
+
max_tokens: int | None = 1024
|
121
|
+
n: int | None = 1
|
122
|
+
stop: str | list[str] | None = None
|
123
|
+
stream: bool = False # not supported
|
124
|
+
extra_body: dict | None = None
|
125
|
+
response_format: dict | None = None
|
126
|
+
# supported by vLLM:
|
127
|
+
# guided_decoding, include_stop_str_in_output, skip_special_tokens, spaces_between_special_tokens
|
128
|
+
|
129
|
+
|
130
|
+
class OAChatChoice(BaseModel):
|
131
|
+
index: int
|
132
|
+
message: OAChatMessage
|
133
|
+
finish_reason: str | None = "stop"
|
134
|
+
|
135
|
+
|
136
|
+
class OAChatCompletionResponse(BaseModel):
|
137
|
+
id: str
|
138
|
+
object: str = "chat.completion"
|
139
|
+
created: int
|
140
|
+
model: str
|
141
|
+
choices: list[OAChatChoice]
|
142
|
+
|
143
|
+
|
144
|
+
# -------- OpenAI /v1/completions Pydantic Models ---------- #
|
145
|
+
class OACompletionRequest(BaseModel):
|
146
|
+
model: str
|
147
|
+
prompt: str | list[str]
|
148
|
+
temperature: float | None = 0.7
|
149
|
+
top_p: float | None = 1.0
|
150
|
+
presence_penalty: float | None = 0.0
|
151
|
+
frequency_penalty: float | None = 0.0
|
152
|
+
max_tokens: int | None = 1024
|
153
|
+
n: int = 1
|
154
|
+
stop: str | list[str] | None = None
|
155
|
+
stream: bool = False # not supported
|
156
|
+
extra_body: dict | None = None
|
157
|
+
|
158
|
+
|
159
|
+
class OACompletionChoice(BaseModel):
|
160
|
+
index: int
|
161
|
+
text: str
|
162
|
+
logprobs: dict | None = None
|
163
|
+
finish_reason: str | None = "length"
|
164
|
+
|
165
|
+
|
166
|
+
class OACompletionResponse(BaseModel):
|
167
|
+
id: str
|
168
|
+
object: str = "completion"
|
169
|
+
created: int
|
170
|
+
model: str
|
171
|
+
choices: list[OACompletionChoice]
|
172
|
+
|
173
|
+
|
174
|
+
# ---------------------------------------------------------------------- #
|
175
|
+
|
176
|
+
|
177
|
+
def send_and_recv(conn: MPConnection, payload: dict):
|
178
|
+
"""Helper to send a payload and receive a response over a pipe."""
|
179
|
+
# Use the global pipe_lock
|
180
|
+
with pipe_lock:
|
181
|
+
conn.send(payload)
|
182
|
+
return conn.recv()
|
183
|
+
|
184
|
+
|
185
|
+
async def async_send_and_recv(conn: MPConnection, payload: dict, timeout: float = 30.0):
|
186
|
+
"""Async helper to send a payload and receive a response with timeout."""
|
187
|
+
loop = asyncio.get_running_loop()
|
188
|
+
|
189
|
+
# Send the payload in a thread to avoid blocking
|
190
|
+
async with asyncio.timeout(timeout):
|
191
|
+
try:
|
192
|
+
# Use the global pipe_lock in the executor
|
193
|
+
await loop.run_in_executor(None, lambda: pipe_lock.acquire())
|
194
|
+
try:
|
195
|
+
await loop.run_in_executor(None, conn.send, payload)
|
196
|
+
|
197
|
+
# Poll for response with timeout
|
198
|
+
start_time = asyncio.get_event_loop().time()
|
199
|
+
while asyncio.get_event_loop().time() - start_time < timeout:
|
200
|
+
if await loop.run_in_executor(None, conn.poll, 0.1):
|
201
|
+
result = await loop.run_in_executor(None, conn.recv)
|
202
|
+
return result
|
203
|
+
await asyncio.sleep(0.05) # Small sleep to avoid busy waiting
|
204
|
+
|
205
|
+
raise asyncio.TimeoutError(
|
206
|
+
f"Worker did not respond within {timeout} seconds"
|
207
|
+
)
|
208
|
+
finally:
|
209
|
+
await loop.run_in_executor(None, lambda: pipe_lock.release())
|
210
|
+
except asyncio.TimeoutError:
|
211
|
+
logger.error(f"Timeout waiting for worker response after {timeout}s")
|
212
|
+
raise
|
213
|
+
except Exception as e:
|
214
|
+
logger.error(f"Error in async_send_and_recv: {e}", exc_info=True)
|
215
|
+
raise
|
216
|
+
|
217
|
+
|
218
|
+
class WeightSyncWorkerExtension:
|
219
|
+
"""
|
220
|
+
A vLLM worker extension that enables weight synchronization between a client and multiple server workers.
|
221
|
+
|
222
|
+
This worker uses a `StatelessProcessGroup` to establish communication and a `PyNcclCommunicator` to handle
|
223
|
+
efficient GPU-based communication using NCCL. The primary purpose of this class is to receive updated model weights
|
224
|
+
from a client process and distribute them to all worker processes participating in model inference.
|
225
|
+
"""
|
226
|
+
|
227
|
+
# The following attributes are initialized when `init_communicator` method is called.
|
228
|
+
pynccl_comm = None # Communicator for weight updates
|
229
|
+
client_rank = None # Source rank for broadcasting updated weights
|
230
|
+
|
231
|
+
def init_communicator(self, host: str, port: int, world_size: int) -> None:
|
232
|
+
"""
|
233
|
+
Initializes the weight update communicator using a stateless process group.
|
234
|
+
|
235
|
+
This method creates a `StatelessProcessGroup` that allows external training processes to
|
236
|
+
communicate with vLLM workers without interfering with the global torch distributed group.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
host (`str`):
|
240
|
+
Hostname or IP address of the master node.
|
241
|
+
port (`int`):
|
242
|
+
Port number to be used for communication.
|
243
|
+
world_size (`int`):
|
244
|
+
Total number of participating processes in the update group.
|
245
|
+
"""
|
246
|
+
if self.pynccl_comm is not None:
|
247
|
+
raise RuntimeError(
|
248
|
+
"Weight update group already initialized. Call close_communicator first."
|
249
|
+
)
|
250
|
+
|
251
|
+
# Get the rank of the current worker in the global world group.
|
252
|
+
rank = get_world_group().rank
|
253
|
+
|
254
|
+
# Log device information for debugging
|
255
|
+
logger.info(f"[WORKER] Initializing communicator: rank={rank}, device={self.device}, world_size={world_size}") # type: ignore
|
256
|
+
|
257
|
+
# Create a stateless process group to manage communication between training processes and vLLM workers.
|
258
|
+
pg = StatelessProcessGroup.create(
|
259
|
+
host=host, port=port, rank=rank, world_size=world_size
|
260
|
+
)
|
261
|
+
|
262
|
+
# Initialize the NCCL-based communicator for weight synchronization.
|
263
|
+
self.pynccl_comm = PyNcclCommunicator(pg, device=self.device) # type: ignore
|
264
|
+
|
265
|
+
# The client process that sends updated weights has the highest rank (world_size - 1).
|
266
|
+
self.client_rank = world_size - 1
|
267
|
+
|
268
|
+
def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None:
|
269
|
+
"""
|
270
|
+
Receives updated weights from the client process and updates the named parameter in the model.
|
271
|
+
|
272
|
+
Args:
|
273
|
+
name (`str`):
|
274
|
+
Name of the weight tensor being updated.
|
275
|
+
dtype (`torch.dtype`):
|
276
|
+
Data type of the weight tensor (e.g., `torch.float32`).
|
277
|
+
shape (`Sequence[int]`):
|
278
|
+
Shape of the weight tensor.
|
279
|
+
"""
|
280
|
+
import logging
|
281
|
+
|
282
|
+
logger = logging.getLogger(__name__)
|
283
|
+
logger.debug(
|
284
|
+
f"[WORKER] Received weight update request for {name}, dtype={dtype}, shape={shape}"
|
285
|
+
)
|
286
|
+
|
287
|
+
if self.pynccl_comm is None:
|
288
|
+
raise RuntimeError(
|
289
|
+
"Communicator not initialized. Call `init_communicator` first."
|
290
|
+
)
|
291
|
+
|
292
|
+
dtype = getattr(torch, dtype.split(".")[-1])
|
293
|
+
|
294
|
+
# Allocate memory for the incoming weight tensor on the correct device.
|
295
|
+
weight = torch.empty(shape, dtype=dtype, device=self.device) # type: ignore
|
296
|
+
|
297
|
+
logger.debug(f"[WORKER] Starting NCCL broadcast for {name}")
|
298
|
+
# Use NCCL to broadcast the updated weights from the client (src) to all workers.
|
299
|
+
self.pynccl_comm.broadcast(weight, src=self.client_rank) # type: ignore
|
300
|
+
logger.debug(f"[WORKER] NCCL broadcast complete, waiting at barrier for {name}")
|
301
|
+
self.pynccl_comm.group.barrier()
|
302
|
+
logger.debug(f"[WORKER] Barrier passed, loading weights for {name}")
|
303
|
+
|
304
|
+
# Load the received weights into the model.
|
305
|
+
self.model_runner.model.load_weights(weights=[(name, weight)]) # type: ignore
|
306
|
+
logger.debug(f"[WORKER] Weight update complete for {name}")
|
307
|
+
|
308
|
+
def close_communicator(self) -> None:
|
309
|
+
"""
|
310
|
+
Closes the communicator when weight synchronization is no longer needed.
|
311
|
+
|
312
|
+
This method deletes the NCCL communicator to release associated resources.
|
313
|
+
"""
|
314
|
+
|
315
|
+
if self.pynccl_comm is not None:
|
316
|
+
del self.pynccl_comm
|
317
|
+
self.pynccl_comm = None # Ensure attribute is reset to None
|
318
|
+
self.client_rank = None # Ensure attribute is reset to None
|
319
|
+
|
320
|
+
|
321
|
+
@dataclass
|
322
|
+
class ScriptArguments:
|
323
|
+
r"""
|
324
|
+
Arguments for the script.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
model (`str`):
|
328
|
+
Model name or path to load the model from.
|
329
|
+
revision (`str` or `None`, *optional*, defaults to `None`):
|
330
|
+
Revision to use for the model. If not specified, the default branch will be used.
|
331
|
+
tensor_parallel_size (`int`, *optional*, defaults to `1`):
|
332
|
+
Number of tensor parallel workers to use.
|
333
|
+
data_parallel_size (`int`, *optional*, defaults to `1`):
|
334
|
+
Number of data parallel workers to use.
|
335
|
+
host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
336
|
+
Host address to run the server on.
|
337
|
+
port (`int`, *optional*, defaults to `8000`):
|
338
|
+
Port to run the server on.
|
339
|
+
gpu_memory_utilization (`float`, *optional*, defaults to `0.95`):
|
340
|
+
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
341
|
+
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
342
|
+
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
343
|
+
during initialization.
|
344
|
+
dtype (`str`, *optional*, defaults to `"auto"`):
|
345
|
+
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
346
|
+
based on the model configuration. Find the supported values in the vLLM documentation.
|
347
|
+
max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
348
|
+
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
|
349
|
+
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
350
|
+
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
351
|
+
enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`):
|
352
|
+
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support
|
353
|
+
this feature.
|
354
|
+
enforce_eager (`bool` or `None`, *optional*, defaults to `None`):
|
355
|
+
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the
|
356
|
+
model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid.
|
357
|
+
kv_cache_dtype (`str`, *optional*, defaults to `"auto"`):
|
358
|
+
Data type to use for KV cache. If set to `"auto"`, the dtype will default to the model data type.
|
359
|
+
log_level (`str`, *optional*, defaults to `"info"`):
|
360
|
+
Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`,
|
361
|
+
`"trace"`.
|
362
|
+
max_batch_size (int):
|
363
|
+
Maximum number of requests to process in one LLM call from the active pool.
|
364
|
+
batch_request_timeout_seconds (int):
|
365
|
+
Timeout in seconds for a single request waiting for its turn and completion.
|
366
|
+
token_chunk_size (int):
|
367
|
+
Number of tokens to generate per iteration per request in token-chunk dynamic batching.
|
368
|
+
"""
|
369
|
+
|
370
|
+
model: str = field(metadata={"help": "Model name or path to load the model from."})
|
371
|
+
revision: Optional[str] = field(
|
372
|
+
default=None,
|
373
|
+
metadata={
|
374
|
+
"help": "Revision to use for the model. If not specified, the default branch will be used."
|
375
|
+
},
|
376
|
+
)
|
377
|
+
tensor_parallel_size: int = field(
|
378
|
+
default=1,
|
379
|
+
metadata={"help": "Number of tensor parallel workers to use."},
|
380
|
+
)
|
381
|
+
data_parallel_size: int = field(
|
382
|
+
default=1,
|
383
|
+
metadata={"help": "Number of data parallel workers to use."},
|
384
|
+
)
|
385
|
+
host: str = field(
|
386
|
+
default="0.0.0.0",
|
387
|
+
metadata={"help": "Host address to run the server on."},
|
388
|
+
)
|
389
|
+
port: int = field(
|
390
|
+
default=8000,
|
391
|
+
metadata={"help": "Port to run the server on."},
|
392
|
+
)
|
393
|
+
gpu_memory_utilization: float = field(
|
394
|
+
default=0.95,
|
395
|
+
metadata={
|
396
|
+
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
397
|
+
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
398
|
+
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
399
|
+
"out-of-memory (OOM) errors during initialization."
|
400
|
+
},
|
401
|
+
)
|
402
|
+
dtype: str = field(
|
403
|
+
default="auto",
|
404
|
+
metadata={
|
405
|
+
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
406
|
+
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
407
|
+
},
|
408
|
+
)
|
409
|
+
max_model_len: Optional[int] = field(
|
410
|
+
default=8192,
|
411
|
+
metadata={
|
412
|
+
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
413
|
+
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
414
|
+
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
415
|
+
},
|
416
|
+
)
|
417
|
+
enable_prefix_caching: Optional[bool] = field(
|
418
|
+
default=True,
|
419
|
+
metadata={
|
420
|
+
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
421
|
+
"hardware support this feature."
|
422
|
+
},
|
423
|
+
)
|
424
|
+
enforce_eager: Optional[bool] = field(
|
425
|
+
default=None,
|
426
|
+
metadata={
|
427
|
+
"help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always "
|
428
|
+
"execute the model in eager mode. If `False` (default behavior), we will use CUDA graph and eager "
|
429
|
+
"execution in hybrid."
|
430
|
+
},
|
431
|
+
)
|
432
|
+
kv_cache_dtype: str = field(
|
433
|
+
default="auto",
|
434
|
+
metadata={
|
435
|
+
"help": "Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type."
|
436
|
+
},
|
437
|
+
)
|
438
|
+
log_level: str = field(
|
439
|
+
default="debug",
|
440
|
+
metadata={
|
441
|
+
"help": "Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', "
|
442
|
+
"'trace'."
|
443
|
+
},
|
444
|
+
)
|
445
|
+
max_batch_size: int = field(
|
446
|
+
default=32,
|
447
|
+
metadata={
|
448
|
+
"help": "Maximum number of requests to process in one LLM call from the active pool."
|
449
|
+
},
|
450
|
+
)
|
451
|
+
batch_request_timeout_seconds: int = field(
|
452
|
+
default=300,
|
453
|
+
metadata={
|
454
|
+
"help": "Timeout in seconds for a single request waiting for its turn and completion."
|
455
|
+
},
|
456
|
+
)
|
457
|
+
token_chunk_size: int = field(
|
458
|
+
default=64,
|
459
|
+
metadata={
|
460
|
+
"help": "Number of tokens to generate per iteration in token-chunk dynamic batching."
|
461
|
+
},
|
462
|
+
)
|
463
|
+
|
464
|
+
|
465
|
+
# Global/module-level variables for token-chunk dynamic batching
|
466
|
+
_SAMPLING_PARAM_NAMES: Optional[frozenset[str]] = None
|
467
|
+
|
468
|
+
|
469
|
+
@dataclass(frozen=True)
|
470
|
+
class PoolSignature:
|
471
|
+
model_name: str
|
472
|
+
request_type: Literal["chat", "completion"]
|
473
|
+
# Excludes max_tokens and stream
|
474
|
+
sampling_params_tuple: tuple[tuple[str, AnyType], ...]
|
475
|
+
extra_body_params_tuple: tuple[tuple[str, AnyType], ...]
|
476
|
+
|
477
|
+
|
478
|
+
@dataclass
|
479
|
+
class PooledRequestState:
|
480
|
+
original_request: AnyType # OAChatCompletionRequest or OACompletionRequest
|
481
|
+
completion_event: asyncio.Event
|
482
|
+
result_container: list
|
483
|
+
request_id: str
|
484
|
+
request_type: Literal["chat", "completion"]
|
485
|
+
pool_signature: PoolSignature # Store the signature for quick checks
|
486
|
+
effective_max_tokens: int
|
487
|
+
accumulated_content: str = ""
|
488
|
+
generated_token_count: int = 0
|
489
|
+
original_chat_messages: Optional[list[OAChatMessage]] = None
|
490
|
+
original_prompt: Optional[str | list[str]] = None
|
491
|
+
error: Optional[Exception] = None
|
492
|
+
finish_reason: Optional[str] = None
|
493
|
+
completed_and_signaled: bool = False
|
494
|
+
timed_out: bool = False
|
495
|
+
|
496
|
+
@property
|
497
|
+
def is_complete(self) -> bool:
|
498
|
+
"""Single source of truth for whether this request is complete and should not be processed further."""
|
499
|
+
# Already finalized
|
500
|
+
if self.completed_and_signaled:
|
501
|
+
return True
|
502
|
+
|
503
|
+
# Error state
|
504
|
+
if self.error is not None:
|
505
|
+
return True
|
506
|
+
|
507
|
+
# Reached token limit
|
508
|
+
if self.generated_token_count >= self.effective_max_tokens:
|
509
|
+
return True
|
510
|
+
|
511
|
+
# Not enough room for meaningful generation (less than 1 token)
|
512
|
+
tokens_remaining = self.effective_max_tokens - self.generated_token_count
|
513
|
+
if tokens_remaining < 1:
|
514
|
+
return True
|
515
|
+
|
516
|
+
# vLLM indicated completion - but ignore "length" as that's just the chunk limit
|
517
|
+
if self.finish_reason is not None and self.finish_reason != "length":
|
518
|
+
return True
|
519
|
+
|
520
|
+
return False
|
521
|
+
|
522
|
+
|
523
|
+
pending_requests_by_signature: defaultdict[PoolSignature, list[PooledRequestState]] = (
|
524
|
+
defaultdict(list)
|
525
|
+
)
|
526
|
+
active_pool_signature: Optional[PoolSignature] = None
|
527
|
+
active_pool_requests: list[PooledRequestState] = []
|
528
|
+
|
529
|
+
|
530
|
+
def _get_sampling_param_names() -> frozenset[str]:
|
531
|
+
global _SAMPLING_PARAM_NAMES
|
532
|
+
if _SAMPLING_PARAM_NAMES is None:
|
533
|
+
_SAMPLING_PARAM_NAMES = frozenset(
|
534
|
+
inspect.signature(SamplingParams).parameters.keys()
|
535
|
+
| frozenset({"response_format"})
|
536
|
+
)
|
537
|
+
return _SAMPLING_PARAM_NAMES
|
538
|
+
|
539
|
+
|
540
|
+
def create_pool_signature(
|
541
|
+
model_name: str,
|
542
|
+
request_type: Literal["chat", "completion"],
|
543
|
+
raw_request_params: dict[
|
544
|
+
str, AnyType
|
545
|
+
], # Contains all original request fields like temp, top_p, etc.
|
546
|
+
extra_body: Optional[dict[str, AnyType]],
|
547
|
+
) -> PoolSignature:
|
548
|
+
valid_sampling_keys = _get_sampling_param_names()
|
549
|
+
|
550
|
+
sig_sampling_items = []
|
551
|
+
key_openai_to_vllm_map = {
|
552
|
+
"temperature": "temperature",
|
553
|
+
"top_p": "top_p",
|
554
|
+
"n": "n",
|
555
|
+
"presence_penalty": "presence_penalty",
|
556
|
+
"frequency_penalty": "frequency_penalty",
|
557
|
+
"stop": "stop",
|
558
|
+
"seed": "seed",
|
559
|
+
"ignore_eos": "ignore_eos",
|
560
|
+
"min_tokens": "min_tokens",
|
561
|
+
"response_format": "response_format",
|
562
|
+
}
|
563
|
+
|
564
|
+
# Use defaults from Pydantic models if not provided in request
|
565
|
+
param_defaults_for_sig = {
|
566
|
+
"temperature": OAChatCompletionRequest.model_fields["temperature"].default,
|
567
|
+
"top_p": OAChatCompletionRequest.model_fields["top_p"].default,
|
568
|
+
"presence_penalty": OAChatCompletionRequest.model_fields[
|
569
|
+
"presence_penalty"
|
570
|
+
].default,
|
571
|
+
"frequency_penalty": OAChatCompletionRequest.model_fields[
|
572
|
+
"frequency_penalty"
|
573
|
+
].default,
|
574
|
+
"n": OAChatCompletionRequest.model_fields["n"].default,
|
575
|
+
"stop": OAChatCompletionRequest.model_fields["stop"].default,
|
576
|
+
"response_format": OAChatCompletionRequest.model_fields[
|
577
|
+
"response_format"
|
578
|
+
].default,
|
579
|
+
# stop: None, seed: None, ignore_eos: False, min_tokens: 0
|
580
|
+
}
|
581
|
+
|
582
|
+
for oa_key, vllm_key in key_openai_to_vllm_map.items():
|
583
|
+
if vllm_key in valid_sampling_keys:
|
584
|
+
value = raw_request_params.get(
|
585
|
+
oa_key, param_defaults_for_sig.get(oa_key)
|
586
|
+
) # Use default if not in request
|
587
|
+
# For 'stop', ensure it's a tuple if it's a list for hashability
|
588
|
+
if vllm_key == "stop" and isinstance(value, list):
|
589
|
+
value = tuple(value)
|
590
|
+
if vllm_key == "response_format" and isinstance(value, dict):
|
591
|
+
value = json.dumps(value, sort_keys=True)
|
592
|
+
if (
|
593
|
+
value is not None
|
594
|
+
): # Only add if value is meaningfully set or defaulted for signature
|
595
|
+
sig_sampling_items.append((vllm_key, value))
|
596
|
+
|
597
|
+
# Sort for stable signature
|
598
|
+
sig_sampling_items.sort(key=lambda item: item[0])
|
599
|
+
|
600
|
+
filtered_extra_body_items = []
|
601
|
+
if extra_body:
|
602
|
+
for k, v in sorted(extra_body.items()):
|
603
|
+
# Only include extra_body items that are NOT already part of vLLM SamplingParams
|
604
|
+
# to avoid them influencing signature if they are just alternative ways to pass standard params.
|
605
|
+
# This primarily targets things like 'guided_decoding_regex'.
|
606
|
+
if k not in valid_sampling_keys:
|
607
|
+
filtered_extra_body_items.append((k, v))
|
608
|
+
|
609
|
+
return PoolSignature(
|
610
|
+
model_name=model_name,
|
611
|
+
request_type=request_type,
|
612
|
+
sampling_params_tuple=tuple(sig_sampling_items),
|
613
|
+
extra_body_params_tuple=tuple(filtered_extra_body_items),
|
614
|
+
)
|
615
|
+
|
616
|
+
|
617
|
+
def llm_worker(
|
618
|
+
script_args: ScriptArguments,
|
619
|
+
data_parallel_rank: int,
|
620
|
+
master_port: int,
|
621
|
+
connection: MPConnection,
|
622
|
+
) -> None:
|
623
|
+
# Set required environment variables for DP to work with vLLM
|
624
|
+
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
625
|
+
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
626
|
+
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
627
|
+
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
628
|
+
|
629
|
+
llm = LLM(
|
630
|
+
model=script_args.model,
|
631
|
+
revision=script_args.revision,
|
632
|
+
tensor_parallel_size=script_args.tensor_parallel_size,
|
633
|
+
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
634
|
+
enforce_eager=script_args.enforce_eager,
|
635
|
+
dtype=script_args.dtype,
|
636
|
+
enable_prefix_caching=script_args.enable_prefix_caching,
|
637
|
+
max_model_len=script_args.max_model_len,
|
638
|
+
worker_extension_cls="arbor.server.services.inference.vllm_serve.WeightSyncWorkerExtension",
|
639
|
+
)
|
640
|
+
|
641
|
+
# Send ready signal to parent process
|
642
|
+
connection.send({"status": "ready"})
|
643
|
+
|
644
|
+
while True:
|
645
|
+
# Wait for commands from the parent process
|
646
|
+
try:
|
647
|
+
command = connection.recv()
|
648
|
+
except KeyboardInterrupt:
|
649
|
+
llm.collective_rpc(method="close_communicator")
|
650
|
+
break
|
651
|
+
|
652
|
+
# Handle commands
|
653
|
+
if command["type"] in ["call", "fire_and_forget"]:
|
654
|
+
method_name = command["method"]
|
655
|
+
args, kwargs = command.get("args", ()), command.get("kwargs", {})
|
656
|
+
|
657
|
+
# Add debugging
|
658
|
+
logger.debug(
|
659
|
+
f"[WORKER {data_parallel_rank}] Received command: {method_name}"
|
660
|
+
)
|
661
|
+
|
662
|
+
try:
|
663
|
+
method = getattr(llm, method_name)
|
664
|
+
logger.debug(
|
665
|
+
f"[WORKER {data_parallel_rank}] Calling {method_name} with kwargs keys: {list(kwargs.keys()) if kwargs else 'none'}"
|
666
|
+
)
|
667
|
+
|
668
|
+
# Call the method
|
669
|
+
result = method(*args, **kwargs)
|
670
|
+
|
671
|
+
logger.debug(
|
672
|
+
f"[WORKER {data_parallel_rank}] {method_name} completed, result type: {type(result)}"
|
673
|
+
)
|
674
|
+
|
675
|
+
if command["type"] == "call":
|
676
|
+
# Send result back
|
677
|
+
logger.debug(f"[WORKER {data_parallel_rank}] Sending result back")
|
678
|
+
connection.send(result)
|
679
|
+
logger.debug(f"[WORKER {data_parallel_rank}] Result sent")
|
680
|
+
except Exception as e:
|
681
|
+
logger.error(
|
682
|
+
f"[WORKER {data_parallel_rank}] Error in {method_name}: {e}",
|
683
|
+
exc_info=True,
|
684
|
+
)
|
685
|
+
if command["type"] == "call":
|
686
|
+
# Send error back as a special result
|
687
|
+
connection.send(
|
688
|
+
{"error": str(e), "traceback": traceback.format_exc()}
|
689
|
+
)
|
690
|
+
elif command["type"] == "shutdown":
|
691
|
+
logger.info(f"[WORKER {data_parallel_rank}] Received shutdown command")
|
692
|
+
break
|
693
|
+
|
694
|
+
|
695
|
+
def chunk_list(lst: list, n: int) -> list[list]:
|
696
|
+
"""
|
697
|
+
Split list `lst` into `n` evenly distributed sublists.
|
698
|
+
|
699
|
+
Example:
|
700
|
+
>>> chunk_list([1, 2, 3, 4, 5, 6], 2)
|
701
|
+
[[1, 2, 3], [4, 5, 6]]
|
702
|
+
>>> chunk_list([1, 2, 3, 4, 5, 6], 4)
|
703
|
+
[[1, 2], [3, 4], [5], [6]]
|
704
|
+
>>> chunk_list([1, 2, 3, 4, 5, 6], 8)
|
705
|
+
[[1], [2], [3], [4], [5], [6], [], []]
|
706
|
+
"""
|
707
|
+
k, r = divmod(len(lst), n)
|
708
|
+
return [lst[i * k + min(i, r) : (i + 1) * k + min(i + 1, r)] for i in range(n)]
|
709
|
+
|
710
|
+
|
711
|
+
async def batch_processing_loop(
|
712
|
+
script_args: ScriptArguments,
|
713
|
+
connections: list[AnyType],
|
714
|
+
queue: asyncio.Queue, # This queue now receives PooledRequestState
|
715
|
+
logger_instance: logging.Logger,
|
716
|
+
):
|
717
|
+
global pending_requests_by_signature, active_pool_signature, active_pool_requests
|
718
|
+
global active_generation_count, generation_count_lock
|
719
|
+
|
720
|
+
if not connections:
|
721
|
+
logger_instance.error(
|
722
|
+
"Batch Processor: No LLM workers available. Shutting down loop."
|
723
|
+
)
|
724
|
+
# We cannot process anything if connections are not there from the start.
|
725
|
+
# New requests added to queue will eventually timeout or error when picked by lifespan shutdown.
|
726
|
+
return
|
727
|
+
|
728
|
+
while True:
|
729
|
+
try:
|
730
|
+
# 1. Ingest new requests (non-blocking or short timeout)
|
731
|
+
# This part tries to fill pending_requests_by_signature from the main request_queue
|
732
|
+
try:
|
733
|
+
while True: # Loop to drain current items in asyncio.Queue
|
734
|
+
pooled_req_state: PooledRequestState = queue.get_nowait()
|
735
|
+
pending_requests_by_signature[
|
736
|
+
pooled_req_state.pool_signature
|
737
|
+
].append(pooled_req_state)
|
738
|
+
logger_instance.debug(
|
739
|
+
f"[LIFECYCLE] Request {pooled_req_state.request_id} enqueued to pending pool. "
|
740
|
+
f"max_tokens={pooled_req_state.effective_max_tokens}, "
|
741
|
+
f"signature={pooled_req_state.pool_signature}"
|
742
|
+
)
|
743
|
+
queue.task_done() # Signal that this item from main queue is taken
|
744
|
+
except asyncio.QueueEmpty:
|
745
|
+
pass # No new requests in the main queue right now
|
746
|
+
|
747
|
+
# 2. Activate a new pool if current is empty and pending requests exist
|
748
|
+
if not active_pool_requests and pending_requests_by_signature:
|
749
|
+
# Simple strategy: pick the first signature that has pending requests
|
750
|
+
# More sophisticated strategies (e.g. largest batch) could be implemented here
|
751
|
+
# items() gives a view, convert to list to pop
|
752
|
+
available_signatures = list(pending_requests_by_signature.keys())
|
753
|
+
if available_signatures:
|
754
|
+
active_pool_signature = available_signatures[0] # Pick one
|
755
|
+
active_pool_requests = pending_requests_by_signature.pop(
|
756
|
+
active_pool_signature
|
757
|
+
)
|
758
|
+
logger_instance.info(
|
759
|
+
f"[LIFECYCLE] Activated new pool with signature: {active_pool_signature}. Size: {len(active_pool_requests)}"
|
760
|
+
)
|
761
|
+
for req in active_pool_requests:
|
762
|
+
logger_instance.debug(
|
763
|
+
f"[LIFECYCLE] Request {req.request_id} moved to active pool"
|
764
|
+
)
|
765
|
+
|
766
|
+
# 3. Merge new matching requests into the active pool
|
767
|
+
if (
|
768
|
+
active_pool_signature
|
769
|
+
and active_pool_signature in pending_requests_by_signature
|
770
|
+
):
|
771
|
+
newly_matching_requests = pending_requests_by_signature.pop(
|
772
|
+
active_pool_signature
|
773
|
+
)
|
774
|
+
active_pool_requests.extend(newly_matching_requests)
|
775
|
+
logger_instance.info(
|
776
|
+
f"Merged {len(newly_matching_requests)} requests into active pool. New size: {len(active_pool_requests)}"
|
777
|
+
)
|
778
|
+
|
779
|
+
# 4. Process active pool chunk
|
780
|
+
if active_pool_requests:
|
781
|
+
# Take a sub-batch from the active pool
|
782
|
+
# If active_pool_requests is not empty, active_pool_signature must be set.
|
783
|
+
assert (
|
784
|
+
active_pool_signature is not None
|
785
|
+
), "active_pool_signature cannot be None if active_pool_requests is populated"
|
786
|
+
|
787
|
+
# Filter out already-completed requests before selecting sub-batch
|
788
|
+
available_requests = [
|
789
|
+
req for req in active_pool_requests if not req.is_complete
|
790
|
+
]
|
791
|
+
|
792
|
+
# Log the state of requests in the pool
|
793
|
+
logger_instance.debug(
|
794
|
+
f"Active pool has {len(active_pool_requests)} total requests, {len(available_requests)} available for processing"
|
795
|
+
)
|
796
|
+
for req in active_pool_requests:
|
797
|
+
logger_instance.debug(
|
798
|
+
f" Request {req.request_id}: tokens={req.generated_token_count}/{req.effective_max_tokens}, "
|
799
|
+
f"is_complete={req.is_complete}, finish_reason={req.finish_reason}"
|
800
|
+
)
|
801
|
+
|
802
|
+
if not available_requests:
|
803
|
+
# All requests in active pool are already completed, clear the pool
|
804
|
+
logger_instance.info(
|
805
|
+
f"All requests in active pool {active_pool_signature} are already completed. Clearing pool."
|
806
|
+
)
|
807
|
+
active_pool_requests.clear()
|
808
|
+
active_pool_signature = None
|
809
|
+
continue
|
810
|
+
|
811
|
+
sub_batch_to_process: list[PooledRequestState] = []
|
812
|
+
sub_batch_size = min(
|
813
|
+
len(available_requests), script_args.max_batch_size
|
814
|
+
)
|
815
|
+
sub_batch_to_process = available_requests[:sub_batch_size]
|
816
|
+
|
817
|
+
logger_instance.debug(
|
818
|
+
f"[BATCH_PROCESSOR] Processing sub-batch of {len(sub_batch_to_process)} for sig: {active_pool_signature}"
|
819
|
+
)
|
820
|
+
|
821
|
+
# Track active generation
|
822
|
+
async with generation_count_lock:
|
823
|
+
active_generation_count += 1
|
824
|
+
logger_instance.debug(
|
825
|
+
f"[BATCH_PROCESSOR] Active generation count increased to {active_generation_count}"
|
826
|
+
)
|
827
|
+
|
828
|
+
try:
|
829
|
+
# Prepare inputs for LLM
|
830
|
+
# All requests in sub_batch_to_process share active_pool_signature
|
831
|
+
# So, sampling params (except max_tokens) are the same.
|
832
|
+
|
833
|
+
# Construct SamplingParams from the active_pool_signature
|
834
|
+
# The signature stores param tuples. Convert back to dict for SamplingParams.
|
835
|
+
sig_sampling_dict = dict(
|
836
|
+
active_pool_signature.sampling_params_tuple
|
837
|
+
)
|
838
|
+
sig_extra_body_dict = dict(
|
839
|
+
active_pool_signature.extra_body_params_tuple
|
840
|
+
)
|
841
|
+
|
842
|
+
# Override 'n' to 1 for chunked generation as per design.
|
843
|
+
# Log if original 'n' was different.
|
844
|
+
original_n = sig_sampling_dict.get("n", 1)
|
845
|
+
if original_n != 1:
|
846
|
+
logger_instance.warning(
|
847
|
+
f"Pool {active_pool_signature}: Original 'n={original_n}' overridden to n=1 for chunked generation."
|
848
|
+
)
|
849
|
+
|
850
|
+
# Calculate the minimum tokens remaining across all requests in the batch
|
851
|
+
min_tokens_remaining = min(
|
852
|
+
req.effective_max_tokens - req.generated_token_count
|
853
|
+
for req in sub_batch_to_process
|
854
|
+
)
|
855
|
+
|
856
|
+
# Log token calculations
|
857
|
+
logger_instance.debug(
|
858
|
+
f"[CHUNK_CALC] Calculating chunk size for {len(sub_batch_to_process)} requests:"
|
859
|
+
)
|
860
|
+
for req in sub_batch_to_process:
|
861
|
+
tokens_left = (
|
862
|
+
req.effective_max_tokens - req.generated_token_count
|
863
|
+
)
|
864
|
+
logger_instance.debug(
|
865
|
+
f"[CHUNK_CALC] Request {req.request_id}: {req.generated_token_count}/{req.effective_max_tokens} tokens, {tokens_left} remaining"
|
866
|
+
)
|
867
|
+
|
868
|
+
# Limit chunk size to available room
|
869
|
+
chunk_size = min(script_args.token_chunk_size, min_tokens_remaining)
|
870
|
+
logger_instance.debug(
|
871
|
+
f"[CHUNK_CALC] Final chunk size: {chunk_size} (configured: {script_args.token_chunk_size}, min_remaining: {min_tokens_remaining})"
|
872
|
+
)
|
873
|
+
|
874
|
+
# CRITICAL: Ensure chunk_size is at least 1 to avoid vLLM issues
|
875
|
+
if chunk_size <= 0:
|
876
|
+
logger_instance.error(
|
877
|
+
f"Invalid chunk size {chunk_size} calculated. Min remaining: {min_tokens_remaining}"
|
878
|
+
)
|
879
|
+
# Mark all requests as complete if we can't generate any more tokens
|
880
|
+
for req_state in sub_batch_to_process:
|
881
|
+
if not req_state.is_complete:
|
882
|
+
req_state.finish_reason = "length"
|
883
|
+
logger_instance.info(
|
884
|
+
f"Request {req_state.request_id} marked complete due to no room for generation"
|
885
|
+
)
|
886
|
+
continue # Skip this iteration
|
887
|
+
|
888
|
+
logger_instance.debug(
|
889
|
+
f"Chunk size for batch: {chunk_size} (min remaining: {min_tokens_remaining}, configured: {script_args.token_chunk_size})"
|
890
|
+
)
|
891
|
+
|
892
|
+
# Create a new dict for **kwargs, excluding 'n' as it's set explicitly
|
893
|
+
kwargs_for_sampling_params = {
|
894
|
+
k: v for k, v in sig_sampling_dict.items() if k != "n"
|
895
|
+
}
|
896
|
+
|
897
|
+
guided_decoding = None
|
898
|
+
if "guided_decoding_regex" in sig_extra_body_dict:
|
899
|
+
guided_decoding = GuidedDecodingParams(
|
900
|
+
backend="outlines",
|
901
|
+
regex=sig_extra_body_dict["guided_decoding_regex"],
|
902
|
+
)
|
903
|
+
elif "response_format" in kwargs_for_sampling_params:
|
904
|
+
response_format_json = json.loads(
|
905
|
+
kwargs_for_sampling_params["response_format"]
|
906
|
+
)
|
907
|
+
response_format_type = response_format_json["type"]
|
908
|
+
logger_instance.info(
|
909
|
+
f"Response format param: {response_format_type}"
|
910
|
+
)
|
911
|
+
if response_format_type == "json_schema":
|
912
|
+
json_schema = response_format_json["json_schema"]
|
913
|
+
guided_decoding = GuidedDecodingParams(json=json_schema)
|
914
|
+
logger_instance.info(
|
915
|
+
f"Going with json_schema {json_schema}"
|
916
|
+
)
|
917
|
+
elif response_format_type == "json_object":
|
918
|
+
logger_instance.info(f"Going with json_object")
|
919
|
+
guided_decoding = GuidedDecodingParams(json_object=True)
|
920
|
+
else:
|
921
|
+
logger_instance.info(
|
922
|
+
f"Response format param provided that isn't supported: {response_format_type}"
|
923
|
+
)
|
924
|
+
# remove the response_format key because we can't pass it to sampling params
|
925
|
+
del kwargs_for_sampling_params["response_format"]
|
926
|
+
else:
|
927
|
+
# guided_decoding should be none if we don't want to use any guided decoding
|
928
|
+
pass
|
929
|
+
|
930
|
+
vllm_sampling_params = SamplingParams(
|
931
|
+
**kwargs_for_sampling_params,
|
932
|
+
n=1, # Generate one sequence continuation per request in the chunk
|
933
|
+
max_tokens=chunk_size, # Use calculated chunk size
|
934
|
+
# Ensure guided_decoding is correctly set up if present in extra_body
|
935
|
+
guided_decoding=guided_decoding,
|
936
|
+
# Remove any params from extra_body that might also be in SamplingParams if they were not filtered by create_pool_signature
|
937
|
+
**{
|
938
|
+
k: v
|
939
|
+
for k, v in sig_extra_body_dict.items()
|
940
|
+
if k in _get_sampling_param_names()
|
941
|
+
and k != "guided_decoding_regex"
|
942
|
+
},
|
943
|
+
)
|
944
|
+
|
945
|
+
# --- Bucket chat requests by first chunk vs continuing ---
|
946
|
+
first_chunk_inputs = []
|
947
|
+
first_chunk_states = []
|
948
|
+
continue_chunk_states = []
|
949
|
+
prompts_for_vllm = []
|
950
|
+
is_chat_pool = active_pool_signature.request_type == "chat"
|
951
|
+
|
952
|
+
for req_state in sub_batch_to_process:
|
953
|
+
if is_chat_pool:
|
954
|
+
current_messages = []
|
955
|
+
if req_state.original_chat_messages:
|
956
|
+
current_messages.extend(
|
957
|
+
[
|
958
|
+
m.model_dump()
|
959
|
+
for m in req_state.original_chat_messages
|
960
|
+
]
|
961
|
+
)
|
962
|
+
|
963
|
+
# For continuing generation, we need to ensure there's an assistant message to continue
|
964
|
+
if req_state.generated_token_count == 0:
|
965
|
+
# First chunk - ensure we have a valid message sequence ending with user
|
966
|
+
if not current_messages:
|
967
|
+
logger_instance.error(
|
968
|
+
f"Request {req_state.request_id} has no messages"
|
969
|
+
)
|
970
|
+
req_state.error = ValueError("No messages provided")
|
971
|
+
continue
|
972
|
+
if current_messages[-1]["role"] != "user":
|
973
|
+
logger_instance.error(
|
974
|
+
f"Request {req_state.request_id} last message is not from user for first chunk"
|
975
|
+
)
|
976
|
+
req_state.error = ValueError(
|
977
|
+
"Last message must be from user for first chunk"
|
978
|
+
)
|
979
|
+
continue
|
980
|
+
first_chunk_inputs.append(current_messages)
|
981
|
+
first_chunk_states.append(req_state)
|
982
|
+
else:
|
983
|
+
# Continuing chunk - add accumulated content as assistant message
|
984
|
+
if req_state.accumulated_content:
|
985
|
+
current_messages.append(
|
986
|
+
{
|
987
|
+
"role": "assistant",
|
988
|
+
"content": req_state.accumulated_content,
|
989
|
+
}
|
990
|
+
)
|
991
|
+
else:
|
992
|
+
# This should not happen - we should have content if we're continuing
|
993
|
+
logger_instance.error(
|
994
|
+
f"Request {req_state.request_id} has no accumulated content for continuation"
|
995
|
+
)
|
996
|
+
req_state.error = ValueError(
|
997
|
+
"No content to continue"
|
998
|
+
)
|
999
|
+
continue
|
1000
|
+
continue_chunk_states.append(req_state)
|
1001
|
+
else:
|
1002
|
+
if isinstance(req_state.original_prompt, str):
|
1003
|
+
current_prompt = req_state.original_prompt
|
1004
|
+
elif isinstance(req_state.original_prompt, list):
|
1005
|
+
current_prompt = (
|
1006
|
+
req_state.original_prompt[0]
|
1007
|
+
if req_state.original_prompt
|
1008
|
+
else ""
|
1009
|
+
)
|
1010
|
+
else:
|
1011
|
+
current_prompt = str(req_state.original_prompt or "")
|
1012
|
+
prompts_for_vllm.append(
|
1013
|
+
current_prompt + req_state.accumulated_content
|
1014
|
+
)
|
1015
|
+
|
1016
|
+
# Only process first-chunk chat requests in this tick, then continuing if no first-chunk left
|
1017
|
+
llm_results = []
|
1018
|
+
processed_states = []
|
1019
|
+
if is_chat_pool:
|
1020
|
+
loop = asyncio.get_running_loop()
|
1021
|
+
if first_chunk_inputs:
|
1022
|
+
# Filter out any already-completed requests from first_chunk_states
|
1023
|
+
active_first_states = []
|
1024
|
+
active_first_inputs = []
|
1025
|
+
for i, req_state in enumerate(first_chunk_states):
|
1026
|
+
if req_state.is_complete:
|
1027
|
+
logger_instance.debug(
|
1028
|
+
f"Skipping already-completed request {req_state.request_id} in first chunk processing"
|
1029
|
+
)
|
1030
|
+
continue
|
1031
|
+
active_first_states.append(req_state)
|
1032
|
+
active_first_inputs.append(first_chunk_inputs[i])
|
1033
|
+
|
1034
|
+
if not active_first_states:
|
1035
|
+
logger_instance.debug(
|
1036
|
+
"All first chunk requests are already completed, skipping LLM call"
|
1037
|
+
)
|
1038
|
+
processed_states = []
|
1039
|
+
llm_results = []
|
1040
|
+
else:
|
1041
|
+
flags = dict(
|
1042
|
+
add_generation_prompt=True,
|
1043
|
+
continue_final_message=False,
|
1044
|
+
)
|
1045
|
+
payload = {
|
1046
|
+
"type": "call",
|
1047
|
+
"method": "chat",
|
1048
|
+
"kwargs": {
|
1049
|
+
"messages": active_first_inputs,
|
1050
|
+
"sampling_params": vllm_sampling_params,
|
1051
|
+
**flags,
|
1052
|
+
},
|
1053
|
+
}
|
1054
|
+
logger_instance.debug(
|
1055
|
+
f"Sending first-chunk chat request to LLM with {len(active_first_inputs)} messages"
|
1056
|
+
)
|
1057
|
+
|
1058
|
+
worker_idx = -1 # Initialize to avoid unbound variable
|
1059
|
+
try:
|
1060
|
+
worker_idx, worker_conn = (
|
1061
|
+
await get_next_worker_connection(connections)
|
1062
|
+
)
|
1063
|
+
logger_instance.debug(
|
1064
|
+
f"Using worker {worker_idx} for first-chunk chat request"
|
1065
|
+
)
|
1066
|
+
llm_results = await async_send_and_recv(
|
1067
|
+
worker_conn, payload, timeout=60.0
|
1068
|
+
)
|
1069
|
+
logger_instance.debug(
|
1070
|
+
f"Received {len(llm_results)} results from LLM for first-chunk chat"
|
1071
|
+
)
|
1072
|
+
except asyncio.TimeoutError:
|
1073
|
+
logger_instance.error(
|
1074
|
+
f"Worker {worker_idx} timeout for first-chunk chat after 60s"
|
1075
|
+
)
|
1076
|
+
for req_state in active_first_states:
|
1077
|
+
req_state.error = TimeoutError(
|
1078
|
+
"Worker timeout during generation"
|
1079
|
+
)
|
1080
|
+
llm_results = []
|
1081
|
+
except Exception as e:
|
1082
|
+
logger_instance.error(
|
1083
|
+
f"Error calling LLM for first-chunk chat: {e}",
|
1084
|
+
exc_info=True,
|
1085
|
+
)
|
1086
|
+
for req_state in active_first_states:
|
1087
|
+
req_state.error = e
|
1088
|
+
llm_results = []
|
1089
|
+
processed_states = active_first_states
|
1090
|
+
elif continue_chunk_states:
|
1091
|
+
# No first-chunk requests, process continuing requests
|
1092
|
+
continue_chunk_inputs = []
|
1093
|
+
# Filter out any already-completed requests
|
1094
|
+
active_continue_states = []
|
1095
|
+
for req_state in continue_chunk_states:
|
1096
|
+
if req_state.is_complete:
|
1097
|
+
logger_instance.debug(
|
1098
|
+
f"Skipping already-completed request {req_state.request_id} in continue chunk processing"
|
1099
|
+
)
|
1100
|
+
continue
|
1101
|
+
active_continue_states.append(req_state)
|
1102
|
+
current_messages = []
|
1103
|
+
if req_state.original_chat_messages:
|
1104
|
+
current_messages.extend(
|
1105
|
+
[
|
1106
|
+
m.model_dump()
|
1107
|
+
for m in req_state.original_chat_messages
|
1108
|
+
]
|
1109
|
+
)
|
1110
|
+
|
1111
|
+
# Must have accumulated content to continue
|
1112
|
+
if not req_state.accumulated_content:
|
1113
|
+
logger_instance.error(
|
1114
|
+
f"Request {req_state.request_id} has no accumulated content for continuation"
|
1115
|
+
)
|
1116
|
+
req_state.error = ValueError(
|
1117
|
+
"No content to continue generation"
|
1118
|
+
)
|
1119
|
+
active_continue_states.remove(
|
1120
|
+
req_state
|
1121
|
+
) # Remove from active list
|
1122
|
+
continue
|
1123
|
+
|
1124
|
+
# Add the accumulated content as the assistant message to continue
|
1125
|
+
current_messages.append(
|
1126
|
+
{
|
1127
|
+
"role": "assistant",
|
1128
|
+
"content": req_state.accumulated_content,
|
1129
|
+
}
|
1130
|
+
)
|
1131
|
+
continue_chunk_inputs.append(current_messages)
|
1132
|
+
|
1133
|
+
if not active_continue_states:
|
1134
|
+
logger_instance.debug(
|
1135
|
+
"All continue chunk requests are already completed, skipping LLM call"
|
1136
|
+
)
|
1137
|
+
processed_states = []
|
1138
|
+
llm_results = []
|
1139
|
+
else:
|
1140
|
+
flags = dict(
|
1141
|
+
add_generation_prompt=False,
|
1142
|
+
continue_final_message=True,
|
1143
|
+
)
|
1144
|
+
payload = {
|
1145
|
+
"type": "call",
|
1146
|
+
"method": "chat",
|
1147
|
+
"kwargs": {
|
1148
|
+
"messages": continue_chunk_inputs,
|
1149
|
+
"sampling_params": vllm_sampling_params,
|
1150
|
+
**flags,
|
1151
|
+
},
|
1152
|
+
}
|
1153
|
+
logger_instance.debug(
|
1154
|
+
f"Sending continue-chunk chat request to LLM with {len(continue_chunk_inputs)} messages"
|
1155
|
+
)
|
1156
|
+
|
1157
|
+
worker_idx = -1 # Initialize to avoid unbound variable
|
1158
|
+
try:
|
1159
|
+
worker_idx, worker_conn = (
|
1160
|
+
await get_next_worker_connection(connections)
|
1161
|
+
)
|
1162
|
+
logger_instance.debug(
|
1163
|
+
f"Using worker {worker_idx} for continue-chunk chat request"
|
1164
|
+
)
|
1165
|
+
llm_results = await async_send_and_recv(
|
1166
|
+
worker_conn, payload, timeout=60.0
|
1167
|
+
)
|
1168
|
+
logger_instance.debug(
|
1169
|
+
f"Received {len(llm_results)} results from LLM for continue-chunk chat"
|
1170
|
+
)
|
1171
|
+
except asyncio.TimeoutError:
|
1172
|
+
logger_instance.error(
|
1173
|
+
f"Worker {worker_idx} timeout for continue-chunk chat after 60s"
|
1174
|
+
)
|
1175
|
+
for req_state in active_continue_states:
|
1176
|
+
req_state.error = TimeoutError(
|
1177
|
+
"Worker timeout during generation"
|
1178
|
+
)
|
1179
|
+
llm_results = []
|
1180
|
+
except Exception as e:
|
1181
|
+
logger_instance.error(
|
1182
|
+
f"Error calling LLM for continue-chunk chat: {e}",
|
1183
|
+
exc_info=True,
|
1184
|
+
)
|
1185
|
+
for req_state in active_continue_states:
|
1186
|
+
req_state.error = e
|
1187
|
+
llm_results = []
|
1188
|
+
processed_states = active_continue_states
|
1189
|
+
else:
|
1190
|
+
# No requests to process in this iteration
|
1191
|
+
logger_instance.debug(
|
1192
|
+
"No chat requests to process in this iteration"
|
1193
|
+
)
|
1194
|
+
processed_states = []
|
1195
|
+
llm_results = []
|
1196
|
+
else:
|
1197
|
+
# completion – unchanged
|
1198
|
+
loop = asyncio.get_running_loop()
|
1199
|
+
payload = {
|
1200
|
+
"type": "call",
|
1201
|
+
"method": "generate",
|
1202
|
+
"kwargs": {
|
1203
|
+
"prompts": prompts_for_vllm,
|
1204
|
+
"sampling_params": vllm_sampling_params,
|
1205
|
+
},
|
1206
|
+
}
|
1207
|
+
logger_instance.debug(
|
1208
|
+
f"Sending completion request to LLM with {len(prompts_for_vllm)} prompts"
|
1209
|
+
)
|
1210
|
+
worker_idx = -1 # Initialize to avoid unbound variable
|
1211
|
+
try:
|
1212
|
+
worker_idx, worker_conn = await get_next_worker_connection(
|
1213
|
+
connections
|
1214
|
+
)
|
1215
|
+
logger_instance.debug(
|
1216
|
+
f"Using worker {worker_idx} for completion request"
|
1217
|
+
)
|
1218
|
+
llm_results = await async_send_and_recv(
|
1219
|
+
worker_conn, payload, timeout=60.0
|
1220
|
+
)
|
1221
|
+
logger_instance.debug(
|
1222
|
+
f"Received {len(llm_results)} results from LLM for completion"
|
1223
|
+
)
|
1224
|
+
except asyncio.TimeoutError:
|
1225
|
+
logger_instance.error(
|
1226
|
+
f"Worker {worker_idx} timeout for completion after 60s"
|
1227
|
+
)
|
1228
|
+
for req_state in sub_batch_to_process:
|
1229
|
+
req_state.error = TimeoutError(
|
1230
|
+
"Worker timeout during generation"
|
1231
|
+
)
|
1232
|
+
llm_results = []
|
1233
|
+
except Exception as e:
|
1234
|
+
logger_instance.error(
|
1235
|
+
f"Error calling LLM for completion: {e}", exc_info=True
|
1236
|
+
)
|
1237
|
+
for req_state in sub_batch_to_process:
|
1238
|
+
req_state.error = e
|
1239
|
+
llm_results = []
|
1240
|
+
processed_states = sub_batch_to_process
|
1241
|
+
|
1242
|
+
# Now, update state for each request in the processed_states
|
1243
|
+
temp_failed_requests_in_sub_batch: list[PooledRequestState] = []
|
1244
|
+
|
1245
|
+
if is_chat_pool:
|
1246
|
+
if processed_states and (
|
1247
|
+
len(llm_results) != len(processed_states)
|
1248
|
+
):
|
1249
|
+
logger_instance.error(
|
1250
|
+
f"LLM result count mismatch. Expected {len(processed_states)}, got {len(llm_results)} for sig {active_pool_signature}. Marking affected requests in sub-batch as error."
|
1251
|
+
)
|
1252
|
+
for req_state in processed_states:
|
1253
|
+
if not req_state.completed_and_signaled:
|
1254
|
+
req_state.error = RuntimeError(
|
1255
|
+
"LLM result mismatch in batch processing."
|
1256
|
+
)
|
1257
|
+
req_state.finish_reason = "error"
|
1258
|
+
temp_failed_requests_in_sub_batch.append(req_state)
|
1259
|
+
else:
|
1260
|
+
real_idx = 0
|
1261
|
+
for req_state in processed_states:
|
1262
|
+
if req_state.completed_and_signaled:
|
1263
|
+
continue
|
1264
|
+
request_output = llm_results[real_idx]
|
1265
|
+
real_idx += 1
|
1266
|
+
if (
|
1267
|
+
not request_output.outputs
|
1268
|
+
or len(request_output.outputs) == 0
|
1269
|
+
):
|
1270
|
+
logger_instance.warning(
|
1271
|
+
f"Request {req_state.request_id} (idx {real_idx-1}) received no output from vLLM in chunk."
|
1272
|
+
)
|
1273
|
+
# This might happen if vLLM can't generate any tokens (e.g., due to constraints)
|
1274
|
+
# Mark as complete rather than error
|
1275
|
+
req_state.finish_reason = (
|
1276
|
+
"stop" # vLLM couldn't generate
|
1277
|
+
)
|
1278
|
+
logger_instance.info(
|
1279
|
+
f"Request {req_state.request_id} marked complete due to empty vLLM output"
|
1280
|
+
)
|
1281
|
+
continue
|
1282
|
+
completion_output = request_output.outputs[0]
|
1283
|
+
new_text_chunk = completion_output.text
|
1284
|
+
req_state.accumulated_content += new_text_chunk
|
1285
|
+
new_token_count = len(completion_output.token_ids)
|
1286
|
+
req_state.generated_token_count += new_token_count
|
1287
|
+
|
1288
|
+
# Store vLLM's finish reason but we'll interpret it carefully
|
1289
|
+
vllm_finish_reason = completion_output.finish_reason
|
1290
|
+
logger_instance.debug(
|
1291
|
+
f"[VLLM_RESPONSE] Request {req_state.request_id}: vLLM returned {new_token_count} tokens, finish_reason={vllm_finish_reason}"
|
1292
|
+
)
|
1293
|
+
|
1294
|
+
# Only update our finish_reason if it's meaningful
|
1295
|
+
if vllm_finish_reason == "length":
|
1296
|
+
# vLLM hit the chunk limit - only set our finish_reason if we're at our actual limit
|
1297
|
+
if (
|
1298
|
+
req_state.generated_token_count
|
1299
|
+
>= req_state.effective_max_tokens
|
1300
|
+
):
|
1301
|
+
req_state.finish_reason = "length"
|
1302
|
+
logger_instance.debug(
|
1303
|
+
f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='length' - hit actual limit"
|
1304
|
+
)
|
1305
|
+
else:
|
1306
|
+
# Don't set finish_reason - we can continue generating
|
1307
|
+
logger_instance.debug(
|
1308
|
+
f"[FINISH_REASON] Request {req_state.request_id}: Ignoring vLLM's finish_reason='length' - only at chunk limit"
|
1309
|
+
)
|
1310
|
+
elif vllm_finish_reason is not None:
|
1311
|
+
# Other finish reasons (stop, eos_token, etc.) are real completions
|
1312
|
+
req_state.finish_reason = vllm_finish_reason
|
1313
|
+
logger_instance.debug(
|
1314
|
+
f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='{vllm_finish_reason}' from vLLM"
|
1315
|
+
)
|
1316
|
+
|
1317
|
+
# Log detailed state for debugging
|
1318
|
+
logger_instance.debug(
|
1319
|
+
f"Request {req_state.request_id} chunk result: "
|
1320
|
+
f"new_tokens={new_token_count}, total_tokens={req_state.generated_token_count}, "
|
1321
|
+
f"finish_reason={req_state.finish_reason}, chunk_text_len={len(new_text_chunk)}"
|
1322
|
+
)
|
1323
|
+
|
1324
|
+
# Check if generation has stopped
|
1325
|
+
if new_token_count < chunk_size:
|
1326
|
+
# Incomplete chunk indicates generation should stop
|
1327
|
+
logger_instance.info(
|
1328
|
+
f"Request {req_state.request_id} generated incomplete chunk. "
|
1329
|
+
f"Generated {new_token_count}/{chunk_size} tokens in chunk. "
|
1330
|
+
f"Marking as complete to prevent doom loop."
|
1331
|
+
)
|
1332
|
+
# Set finish reason if not already set by vLLM
|
1333
|
+
if req_state.finish_reason is None:
|
1334
|
+
req_state.finish_reason = (
|
1335
|
+
"stop" # Generation naturally stopped
|
1336
|
+
)
|
1337
|
+
logger_instance.debug(
|
1338
|
+
f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='stop' due to incomplete chunk"
|
1339
|
+
)
|
1340
|
+
|
1341
|
+
# Log current state
|
1342
|
+
logger_instance.debug(
|
1343
|
+
f"Request {req_state.request_id} chunk processed. Tokens in chunk: {new_token_count}, total: {req_state.generated_token_count}, is_complete: {req_state.is_complete}"
|
1344
|
+
)
|
1345
|
+
else:
|
1346
|
+
# completion – unchanged
|
1347
|
+
if len(llm_results) != len(processed_states):
|
1348
|
+
logger_instance.error(
|
1349
|
+
f"LLM result count mismatch. Expected {len(processed_states)}, got {len(llm_results)} for sig {active_pool_signature}. Marking affected requests in sub-batch as error."
|
1350
|
+
)
|
1351
|
+
for req_state in processed_states:
|
1352
|
+
if not req_state.completed_and_signaled:
|
1353
|
+
req_state.error = RuntimeError(
|
1354
|
+
"LLM result mismatch in batch processing."
|
1355
|
+
)
|
1356
|
+
req_state.finish_reason = "error"
|
1357
|
+
temp_failed_requests_in_sub_batch.append(req_state)
|
1358
|
+
else:
|
1359
|
+
real_idx = 0
|
1360
|
+
for req_state in processed_states:
|
1361
|
+
if req_state.completed_and_signaled:
|
1362
|
+
continue
|
1363
|
+
request_output = llm_results[real_idx]
|
1364
|
+
real_idx += 1
|
1365
|
+
if (
|
1366
|
+
not request_output.outputs
|
1367
|
+
or len(request_output.outputs) == 0
|
1368
|
+
):
|
1369
|
+
logger_instance.warning(
|
1370
|
+
f"Request {req_state.request_id} (idx {real_idx-1}) received no output from vLLM in chunk."
|
1371
|
+
)
|
1372
|
+
# This might happen if vLLM can't generate any tokens (e.g., due to constraints)
|
1373
|
+
# Mark as complete rather than error
|
1374
|
+
req_state.finish_reason = (
|
1375
|
+
"stop" # vLLM couldn't generate
|
1376
|
+
)
|
1377
|
+
logger_instance.info(
|
1378
|
+
f"Request {req_state.request_id} marked complete due to empty vLLM output"
|
1379
|
+
)
|
1380
|
+
continue
|
1381
|
+
completion_output = request_output.outputs[0]
|
1382
|
+
new_text_chunk = completion_output.text
|
1383
|
+
req_state.accumulated_content += new_text_chunk
|
1384
|
+
new_token_count = len(completion_output.token_ids)
|
1385
|
+
req_state.generated_token_count += new_token_count
|
1386
|
+
|
1387
|
+
# Store vLLM's finish reason but we'll interpret it carefully
|
1388
|
+
vllm_finish_reason = completion_output.finish_reason
|
1389
|
+
logger_instance.debug(
|
1390
|
+
f"[VLLM_RESPONSE] Request {req_state.request_id}: vLLM returned {new_token_count} tokens, finish_reason={vllm_finish_reason}"
|
1391
|
+
)
|
1392
|
+
|
1393
|
+
# Only update our finish_reason if it's meaningful
|
1394
|
+
if vllm_finish_reason == "length":
|
1395
|
+
# vLLM hit the chunk limit - only set our finish_reason if we're at our actual limit
|
1396
|
+
if (
|
1397
|
+
req_state.generated_token_count
|
1398
|
+
>= req_state.effective_max_tokens
|
1399
|
+
):
|
1400
|
+
req_state.finish_reason = "length"
|
1401
|
+
logger_instance.debug(
|
1402
|
+
f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='length' - hit actual limit"
|
1403
|
+
)
|
1404
|
+
else:
|
1405
|
+
# Don't set finish_reason - we can continue generating
|
1406
|
+
logger_instance.debug(
|
1407
|
+
f"[FINISH_REASON] Request {req_state.request_id}: Ignoring vLLM's finish_reason='length' - only at chunk limit"
|
1408
|
+
)
|
1409
|
+
elif vllm_finish_reason is not None:
|
1410
|
+
# Other finish reasons (stop, eos_token, etc.) are real completions
|
1411
|
+
req_state.finish_reason = vllm_finish_reason
|
1412
|
+
logger_instance.debug(
|
1413
|
+
f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='{vllm_finish_reason}' from vLLM"
|
1414
|
+
)
|
1415
|
+
|
1416
|
+
# Log detailed state for debugging
|
1417
|
+
logger_instance.debug(
|
1418
|
+
f"Request {req_state.request_id} chunk result: "
|
1419
|
+
f"new_tokens={new_token_count}, total_tokens={req_state.generated_token_count}, "
|
1420
|
+
f"finish_reason={req_state.finish_reason}, chunk_text_len={len(new_text_chunk)}"
|
1421
|
+
)
|
1422
|
+
|
1423
|
+
# Check if generation has stopped
|
1424
|
+
if new_token_count < chunk_size:
|
1425
|
+
# Incomplete chunk indicates generation should stop
|
1426
|
+
logger_instance.info(
|
1427
|
+
f"Request {req_state.request_id} generated incomplete chunk. "
|
1428
|
+
f"Generated {new_token_count}/{chunk_size} tokens in chunk. "
|
1429
|
+
f"Marking as complete to prevent doom loop."
|
1430
|
+
)
|
1431
|
+
# Set finish reason if not already set by vLLM
|
1432
|
+
if req_state.finish_reason is None:
|
1433
|
+
req_state.finish_reason = (
|
1434
|
+
"stop" # Generation naturally stopped
|
1435
|
+
)
|
1436
|
+
logger_instance.debug(
|
1437
|
+
f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='stop' due to incomplete chunk"
|
1438
|
+
)
|
1439
|
+
|
1440
|
+
# Log current state
|
1441
|
+
logger_instance.debug(
|
1442
|
+
f"Request {req_state.request_id} chunk processed. Tokens in chunk: {new_token_count}, total: {req_state.generated_token_count}, is_complete: {req_state.is_complete}"
|
1443
|
+
)
|
1444
|
+
|
1445
|
+
# Now, handle all finished or errored requests from this sub_batch
|
1446
|
+
# These need to be removed from active_pool_requests and have their events set.
|
1447
|
+
completed_in_sub_batch: list[PooledRequestState] = []
|
1448
|
+
remaining_in_sub_batch: list[PooledRequestState] = []
|
1449
|
+
|
1450
|
+
# Iterate over sub_batch_to_process to decide their fate
|
1451
|
+
updated_active_pool = []
|
1452
|
+
|
1453
|
+
# Create a set of request_ids from sub_batch_to_process for quick lookup
|
1454
|
+
sub_batch_ids = {r.request_id for r in sub_batch_to_process}
|
1455
|
+
|
1456
|
+
for req_state in active_pool_requests: # Iterate main active pool
|
1457
|
+
if req_state.request_id not in sub_batch_ids:
|
1458
|
+
updated_active_pool.append(
|
1459
|
+
req_state
|
1460
|
+
) # Keep if not in current sub-batch
|
1461
|
+
continue
|
1462
|
+
|
1463
|
+
# req_state is from the current sub_batch. Check its status.
|
1464
|
+
logger_instance.debug(
|
1465
|
+
f"[COMPLETION_CHECK] Checking request {req_state.request_id}: "
|
1466
|
+
f"generated={req_state.generated_token_count}/{req_state.effective_max_tokens}, "
|
1467
|
+
f"finish_reason={req_state.finish_reason}, is_complete={req_state.is_complete}"
|
1468
|
+
)
|
1469
|
+
|
1470
|
+
if (
|
1471
|
+
req_state.is_complete
|
1472
|
+
and not req_state.completed_and_signaled
|
1473
|
+
):
|
1474
|
+
# Request is complete but not yet signaled
|
1475
|
+
completed_in_sub_batch.append(req_state)
|
1476
|
+
logger_instance.info(
|
1477
|
+
f"[LIFECYCLE] Request {req_state.request_id} is complete and will be finalized. "
|
1478
|
+
f"Generated {req_state.generated_token_count} tokens, finish_reason={req_state.finish_reason}"
|
1479
|
+
)
|
1480
|
+
elif not req_state.is_complete:
|
1481
|
+
# Request is not complete, keep for next chunk
|
1482
|
+
updated_active_pool.append(req_state)
|
1483
|
+
logger_instance.debug(
|
1484
|
+
f"[LIFECYCLE] Request {req_state.request_id} is not complete, keeping for next chunk"
|
1485
|
+
)
|
1486
|
+
# If already signaled (completed_and_signaled is True), don't re-add or re-signal
|
1487
|
+
|
1488
|
+
active_pool_requests = updated_active_pool
|
1489
|
+
if not active_pool_requests:
|
1490
|
+
# Store signature before setting to None for logging
|
1491
|
+
deactivated_signature = active_pool_signature
|
1492
|
+
active_pool_signature = None # Deactivate pool if empty
|
1493
|
+
logger_instance.info(
|
1494
|
+
f"Deactivated pool {deactivated_signature} as it is now empty."
|
1495
|
+
)
|
1496
|
+
|
1497
|
+
for req_state in completed_in_sub_batch:
|
1498
|
+
if req_state.completed_and_signaled:
|
1499
|
+
continue # Already handled
|
1500
|
+
|
1501
|
+
response_content: (
|
1502
|
+
OAChatCompletionResponse
|
1503
|
+
| OACompletionResponse
|
1504
|
+
| JSONResponse
|
1505
|
+
)
|
1506
|
+
if req_state.error:
|
1507
|
+
logger_instance.error(
|
1508
|
+
f"Request {req_state.request_id} failed with error: {req_state.error}"
|
1509
|
+
)
|
1510
|
+
response_content = JSONResponse(
|
1511
|
+
status_code=500,
|
1512
|
+
content={
|
1513
|
+
"error": f"Processing error: {str(req_state.error)}",
|
1514
|
+
"request_id": req_state.request_id,
|
1515
|
+
},
|
1516
|
+
)
|
1517
|
+
elif req_state.request_type == "chat":
|
1518
|
+
final_choices = [
|
1519
|
+
OAChatChoice(
|
1520
|
+
index=0,
|
1521
|
+
message=OAChatMessage(
|
1522
|
+
role="assistant",
|
1523
|
+
content=req_state.accumulated_content,
|
1524
|
+
),
|
1525
|
+
finish_reason=req_state.finish_reason,
|
1526
|
+
)
|
1527
|
+
]
|
1528
|
+
response_content = OAChatCompletionResponse(
|
1529
|
+
id=f"chatcmpl-{uuid4().hex}", # Use original request_id if available? For now, new UUID.
|
1530
|
+
created=int(datetime.now(tz=timezone.utc).timestamp()),
|
1531
|
+
model=req_state.original_request.model,
|
1532
|
+
choices=final_choices,
|
1533
|
+
)
|
1534
|
+
else: # Completion
|
1535
|
+
final_choices = [
|
1536
|
+
OACompletionChoice(
|
1537
|
+
index=0,
|
1538
|
+
text=req_state.accumulated_content,
|
1539
|
+
finish_reason=req_state.finish_reason,
|
1540
|
+
)
|
1541
|
+
]
|
1542
|
+
response_content = OACompletionResponse(
|
1543
|
+
id=f"cmpl-{uuid4().hex}",
|
1544
|
+
created=int(datetime.now(tz=timezone.utc).timestamp()),
|
1545
|
+
model=req_state.original_request.model,
|
1546
|
+
choices=final_choices,
|
1547
|
+
)
|
1548
|
+
|
1549
|
+
req_state.result_container[0] = response_content
|
1550
|
+
req_state.completion_event.set()
|
1551
|
+
req_state.completed_and_signaled = True
|
1552
|
+
|
1553
|
+
finally:
|
1554
|
+
# Always decrement active generation count
|
1555
|
+
async with generation_count_lock:
|
1556
|
+
active_generation_count -= 1
|
1557
|
+
logger_instance.debug(
|
1558
|
+
f"[BATCH_PROCESSOR] Active generation count decreased to {active_generation_count}"
|
1559
|
+
)
|
1560
|
+
|
1561
|
+
else: # No active pool
|
1562
|
+
await asyncio.sleep(
|
1563
|
+
0.01
|
1564
|
+
) # Small sleep if no active pool and pending queue was empty
|
1565
|
+
|
1566
|
+
except asyncio.CancelledError:
|
1567
|
+
logger_instance.info("Batch processing loop cancelled.")
|
1568
|
+
all_requests_to_cancel = list(active_pool_requests)
|
1569
|
+
active_pool_requests.clear()
|
1570
|
+
active_pool_signature = None
|
1571
|
+
for sig_list in pending_requests_by_signature.values():
|
1572
|
+
all_requests_to_cancel.extend(sig_list)
|
1573
|
+
pending_requests_by_signature.clear()
|
1574
|
+
|
1575
|
+
for req_state in all_requests_to_cancel:
|
1576
|
+
if not req_state.completed_and_signaled:
|
1577
|
+
req_state.result_container[0] = JSONResponse(
|
1578
|
+
status_code=503,
|
1579
|
+
content={"error": "Server shutting down, request cancelled."},
|
1580
|
+
)
|
1581
|
+
req_state.completion_event.set()
|
1582
|
+
req_state.completed_and_signaled = True
|
1583
|
+
break
|
1584
|
+
except Exception as e:
|
1585
|
+
logger_instance.error(
|
1586
|
+
f"Critical error in batch processing loop: {e}", exc_info=True
|
1587
|
+
)
|
1588
|
+
all_requests_to_fail = list(active_pool_requests)
|
1589
|
+
active_pool_requests.clear()
|
1590
|
+
active_pool_signature = None
|
1591
|
+
for sig_list in pending_requests_by_signature.values():
|
1592
|
+
all_requests_to_fail.extend(sig_list)
|
1593
|
+
pending_requests_by_signature.clear()
|
1594
|
+
|
1595
|
+
for req_state in all_requests_to_fail:
|
1596
|
+
if not req_state.completed_and_signaled:
|
1597
|
+
req_state.result_container[0] = JSONResponse(
|
1598
|
+
status_code=500,
|
1599
|
+
content={"error": f"Critical batch processor error: {str(e)}"},
|
1600
|
+
)
|
1601
|
+
req_state.completion_event.set()
|
1602
|
+
req_state.completed_and_signaled = True
|
1603
|
+
await asyncio.sleep(1) # Pause before retrying loop
|
1604
|
+
|
1605
|
+
|
1606
|
+
def main(script_args: ScriptArguments):
|
1607
|
+
global request_queue, batch_processor_task # Allow lifespan to assign to these
|
1608
|
+
|
1609
|
+
# Spawn dp workers, and setup pipes for communication
|
1610
|
+
master_port = get_open_port()
|
1611
|
+
connections: list[AnyType] = (
|
1612
|
+
[]
|
1613
|
+
) # Use Any type to avoid PipeConnection vs Connection mismatch
|
1614
|
+
processes = []
|
1615
|
+
for data_parallel_rank in range(script_args.data_parallel_size):
|
1616
|
+
# Use duplex=True to allow bidirectional communication
|
1617
|
+
# This is needed for "call" type commands that expect responses
|
1618
|
+
parent_connection, child_connection = Pipe(duplex=True)
|
1619
|
+
process = Process(
|
1620
|
+
target=llm_worker,
|
1621
|
+
args=(script_args, data_parallel_rank, master_port, child_connection),
|
1622
|
+
)
|
1623
|
+
process.start()
|
1624
|
+
connections.append(parent_connection)
|
1625
|
+
processes.append(process)
|
1626
|
+
|
1627
|
+
@asynccontextmanager
|
1628
|
+
async def lifespan(app: FastAPI):
|
1629
|
+
nonlocal processes # Capture from outer scope
|
1630
|
+
global request_queue, batch_processor_task # Defined at module level
|
1631
|
+
|
1632
|
+
logger.info(
|
1633
|
+
f"Lifespan: Waiting for {script_args.data_parallel_size} LLM worker(s) to be ready..."
|
1634
|
+
)
|
1635
|
+
ready_connections = set()
|
1636
|
+
|
1637
|
+
# Timeout for waiting for workers to get ready (e.g., 5 minutes)
|
1638
|
+
timeout_seconds = 300
|
1639
|
+
start_wait_time = time.time()
|
1640
|
+
|
1641
|
+
while len(ready_connections) < script_args.data_parallel_size:
|
1642
|
+
if time.time() - start_wait_time > timeout_seconds:
|
1643
|
+
logger.error(
|
1644
|
+
f"Lifespan: Timeout waiting for all LLM workers. Expected {script_args.data_parallel_size}, got {len(ready_connections)} ready."
|
1645
|
+
)
|
1646
|
+
raise RuntimeError("LLM workers failed to initialize in time")
|
1647
|
+
|
1648
|
+
for i, connection in enumerate(connections):
|
1649
|
+
if connection not in ready_connections:
|
1650
|
+
# Use poll() with a short timeout to avoid blocking indefinitely if a worker is stuck
|
1651
|
+
if connection.poll(
|
1652
|
+
0.1
|
1653
|
+
): # Check if data is available, with a 0.1s timeout
|
1654
|
+
try:
|
1655
|
+
msg = connection.recv()
|
1656
|
+
logger.info(
|
1657
|
+
f"Lifespan: Received message from worker {i}: {msg}"
|
1658
|
+
)
|
1659
|
+
if isinstance(msg, dict) and msg.get("status") == "ready":
|
1660
|
+
logger.info(f"Lifespan: LLM worker {i} reported ready.")
|
1661
|
+
ready_connections.add(connection)
|
1662
|
+
else:
|
1663
|
+
logger.warning(
|
1664
|
+
f"Lifespan: Received unexpected message from worker {i}: {msg}"
|
1665
|
+
)
|
1666
|
+
except Exception as e:
|
1667
|
+
logger.error(
|
1668
|
+
f"Lifespan: Error receiving message from worker {i}: {e}"
|
1669
|
+
)
|
1670
|
+
|
1671
|
+
if len(ready_connections) < script_args.data_parallel_size:
|
1672
|
+
time.sleep(
|
1673
|
+
0.5
|
1674
|
+
) # Brief sleep to avoid busy-waiting if not all workers are ready yet
|
1675
|
+
|
1676
|
+
if len(ready_connections) == script_args.data_parallel_size:
|
1677
|
+
logger.info(
|
1678
|
+
f"Lifespan: All {script_args.data_parallel_size} LLM worker(s) are ready. Proceeding to yield."
|
1679
|
+
)
|
1680
|
+
# Initialize request queue and start batch processor task
|
1681
|
+
request_queue = asyncio.Queue()
|
1682
|
+
logger.info(
|
1683
|
+
"Lifespan: Initialized request queue for batched chat completions."
|
1684
|
+
)
|
1685
|
+
batch_processor_task = asyncio.create_task(
|
1686
|
+
batch_processing_loop(script_args, connections, request_queue, logger)
|
1687
|
+
)
|
1688
|
+
logger.info("Lifespan: Started batch processing task for chat completions.")
|
1689
|
+
else:
|
1690
|
+
logger.error(
|
1691
|
+
f"Lifespan: Not all LLM workers became ready. Expected {script_args.data_parallel_size}, got {len(ready_connections)}. Uvicorn might not function correctly. Batch processor NOT started."
|
1692
|
+
)
|
1693
|
+
|
1694
|
+
yield
|
1695
|
+
logger.info("Lifespan: Uvicorn server is shutting down. Cleaning up resources.")
|
1696
|
+
|
1697
|
+
if batch_processor_task is not None and not batch_processor_task.done():
|
1698
|
+
logger.info("Lifespan: Cancelling batch processor task...")
|
1699
|
+
batch_processor_task.cancel()
|
1700
|
+
try:
|
1701
|
+
await batch_processor_task
|
1702
|
+
logger.info("Lifespan: Batch processor task finished.")
|
1703
|
+
except asyncio.CancelledError:
|
1704
|
+
logger.info("Lifespan: Batch processor task was cancelled as expected.")
|
1705
|
+
except Exception as e:
|
1706
|
+
logger.error(
|
1707
|
+
f"Lifespan: Exception during batch processor task shutdown: {e}",
|
1708
|
+
exc_info=True,
|
1709
|
+
)
|
1710
|
+
|
1711
|
+
# Wait for processes to terminate
|
1712
|
+
for process in processes:
|
1713
|
+
process.join(timeout=10) # Wait for 10 seconds for the process to terminate
|
1714
|
+
if process.is_alive():
|
1715
|
+
logger.warning(
|
1716
|
+
f"Process {process} is still alive after 10 seconds, attempting to terminate..."
|
1717
|
+
)
|
1718
|
+
process.terminate()
|
1719
|
+
process.join() # ensure process termination after calling terminate()
|
1720
|
+
|
1721
|
+
app = FastAPI(lifespan=lifespan)
|
1722
|
+
|
1723
|
+
# Define the endpoints for the model server
|
1724
|
+
@app.get("/health/")
|
1725
|
+
async def health():
|
1726
|
+
"""
|
1727
|
+
Health check endpoint to verify that the server is running.
|
1728
|
+
"""
|
1729
|
+
return {"status": "ok"}
|
1730
|
+
|
1731
|
+
@app.get("/get_world_size/")
|
1732
|
+
async def get_world_size():
|
1733
|
+
"""
|
1734
|
+
Retrieves the world size of the LLM engine, which is `tensor_parallel_size * data_parallel_size`.
|
1735
|
+
|
1736
|
+
Returns:
|
1737
|
+
`dict`:
|
1738
|
+
A dictionary containing the world size.
|
1739
|
+
|
1740
|
+
Example response:
|
1741
|
+
```json
|
1742
|
+
{"world_size": 8}
|
1743
|
+
```
|
1744
|
+
"""
|
1745
|
+
return {
|
1746
|
+
"world_size": script_args.tensor_parallel_size
|
1747
|
+
* script_args.data_parallel_size
|
1748
|
+
}
|
1749
|
+
|
1750
|
+
# -------- OpenAI-chat ---------- #
|
1751
|
+
@app.post("/v1/chat/completions", response_model=OAChatCompletionResponse)
|
1752
|
+
async def openai_chat(req: OAChatCompletionRequest):
|
1753
|
+
global request_queue
|
1754
|
+
|
1755
|
+
if request_queue is None:
|
1756
|
+
logger.error("/v1/chat/completions: Request queue not initialized.")
|
1757
|
+
return JSONResponse(
|
1758
|
+
status_code=503,
|
1759
|
+
content={
|
1760
|
+
"error": "Server not ready, batch processing queue not initialized."
|
1761
|
+
},
|
1762
|
+
)
|
1763
|
+
|
1764
|
+
request_id = f"chatcmpl-{uuid4().hex}"
|
1765
|
+
logger.debug(f"Received chat request {request_id}, model: {req.model}")
|
1766
|
+
|
1767
|
+
# Create signature for pooling
|
1768
|
+
# The OAChatCompletionRequest fields are: model, messages, temperature, top_p, max_tokens, stream, extra_body
|
1769
|
+
# We need to pass the relevant ones to create_pool_signature
|
1770
|
+
raw_params_for_sig = {
|
1771
|
+
"temperature": req.temperature,
|
1772
|
+
"top_p": req.top_p,
|
1773
|
+
"presence_penalty": req.presence_penalty,
|
1774
|
+
"frequency_penalty": req.frequency_penalty,
|
1775
|
+
# "n" is not in OAChatCompletionRequest, defaults to 1 for chat in OpenAI spec
|
1776
|
+
"n": 1,
|
1777
|
+
"response_format": req.response_format,
|
1778
|
+
}
|
1779
|
+
# Add other SamplingParams-mappable fields if they were part of OAChatCompletionRequest
|
1780
|
+
# For example, if we added 'stop', 'presence_penalty' etc. to OAChatCompletionRequest.
|
1781
|
+
# For now, the above are the main ones.
|
1782
|
+
|
1783
|
+
default_max = OAChatCompletionRequest.model_fields["max_tokens"].default
|
1784
|
+
effective_max_tokens = req.max_tokens or default_max
|
1785
|
+
|
1786
|
+
pool_sig = create_pool_signature(
|
1787
|
+
model_name=req.model,
|
1788
|
+
request_type="chat",
|
1789
|
+
raw_request_params=raw_params_for_sig,
|
1790
|
+
extra_body=req.extra_body,
|
1791
|
+
)
|
1792
|
+
|
1793
|
+
completion_event = asyncio.Event()
|
1794
|
+
result_container = [None]
|
1795
|
+
|
1796
|
+
pooled_state = PooledRequestState(
|
1797
|
+
original_request=req,
|
1798
|
+
completion_event=completion_event,
|
1799
|
+
result_container=result_container,
|
1800
|
+
request_id=request_id,
|
1801
|
+
request_type="chat",
|
1802
|
+
pool_signature=pool_sig,
|
1803
|
+
effective_max_tokens=effective_max_tokens,
|
1804
|
+
original_chat_messages=req.messages, # Store original messages
|
1805
|
+
)
|
1806
|
+
|
1807
|
+
logger.info(
|
1808
|
+
f"[LIFECYCLE] Created chat request {request_id}: max_tokens={effective_max_tokens}, "
|
1809
|
+
f"messages={len(req.messages)}, model={req.model}"
|
1810
|
+
)
|
1811
|
+
|
1812
|
+
try:
|
1813
|
+
await request_queue.put(pooled_state)
|
1814
|
+
logger.debug(f"[LIFECYCLE] Request {request_id} successfully queued")
|
1815
|
+
except Exception as e:
|
1816
|
+
logger.error(f"Enqueueing error for {request_id}: {e}", exc_info=True)
|
1817
|
+
return JSONResponse(
|
1818
|
+
status_code=500,
|
1819
|
+
content={"error": "Internal server error while queueing request."},
|
1820
|
+
)
|
1821
|
+
|
1822
|
+
try:
|
1823
|
+
await asyncio.wait_for(
|
1824
|
+
completion_event.wait(),
|
1825
|
+
timeout=script_args.batch_request_timeout_seconds,
|
1826
|
+
)
|
1827
|
+
except asyncio.TimeoutError:
|
1828
|
+
logger.error(f"Timeout for chat request {request_id} (model {req.model}).")
|
1829
|
+
pooled_state.timed_out = True
|
1830
|
+
pooled_state.completed_and_signaled = True
|
1831
|
+
pooled_state.completion_event.set()
|
1832
|
+
return JSONResponse(
|
1833
|
+
status_code=504, content={"error": "Request timed out."}
|
1834
|
+
)
|
1835
|
+
except Exception as e:
|
1836
|
+
logger.error(
|
1837
|
+
f"Error waiting for completion event for {request_id}: {e}",
|
1838
|
+
exc_info=True,
|
1839
|
+
)
|
1840
|
+
return JSONResponse(
|
1841
|
+
status_code=500,
|
1842
|
+
content={
|
1843
|
+
"error": "Internal server error while waiting for completion."
|
1844
|
+
},
|
1845
|
+
)
|
1846
|
+
|
1847
|
+
response_data = result_container[0]
|
1848
|
+
if (
|
1849
|
+
response_data is None
|
1850
|
+
): # Should ideally be set to an error by processor if timeout internally
|
1851
|
+
logger.error(
|
1852
|
+
f"No result for {request_id} (model {req.model}) after event set. Internal error."
|
1853
|
+
)
|
1854
|
+
return JSONResponse(
|
1855
|
+
status_code=500,
|
1856
|
+
content={"error": "Internal error: No result from processor."},
|
1857
|
+
)
|
1858
|
+
|
1859
|
+
if isinstance(response_data, JSONResponse):
|
1860
|
+
return response_data
|
1861
|
+
|
1862
|
+
if isinstance(response_data, OAChatCompletionResponse):
|
1863
|
+
# Must return JSONResponse for FastAPI
|
1864
|
+
return JSONResponse(response_data.model_dump())
|
1865
|
+
else:
|
1866
|
+
logger.error(
|
1867
|
+
f"Unexpected result type {type(response_data)} for {request_id}."
|
1868
|
+
)
|
1869
|
+
return JSONResponse(
|
1870
|
+
status_code=500,
|
1871
|
+
content={"error": "Internal error: Unexpected result format."},
|
1872
|
+
)
|
1873
|
+
|
1874
|
+
@app.get("/v1/models")
|
1875
|
+
async def list_models():
|
1876
|
+
return {
|
1877
|
+
"data": [{"id": script_args.model, "object": "model", "owned_by": "vllm"}]
|
1878
|
+
}
|
1879
|
+
|
1880
|
+
@app.post("/v1/completions", response_model=OACompletionResponse)
|
1881
|
+
async def openai_completions(req: OACompletionRequest):
|
1882
|
+
global request_queue
|
1883
|
+
|
1884
|
+
if request_queue is None:
|
1885
|
+
logger.error("/v1/completions: Request queue not initialized.")
|
1886
|
+
return JSONResponse(
|
1887
|
+
status_code=503,
|
1888
|
+
content={
|
1889
|
+
"error": "Server not ready, batch processing queue not initialized."
|
1890
|
+
},
|
1891
|
+
)
|
1892
|
+
|
1893
|
+
request_id = f"cmpl-{uuid4().hex}"
|
1894
|
+
logger.debug(f"Received completion request {request_id}, model: {req.model}")
|
1895
|
+
|
1896
|
+
# OACompletionRequest fields: model, prompt, temperature, top_p, max_tokens, n, extra_body
|
1897
|
+
raw_params_for_sig = {
|
1898
|
+
"temperature": req.temperature,
|
1899
|
+
"top_p": req.top_p,
|
1900
|
+
"presence_penalty": req.presence_penalty,
|
1901
|
+
"frequency_penalty": req.frequency_penalty,
|
1902
|
+
"n": req.n, # Pass 'n' from the request
|
1903
|
+
}
|
1904
|
+
# Add other SamplingParams-mappable fields from OACompletionRequest if they exist
|
1905
|
+
# e.g., req.stop, req.presence_penalty etc. if we add them to OACompletionRequest model
|
1906
|
+
# For now, the above are the main ones.
|
1907
|
+
# We need to get ALL fields of OACompletionRequest that are also valid for SamplingParams
|
1908
|
+
# This is safer:
|
1909
|
+
valid_sp_keys = _get_sampling_param_names()
|
1910
|
+
for field_name, field_value in req.model_dump().items():
|
1911
|
+
if field_name in valid_sp_keys and field_name not in raw_params_for_sig:
|
1912
|
+
raw_params_for_sig[field_name] = field_value
|
1913
|
+
|
1914
|
+
default_max = OACompletionRequest.model_fields["max_tokens"].default
|
1915
|
+
effective_max_tokens = req.max_tokens or default_max
|
1916
|
+
|
1917
|
+
pool_sig = create_pool_signature(
|
1918
|
+
model_name=req.model,
|
1919
|
+
request_type="completion",
|
1920
|
+
raw_request_params=raw_params_for_sig,
|
1921
|
+
extra_body=req.extra_body,
|
1922
|
+
)
|
1923
|
+
|
1924
|
+
completion_event = asyncio.Event()
|
1925
|
+
result_container = [None]
|
1926
|
+
|
1927
|
+
# Check for list prompts for completion, which is problematic for current chunking.
|
1928
|
+
# vLLM's generate can take list of prompts, but our chunking logic (appending to prompt) assumes single string.
|
1929
|
+
if isinstance(req.prompt, list):
|
1930
|
+
if len(req.prompt) > 1:
|
1931
|
+
logger.warning(
|
1932
|
+
f"Request {request_id} for completion has a list of prompts. Only the first prompt will be used for chunked generation."
|
1933
|
+
)
|
1934
|
+
current_prompt = req.prompt[0] if req.prompt else ""
|
1935
|
+
elif not req.prompt: # empty list (simplified condition)
|
1936
|
+
current_prompt = ""
|
1937
|
+
else: # list with one element
|
1938
|
+
current_prompt = req.prompt[0]
|
1939
|
+
else: # string
|
1940
|
+
current_prompt = req.prompt
|
1941
|
+
|
1942
|
+
pooled_state = PooledRequestState(
|
1943
|
+
original_request=req,
|
1944
|
+
completion_event=completion_event,
|
1945
|
+
result_container=result_container,
|
1946
|
+
request_id=request_id,
|
1947
|
+
request_type="completion",
|
1948
|
+
pool_signature=pool_sig,
|
1949
|
+
effective_max_tokens=effective_max_tokens,
|
1950
|
+
original_prompt=current_prompt, # Store single prompt for chunking
|
1951
|
+
)
|
1952
|
+
|
1953
|
+
try:
|
1954
|
+
await request_queue.put(pooled_state)
|
1955
|
+
except Exception as e:
|
1956
|
+
logger.error(
|
1957
|
+
f"Enqueueing error for completion {request_id}: {e}", exc_info=True
|
1958
|
+
)
|
1959
|
+
return JSONResponse(
|
1960
|
+
status_code=500,
|
1961
|
+
content={"error": "Internal server error while queueing request."},
|
1962
|
+
)
|
1963
|
+
|
1964
|
+
try:
|
1965
|
+
await asyncio.wait_for(
|
1966
|
+
completion_event.wait(),
|
1967
|
+
timeout=script_args.batch_request_timeout_seconds,
|
1968
|
+
)
|
1969
|
+
except asyncio.TimeoutError:
|
1970
|
+
logger.error(
|
1971
|
+
f"Timeout for completion request {request_id} (model {req.model})."
|
1972
|
+
)
|
1973
|
+
pooled_state.timed_out = True
|
1974
|
+
pooled_state.completed_and_signaled = True
|
1975
|
+
pooled_state.completion_event.set()
|
1976
|
+
return JSONResponse(
|
1977
|
+
status_code=504, content={"error": "Request timed out."}
|
1978
|
+
)
|
1979
|
+
except Exception as e:
|
1980
|
+
logger.error(
|
1981
|
+
f"Error waiting for completion event for {request_id}: {e}",
|
1982
|
+
exc_info=True,
|
1983
|
+
)
|
1984
|
+
return JSONResponse(
|
1985
|
+
status_code=500,
|
1986
|
+
content={
|
1987
|
+
"error": "Internal server error while waiting for completion."
|
1988
|
+
},
|
1989
|
+
)
|
1990
|
+
|
1991
|
+
response_data = result_container[0]
|
1992
|
+
if response_data is None:
|
1993
|
+
logger.error(
|
1994
|
+
f"No result for completion {request_id} (model {req.model}) after event set. Internal error."
|
1995
|
+
)
|
1996
|
+
return JSONResponse(
|
1997
|
+
status_code=500,
|
1998
|
+
content={"error": "Internal error: No result from processor."},
|
1999
|
+
)
|
2000
|
+
|
2001
|
+
if isinstance(response_data, JSONResponse):
|
2002
|
+
return response_data
|
2003
|
+
|
2004
|
+
if isinstance(response_data, OACompletionResponse):
|
2005
|
+
return JSONResponse(
|
2006
|
+
response_data.model_dump()
|
2007
|
+
) # Must return JSONResponse for FastAPI
|
2008
|
+
else:
|
2009
|
+
logger.error(
|
2010
|
+
f"Unexpected result type {type(response_data)} for completion {request_id}."
|
2011
|
+
)
|
2012
|
+
return JSONResponse(
|
2013
|
+
status_code=500,
|
2014
|
+
content={"error": "Internal error: Unexpected result format."},
|
2015
|
+
)
|
2016
|
+
|
2017
|
+
class InitCommunicatorRequest(BaseModel):
|
2018
|
+
host: str
|
2019
|
+
port: int
|
2020
|
+
world_size: int
|
2021
|
+
|
2022
|
+
@app.post("/init_communicator/")
|
2023
|
+
async def init_communicator(request: InitCommunicatorRequest):
|
2024
|
+
"""
|
2025
|
+
Initializes the communicator for synchronizing model weights between a client and multiple server
|
2026
|
+
workers.
|
2027
|
+
|
2028
|
+
Args:
|
2029
|
+
request (`InitCommunicatorRequest`):
|
2030
|
+
- `host` (`str`): Hostname or IP address of the master node.
|
2031
|
+
- `port` (`int`): Port number to be used for communication.
|
2032
|
+
- `world_size` (`int`): Total number of participating processes in the group.
|
2033
|
+
"""
|
2034
|
+
logger.info(
|
2035
|
+
f"[INIT_COMMUNICATOR] Received request: host={request.host}, port={request.port}, world_size={request.world_size}"
|
2036
|
+
)
|
2037
|
+
|
2038
|
+
# Calculate actual world size based on vLLM configuration
|
2039
|
+
vllm_world_size = (
|
2040
|
+
script_args.tensor_parallel_size * script_args.data_parallel_size
|
2041
|
+
)
|
2042
|
+
expected_world_size = vllm_world_size + 1 # +1 for the client
|
2043
|
+
|
2044
|
+
logger.info(
|
2045
|
+
f"[INIT_COMMUNICATOR] vLLM world size: {vllm_world_size} (TP={script_args.tensor_parallel_size} x DP={script_args.data_parallel_size})"
|
2046
|
+
)
|
2047
|
+
logger.info(
|
2048
|
+
f"[INIT_COMMUNICATOR] Expected total world size: {expected_world_size}"
|
2049
|
+
)
|
2050
|
+
|
2051
|
+
if request.world_size != expected_world_size:
|
2052
|
+
logger.warning(
|
2053
|
+
f"[INIT_COMMUNICATOR] World size mismatch! Request: {request.world_size}, Expected: {expected_world_size}"
|
2054
|
+
)
|
2055
|
+
|
2056
|
+
# The function init_communicator is called this way: init_communicator(host, port, world_size)
|
2057
|
+
# So with collective_rpc we need to call it this way:
|
2058
|
+
# llm.collective_rpc(method="init_communicator", args=(host, port, world_size))
|
2059
|
+
kwargs = {
|
2060
|
+
"method": "init_communicator",
|
2061
|
+
"args": (request.host, request.port, expected_world_size),
|
2062
|
+
}
|
2063
|
+
|
2064
|
+
# Send to all workers synchronously to ensure they're ready
|
2065
|
+
successful_workers = []
|
2066
|
+
failed_workers = []
|
2067
|
+
|
2068
|
+
for i, connection in enumerate(connections):
|
2069
|
+
logger.debug(f"[INIT_COMMUNICATOR] Sending to worker {i}")
|
2070
|
+
try:
|
2071
|
+
connection.send(
|
2072
|
+
{
|
2073
|
+
"type": "fire_and_forget",
|
2074
|
+
"method": "collective_rpc",
|
2075
|
+
"kwargs": kwargs,
|
2076
|
+
}
|
2077
|
+
)
|
2078
|
+
successful_workers.append(i)
|
2079
|
+
except Exception as e:
|
2080
|
+
logger.error(f"[INIT_COMMUNICATOR] Failed to notify worker {i}: {e}")
|
2081
|
+
failed_workers.append((i, str(e)))
|
2082
|
+
|
2083
|
+
if failed_workers:
|
2084
|
+
error_msg = f"Failed to notify workers: {failed_workers}"
|
2085
|
+
logger.error(f"[INIT_COMMUNICATOR] {error_msg}")
|
2086
|
+
return JSONResponse(status_code=500, content={"error": error_msg})
|
2087
|
+
|
2088
|
+
logger.info(
|
2089
|
+
f"[INIT_COMMUNICATOR] Successfully notified {len(successful_workers)} workers"
|
2090
|
+
)
|
2091
|
+
return {
|
2092
|
+
"message": "Request received, initializing communicator",
|
2093
|
+
"workers_notified": len(successful_workers),
|
2094
|
+
}
|
2095
|
+
|
2096
|
+
class UpdateWeightsRequest(BaseModel):
|
2097
|
+
name: str
|
2098
|
+
dtype: str
|
2099
|
+
shape: list[int]
|
2100
|
+
|
2101
|
+
class BatchUpdateWeightsRequest(BaseModel):
|
2102
|
+
updates: list[UpdateWeightsRequest]
|
2103
|
+
|
2104
|
+
@app.post("/update_named_param/")
|
2105
|
+
async def update_named_param(request: UpdateWeightsRequest):
|
2106
|
+
"""
|
2107
|
+
Updates the model weights with the provided tensor.
|
2108
|
+
|
2109
|
+
Once this endpoint is called, the client process should broadcast the updated weights to all server workers.
|
2110
|
+
|
2111
|
+
Args:
|
2112
|
+
request (`UpdateWeightsRequest`):
|
2113
|
+
- `name` (`str`): Name of the weight tensor being updated.
|
2114
|
+
- `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`).
|
2115
|
+
- `shape` (list of `int`): Shape of the weight
|
2116
|
+
|
2117
|
+
"""
|
2118
|
+
# Acquire semaphore to limit concurrent weight updates
|
2119
|
+
async with weight_update_semaphore:
|
2120
|
+
logger.info(
|
2121
|
+
f"[UPDATE_PARAM] Received weight update for: {request.name}, dtype={request.dtype}, shape={request.shape}"
|
2122
|
+
)
|
2123
|
+
|
2124
|
+
# Wait for active generations to complete before updating weights
|
2125
|
+
# This prevents conflicts between generation and weight loading
|
2126
|
+
wait_start = time.time()
|
2127
|
+
if active_generation_count > 0:
|
2128
|
+
logger.info(
|
2129
|
+
f"[UPDATE_PARAM] Waiting for {active_generation_count} active generations to complete before weight update"
|
2130
|
+
)
|
2131
|
+
while active_generation_count > 0:
|
2132
|
+
if time.time() - wait_start > 30.0: # 30 second timeout
|
2133
|
+
logger.warning(
|
2134
|
+
f"[UPDATE_PARAM] Timeout waiting for {active_generation_count} active generations to complete"
|
2135
|
+
)
|
2136
|
+
break
|
2137
|
+
await asyncio.sleep(0.1)
|
2138
|
+
if active_generation_count == 0:
|
2139
|
+
logger.debug(
|
2140
|
+
f"[UPDATE_PARAM] All generations complete, proceeding with weight update"
|
2141
|
+
)
|
2142
|
+
|
2143
|
+
# CRITICAL: Notify workers IMMEDIATELY so they're ready for NCCL broadcast
|
2144
|
+
# This must happen before returning the HTTP response to maintain synchronization with trainer
|
2145
|
+
# dtype = torch.__getattribute__(request.dtype.split(".")[-1])
|
2146
|
+
kwargs = {
|
2147
|
+
"method": "update_named_param",
|
2148
|
+
"args": (request.name, request.dtype, tuple(request.shape)),
|
2149
|
+
}
|
2150
|
+
|
2151
|
+
# Send to all workers synchronously to ensure they're ready
|
2152
|
+
# Using fire_and_forget since we don't need the result
|
2153
|
+
for i, connection in enumerate(connections):
|
2154
|
+
logger.debug(f"[UPDATE_PARAM] Notifying worker {i} about weight update")
|
2155
|
+
try:
|
2156
|
+
connection.send(
|
2157
|
+
{
|
2158
|
+
"type": "fire_and_forget",
|
2159
|
+
"method": "collective_rpc",
|
2160
|
+
"kwargs": kwargs,
|
2161
|
+
}
|
2162
|
+
)
|
2163
|
+
except Exception as e:
|
2164
|
+
logger.error(f"[UPDATE_PARAM] Failed to notify worker {i}: {e}")
|
2165
|
+
return JSONResponse(
|
2166
|
+
status_code=500,
|
2167
|
+
content={"error": f"Failed to notify worker {i}: {str(e)}"},
|
2168
|
+
)
|
2169
|
+
|
2170
|
+
logger.debug(
|
2171
|
+
f"[UPDATE_PARAM] All workers notified, trainer should now broadcast via NCCL"
|
2172
|
+
)
|
2173
|
+
return {"message": "Weight update processed"}
|
2174
|
+
|
2175
|
+
@app.post("/batch_update_named_params/")
|
2176
|
+
async def batch_update_named_params(request: BatchUpdateWeightsRequest):
|
2177
|
+
"""
|
2178
|
+
Updates multiple model weights in a batch. Processes updates sequentially
|
2179
|
+
to ensure proper synchronization with NCCL broadcasts from the client.
|
2180
|
+
|
2181
|
+
Args:
|
2182
|
+
request (`BatchUpdateWeightsRequest`):
|
2183
|
+
- `updates` (list of `UpdateWeightsRequest`): List of weight updates to process
|
2184
|
+
"""
|
2185
|
+
logger.info(
|
2186
|
+
f"[BATCH_UPDATE] Received batch of {len(request.updates)} weight updates"
|
2187
|
+
)
|
2188
|
+
|
2189
|
+
# Process updates sequentially to maintain NCCL synchronization
|
2190
|
+
# The client will broadcast each parameter after we notify workers
|
2191
|
+
successful = []
|
2192
|
+
errors = []
|
2193
|
+
|
2194
|
+
for update in request.updates:
|
2195
|
+
try:
|
2196
|
+
# Acquire semaphore to limit concurrent updates across different batches
|
2197
|
+
async with weight_update_semaphore:
|
2198
|
+
logger.debug(
|
2199
|
+
f"[BATCH_UPDATE] Processing weight update for: {update.name}"
|
2200
|
+
)
|
2201
|
+
|
2202
|
+
# Wait for active generations if needed
|
2203
|
+
wait_start = time.time()
|
2204
|
+
while active_generation_count > 0:
|
2205
|
+
if time.time() - wait_start > 30.0:
|
2206
|
+
logger.warning(
|
2207
|
+
f"[BATCH_UPDATE] Timeout waiting for generations"
|
2208
|
+
)
|
2209
|
+
break
|
2210
|
+
await asyncio.sleep(0.1)
|
2211
|
+
|
2212
|
+
# Notify workers synchronously
|
2213
|
+
dtype = getattr(torch, update.dtype.split(".")[-1])
|
2214
|
+
kwargs = {
|
2215
|
+
"method": "update_named_param",
|
2216
|
+
"args": (update.name, dtype, tuple(update.shape)),
|
2217
|
+
}
|
2218
|
+
|
2219
|
+
for i, connection in enumerate(connections):
|
2220
|
+
try:
|
2221
|
+
connection.send(
|
2222
|
+
{
|
2223
|
+
"type": "fire_and_forget",
|
2224
|
+
"method": "collective_rpc",
|
2225
|
+
"kwargs": kwargs,
|
2226
|
+
}
|
2227
|
+
)
|
2228
|
+
except Exception as e:
|
2229
|
+
logger.error(
|
2230
|
+
f"[BATCH_UPDATE] Failed to notify worker {i} for {update.name}: {e}"
|
2231
|
+
)
|
2232
|
+
raise Exception(f"Failed to notify worker {i}")
|
2233
|
+
|
2234
|
+
successful.append(update.name)
|
2235
|
+
logger.debug(f"[BATCH_UPDATE] Workers notified for {update.name}")
|
2236
|
+
|
2237
|
+
except Exception as e:
|
2238
|
+
errors.append({"param": update.name, "error": str(e)})
|
2239
|
+
logger.error(f"[BATCH_UPDATE] Error processing {update.name}: {e}")
|
2240
|
+
|
2241
|
+
if errors:
|
2242
|
+
return JSONResponse(
|
2243
|
+
status_code=207, # Multi-Status
|
2244
|
+
content={
|
2245
|
+
"message": f"Batch update completed with {len(errors)} errors",
|
2246
|
+
"successful": successful,
|
2247
|
+
"errors": errors,
|
2248
|
+
},
|
2249
|
+
)
|
2250
|
+
|
2251
|
+
logger.info(
|
2252
|
+
f"[BATCH_UPDATE] Successfully processed {len(successful)} weight updates"
|
2253
|
+
)
|
2254
|
+
return {
|
2255
|
+
"message": f"Successfully updated {len(successful)} parameters",
|
2256
|
+
"successful": successful,
|
2257
|
+
}
|
2258
|
+
|
2259
|
+
@app.post("/reset_prefix_cache/")
|
2260
|
+
async def reset_prefix_cache():
|
2261
|
+
"""
|
2262
|
+
Resets the prefix cache for the model.
|
2263
|
+
"""
|
2264
|
+
# Send requests and collect results synchronously
|
2265
|
+
all_outputs = []
|
2266
|
+
for connection in connections:
|
2267
|
+
try:
|
2268
|
+
connection.send({"type": "call", "method": "reset_prefix_cache"})
|
2269
|
+
output = connection.recv()
|
2270
|
+
all_outputs.append(output)
|
2271
|
+
except Exception as e:
|
2272
|
+
logger.error(f"Failed to reset prefix cache on worker: {e}")
|
2273
|
+
all_outputs.append(False)
|
2274
|
+
|
2275
|
+
success = all(output for output in all_outputs)
|
2276
|
+
return {
|
2277
|
+
"message": "Request received, resetting prefix cache status: "
|
2278
|
+
+ str(success)
|
2279
|
+
}
|
2280
|
+
|
2281
|
+
@app.post("/close_communicator/")
|
2282
|
+
async def close_communicator():
|
2283
|
+
"""
|
2284
|
+
Closes the weight update group and cleans up associated resources.
|
2285
|
+
"""
|
2286
|
+
kwargs = {"method": "close_communicator"}
|
2287
|
+
|
2288
|
+
# Send to all workers
|
2289
|
+
for connection in connections:
|
2290
|
+
try:
|
2291
|
+
connection.send(
|
2292
|
+
{
|
2293
|
+
"type": "fire_and_forget",
|
2294
|
+
"method": "collective_rpc",
|
2295
|
+
"kwargs": kwargs,
|
2296
|
+
}
|
2297
|
+
)
|
2298
|
+
except Exception as e:
|
2299
|
+
logger.warning(f"Failed to send close_communicator to worker: {e}")
|
2300
|
+
# Don't fail the request if we can't notify a worker during shutdown
|
2301
|
+
|
2302
|
+
return {"message": "Request received, closing communicator"}
|
2303
|
+
|
2304
|
+
# Start the server
|
2305
|
+
# Always use 1 Uvicorn worker. vLLM handles its own worker processes and scheduling.
|
2306
|
+
num_uvicorn_workers = 1
|
2307
|
+
|
2308
|
+
logger.info(
|
2309
|
+
f"Starting Uvicorn with {num_uvicorn_workers} worker(s). Data parallel size: {script_args.data_parallel_size}"
|
2310
|
+
)
|
2311
|
+
uvicorn.run(
|
2312
|
+
app,
|
2313
|
+
host=script_args.host,
|
2314
|
+
port=script_args.port,
|
2315
|
+
log_level=script_args.log_level,
|
2316
|
+
workers=num_uvicorn_workers,
|
2317
|
+
)
|
2318
|
+
|
2319
|
+
|
2320
|
+
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
|
2321
|
+
if subparsers is not None:
|
2322
|
+
parser = subparsers.add_parser(
|
2323
|
+
"vllm-serve",
|
2324
|
+
help="Run the vLLM serve script",
|
2325
|
+
dataclass_types=ScriptArguments,
|
2326
|
+
)
|
2327
|
+
else:
|
2328
|
+
parser = TrlParser(ScriptArguments)
|
2329
|
+
return parser
|
2330
|
+
|
2331
|
+
|
2332
|
+
if __name__ == "__main__":
|
2333
|
+
parser = make_parser()
|
2334
|
+
(script_args,) = parser.parse_args_and_config()
|
2335
|
+
main(script_args)
|