arbor-ai 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2335 @@
1
+ """
2
+ OpenAI-compatible vLLM server with weight synchronization.
3
+
4
+ Usage:
5
+
6
+ ```bash
7
+ uv run python vllm_server.py --model <model_name> --port <port>
8
+ ```
9
+
10
+ Supports:
11
+ - /v1/chat/completions
12
+ - /v1/completions
13
+ """
14
+
15
+ import argparse
16
+ import asyncio
17
+ import inspect
18
+ import json
19
+ import logging
20
+ import os
21
+ import threading
22
+ import time
23
+ import traceback
24
+ from collections import defaultdict
25
+ from collections.abc import Sequence
26
+ from contextlib import asynccontextmanager
27
+ from dataclasses import dataclass, field
28
+ from datetime import datetime, timezone
29
+ from multiprocessing import Pipe, Process
30
+ from multiprocessing.connection import Connection as MPConnection
31
+ from typing import Any
32
+ from typing import Any as AnyType
33
+ from typing import Literal, Optional
34
+ from uuid import uuid4
35
+
36
+ import torch
37
+ import uvicorn
38
+ from fastapi import FastAPI
39
+ from fastapi.responses import JSONResponse, StreamingResponse
40
+ from pydantic import BaseModel, ConfigDict
41
+ from trl import TrlParser
42
+ from vllm import LLM, SamplingParams
43
+ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
44
+ from vllm.distributed.parallel_state import get_world_group
45
+ from vllm.distributed.utils import StatelessProcessGroup
46
+ from vllm.sampling_params import GuidedDecodingParams
47
+ from vllm.utils import get_open_port
48
+
49
+ logging.basicConfig(
50
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
51
+ )
52
+ logger = logging.getLogger(__name__) # Ensure logger is defined
53
+
54
+ # We use CUDA with multiprocessing, so we must use the 'spawn' start method. Otherwise, we will get the following
55
+ # error: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use
56
+ # the 'spawn' start method
57
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
58
+
59
+ # At the global level, after imports and logger setup:
60
+ pipe_lock = threading.Lock() # Global lock for pipe operations
61
+ request_queue: Optional[asyncio.Queue] = None
62
+ batch_processor_task: Optional[asyncio.Task] = None
63
+
64
+ # Generation tracking
65
+ active_generation_count = 0
66
+ generation_count_lock = asyncio.Lock()
67
+
68
+ # Weight update throttling
69
+ MAX_CONCURRENT_WEIGHT_UPDATES = 5
70
+ weight_update_semaphore = asyncio.Semaphore(MAX_CONCURRENT_WEIGHT_UPDATES)
71
+
72
+ # Worker rotation for load balancing
73
+ worker_rotation_index = 0
74
+ worker_rotation_lock = asyncio.Lock()
75
+
76
+
77
+ async def get_next_worker_connection(connections: list[AnyType]) -> tuple[int, AnyType]:
78
+ """Get the next worker connection using round-robin rotation."""
79
+ global worker_rotation_index
80
+ async with worker_rotation_lock:
81
+ if not connections:
82
+ raise RuntimeError("No worker connections available")
83
+ worker_idx = worker_rotation_index % len(connections)
84
+ worker_rotation_index += 1
85
+ return worker_idx, connections[worker_idx]
86
+
87
+
88
+ # -------- OpenAI /v1/chat/completions Pydantic Models ---------- #
89
+ class OAChatMessage(BaseModel):
90
+ role: str
91
+ content: str
92
+
93
+
94
+ class OAChatResponseFormat(BaseModel):
95
+ type: Literal["json_schema", "json_object"]
96
+ json_schema: Optional[dict] = None
97
+
98
+ model_config = ConfigDict(frozen=True)
99
+
100
+ def __hash__(self):
101
+ return hash(
102
+ (
103
+ self.type,
104
+ (
105
+ json.dumps(self.json_schema, sort_keys=True)
106
+ if self.json_schema
107
+ else None
108
+ ),
109
+ )
110
+ )
111
+
112
+
113
+ class OAChatCompletionRequest(BaseModel):
114
+ model: str
115
+ messages: list[OAChatMessage]
116
+ temperature: float | None = 0.7
117
+ top_p: float | None = 1.0
118
+ presence_penalty: float | None = 0.0
119
+ frequency_penalty: float | None = 0.0
120
+ max_tokens: int | None = 1024
121
+ n: int | None = 1
122
+ stop: str | list[str] | None = None
123
+ stream: bool = False # not supported
124
+ extra_body: dict | None = None
125
+ response_format: dict | None = None
126
+ # supported by vLLM:
127
+ # guided_decoding, include_stop_str_in_output, skip_special_tokens, spaces_between_special_tokens
128
+
129
+
130
+ class OAChatChoice(BaseModel):
131
+ index: int
132
+ message: OAChatMessage
133
+ finish_reason: str | None = "stop"
134
+
135
+
136
+ class OAChatCompletionResponse(BaseModel):
137
+ id: str
138
+ object: str = "chat.completion"
139
+ created: int
140
+ model: str
141
+ choices: list[OAChatChoice]
142
+
143
+
144
+ # -------- OpenAI /v1/completions Pydantic Models ---------- #
145
+ class OACompletionRequest(BaseModel):
146
+ model: str
147
+ prompt: str | list[str]
148
+ temperature: float | None = 0.7
149
+ top_p: float | None = 1.0
150
+ presence_penalty: float | None = 0.0
151
+ frequency_penalty: float | None = 0.0
152
+ max_tokens: int | None = 1024
153
+ n: int = 1
154
+ stop: str | list[str] | None = None
155
+ stream: bool = False # not supported
156
+ extra_body: dict | None = None
157
+
158
+
159
+ class OACompletionChoice(BaseModel):
160
+ index: int
161
+ text: str
162
+ logprobs: dict | None = None
163
+ finish_reason: str | None = "length"
164
+
165
+
166
+ class OACompletionResponse(BaseModel):
167
+ id: str
168
+ object: str = "completion"
169
+ created: int
170
+ model: str
171
+ choices: list[OACompletionChoice]
172
+
173
+
174
+ # ---------------------------------------------------------------------- #
175
+
176
+
177
+ def send_and_recv(conn: MPConnection, payload: dict):
178
+ """Helper to send a payload and receive a response over a pipe."""
179
+ # Use the global pipe_lock
180
+ with pipe_lock:
181
+ conn.send(payload)
182
+ return conn.recv()
183
+
184
+
185
+ async def async_send_and_recv(conn: MPConnection, payload: dict, timeout: float = 30.0):
186
+ """Async helper to send a payload and receive a response with timeout."""
187
+ loop = asyncio.get_running_loop()
188
+
189
+ # Send the payload in a thread to avoid blocking
190
+ async with asyncio.timeout(timeout):
191
+ try:
192
+ # Use the global pipe_lock in the executor
193
+ await loop.run_in_executor(None, lambda: pipe_lock.acquire())
194
+ try:
195
+ await loop.run_in_executor(None, conn.send, payload)
196
+
197
+ # Poll for response with timeout
198
+ start_time = asyncio.get_event_loop().time()
199
+ while asyncio.get_event_loop().time() - start_time < timeout:
200
+ if await loop.run_in_executor(None, conn.poll, 0.1):
201
+ result = await loop.run_in_executor(None, conn.recv)
202
+ return result
203
+ await asyncio.sleep(0.05) # Small sleep to avoid busy waiting
204
+
205
+ raise asyncio.TimeoutError(
206
+ f"Worker did not respond within {timeout} seconds"
207
+ )
208
+ finally:
209
+ await loop.run_in_executor(None, lambda: pipe_lock.release())
210
+ except asyncio.TimeoutError:
211
+ logger.error(f"Timeout waiting for worker response after {timeout}s")
212
+ raise
213
+ except Exception as e:
214
+ logger.error(f"Error in async_send_and_recv: {e}", exc_info=True)
215
+ raise
216
+
217
+
218
+ class WeightSyncWorkerExtension:
219
+ """
220
+ A vLLM worker extension that enables weight synchronization between a client and multiple server workers.
221
+
222
+ This worker uses a `StatelessProcessGroup` to establish communication and a `PyNcclCommunicator` to handle
223
+ efficient GPU-based communication using NCCL. The primary purpose of this class is to receive updated model weights
224
+ from a client process and distribute them to all worker processes participating in model inference.
225
+ """
226
+
227
+ # The following attributes are initialized when `init_communicator` method is called.
228
+ pynccl_comm = None # Communicator for weight updates
229
+ client_rank = None # Source rank for broadcasting updated weights
230
+
231
+ def init_communicator(self, host: str, port: int, world_size: int) -> None:
232
+ """
233
+ Initializes the weight update communicator using a stateless process group.
234
+
235
+ This method creates a `StatelessProcessGroup` that allows external training processes to
236
+ communicate with vLLM workers without interfering with the global torch distributed group.
237
+
238
+ Args:
239
+ host (`str`):
240
+ Hostname or IP address of the master node.
241
+ port (`int`):
242
+ Port number to be used for communication.
243
+ world_size (`int`):
244
+ Total number of participating processes in the update group.
245
+ """
246
+ if self.pynccl_comm is not None:
247
+ raise RuntimeError(
248
+ "Weight update group already initialized. Call close_communicator first."
249
+ )
250
+
251
+ # Get the rank of the current worker in the global world group.
252
+ rank = get_world_group().rank
253
+
254
+ # Log device information for debugging
255
+ logger.info(f"[WORKER] Initializing communicator: rank={rank}, device={self.device}, world_size={world_size}") # type: ignore
256
+
257
+ # Create a stateless process group to manage communication between training processes and vLLM workers.
258
+ pg = StatelessProcessGroup.create(
259
+ host=host, port=port, rank=rank, world_size=world_size
260
+ )
261
+
262
+ # Initialize the NCCL-based communicator for weight synchronization.
263
+ self.pynccl_comm = PyNcclCommunicator(pg, device=self.device) # type: ignore
264
+
265
+ # The client process that sends updated weights has the highest rank (world_size - 1).
266
+ self.client_rank = world_size - 1
267
+
268
+ def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None:
269
+ """
270
+ Receives updated weights from the client process and updates the named parameter in the model.
271
+
272
+ Args:
273
+ name (`str`):
274
+ Name of the weight tensor being updated.
275
+ dtype (`torch.dtype`):
276
+ Data type of the weight tensor (e.g., `torch.float32`).
277
+ shape (`Sequence[int]`):
278
+ Shape of the weight tensor.
279
+ """
280
+ import logging
281
+
282
+ logger = logging.getLogger(__name__)
283
+ logger.debug(
284
+ f"[WORKER] Received weight update request for {name}, dtype={dtype}, shape={shape}"
285
+ )
286
+
287
+ if self.pynccl_comm is None:
288
+ raise RuntimeError(
289
+ "Communicator not initialized. Call `init_communicator` first."
290
+ )
291
+
292
+ dtype = getattr(torch, dtype.split(".")[-1])
293
+
294
+ # Allocate memory for the incoming weight tensor on the correct device.
295
+ weight = torch.empty(shape, dtype=dtype, device=self.device) # type: ignore
296
+
297
+ logger.debug(f"[WORKER] Starting NCCL broadcast for {name}")
298
+ # Use NCCL to broadcast the updated weights from the client (src) to all workers.
299
+ self.pynccl_comm.broadcast(weight, src=self.client_rank) # type: ignore
300
+ logger.debug(f"[WORKER] NCCL broadcast complete, waiting at barrier for {name}")
301
+ self.pynccl_comm.group.barrier()
302
+ logger.debug(f"[WORKER] Barrier passed, loading weights for {name}")
303
+
304
+ # Load the received weights into the model.
305
+ self.model_runner.model.load_weights(weights=[(name, weight)]) # type: ignore
306
+ logger.debug(f"[WORKER] Weight update complete for {name}")
307
+
308
+ def close_communicator(self) -> None:
309
+ """
310
+ Closes the communicator when weight synchronization is no longer needed.
311
+
312
+ This method deletes the NCCL communicator to release associated resources.
313
+ """
314
+
315
+ if self.pynccl_comm is not None:
316
+ del self.pynccl_comm
317
+ self.pynccl_comm = None # Ensure attribute is reset to None
318
+ self.client_rank = None # Ensure attribute is reset to None
319
+
320
+
321
+ @dataclass
322
+ class ScriptArguments:
323
+ r"""
324
+ Arguments for the script.
325
+
326
+ Args:
327
+ model (`str`):
328
+ Model name or path to load the model from.
329
+ revision (`str` or `None`, *optional*, defaults to `None`):
330
+ Revision to use for the model. If not specified, the default branch will be used.
331
+ tensor_parallel_size (`int`, *optional*, defaults to `1`):
332
+ Number of tensor parallel workers to use.
333
+ data_parallel_size (`int`, *optional*, defaults to `1`):
334
+ Number of data parallel workers to use.
335
+ host (`str`, *optional*, defaults to `"0.0.0.0"`):
336
+ Host address to run the server on.
337
+ port (`int`, *optional*, defaults to `8000`):
338
+ Port to run the server on.
339
+ gpu_memory_utilization (`float`, *optional*, defaults to `0.95`):
340
+ Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
341
+ device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
342
+ improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
343
+ during initialization.
344
+ dtype (`str`, *optional*, defaults to `"auto"`):
345
+ Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
346
+ based on the model configuration. Find the supported values in the vLLM documentation.
347
+ max_model_len (`int` or `None`, *optional*, defaults to `None`):
348
+ If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
349
+ `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
350
+ context size, which might be much larger than the KV cache, leading to inefficiencies.
351
+ enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`):
352
+ Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support
353
+ this feature.
354
+ enforce_eager (`bool` or `None`, *optional*, defaults to `None`):
355
+ Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the
356
+ model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid.
357
+ kv_cache_dtype (`str`, *optional*, defaults to `"auto"`):
358
+ Data type to use for KV cache. If set to `"auto"`, the dtype will default to the model data type.
359
+ log_level (`str`, *optional*, defaults to `"info"`):
360
+ Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`,
361
+ `"trace"`.
362
+ max_batch_size (int):
363
+ Maximum number of requests to process in one LLM call from the active pool.
364
+ batch_request_timeout_seconds (int):
365
+ Timeout in seconds for a single request waiting for its turn and completion.
366
+ token_chunk_size (int):
367
+ Number of tokens to generate per iteration per request in token-chunk dynamic batching.
368
+ """
369
+
370
+ model: str = field(metadata={"help": "Model name or path to load the model from."})
371
+ revision: Optional[str] = field(
372
+ default=None,
373
+ metadata={
374
+ "help": "Revision to use for the model. If not specified, the default branch will be used."
375
+ },
376
+ )
377
+ tensor_parallel_size: int = field(
378
+ default=1,
379
+ metadata={"help": "Number of tensor parallel workers to use."},
380
+ )
381
+ data_parallel_size: int = field(
382
+ default=1,
383
+ metadata={"help": "Number of data parallel workers to use."},
384
+ )
385
+ host: str = field(
386
+ default="0.0.0.0",
387
+ metadata={"help": "Host address to run the server on."},
388
+ )
389
+ port: int = field(
390
+ default=8000,
391
+ metadata={"help": "Port to run the server on."},
392
+ )
393
+ gpu_memory_utilization: float = field(
394
+ default=0.95,
395
+ metadata={
396
+ "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
397
+ "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
398
+ "size and thus improve the model's throughput. However, if the value is too high, it may cause "
399
+ "out-of-memory (OOM) errors during initialization."
400
+ },
401
+ )
402
+ dtype: str = field(
403
+ default="auto",
404
+ metadata={
405
+ "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
406
+ "determined based on the model configuration. Find the supported values in the vLLM documentation."
407
+ },
408
+ )
409
+ max_model_len: Optional[int] = field(
410
+ default=8192,
411
+ metadata={
412
+ "help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
413
+ "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
414
+ "context size, which might be much larger than the KV cache, leading to inefficiencies."
415
+ },
416
+ )
417
+ enable_prefix_caching: Optional[bool] = field(
418
+ default=True,
419
+ metadata={
420
+ "help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
421
+ "hardware support this feature."
422
+ },
423
+ )
424
+ enforce_eager: Optional[bool] = field(
425
+ default=None,
426
+ metadata={
427
+ "help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always "
428
+ "execute the model in eager mode. If `False` (default behavior), we will use CUDA graph and eager "
429
+ "execution in hybrid."
430
+ },
431
+ )
432
+ kv_cache_dtype: str = field(
433
+ default="auto",
434
+ metadata={
435
+ "help": "Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type."
436
+ },
437
+ )
438
+ log_level: str = field(
439
+ default="debug",
440
+ metadata={
441
+ "help": "Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', "
442
+ "'trace'."
443
+ },
444
+ )
445
+ max_batch_size: int = field(
446
+ default=32,
447
+ metadata={
448
+ "help": "Maximum number of requests to process in one LLM call from the active pool."
449
+ },
450
+ )
451
+ batch_request_timeout_seconds: int = field(
452
+ default=300,
453
+ metadata={
454
+ "help": "Timeout in seconds for a single request waiting for its turn and completion."
455
+ },
456
+ )
457
+ token_chunk_size: int = field(
458
+ default=64,
459
+ metadata={
460
+ "help": "Number of tokens to generate per iteration in token-chunk dynamic batching."
461
+ },
462
+ )
463
+
464
+
465
+ # Global/module-level variables for token-chunk dynamic batching
466
+ _SAMPLING_PARAM_NAMES: Optional[frozenset[str]] = None
467
+
468
+
469
+ @dataclass(frozen=True)
470
+ class PoolSignature:
471
+ model_name: str
472
+ request_type: Literal["chat", "completion"]
473
+ # Excludes max_tokens and stream
474
+ sampling_params_tuple: tuple[tuple[str, AnyType], ...]
475
+ extra_body_params_tuple: tuple[tuple[str, AnyType], ...]
476
+
477
+
478
+ @dataclass
479
+ class PooledRequestState:
480
+ original_request: AnyType # OAChatCompletionRequest or OACompletionRequest
481
+ completion_event: asyncio.Event
482
+ result_container: list
483
+ request_id: str
484
+ request_type: Literal["chat", "completion"]
485
+ pool_signature: PoolSignature # Store the signature for quick checks
486
+ effective_max_tokens: int
487
+ accumulated_content: str = ""
488
+ generated_token_count: int = 0
489
+ original_chat_messages: Optional[list[OAChatMessage]] = None
490
+ original_prompt: Optional[str | list[str]] = None
491
+ error: Optional[Exception] = None
492
+ finish_reason: Optional[str] = None
493
+ completed_and_signaled: bool = False
494
+ timed_out: bool = False
495
+
496
+ @property
497
+ def is_complete(self) -> bool:
498
+ """Single source of truth for whether this request is complete and should not be processed further."""
499
+ # Already finalized
500
+ if self.completed_and_signaled:
501
+ return True
502
+
503
+ # Error state
504
+ if self.error is not None:
505
+ return True
506
+
507
+ # Reached token limit
508
+ if self.generated_token_count >= self.effective_max_tokens:
509
+ return True
510
+
511
+ # Not enough room for meaningful generation (less than 1 token)
512
+ tokens_remaining = self.effective_max_tokens - self.generated_token_count
513
+ if tokens_remaining < 1:
514
+ return True
515
+
516
+ # vLLM indicated completion - but ignore "length" as that's just the chunk limit
517
+ if self.finish_reason is not None and self.finish_reason != "length":
518
+ return True
519
+
520
+ return False
521
+
522
+
523
+ pending_requests_by_signature: defaultdict[PoolSignature, list[PooledRequestState]] = (
524
+ defaultdict(list)
525
+ )
526
+ active_pool_signature: Optional[PoolSignature] = None
527
+ active_pool_requests: list[PooledRequestState] = []
528
+
529
+
530
+ def _get_sampling_param_names() -> frozenset[str]:
531
+ global _SAMPLING_PARAM_NAMES
532
+ if _SAMPLING_PARAM_NAMES is None:
533
+ _SAMPLING_PARAM_NAMES = frozenset(
534
+ inspect.signature(SamplingParams).parameters.keys()
535
+ | frozenset({"response_format"})
536
+ )
537
+ return _SAMPLING_PARAM_NAMES
538
+
539
+
540
+ def create_pool_signature(
541
+ model_name: str,
542
+ request_type: Literal["chat", "completion"],
543
+ raw_request_params: dict[
544
+ str, AnyType
545
+ ], # Contains all original request fields like temp, top_p, etc.
546
+ extra_body: Optional[dict[str, AnyType]],
547
+ ) -> PoolSignature:
548
+ valid_sampling_keys = _get_sampling_param_names()
549
+
550
+ sig_sampling_items = []
551
+ key_openai_to_vllm_map = {
552
+ "temperature": "temperature",
553
+ "top_p": "top_p",
554
+ "n": "n",
555
+ "presence_penalty": "presence_penalty",
556
+ "frequency_penalty": "frequency_penalty",
557
+ "stop": "stop",
558
+ "seed": "seed",
559
+ "ignore_eos": "ignore_eos",
560
+ "min_tokens": "min_tokens",
561
+ "response_format": "response_format",
562
+ }
563
+
564
+ # Use defaults from Pydantic models if not provided in request
565
+ param_defaults_for_sig = {
566
+ "temperature": OAChatCompletionRequest.model_fields["temperature"].default,
567
+ "top_p": OAChatCompletionRequest.model_fields["top_p"].default,
568
+ "presence_penalty": OAChatCompletionRequest.model_fields[
569
+ "presence_penalty"
570
+ ].default,
571
+ "frequency_penalty": OAChatCompletionRequest.model_fields[
572
+ "frequency_penalty"
573
+ ].default,
574
+ "n": OAChatCompletionRequest.model_fields["n"].default,
575
+ "stop": OAChatCompletionRequest.model_fields["stop"].default,
576
+ "response_format": OAChatCompletionRequest.model_fields[
577
+ "response_format"
578
+ ].default,
579
+ # stop: None, seed: None, ignore_eos: False, min_tokens: 0
580
+ }
581
+
582
+ for oa_key, vllm_key in key_openai_to_vllm_map.items():
583
+ if vllm_key in valid_sampling_keys:
584
+ value = raw_request_params.get(
585
+ oa_key, param_defaults_for_sig.get(oa_key)
586
+ ) # Use default if not in request
587
+ # For 'stop', ensure it's a tuple if it's a list for hashability
588
+ if vllm_key == "stop" and isinstance(value, list):
589
+ value = tuple(value)
590
+ if vllm_key == "response_format" and isinstance(value, dict):
591
+ value = json.dumps(value, sort_keys=True)
592
+ if (
593
+ value is not None
594
+ ): # Only add if value is meaningfully set or defaulted for signature
595
+ sig_sampling_items.append((vllm_key, value))
596
+
597
+ # Sort for stable signature
598
+ sig_sampling_items.sort(key=lambda item: item[0])
599
+
600
+ filtered_extra_body_items = []
601
+ if extra_body:
602
+ for k, v in sorted(extra_body.items()):
603
+ # Only include extra_body items that are NOT already part of vLLM SamplingParams
604
+ # to avoid them influencing signature if they are just alternative ways to pass standard params.
605
+ # This primarily targets things like 'guided_decoding_regex'.
606
+ if k not in valid_sampling_keys:
607
+ filtered_extra_body_items.append((k, v))
608
+
609
+ return PoolSignature(
610
+ model_name=model_name,
611
+ request_type=request_type,
612
+ sampling_params_tuple=tuple(sig_sampling_items),
613
+ extra_body_params_tuple=tuple(filtered_extra_body_items),
614
+ )
615
+
616
+
617
+ def llm_worker(
618
+ script_args: ScriptArguments,
619
+ data_parallel_rank: int,
620
+ master_port: int,
621
+ connection: MPConnection,
622
+ ) -> None:
623
+ # Set required environment variables for DP to work with vLLM
624
+ os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
625
+ os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
626
+ os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
627
+ os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
628
+
629
+ llm = LLM(
630
+ model=script_args.model,
631
+ revision=script_args.revision,
632
+ tensor_parallel_size=script_args.tensor_parallel_size,
633
+ gpu_memory_utilization=script_args.gpu_memory_utilization,
634
+ enforce_eager=script_args.enforce_eager,
635
+ dtype=script_args.dtype,
636
+ enable_prefix_caching=script_args.enable_prefix_caching,
637
+ max_model_len=script_args.max_model_len,
638
+ worker_extension_cls="arbor.server.services.inference.vllm_serve.WeightSyncWorkerExtension",
639
+ )
640
+
641
+ # Send ready signal to parent process
642
+ connection.send({"status": "ready"})
643
+
644
+ while True:
645
+ # Wait for commands from the parent process
646
+ try:
647
+ command = connection.recv()
648
+ except KeyboardInterrupt:
649
+ llm.collective_rpc(method="close_communicator")
650
+ break
651
+
652
+ # Handle commands
653
+ if command["type"] in ["call", "fire_and_forget"]:
654
+ method_name = command["method"]
655
+ args, kwargs = command.get("args", ()), command.get("kwargs", {})
656
+
657
+ # Add debugging
658
+ logger.debug(
659
+ f"[WORKER {data_parallel_rank}] Received command: {method_name}"
660
+ )
661
+
662
+ try:
663
+ method = getattr(llm, method_name)
664
+ logger.debug(
665
+ f"[WORKER {data_parallel_rank}] Calling {method_name} with kwargs keys: {list(kwargs.keys()) if kwargs else 'none'}"
666
+ )
667
+
668
+ # Call the method
669
+ result = method(*args, **kwargs)
670
+
671
+ logger.debug(
672
+ f"[WORKER {data_parallel_rank}] {method_name} completed, result type: {type(result)}"
673
+ )
674
+
675
+ if command["type"] == "call":
676
+ # Send result back
677
+ logger.debug(f"[WORKER {data_parallel_rank}] Sending result back")
678
+ connection.send(result)
679
+ logger.debug(f"[WORKER {data_parallel_rank}] Result sent")
680
+ except Exception as e:
681
+ logger.error(
682
+ f"[WORKER {data_parallel_rank}] Error in {method_name}: {e}",
683
+ exc_info=True,
684
+ )
685
+ if command["type"] == "call":
686
+ # Send error back as a special result
687
+ connection.send(
688
+ {"error": str(e), "traceback": traceback.format_exc()}
689
+ )
690
+ elif command["type"] == "shutdown":
691
+ logger.info(f"[WORKER {data_parallel_rank}] Received shutdown command")
692
+ break
693
+
694
+
695
+ def chunk_list(lst: list, n: int) -> list[list]:
696
+ """
697
+ Split list `lst` into `n` evenly distributed sublists.
698
+
699
+ Example:
700
+ >>> chunk_list([1, 2, 3, 4, 5, 6], 2)
701
+ [[1, 2, 3], [4, 5, 6]]
702
+ >>> chunk_list([1, 2, 3, 4, 5, 6], 4)
703
+ [[1, 2], [3, 4], [5], [6]]
704
+ >>> chunk_list([1, 2, 3, 4, 5, 6], 8)
705
+ [[1], [2], [3], [4], [5], [6], [], []]
706
+ """
707
+ k, r = divmod(len(lst), n)
708
+ return [lst[i * k + min(i, r) : (i + 1) * k + min(i + 1, r)] for i in range(n)]
709
+
710
+
711
+ async def batch_processing_loop(
712
+ script_args: ScriptArguments,
713
+ connections: list[AnyType],
714
+ queue: asyncio.Queue, # This queue now receives PooledRequestState
715
+ logger_instance: logging.Logger,
716
+ ):
717
+ global pending_requests_by_signature, active_pool_signature, active_pool_requests
718
+ global active_generation_count, generation_count_lock
719
+
720
+ if not connections:
721
+ logger_instance.error(
722
+ "Batch Processor: No LLM workers available. Shutting down loop."
723
+ )
724
+ # We cannot process anything if connections are not there from the start.
725
+ # New requests added to queue will eventually timeout or error when picked by lifespan shutdown.
726
+ return
727
+
728
+ while True:
729
+ try:
730
+ # 1. Ingest new requests (non-blocking or short timeout)
731
+ # This part tries to fill pending_requests_by_signature from the main request_queue
732
+ try:
733
+ while True: # Loop to drain current items in asyncio.Queue
734
+ pooled_req_state: PooledRequestState = queue.get_nowait()
735
+ pending_requests_by_signature[
736
+ pooled_req_state.pool_signature
737
+ ].append(pooled_req_state)
738
+ logger_instance.debug(
739
+ f"[LIFECYCLE] Request {pooled_req_state.request_id} enqueued to pending pool. "
740
+ f"max_tokens={pooled_req_state.effective_max_tokens}, "
741
+ f"signature={pooled_req_state.pool_signature}"
742
+ )
743
+ queue.task_done() # Signal that this item from main queue is taken
744
+ except asyncio.QueueEmpty:
745
+ pass # No new requests in the main queue right now
746
+
747
+ # 2. Activate a new pool if current is empty and pending requests exist
748
+ if not active_pool_requests and pending_requests_by_signature:
749
+ # Simple strategy: pick the first signature that has pending requests
750
+ # More sophisticated strategies (e.g. largest batch) could be implemented here
751
+ # items() gives a view, convert to list to pop
752
+ available_signatures = list(pending_requests_by_signature.keys())
753
+ if available_signatures:
754
+ active_pool_signature = available_signatures[0] # Pick one
755
+ active_pool_requests = pending_requests_by_signature.pop(
756
+ active_pool_signature
757
+ )
758
+ logger_instance.info(
759
+ f"[LIFECYCLE] Activated new pool with signature: {active_pool_signature}. Size: {len(active_pool_requests)}"
760
+ )
761
+ for req in active_pool_requests:
762
+ logger_instance.debug(
763
+ f"[LIFECYCLE] Request {req.request_id} moved to active pool"
764
+ )
765
+
766
+ # 3. Merge new matching requests into the active pool
767
+ if (
768
+ active_pool_signature
769
+ and active_pool_signature in pending_requests_by_signature
770
+ ):
771
+ newly_matching_requests = pending_requests_by_signature.pop(
772
+ active_pool_signature
773
+ )
774
+ active_pool_requests.extend(newly_matching_requests)
775
+ logger_instance.info(
776
+ f"Merged {len(newly_matching_requests)} requests into active pool. New size: {len(active_pool_requests)}"
777
+ )
778
+
779
+ # 4. Process active pool chunk
780
+ if active_pool_requests:
781
+ # Take a sub-batch from the active pool
782
+ # If active_pool_requests is not empty, active_pool_signature must be set.
783
+ assert (
784
+ active_pool_signature is not None
785
+ ), "active_pool_signature cannot be None if active_pool_requests is populated"
786
+
787
+ # Filter out already-completed requests before selecting sub-batch
788
+ available_requests = [
789
+ req for req in active_pool_requests if not req.is_complete
790
+ ]
791
+
792
+ # Log the state of requests in the pool
793
+ logger_instance.debug(
794
+ f"Active pool has {len(active_pool_requests)} total requests, {len(available_requests)} available for processing"
795
+ )
796
+ for req in active_pool_requests:
797
+ logger_instance.debug(
798
+ f" Request {req.request_id}: tokens={req.generated_token_count}/{req.effective_max_tokens}, "
799
+ f"is_complete={req.is_complete}, finish_reason={req.finish_reason}"
800
+ )
801
+
802
+ if not available_requests:
803
+ # All requests in active pool are already completed, clear the pool
804
+ logger_instance.info(
805
+ f"All requests in active pool {active_pool_signature} are already completed. Clearing pool."
806
+ )
807
+ active_pool_requests.clear()
808
+ active_pool_signature = None
809
+ continue
810
+
811
+ sub_batch_to_process: list[PooledRequestState] = []
812
+ sub_batch_size = min(
813
+ len(available_requests), script_args.max_batch_size
814
+ )
815
+ sub_batch_to_process = available_requests[:sub_batch_size]
816
+
817
+ logger_instance.debug(
818
+ f"[BATCH_PROCESSOR] Processing sub-batch of {len(sub_batch_to_process)} for sig: {active_pool_signature}"
819
+ )
820
+
821
+ # Track active generation
822
+ async with generation_count_lock:
823
+ active_generation_count += 1
824
+ logger_instance.debug(
825
+ f"[BATCH_PROCESSOR] Active generation count increased to {active_generation_count}"
826
+ )
827
+
828
+ try:
829
+ # Prepare inputs for LLM
830
+ # All requests in sub_batch_to_process share active_pool_signature
831
+ # So, sampling params (except max_tokens) are the same.
832
+
833
+ # Construct SamplingParams from the active_pool_signature
834
+ # The signature stores param tuples. Convert back to dict for SamplingParams.
835
+ sig_sampling_dict = dict(
836
+ active_pool_signature.sampling_params_tuple
837
+ )
838
+ sig_extra_body_dict = dict(
839
+ active_pool_signature.extra_body_params_tuple
840
+ )
841
+
842
+ # Override 'n' to 1 for chunked generation as per design.
843
+ # Log if original 'n' was different.
844
+ original_n = sig_sampling_dict.get("n", 1)
845
+ if original_n != 1:
846
+ logger_instance.warning(
847
+ f"Pool {active_pool_signature}: Original 'n={original_n}' overridden to n=1 for chunked generation."
848
+ )
849
+
850
+ # Calculate the minimum tokens remaining across all requests in the batch
851
+ min_tokens_remaining = min(
852
+ req.effective_max_tokens - req.generated_token_count
853
+ for req in sub_batch_to_process
854
+ )
855
+
856
+ # Log token calculations
857
+ logger_instance.debug(
858
+ f"[CHUNK_CALC] Calculating chunk size for {len(sub_batch_to_process)} requests:"
859
+ )
860
+ for req in sub_batch_to_process:
861
+ tokens_left = (
862
+ req.effective_max_tokens - req.generated_token_count
863
+ )
864
+ logger_instance.debug(
865
+ f"[CHUNK_CALC] Request {req.request_id}: {req.generated_token_count}/{req.effective_max_tokens} tokens, {tokens_left} remaining"
866
+ )
867
+
868
+ # Limit chunk size to available room
869
+ chunk_size = min(script_args.token_chunk_size, min_tokens_remaining)
870
+ logger_instance.debug(
871
+ f"[CHUNK_CALC] Final chunk size: {chunk_size} (configured: {script_args.token_chunk_size}, min_remaining: {min_tokens_remaining})"
872
+ )
873
+
874
+ # CRITICAL: Ensure chunk_size is at least 1 to avoid vLLM issues
875
+ if chunk_size <= 0:
876
+ logger_instance.error(
877
+ f"Invalid chunk size {chunk_size} calculated. Min remaining: {min_tokens_remaining}"
878
+ )
879
+ # Mark all requests as complete if we can't generate any more tokens
880
+ for req_state in sub_batch_to_process:
881
+ if not req_state.is_complete:
882
+ req_state.finish_reason = "length"
883
+ logger_instance.info(
884
+ f"Request {req_state.request_id} marked complete due to no room for generation"
885
+ )
886
+ continue # Skip this iteration
887
+
888
+ logger_instance.debug(
889
+ f"Chunk size for batch: {chunk_size} (min remaining: {min_tokens_remaining}, configured: {script_args.token_chunk_size})"
890
+ )
891
+
892
+ # Create a new dict for **kwargs, excluding 'n' as it's set explicitly
893
+ kwargs_for_sampling_params = {
894
+ k: v for k, v in sig_sampling_dict.items() if k != "n"
895
+ }
896
+
897
+ guided_decoding = None
898
+ if "guided_decoding_regex" in sig_extra_body_dict:
899
+ guided_decoding = GuidedDecodingParams(
900
+ backend="outlines",
901
+ regex=sig_extra_body_dict["guided_decoding_regex"],
902
+ )
903
+ elif "response_format" in kwargs_for_sampling_params:
904
+ response_format_json = json.loads(
905
+ kwargs_for_sampling_params["response_format"]
906
+ )
907
+ response_format_type = response_format_json["type"]
908
+ logger_instance.info(
909
+ f"Response format param: {response_format_type}"
910
+ )
911
+ if response_format_type == "json_schema":
912
+ json_schema = response_format_json["json_schema"]
913
+ guided_decoding = GuidedDecodingParams(json=json_schema)
914
+ logger_instance.info(
915
+ f"Going with json_schema {json_schema}"
916
+ )
917
+ elif response_format_type == "json_object":
918
+ logger_instance.info(f"Going with json_object")
919
+ guided_decoding = GuidedDecodingParams(json_object=True)
920
+ else:
921
+ logger_instance.info(
922
+ f"Response format param provided that isn't supported: {response_format_type}"
923
+ )
924
+ # remove the response_format key because we can't pass it to sampling params
925
+ del kwargs_for_sampling_params["response_format"]
926
+ else:
927
+ # guided_decoding should be none if we don't want to use any guided decoding
928
+ pass
929
+
930
+ vllm_sampling_params = SamplingParams(
931
+ **kwargs_for_sampling_params,
932
+ n=1, # Generate one sequence continuation per request in the chunk
933
+ max_tokens=chunk_size, # Use calculated chunk size
934
+ # Ensure guided_decoding is correctly set up if present in extra_body
935
+ guided_decoding=guided_decoding,
936
+ # Remove any params from extra_body that might also be in SamplingParams if they were not filtered by create_pool_signature
937
+ **{
938
+ k: v
939
+ for k, v in sig_extra_body_dict.items()
940
+ if k in _get_sampling_param_names()
941
+ and k != "guided_decoding_regex"
942
+ },
943
+ )
944
+
945
+ # --- Bucket chat requests by first chunk vs continuing ---
946
+ first_chunk_inputs = []
947
+ first_chunk_states = []
948
+ continue_chunk_states = []
949
+ prompts_for_vllm = []
950
+ is_chat_pool = active_pool_signature.request_type == "chat"
951
+
952
+ for req_state in sub_batch_to_process:
953
+ if is_chat_pool:
954
+ current_messages = []
955
+ if req_state.original_chat_messages:
956
+ current_messages.extend(
957
+ [
958
+ m.model_dump()
959
+ for m in req_state.original_chat_messages
960
+ ]
961
+ )
962
+
963
+ # For continuing generation, we need to ensure there's an assistant message to continue
964
+ if req_state.generated_token_count == 0:
965
+ # First chunk - ensure we have a valid message sequence ending with user
966
+ if not current_messages:
967
+ logger_instance.error(
968
+ f"Request {req_state.request_id} has no messages"
969
+ )
970
+ req_state.error = ValueError("No messages provided")
971
+ continue
972
+ if current_messages[-1]["role"] != "user":
973
+ logger_instance.error(
974
+ f"Request {req_state.request_id} last message is not from user for first chunk"
975
+ )
976
+ req_state.error = ValueError(
977
+ "Last message must be from user for first chunk"
978
+ )
979
+ continue
980
+ first_chunk_inputs.append(current_messages)
981
+ first_chunk_states.append(req_state)
982
+ else:
983
+ # Continuing chunk - add accumulated content as assistant message
984
+ if req_state.accumulated_content:
985
+ current_messages.append(
986
+ {
987
+ "role": "assistant",
988
+ "content": req_state.accumulated_content,
989
+ }
990
+ )
991
+ else:
992
+ # This should not happen - we should have content if we're continuing
993
+ logger_instance.error(
994
+ f"Request {req_state.request_id} has no accumulated content for continuation"
995
+ )
996
+ req_state.error = ValueError(
997
+ "No content to continue"
998
+ )
999
+ continue
1000
+ continue_chunk_states.append(req_state)
1001
+ else:
1002
+ if isinstance(req_state.original_prompt, str):
1003
+ current_prompt = req_state.original_prompt
1004
+ elif isinstance(req_state.original_prompt, list):
1005
+ current_prompt = (
1006
+ req_state.original_prompt[0]
1007
+ if req_state.original_prompt
1008
+ else ""
1009
+ )
1010
+ else:
1011
+ current_prompt = str(req_state.original_prompt or "")
1012
+ prompts_for_vllm.append(
1013
+ current_prompt + req_state.accumulated_content
1014
+ )
1015
+
1016
+ # Only process first-chunk chat requests in this tick, then continuing if no first-chunk left
1017
+ llm_results = []
1018
+ processed_states = []
1019
+ if is_chat_pool:
1020
+ loop = asyncio.get_running_loop()
1021
+ if first_chunk_inputs:
1022
+ # Filter out any already-completed requests from first_chunk_states
1023
+ active_first_states = []
1024
+ active_first_inputs = []
1025
+ for i, req_state in enumerate(first_chunk_states):
1026
+ if req_state.is_complete:
1027
+ logger_instance.debug(
1028
+ f"Skipping already-completed request {req_state.request_id} in first chunk processing"
1029
+ )
1030
+ continue
1031
+ active_first_states.append(req_state)
1032
+ active_first_inputs.append(first_chunk_inputs[i])
1033
+
1034
+ if not active_first_states:
1035
+ logger_instance.debug(
1036
+ "All first chunk requests are already completed, skipping LLM call"
1037
+ )
1038
+ processed_states = []
1039
+ llm_results = []
1040
+ else:
1041
+ flags = dict(
1042
+ add_generation_prompt=True,
1043
+ continue_final_message=False,
1044
+ )
1045
+ payload = {
1046
+ "type": "call",
1047
+ "method": "chat",
1048
+ "kwargs": {
1049
+ "messages": active_first_inputs,
1050
+ "sampling_params": vllm_sampling_params,
1051
+ **flags,
1052
+ },
1053
+ }
1054
+ logger_instance.debug(
1055
+ f"Sending first-chunk chat request to LLM with {len(active_first_inputs)} messages"
1056
+ )
1057
+
1058
+ worker_idx = -1 # Initialize to avoid unbound variable
1059
+ try:
1060
+ worker_idx, worker_conn = (
1061
+ await get_next_worker_connection(connections)
1062
+ )
1063
+ logger_instance.debug(
1064
+ f"Using worker {worker_idx} for first-chunk chat request"
1065
+ )
1066
+ llm_results = await async_send_and_recv(
1067
+ worker_conn, payload, timeout=60.0
1068
+ )
1069
+ logger_instance.debug(
1070
+ f"Received {len(llm_results)} results from LLM for first-chunk chat"
1071
+ )
1072
+ except asyncio.TimeoutError:
1073
+ logger_instance.error(
1074
+ f"Worker {worker_idx} timeout for first-chunk chat after 60s"
1075
+ )
1076
+ for req_state in active_first_states:
1077
+ req_state.error = TimeoutError(
1078
+ "Worker timeout during generation"
1079
+ )
1080
+ llm_results = []
1081
+ except Exception as e:
1082
+ logger_instance.error(
1083
+ f"Error calling LLM for first-chunk chat: {e}",
1084
+ exc_info=True,
1085
+ )
1086
+ for req_state in active_first_states:
1087
+ req_state.error = e
1088
+ llm_results = []
1089
+ processed_states = active_first_states
1090
+ elif continue_chunk_states:
1091
+ # No first-chunk requests, process continuing requests
1092
+ continue_chunk_inputs = []
1093
+ # Filter out any already-completed requests
1094
+ active_continue_states = []
1095
+ for req_state in continue_chunk_states:
1096
+ if req_state.is_complete:
1097
+ logger_instance.debug(
1098
+ f"Skipping already-completed request {req_state.request_id} in continue chunk processing"
1099
+ )
1100
+ continue
1101
+ active_continue_states.append(req_state)
1102
+ current_messages = []
1103
+ if req_state.original_chat_messages:
1104
+ current_messages.extend(
1105
+ [
1106
+ m.model_dump()
1107
+ for m in req_state.original_chat_messages
1108
+ ]
1109
+ )
1110
+
1111
+ # Must have accumulated content to continue
1112
+ if not req_state.accumulated_content:
1113
+ logger_instance.error(
1114
+ f"Request {req_state.request_id} has no accumulated content for continuation"
1115
+ )
1116
+ req_state.error = ValueError(
1117
+ "No content to continue generation"
1118
+ )
1119
+ active_continue_states.remove(
1120
+ req_state
1121
+ ) # Remove from active list
1122
+ continue
1123
+
1124
+ # Add the accumulated content as the assistant message to continue
1125
+ current_messages.append(
1126
+ {
1127
+ "role": "assistant",
1128
+ "content": req_state.accumulated_content,
1129
+ }
1130
+ )
1131
+ continue_chunk_inputs.append(current_messages)
1132
+
1133
+ if not active_continue_states:
1134
+ logger_instance.debug(
1135
+ "All continue chunk requests are already completed, skipping LLM call"
1136
+ )
1137
+ processed_states = []
1138
+ llm_results = []
1139
+ else:
1140
+ flags = dict(
1141
+ add_generation_prompt=False,
1142
+ continue_final_message=True,
1143
+ )
1144
+ payload = {
1145
+ "type": "call",
1146
+ "method": "chat",
1147
+ "kwargs": {
1148
+ "messages": continue_chunk_inputs,
1149
+ "sampling_params": vllm_sampling_params,
1150
+ **flags,
1151
+ },
1152
+ }
1153
+ logger_instance.debug(
1154
+ f"Sending continue-chunk chat request to LLM with {len(continue_chunk_inputs)} messages"
1155
+ )
1156
+
1157
+ worker_idx = -1 # Initialize to avoid unbound variable
1158
+ try:
1159
+ worker_idx, worker_conn = (
1160
+ await get_next_worker_connection(connections)
1161
+ )
1162
+ logger_instance.debug(
1163
+ f"Using worker {worker_idx} for continue-chunk chat request"
1164
+ )
1165
+ llm_results = await async_send_and_recv(
1166
+ worker_conn, payload, timeout=60.0
1167
+ )
1168
+ logger_instance.debug(
1169
+ f"Received {len(llm_results)} results from LLM for continue-chunk chat"
1170
+ )
1171
+ except asyncio.TimeoutError:
1172
+ logger_instance.error(
1173
+ f"Worker {worker_idx} timeout for continue-chunk chat after 60s"
1174
+ )
1175
+ for req_state in active_continue_states:
1176
+ req_state.error = TimeoutError(
1177
+ "Worker timeout during generation"
1178
+ )
1179
+ llm_results = []
1180
+ except Exception as e:
1181
+ logger_instance.error(
1182
+ f"Error calling LLM for continue-chunk chat: {e}",
1183
+ exc_info=True,
1184
+ )
1185
+ for req_state in active_continue_states:
1186
+ req_state.error = e
1187
+ llm_results = []
1188
+ processed_states = active_continue_states
1189
+ else:
1190
+ # No requests to process in this iteration
1191
+ logger_instance.debug(
1192
+ "No chat requests to process in this iteration"
1193
+ )
1194
+ processed_states = []
1195
+ llm_results = []
1196
+ else:
1197
+ # completion – unchanged
1198
+ loop = asyncio.get_running_loop()
1199
+ payload = {
1200
+ "type": "call",
1201
+ "method": "generate",
1202
+ "kwargs": {
1203
+ "prompts": prompts_for_vllm,
1204
+ "sampling_params": vllm_sampling_params,
1205
+ },
1206
+ }
1207
+ logger_instance.debug(
1208
+ f"Sending completion request to LLM with {len(prompts_for_vllm)} prompts"
1209
+ )
1210
+ worker_idx = -1 # Initialize to avoid unbound variable
1211
+ try:
1212
+ worker_idx, worker_conn = await get_next_worker_connection(
1213
+ connections
1214
+ )
1215
+ logger_instance.debug(
1216
+ f"Using worker {worker_idx} for completion request"
1217
+ )
1218
+ llm_results = await async_send_and_recv(
1219
+ worker_conn, payload, timeout=60.0
1220
+ )
1221
+ logger_instance.debug(
1222
+ f"Received {len(llm_results)} results from LLM for completion"
1223
+ )
1224
+ except asyncio.TimeoutError:
1225
+ logger_instance.error(
1226
+ f"Worker {worker_idx} timeout for completion after 60s"
1227
+ )
1228
+ for req_state in sub_batch_to_process:
1229
+ req_state.error = TimeoutError(
1230
+ "Worker timeout during generation"
1231
+ )
1232
+ llm_results = []
1233
+ except Exception as e:
1234
+ logger_instance.error(
1235
+ f"Error calling LLM for completion: {e}", exc_info=True
1236
+ )
1237
+ for req_state in sub_batch_to_process:
1238
+ req_state.error = e
1239
+ llm_results = []
1240
+ processed_states = sub_batch_to_process
1241
+
1242
+ # Now, update state for each request in the processed_states
1243
+ temp_failed_requests_in_sub_batch: list[PooledRequestState] = []
1244
+
1245
+ if is_chat_pool:
1246
+ if processed_states and (
1247
+ len(llm_results) != len(processed_states)
1248
+ ):
1249
+ logger_instance.error(
1250
+ f"LLM result count mismatch. Expected {len(processed_states)}, got {len(llm_results)} for sig {active_pool_signature}. Marking affected requests in sub-batch as error."
1251
+ )
1252
+ for req_state in processed_states:
1253
+ if not req_state.completed_and_signaled:
1254
+ req_state.error = RuntimeError(
1255
+ "LLM result mismatch in batch processing."
1256
+ )
1257
+ req_state.finish_reason = "error"
1258
+ temp_failed_requests_in_sub_batch.append(req_state)
1259
+ else:
1260
+ real_idx = 0
1261
+ for req_state in processed_states:
1262
+ if req_state.completed_and_signaled:
1263
+ continue
1264
+ request_output = llm_results[real_idx]
1265
+ real_idx += 1
1266
+ if (
1267
+ not request_output.outputs
1268
+ or len(request_output.outputs) == 0
1269
+ ):
1270
+ logger_instance.warning(
1271
+ f"Request {req_state.request_id} (idx {real_idx-1}) received no output from vLLM in chunk."
1272
+ )
1273
+ # This might happen if vLLM can't generate any tokens (e.g., due to constraints)
1274
+ # Mark as complete rather than error
1275
+ req_state.finish_reason = (
1276
+ "stop" # vLLM couldn't generate
1277
+ )
1278
+ logger_instance.info(
1279
+ f"Request {req_state.request_id} marked complete due to empty vLLM output"
1280
+ )
1281
+ continue
1282
+ completion_output = request_output.outputs[0]
1283
+ new_text_chunk = completion_output.text
1284
+ req_state.accumulated_content += new_text_chunk
1285
+ new_token_count = len(completion_output.token_ids)
1286
+ req_state.generated_token_count += new_token_count
1287
+
1288
+ # Store vLLM's finish reason but we'll interpret it carefully
1289
+ vllm_finish_reason = completion_output.finish_reason
1290
+ logger_instance.debug(
1291
+ f"[VLLM_RESPONSE] Request {req_state.request_id}: vLLM returned {new_token_count} tokens, finish_reason={vllm_finish_reason}"
1292
+ )
1293
+
1294
+ # Only update our finish_reason if it's meaningful
1295
+ if vllm_finish_reason == "length":
1296
+ # vLLM hit the chunk limit - only set our finish_reason if we're at our actual limit
1297
+ if (
1298
+ req_state.generated_token_count
1299
+ >= req_state.effective_max_tokens
1300
+ ):
1301
+ req_state.finish_reason = "length"
1302
+ logger_instance.debug(
1303
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='length' - hit actual limit"
1304
+ )
1305
+ else:
1306
+ # Don't set finish_reason - we can continue generating
1307
+ logger_instance.debug(
1308
+ f"[FINISH_REASON] Request {req_state.request_id}: Ignoring vLLM's finish_reason='length' - only at chunk limit"
1309
+ )
1310
+ elif vllm_finish_reason is not None:
1311
+ # Other finish reasons (stop, eos_token, etc.) are real completions
1312
+ req_state.finish_reason = vllm_finish_reason
1313
+ logger_instance.debug(
1314
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='{vllm_finish_reason}' from vLLM"
1315
+ )
1316
+
1317
+ # Log detailed state for debugging
1318
+ logger_instance.debug(
1319
+ f"Request {req_state.request_id} chunk result: "
1320
+ f"new_tokens={new_token_count}, total_tokens={req_state.generated_token_count}, "
1321
+ f"finish_reason={req_state.finish_reason}, chunk_text_len={len(new_text_chunk)}"
1322
+ )
1323
+
1324
+ # Check if generation has stopped
1325
+ if new_token_count < chunk_size:
1326
+ # Incomplete chunk indicates generation should stop
1327
+ logger_instance.info(
1328
+ f"Request {req_state.request_id} generated incomplete chunk. "
1329
+ f"Generated {new_token_count}/{chunk_size} tokens in chunk. "
1330
+ f"Marking as complete to prevent doom loop."
1331
+ )
1332
+ # Set finish reason if not already set by vLLM
1333
+ if req_state.finish_reason is None:
1334
+ req_state.finish_reason = (
1335
+ "stop" # Generation naturally stopped
1336
+ )
1337
+ logger_instance.debug(
1338
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='stop' due to incomplete chunk"
1339
+ )
1340
+
1341
+ # Log current state
1342
+ logger_instance.debug(
1343
+ f"Request {req_state.request_id} chunk processed. Tokens in chunk: {new_token_count}, total: {req_state.generated_token_count}, is_complete: {req_state.is_complete}"
1344
+ )
1345
+ else:
1346
+ # completion – unchanged
1347
+ if len(llm_results) != len(processed_states):
1348
+ logger_instance.error(
1349
+ f"LLM result count mismatch. Expected {len(processed_states)}, got {len(llm_results)} for sig {active_pool_signature}. Marking affected requests in sub-batch as error."
1350
+ )
1351
+ for req_state in processed_states:
1352
+ if not req_state.completed_and_signaled:
1353
+ req_state.error = RuntimeError(
1354
+ "LLM result mismatch in batch processing."
1355
+ )
1356
+ req_state.finish_reason = "error"
1357
+ temp_failed_requests_in_sub_batch.append(req_state)
1358
+ else:
1359
+ real_idx = 0
1360
+ for req_state in processed_states:
1361
+ if req_state.completed_and_signaled:
1362
+ continue
1363
+ request_output = llm_results[real_idx]
1364
+ real_idx += 1
1365
+ if (
1366
+ not request_output.outputs
1367
+ or len(request_output.outputs) == 0
1368
+ ):
1369
+ logger_instance.warning(
1370
+ f"Request {req_state.request_id} (idx {real_idx-1}) received no output from vLLM in chunk."
1371
+ )
1372
+ # This might happen if vLLM can't generate any tokens (e.g., due to constraints)
1373
+ # Mark as complete rather than error
1374
+ req_state.finish_reason = (
1375
+ "stop" # vLLM couldn't generate
1376
+ )
1377
+ logger_instance.info(
1378
+ f"Request {req_state.request_id} marked complete due to empty vLLM output"
1379
+ )
1380
+ continue
1381
+ completion_output = request_output.outputs[0]
1382
+ new_text_chunk = completion_output.text
1383
+ req_state.accumulated_content += new_text_chunk
1384
+ new_token_count = len(completion_output.token_ids)
1385
+ req_state.generated_token_count += new_token_count
1386
+
1387
+ # Store vLLM's finish reason but we'll interpret it carefully
1388
+ vllm_finish_reason = completion_output.finish_reason
1389
+ logger_instance.debug(
1390
+ f"[VLLM_RESPONSE] Request {req_state.request_id}: vLLM returned {new_token_count} tokens, finish_reason={vllm_finish_reason}"
1391
+ )
1392
+
1393
+ # Only update our finish_reason if it's meaningful
1394
+ if vllm_finish_reason == "length":
1395
+ # vLLM hit the chunk limit - only set our finish_reason if we're at our actual limit
1396
+ if (
1397
+ req_state.generated_token_count
1398
+ >= req_state.effective_max_tokens
1399
+ ):
1400
+ req_state.finish_reason = "length"
1401
+ logger_instance.debug(
1402
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='length' - hit actual limit"
1403
+ )
1404
+ else:
1405
+ # Don't set finish_reason - we can continue generating
1406
+ logger_instance.debug(
1407
+ f"[FINISH_REASON] Request {req_state.request_id}: Ignoring vLLM's finish_reason='length' - only at chunk limit"
1408
+ )
1409
+ elif vllm_finish_reason is not None:
1410
+ # Other finish reasons (stop, eos_token, etc.) are real completions
1411
+ req_state.finish_reason = vllm_finish_reason
1412
+ logger_instance.debug(
1413
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='{vllm_finish_reason}' from vLLM"
1414
+ )
1415
+
1416
+ # Log detailed state for debugging
1417
+ logger_instance.debug(
1418
+ f"Request {req_state.request_id} chunk result: "
1419
+ f"new_tokens={new_token_count}, total_tokens={req_state.generated_token_count}, "
1420
+ f"finish_reason={req_state.finish_reason}, chunk_text_len={len(new_text_chunk)}"
1421
+ )
1422
+
1423
+ # Check if generation has stopped
1424
+ if new_token_count < chunk_size:
1425
+ # Incomplete chunk indicates generation should stop
1426
+ logger_instance.info(
1427
+ f"Request {req_state.request_id} generated incomplete chunk. "
1428
+ f"Generated {new_token_count}/{chunk_size} tokens in chunk. "
1429
+ f"Marking as complete to prevent doom loop."
1430
+ )
1431
+ # Set finish reason if not already set by vLLM
1432
+ if req_state.finish_reason is None:
1433
+ req_state.finish_reason = (
1434
+ "stop" # Generation naturally stopped
1435
+ )
1436
+ logger_instance.debug(
1437
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='stop' due to incomplete chunk"
1438
+ )
1439
+
1440
+ # Log current state
1441
+ logger_instance.debug(
1442
+ f"Request {req_state.request_id} chunk processed. Tokens in chunk: {new_token_count}, total: {req_state.generated_token_count}, is_complete: {req_state.is_complete}"
1443
+ )
1444
+
1445
+ # Now, handle all finished or errored requests from this sub_batch
1446
+ # These need to be removed from active_pool_requests and have their events set.
1447
+ completed_in_sub_batch: list[PooledRequestState] = []
1448
+ remaining_in_sub_batch: list[PooledRequestState] = []
1449
+
1450
+ # Iterate over sub_batch_to_process to decide their fate
1451
+ updated_active_pool = []
1452
+
1453
+ # Create a set of request_ids from sub_batch_to_process for quick lookup
1454
+ sub_batch_ids = {r.request_id for r in sub_batch_to_process}
1455
+
1456
+ for req_state in active_pool_requests: # Iterate main active pool
1457
+ if req_state.request_id not in sub_batch_ids:
1458
+ updated_active_pool.append(
1459
+ req_state
1460
+ ) # Keep if not in current sub-batch
1461
+ continue
1462
+
1463
+ # req_state is from the current sub_batch. Check its status.
1464
+ logger_instance.debug(
1465
+ f"[COMPLETION_CHECK] Checking request {req_state.request_id}: "
1466
+ f"generated={req_state.generated_token_count}/{req_state.effective_max_tokens}, "
1467
+ f"finish_reason={req_state.finish_reason}, is_complete={req_state.is_complete}"
1468
+ )
1469
+
1470
+ if (
1471
+ req_state.is_complete
1472
+ and not req_state.completed_and_signaled
1473
+ ):
1474
+ # Request is complete but not yet signaled
1475
+ completed_in_sub_batch.append(req_state)
1476
+ logger_instance.info(
1477
+ f"[LIFECYCLE] Request {req_state.request_id} is complete and will be finalized. "
1478
+ f"Generated {req_state.generated_token_count} tokens, finish_reason={req_state.finish_reason}"
1479
+ )
1480
+ elif not req_state.is_complete:
1481
+ # Request is not complete, keep for next chunk
1482
+ updated_active_pool.append(req_state)
1483
+ logger_instance.debug(
1484
+ f"[LIFECYCLE] Request {req_state.request_id} is not complete, keeping for next chunk"
1485
+ )
1486
+ # If already signaled (completed_and_signaled is True), don't re-add or re-signal
1487
+
1488
+ active_pool_requests = updated_active_pool
1489
+ if not active_pool_requests:
1490
+ # Store signature before setting to None for logging
1491
+ deactivated_signature = active_pool_signature
1492
+ active_pool_signature = None # Deactivate pool if empty
1493
+ logger_instance.info(
1494
+ f"Deactivated pool {deactivated_signature} as it is now empty."
1495
+ )
1496
+
1497
+ for req_state in completed_in_sub_batch:
1498
+ if req_state.completed_and_signaled:
1499
+ continue # Already handled
1500
+
1501
+ response_content: (
1502
+ OAChatCompletionResponse
1503
+ | OACompletionResponse
1504
+ | JSONResponse
1505
+ )
1506
+ if req_state.error:
1507
+ logger_instance.error(
1508
+ f"Request {req_state.request_id} failed with error: {req_state.error}"
1509
+ )
1510
+ response_content = JSONResponse(
1511
+ status_code=500,
1512
+ content={
1513
+ "error": f"Processing error: {str(req_state.error)}",
1514
+ "request_id": req_state.request_id,
1515
+ },
1516
+ )
1517
+ elif req_state.request_type == "chat":
1518
+ final_choices = [
1519
+ OAChatChoice(
1520
+ index=0,
1521
+ message=OAChatMessage(
1522
+ role="assistant",
1523
+ content=req_state.accumulated_content,
1524
+ ),
1525
+ finish_reason=req_state.finish_reason,
1526
+ )
1527
+ ]
1528
+ response_content = OAChatCompletionResponse(
1529
+ id=f"chatcmpl-{uuid4().hex}", # Use original request_id if available? For now, new UUID.
1530
+ created=int(datetime.now(tz=timezone.utc).timestamp()),
1531
+ model=req_state.original_request.model,
1532
+ choices=final_choices,
1533
+ )
1534
+ else: # Completion
1535
+ final_choices = [
1536
+ OACompletionChoice(
1537
+ index=0,
1538
+ text=req_state.accumulated_content,
1539
+ finish_reason=req_state.finish_reason,
1540
+ )
1541
+ ]
1542
+ response_content = OACompletionResponse(
1543
+ id=f"cmpl-{uuid4().hex}",
1544
+ created=int(datetime.now(tz=timezone.utc).timestamp()),
1545
+ model=req_state.original_request.model,
1546
+ choices=final_choices,
1547
+ )
1548
+
1549
+ req_state.result_container[0] = response_content
1550
+ req_state.completion_event.set()
1551
+ req_state.completed_and_signaled = True
1552
+
1553
+ finally:
1554
+ # Always decrement active generation count
1555
+ async with generation_count_lock:
1556
+ active_generation_count -= 1
1557
+ logger_instance.debug(
1558
+ f"[BATCH_PROCESSOR] Active generation count decreased to {active_generation_count}"
1559
+ )
1560
+
1561
+ else: # No active pool
1562
+ await asyncio.sleep(
1563
+ 0.01
1564
+ ) # Small sleep if no active pool and pending queue was empty
1565
+
1566
+ except asyncio.CancelledError:
1567
+ logger_instance.info("Batch processing loop cancelled.")
1568
+ all_requests_to_cancel = list(active_pool_requests)
1569
+ active_pool_requests.clear()
1570
+ active_pool_signature = None
1571
+ for sig_list in pending_requests_by_signature.values():
1572
+ all_requests_to_cancel.extend(sig_list)
1573
+ pending_requests_by_signature.clear()
1574
+
1575
+ for req_state in all_requests_to_cancel:
1576
+ if not req_state.completed_and_signaled:
1577
+ req_state.result_container[0] = JSONResponse(
1578
+ status_code=503,
1579
+ content={"error": "Server shutting down, request cancelled."},
1580
+ )
1581
+ req_state.completion_event.set()
1582
+ req_state.completed_and_signaled = True
1583
+ break
1584
+ except Exception as e:
1585
+ logger_instance.error(
1586
+ f"Critical error in batch processing loop: {e}", exc_info=True
1587
+ )
1588
+ all_requests_to_fail = list(active_pool_requests)
1589
+ active_pool_requests.clear()
1590
+ active_pool_signature = None
1591
+ for sig_list in pending_requests_by_signature.values():
1592
+ all_requests_to_fail.extend(sig_list)
1593
+ pending_requests_by_signature.clear()
1594
+
1595
+ for req_state in all_requests_to_fail:
1596
+ if not req_state.completed_and_signaled:
1597
+ req_state.result_container[0] = JSONResponse(
1598
+ status_code=500,
1599
+ content={"error": f"Critical batch processor error: {str(e)}"},
1600
+ )
1601
+ req_state.completion_event.set()
1602
+ req_state.completed_and_signaled = True
1603
+ await asyncio.sleep(1) # Pause before retrying loop
1604
+
1605
+
1606
+ def main(script_args: ScriptArguments):
1607
+ global request_queue, batch_processor_task # Allow lifespan to assign to these
1608
+
1609
+ # Spawn dp workers, and setup pipes for communication
1610
+ master_port = get_open_port()
1611
+ connections: list[AnyType] = (
1612
+ []
1613
+ ) # Use Any type to avoid PipeConnection vs Connection mismatch
1614
+ processes = []
1615
+ for data_parallel_rank in range(script_args.data_parallel_size):
1616
+ # Use duplex=True to allow bidirectional communication
1617
+ # This is needed for "call" type commands that expect responses
1618
+ parent_connection, child_connection = Pipe(duplex=True)
1619
+ process = Process(
1620
+ target=llm_worker,
1621
+ args=(script_args, data_parallel_rank, master_port, child_connection),
1622
+ )
1623
+ process.start()
1624
+ connections.append(parent_connection)
1625
+ processes.append(process)
1626
+
1627
+ @asynccontextmanager
1628
+ async def lifespan(app: FastAPI):
1629
+ nonlocal processes # Capture from outer scope
1630
+ global request_queue, batch_processor_task # Defined at module level
1631
+
1632
+ logger.info(
1633
+ f"Lifespan: Waiting for {script_args.data_parallel_size} LLM worker(s) to be ready..."
1634
+ )
1635
+ ready_connections = set()
1636
+
1637
+ # Timeout for waiting for workers to get ready (e.g., 5 minutes)
1638
+ timeout_seconds = 300
1639
+ start_wait_time = time.time()
1640
+
1641
+ while len(ready_connections) < script_args.data_parallel_size:
1642
+ if time.time() - start_wait_time > timeout_seconds:
1643
+ logger.error(
1644
+ f"Lifespan: Timeout waiting for all LLM workers. Expected {script_args.data_parallel_size}, got {len(ready_connections)} ready."
1645
+ )
1646
+ raise RuntimeError("LLM workers failed to initialize in time")
1647
+
1648
+ for i, connection in enumerate(connections):
1649
+ if connection not in ready_connections:
1650
+ # Use poll() with a short timeout to avoid blocking indefinitely if a worker is stuck
1651
+ if connection.poll(
1652
+ 0.1
1653
+ ): # Check if data is available, with a 0.1s timeout
1654
+ try:
1655
+ msg = connection.recv()
1656
+ logger.info(
1657
+ f"Lifespan: Received message from worker {i}: {msg}"
1658
+ )
1659
+ if isinstance(msg, dict) and msg.get("status") == "ready":
1660
+ logger.info(f"Lifespan: LLM worker {i} reported ready.")
1661
+ ready_connections.add(connection)
1662
+ else:
1663
+ logger.warning(
1664
+ f"Lifespan: Received unexpected message from worker {i}: {msg}"
1665
+ )
1666
+ except Exception as e:
1667
+ logger.error(
1668
+ f"Lifespan: Error receiving message from worker {i}: {e}"
1669
+ )
1670
+
1671
+ if len(ready_connections) < script_args.data_parallel_size:
1672
+ time.sleep(
1673
+ 0.5
1674
+ ) # Brief sleep to avoid busy-waiting if not all workers are ready yet
1675
+
1676
+ if len(ready_connections) == script_args.data_parallel_size:
1677
+ logger.info(
1678
+ f"Lifespan: All {script_args.data_parallel_size} LLM worker(s) are ready. Proceeding to yield."
1679
+ )
1680
+ # Initialize request queue and start batch processor task
1681
+ request_queue = asyncio.Queue()
1682
+ logger.info(
1683
+ "Lifespan: Initialized request queue for batched chat completions."
1684
+ )
1685
+ batch_processor_task = asyncio.create_task(
1686
+ batch_processing_loop(script_args, connections, request_queue, logger)
1687
+ )
1688
+ logger.info("Lifespan: Started batch processing task for chat completions.")
1689
+ else:
1690
+ logger.error(
1691
+ f"Lifespan: Not all LLM workers became ready. Expected {script_args.data_parallel_size}, got {len(ready_connections)}. Uvicorn might not function correctly. Batch processor NOT started."
1692
+ )
1693
+
1694
+ yield
1695
+ logger.info("Lifespan: Uvicorn server is shutting down. Cleaning up resources.")
1696
+
1697
+ if batch_processor_task is not None and not batch_processor_task.done():
1698
+ logger.info("Lifespan: Cancelling batch processor task...")
1699
+ batch_processor_task.cancel()
1700
+ try:
1701
+ await batch_processor_task
1702
+ logger.info("Lifespan: Batch processor task finished.")
1703
+ except asyncio.CancelledError:
1704
+ logger.info("Lifespan: Batch processor task was cancelled as expected.")
1705
+ except Exception as e:
1706
+ logger.error(
1707
+ f"Lifespan: Exception during batch processor task shutdown: {e}",
1708
+ exc_info=True,
1709
+ )
1710
+
1711
+ # Wait for processes to terminate
1712
+ for process in processes:
1713
+ process.join(timeout=10) # Wait for 10 seconds for the process to terminate
1714
+ if process.is_alive():
1715
+ logger.warning(
1716
+ f"Process {process} is still alive after 10 seconds, attempting to terminate..."
1717
+ )
1718
+ process.terminate()
1719
+ process.join() # ensure process termination after calling terminate()
1720
+
1721
+ app = FastAPI(lifespan=lifespan)
1722
+
1723
+ # Define the endpoints for the model server
1724
+ @app.get("/health/")
1725
+ async def health():
1726
+ """
1727
+ Health check endpoint to verify that the server is running.
1728
+ """
1729
+ return {"status": "ok"}
1730
+
1731
+ @app.get("/get_world_size/")
1732
+ async def get_world_size():
1733
+ """
1734
+ Retrieves the world size of the LLM engine, which is `tensor_parallel_size * data_parallel_size`.
1735
+
1736
+ Returns:
1737
+ `dict`:
1738
+ A dictionary containing the world size.
1739
+
1740
+ Example response:
1741
+ ```json
1742
+ {"world_size": 8}
1743
+ ```
1744
+ """
1745
+ return {
1746
+ "world_size": script_args.tensor_parallel_size
1747
+ * script_args.data_parallel_size
1748
+ }
1749
+
1750
+ # -------- OpenAI-chat ---------- #
1751
+ @app.post("/v1/chat/completions", response_model=OAChatCompletionResponse)
1752
+ async def openai_chat(req: OAChatCompletionRequest):
1753
+ global request_queue
1754
+
1755
+ if request_queue is None:
1756
+ logger.error("/v1/chat/completions: Request queue not initialized.")
1757
+ return JSONResponse(
1758
+ status_code=503,
1759
+ content={
1760
+ "error": "Server not ready, batch processing queue not initialized."
1761
+ },
1762
+ )
1763
+
1764
+ request_id = f"chatcmpl-{uuid4().hex}"
1765
+ logger.debug(f"Received chat request {request_id}, model: {req.model}")
1766
+
1767
+ # Create signature for pooling
1768
+ # The OAChatCompletionRequest fields are: model, messages, temperature, top_p, max_tokens, stream, extra_body
1769
+ # We need to pass the relevant ones to create_pool_signature
1770
+ raw_params_for_sig = {
1771
+ "temperature": req.temperature,
1772
+ "top_p": req.top_p,
1773
+ "presence_penalty": req.presence_penalty,
1774
+ "frequency_penalty": req.frequency_penalty,
1775
+ # "n" is not in OAChatCompletionRequest, defaults to 1 for chat in OpenAI spec
1776
+ "n": 1,
1777
+ "response_format": req.response_format,
1778
+ }
1779
+ # Add other SamplingParams-mappable fields if they were part of OAChatCompletionRequest
1780
+ # For example, if we added 'stop', 'presence_penalty' etc. to OAChatCompletionRequest.
1781
+ # For now, the above are the main ones.
1782
+
1783
+ default_max = OAChatCompletionRequest.model_fields["max_tokens"].default
1784
+ effective_max_tokens = req.max_tokens or default_max
1785
+
1786
+ pool_sig = create_pool_signature(
1787
+ model_name=req.model,
1788
+ request_type="chat",
1789
+ raw_request_params=raw_params_for_sig,
1790
+ extra_body=req.extra_body,
1791
+ )
1792
+
1793
+ completion_event = asyncio.Event()
1794
+ result_container = [None]
1795
+
1796
+ pooled_state = PooledRequestState(
1797
+ original_request=req,
1798
+ completion_event=completion_event,
1799
+ result_container=result_container,
1800
+ request_id=request_id,
1801
+ request_type="chat",
1802
+ pool_signature=pool_sig,
1803
+ effective_max_tokens=effective_max_tokens,
1804
+ original_chat_messages=req.messages, # Store original messages
1805
+ )
1806
+
1807
+ logger.info(
1808
+ f"[LIFECYCLE] Created chat request {request_id}: max_tokens={effective_max_tokens}, "
1809
+ f"messages={len(req.messages)}, model={req.model}"
1810
+ )
1811
+
1812
+ try:
1813
+ await request_queue.put(pooled_state)
1814
+ logger.debug(f"[LIFECYCLE] Request {request_id} successfully queued")
1815
+ except Exception as e:
1816
+ logger.error(f"Enqueueing error for {request_id}: {e}", exc_info=True)
1817
+ return JSONResponse(
1818
+ status_code=500,
1819
+ content={"error": "Internal server error while queueing request."},
1820
+ )
1821
+
1822
+ try:
1823
+ await asyncio.wait_for(
1824
+ completion_event.wait(),
1825
+ timeout=script_args.batch_request_timeout_seconds,
1826
+ )
1827
+ except asyncio.TimeoutError:
1828
+ logger.error(f"Timeout for chat request {request_id} (model {req.model}).")
1829
+ pooled_state.timed_out = True
1830
+ pooled_state.completed_and_signaled = True
1831
+ pooled_state.completion_event.set()
1832
+ return JSONResponse(
1833
+ status_code=504, content={"error": "Request timed out."}
1834
+ )
1835
+ except Exception as e:
1836
+ logger.error(
1837
+ f"Error waiting for completion event for {request_id}: {e}",
1838
+ exc_info=True,
1839
+ )
1840
+ return JSONResponse(
1841
+ status_code=500,
1842
+ content={
1843
+ "error": "Internal server error while waiting for completion."
1844
+ },
1845
+ )
1846
+
1847
+ response_data = result_container[0]
1848
+ if (
1849
+ response_data is None
1850
+ ): # Should ideally be set to an error by processor if timeout internally
1851
+ logger.error(
1852
+ f"No result for {request_id} (model {req.model}) after event set. Internal error."
1853
+ )
1854
+ return JSONResponse(
1855
+ status_code=500,
1856
+ content={"error": "Internal error: No result from processor."},
1857
+ )
1858
+
1859
+ if isinstance(response_data, JSONResponse):
1860
+ return response_data
1861
+
1862
+ if isinstance(response_data, OAChatCompletionResponse):
1863
+ # Must return JSONResponse for FastAPI
1864
+ return JSONResponse(response_data.model_dump())
1865
+ else:
1866
+ logger.error(
1867
+ f"Unexpected result type {type(response_data)} for {request_id}."
1868
+ )
1869
+ return JSONResponse(
1870
+ status_code=500,
1871
+ content={"error": "Internal error: Unexpected result format."},
1872
+ )
1873
+
1874
+ @app.get("/v1/models")
1875
+ async def list_models():
1876
+ return {
1877
+ "data": [{"id": script_args.model, "object": "model", "owned_by": "vllm"}]
1878
+ }
1879
+
1880
+ @app.post("/v1/completions", response_model=OACompletionResponse)
1881
+ async def openai_completions(req: OACompletionRequest):
1882
+ global request_queue
1883
+
1884
+ if request_queue is None:
1885
+ logger.error("/v1/completions: Request queue not initialized.")
1886
+ return JSONResponse(
1887
+ status_code=503,
1888
+ content={
1889
+ "error": "Server not ready, batch processing queue not initialized."
1890
+ },
1891
+ )
1892
+
1893
+ request_id = f"cmpl-{uuid4().hex}"
1894
+ logger.debug(f"Received completion request {request_id}, model: {req.model}")
1895
+
1896
+ # OACompletionRequest fields: model, prompt, temperature, top_p, max_tokens, n, extra_body
1897
+ raw_params_for_sig = {
1898
+ "temperature": req.temperature,
1899
+ "top_p": req.top_p,
1900
+ "presence_penalty": req.presence_penalty,
1901
+ "frequency_penalty": req.frequency_penalty,
1902
+ "n": req.n, # Pass 'n' from the request
1903
+ }
1904
+ # Add other SamplingParams-mappable fields from OACompletionRequest if they exist
1905
+ # e.g., req.stop, req.presence_penalty etc. if we add them to OACompletionRequest model
1906
+ # For now, the above are the main ones.
1907
+ # We need to get ALL fields of OACompletionRequest that are also valid for SamplingParams
1908
+ # This is safer:
1909
+ valid_sp_keys = _get_sampling_param_names()
1910
+ for field_name, field_value in req.model_dump().items():
1911
+ if field_name in valid_sp_keys and field_name not in raw_params_for_sig:
1912
+ raw_params_for_sig[field_name] = field_value
1913
+
1914
+ default_max = OACompletionRequest.model_fields["max_tokens"].default
1915
+ effective_max_tokens = req.max_tokens or default_max
1916
+
1917
+ pool_sig = create_pool_signature(
1918
+ model_name=req.model,
1919
+ request_type="completion",
1920
+ raw_request_params=raw_params_for_sig,
1921
+ extra_body=req.extra_body,
1922
+ )
1923
+
1924
+ completion_event = asyncio.Event()
1925
+ result_container = [None]
1926
+
1927
+ # Check for list prompts for completion, which is problematic for current chunking.
1928
+ # vLLM's generate can take list of prompts, but our chunking logic (appending to prompt) assumes single string.
1929
+ if isinstance(req.prompt, list):
1930
+ if len(req.prompt) > 1:
1931
+ logger.warning(
1932
+ f"Request {request_id} for completion has a list of prompts. Only the first prompt will be used for chunked generation."
1933
+ )
1934
+ current_prompt = req.prompt[0] if req.prompt else ""
1935
+ elif not req.prompt: # empty list (simplified condition)
1936
+ current_prompt = ""
1937
+ else: # list with one element
1938
+ current_prompt = req.prompt[0]
1939
+ else: # string
1940
+ current_prompt = req.prompt
1941
+
1942
+ pooled_state = PooledRequestState(
1943
+ original_request=req,
1944
+ completion_event=completion_event,
1945
+ result_container=result_container,
1946
+ request_id=request_id,
1947
+ request_type="completion",
1948
+ pool_signature=pool_sig,
1949
+ effective_max_tokens=effective_max_tokens,
1950
+ original_prompt=current_prompt, # Store single prompt for chunking
1951
+ )
1952
+
1953
+ try:
1954
+ await request_queue.put(pooled_state)
1955
+ except Exception as e:
1956
+ logger.error(
1957
+ f"Enqueueing error for completion {request_id}: {e}", exc_info=True
1958
+ )
1959
+ return JSONResponse(
1960
+ status_code=500,
1961
+ content={"error": "Internal server error while queueing request."},
1962
+ )
1963
+
1964
+ try:
1965
+ await asyncio.wait_for(
1966
+ completion_event.wait(),
1967
+ timeout=script_args.batch_request_timeout_seconds,
1968
+ )
1969
+ except asyncio.TimeoutError:
1970
+ logger.error(
1971
+ f"Timeout for completion request {request_id} (model {req.model})."
1972
+ )
1973
+ pooled_state.timed_out = True
1974
+ pooled_state.completed_and_signaled = True
1975
+ pooled_state.completion_event.set()
1976
+ return JSONResponse(
1977
+ status_code=504, content={"error": "Request timed out."}
1978
+ )
1979
+ except Exception as e:
1980
+ logger.error(
1981
+ f"Error waiting for completion event for {request_id}: {e}",
1982
+ exc_info=True,
1983
+ )
1984
+ return JSONResponse(
1985
+ status_code=500,
1986
+ content={
1987
+ "error": "Internal server error while waiting for completion."
1988
+ },
1989
+ )
1990
+
1991
+ response_data = result_container[0]
1992
+ if response_data is None:
1993
+ logger.error(
1994
+ f"No result for completion {request_id} (model {req.model}) after event set. Internal error."
1995
+ )
1996
+ return JSONResponse(
1997
+ status_code=500,
1998
+ content={"error": "Internal error: No result from processor."},
1999
+ )
2000
+
2001
+ if isinstance(response_data, JSONResponse):
2002
+ return response_data
2003
+
2004
+ if isinstance(response_data, OACompletionResponse):
2005
+ return JSONResponse(
2006
+ response_data.model_dump()
2007
+ ) # Must return JSONResponse for FastAPI
2008
+ else:
2009
+ logger.error(
2010
+ f"Unexpected result type {type(response_data)} for completion {request_id}."
2011
+ )
2012
+ return JSONResponse(
2013
+ status_code=500,
2014
+ content={"error": "Internal error: Unexpected result format."},
2015
+ )
2016
+
2017
+ class InitCommunicatorRequest(BaseModel):
2018
+ host: str
2019
+ port: int
2020
+ world_size: int
2021
+
2022
+ @app.post("/init_communicator/")
2023
+ async def init_communicator(request: InitCommunicatorRequest):
2024
+ """
2025
+ Initializes the communicator for synchronizing model weights between a client and multiple server
2026
+ workers.
2027
+
2028
+ Args:
2029
+ request (`InitCommunicatorRequest`):
2030
+ - `host` (`str`): Hostname or IP address of the master node.
2031
+ - `port` (`int`): Port number to be used for communication.
2032
+ - `world_size` (`int`): Total number of participating processes in the group.
2033
+ """
2034
+ logger.info(
2035
+ f"[INIT_COMMUNICATOR] Received request: host={request.host}, port={request.port}, world_size={request.world_size}"
2036
+ )
2037
+
2038
+ # Calculate actual world size based on vLLM configuration
2039
+ vllm_world_size = (
2040
+ script_args.tensor_parallel_size * script_args.data_parallel_size
2041
+ )
2042
+ expected_world_size = vllm_world_size + 1 # +1 for the client
2043
+
2044
+ logger.info(
2045
+ f"[INIT_COMMUNICATOR] vLLM world size: {vllm_world_size} (TP={script_args.tensor_parallel_size} x DP={script_args.data_parallel_size})"
2046
+ )
2047
+ logger.info(
2048
+ f"[INIT_COMMUNICATOR] Expected total world size: {expected_world_size}"
2049
+ )
2050
+
2051
+ if request.world_size != expected_world_size:
2052
+ logger.warning(
2053
+ f"[INIT_COMMUNICATOR] World size mismatch! Request: {request.world_size}, Expected: {expected_world_size}"
2054
+ )
2055
+
2056
+ # The function init_communicator is called this way: init_communicator(host, port, world_size)
2057
+ # So with collective_rpc we need to call it this way:
2058
+ # llm.collective_rpc(method="init_communicator", args=(host, port, world_size))
2059
+ kwargs = {
2060
+ "method": "init_communicator",
2061
+ "args": (request.host, request.port, expected_world_size),
2062
+ }
2063
+
2064
+ # Send to all workers synchronously to ensure they're ready
2065
+ successful_workers = []
2066
+ failed_workers = []
2067
+
2068
+ for i, connection in enumerate(connections):
2069
+ logger.debug(f"[INIT_COMMUNICATOR] Sending to worker {i}")
2070
+ try:
2071
+ connection.send(
2072
+ {
2073
+ "type": "fire_and_forget",
2074
+ "method": "collective_rpc",
2075
+ "kwargs": kwargs,
2076
+ }
2077
+ )
2078
+ successful_workers.append(i)
2079
+ except Exception as e:
2080
+ logger.error(f"[INIT_COMMUNICATOR] Failed to notify worker {i}: {e}")
2081
+ failed_workers.append((i, str(e)))
2082
+
2083
+ if failed_workers:
2084
+ error_msg = f"Failed to notify workers: {failed_workers}"
2085
+ logger.error(f"[INIT_COMMUNICATOR] {error_msg}")
2086
+ return JSONResponse(status_code=500, content={"error": error_msg})
2087
+
2088
+ logger.info(
2089
+ f"[INIT_COMMUNICATOR] Successfully notified {len(successful_workers)} workers"
2090
+ )
2091
+ return {
2092
+ "message": "Request received, initializing communicator",
2093
+ "workers_notified": len(successful_workers),
2094
+ }
2095
+
2096
+ class UpdateWeightsRequest(BaseModel):
2097
+ name: str
2098
+ dtype: str
2099
+ shape: list[int]
2100
+
2101
+ class BatchUpdateWeightsRequest(BaseModel):
2102
+ updates: list[UpdateWeightsRequest]
2103
+
2104
+ @app.post("/update_named_param/")
2105
+ async def update_named_param(request: UpdateWeightsRequest):
2106
+ """
2107
+ Updates the model weights with the provided tensor.
2108
+
2109
+ Once this endpoint is called, the client process should broadcast the updated weights to all server workers.
2110
+
2111
+ Args:
2112
+ request (`UpdateWeightsRequest`):
2113
+ - `name` (`str`): Name of the weight tensor being updated.
2114
+ - `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`).
2115
+ - `shape` (list of `int`): Shape of the weight
2116
+
2117
+ """
2118
+ # Acquire semaphore to limit concurrent weight updates
2119
+ async with weight_update_semaphore:
2120
+ logger.info(
2121
+ f"[UPDATE_PARAM] Received weight update for: {request.name}, dtype={request.dtype}, shape={request.shape}"
2122
+ )
2123
+
2124
+ # Wait for active generations to complete before updating weights
2125
+ # This prevents conflicts between generation and weight loading
2126
+ wait_start = time.time()
2127
+ if active_generation_count > 0:
2128
+ logger.info(
2129
+ f"[UPDATE_PARAM] Waiting for {active_generation_count} active generations to complete before weight update"
2130
+ )
2131
+ while active_generation_count > 0:
2132
+ if time.time() - wait_start > 30.0: # 30 second timeout
2133
+ logger.warning(
2134
+ f"[UPDATE_PARAM] Timeout waiting for {active_generation_count} active generations to complete"
2135
+ )
2136
+ break
2137
+ await asyncio.sleep(0.1)
2138
+ if active_generation_count == 0:
2139
+ logger.debug(
2140
+ f"[UPDATE_PARAM] All generations complete, proceeding with weight update"
2141
+ )
2142
+
2143
+ # CRITICAL: Notify workers IMMEDIATELY so they're ready for NCCL broadcast
2144
+ # This must happen before returning the HTTP response to maintain synchronization with trainer
2145
+ # dtype = torch.__getattribute__(request.dtype.split(".")[-1])
2146
+ kwargs = {
2147
+ "method": "update_named_param",
2148
+ "args": (request.name, request.dtype, tuple(request.shape)),
2149
+ }
2150
+
2151
+ # Send to all workers synchronously to ensure they're ready
2152
+ # Using fire_and_forget since we don't need the result
2153
+ for i, connection in enumerate(connections):
2154
+ logger.debug(f"[UPDATE_PARAM] Notifying worker {i} about weight update")
2155
+ try:
2156
+ connection.send(
2157
+ {
2158
+ "type": "fire_and_forget",
2159
+ "method": "collective_rpc",
2160
+ "kwargs": kwargs,
2161
+ }
2162
+ )
2163
+ except Exception as e:
2164
+ logger.error(f"[UPDATE_PARAM] Failed to notify worker {i}: {e}")
2165
+ return JSONResponse(
2166
+ status_code=500,
2167
+ content={"error": f"Failed to notify worker {i}: {str(e)}"},
2168
+ )
2169
+
2170
+ logger.debug(
2171
+ f"[UPDATE_PARAM] All workers notified, trainer should now broadcast via NCCL"
2172
+ )
2173
+ return {"message": "Weight update processed"}
2174
+
2175
+ @app.post("/batch_update_named_params/")
2176
+ async def batch_update_named_params(request: BatchUpdateWeightsRequest):
2177
+ """
2178
+ Updates multiple model weights in a batch. Processes updates sequentially
2179
+ to ensure proper synchronization with NCCL broadcasts from the client.
2180
+
2181
+ Args:
2182
+ request (`BatchUpdateWeightsRequest`):
2183
+ - `updates` (list of `UpdateWeightsRequest`): List of weight updates to process
2184
+ """
2185
+ logger.info(
2186
+ f"[BATCH_UPDATE] Received batch of {len(request.updates)} weight updates"
2187
+ )
2188
+
2189
+ # Process updates sequentially to maintain NCCL synchronization
2190
+ # The client will broadcast each parameter after we notify workers
2191
+ successful = []
2192
+ errors = []
2193
+
2194
+ for update in request.updates:
2195
+ try:
2196
+ # Acquire semaphore to limit concurrent updates across different batches
2197
+ async with weight_update_semaphore:
2198
+ logger.debug(
2199
+ f"[BATCH_UPDATE] Processing weight update for: {update.name}"
2200
+ )
2201
+
2202
+ # Wait for active generations if needed
2203
+ wait_start = time.time()
2204
+ while active_generation_count > 0:
2205
+ if time.time() - wait_start > 30.0:
2206
+ logger.warning(
2207
+ f"[BATCH_UPDATE] Timeout waiting for generations"
2208
+ )
2209
+ break
2210
+ await asyncio.sleep(0.1)
2211
+
2212
+ # Notify workers synchronously
2213
+ dtype = getattr(torch, update.dtype.split(".")[-1])
2214
+ kwargs = {
2215
+ "method": "update_named_param",
2216
+ "args": (update.name, dtype, tuple(update.shape)),
2217
+ }
2218
+
2219
+ for i, connection in enumerate(connections):
2220
+ try:
2221
+ connection.send(
2222
+ {
2223
+ "type": "fire_and_forget",
2224
+ "method": "collective_rpc",
2225
+ "kwargs": kwargs,
2226
+ }
2227
+ )
2228
+ except Exception as e:
2229
+ logger.error(
2230
+ f"[BATCH_UPDATE] Failed to notify worker {i} for {update.name}: {e}"
2231
+ )
2232
+ raise Exception(f"Failed to notify worker {i}")
2233
+
2234
+ successful.append(update.name)
2235
+ logger.debug(f"[BATCH_UPDATE] Workers notified for {update.name}")
2236
+
2237
+ except Exception as e:
2238
+ errors.append({"param": update.name, "error": str(e)})
2239
+ logger.error(f"[BATCH_UPDATE] Error processing {update.name}: {e}")
2240
+
2241
+ if errors:
2242
+ return JSONResponse(
2243
+ status_code=207, # Multi-Status
2244
+ content={
2245
+ "message": f"Batch update completed with {len(errors)} errors",
2246
+ "successful": successful,
2247
+ "errors": errors,
2248
+ },
2249
+ )
2250
+
2251
+ logger.info(
2252
+ f"[BATCH_UPDATE] Successfully processed {len(successful)} weight updates"
2253
+ )
2254
+ return {
2255
+ "message": f"Successfully updated {len(successful)} parameters",
2256
+ "successful": successful,
2257
+ }
2258
+
2259
+ @app.post("/reset_prefix_cache/")
2260
+ async def reset_prefix_cache():
2261
+ """
2262
+ Resets the prefix cache for the model.
2263
+ """
2264
+ # Send requests and collect results synchronously
2265
+ all_outputs = []
2266
+ for connection in connections:
2267
+ try:
2268
+ connection.send({"type": "call", "method": "reset_prefix_cache"})
2269
+ output = connection.recv()
2270
+ all_outputs.append(output)
2271
+ except Exception as e:
2272
+ logger.error(f"Failed to reset prefix cache on worker: {e}")
2273
+ all_outputs.append(False)
2274
+
2275
+ success = all(output for output in all_outputs)
2276
+ return {
2277
+ "message": "Request received, resetting prefix cache status: "
2278
+ + str(success)
2279
+ }
2280
+
2281
+ @app.post("/close_communicator/")
2282
+ async def close_communicator():
2283
+ """
2284
+ Closes the weight update group and cleans up associated resources.
2285
+ """
2286
+ kwargs = {"method": "close_communicator"}
2287
+
2288
+ # Send to all workers
2289
+ for connection in connections:
2290
+ try:
2291
+ connection.send(
2292
+ {
2293
+ "type": "fire_and_forget",
2294
+ "method": "collective_rpc",
2295
+ "kwargs": kwargs,
2296
+ }
2297
+ )
2298
+ except Exception as e:
2299
+ logger.warning(f"Failed to send close_communicator to worker: {e}")
2300
+ # Don't fail the request if we can't notify a worker during shutdown
2301
+
2302
+ return {"message": "Request received, closing communicator"}
2303
+
2304
+ # Start the server
2305
+ # Always use 1 Uvicorn worker. vLLM handles its own worker processes and scheduling.
2306
+ num_uvicorn_workers = 1
2307
+
2308
+ logger.info(
2309
+ f"Starting Uvicorn with {num_uvicorn_workers} worker(s). Data parallel size: {script_args.data_parallel_size}"
2310
+ )
2311
+ uvicorn.run(
2312
+ app,
2313
+ host=script_args.host,
2314
+ port=script_args.port,
2315
+ log_level=script_args.log_level,
2316
+ workers=num_uvicorn_workers,
2317
+ )
2318
+
2319
+
2320
+ def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
2321
+ if subparsers is not None:
2322
+ parser = subparsers.add_parser(
2323
+ "vllm-serve",
2324
+ help="Run the vLLM serve script",
2325
+ dataclass_types=ScriptArguments,
2326
+ )
2327
+ else:
2328
+ parser = TrlParser(ScriptArguments)
2329
+ return parser
2330
+
2331
+
2332
+ if __name__ == "__main__":
2333
+ parser = make_parser()
2334
+ (script_args,) = parser.parse_args_and_config()
2335
+ main(script_args)