arbor-ai 0.1.14__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2336 @@
1
+ # adapted from Will Brown's verifiers library (https://github.com/willccbb/verifiers)
2
+ """
3
+ OpenAI-compatible vLLM server with weight synchronization.
4
+
5
+ Usage:
6
+
7
+ ```bash
8
+ uv run python vllm_server.py --model <model_name> --port <port>
9
+ ```
10
+
11
+ Supports:
12
+ - /v1/chat/completions
13
+ - /v1/completions
14
+ """
15
+
16
+ import argparse
17
+ import asyncio
18
+ import inspect
19
+ import json
20
+ import logging
21
+ import os
22
+ import threading
23
+ import time
24
+ import traceback
25
+ from collections import defaultdict
26
+ from collections.abc import Sequence
27
+ from contextlib import asynccontextmanager
28
+ from dataclasses import dataclass, field
29
+ from datetime import datetime, timezone
30
+ from multiprocessing import Pipe, Process
31
+ from multiprocessing.connection import Connection as MPConnection
32
+ from typing import Any
33
+ from typing import Any as AnyType
34
+ from typing import Literal, Optional
35
+ from uuid import uuid4
36
+
37
+ import torch
38
+ import uvicorn
39
+ from fastapi import FastAPI
40
+ from fastapi.responses import JSONResponse, StreamingResponse
41
+ from pydantic import BaseModel, ConfigDict
42
+ from trl import TrlParser
43
+ from vllm import LLM, SamplingParams
44
+ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
45
+ from vllm.distributed.parallel_state import get_world_group
46
+ from vllm.distributed.utils import StatelessProcessGroup
47
+ from vllm.sampling_params import GuidedDecodingParams
48
+ from vllm.utils import get_open_port
49
+
50
+ logging.basicConfig(
51
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
52
+ )
53
+ logger = logging.getLogger(__name__) # Ensure logger is defined
54
+
55
+ # We use CUDA with multiprocessing, so we must use the 'spawn' start method. Otherwise, we will get the following
56
+ # error: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use
57
+ # the 'spawn' start method
58
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
59
+
60
+ # At the global level, after imports and logger setup:
61
+ pipe_lock = threading.Lock() # Global lock for pipe operations
62
+ request_queue: Optional[asyncio.Queue] = None
63
+ batch_processor_task: Optional[asyncio.Task] = None
64
+
65
+ # Generation tracking
66
+ active_generation_count = 0
67
+ generation_count_lock = asyncio.Lock()
68
+
69
+ # Weight update throttling
70
+ MAX_CONCURRENT_WEIGHT_UPDATES = 5
71
+ weight_update_semaphore = asyncio.Semaphore(MAX_CONCURRENT_WEIGHT_UPDATES)
72
+
73
+ # Worker rotation for load balancing
74
+ worker_rotation_index = 0
75
+ worker_rotation_lock = asyncio.Lock()
76
+
77
+
78
+ async def get_next_worker_connection(connections: list[AnyType]) -> tuple[int, AnyType]:
79
+ """Get the next worker connection using round-robin rotation."""
80
+ global worker_rotation_index
81
+ async with worker_rotation_lock:
82
+ if not connections:
83
+ raise RuntimeError("No worker connections available")
84
+ worker_idx = worker_rotation_index % len(connections)
85
+ worker_rotation_index += 1
86
+ return worker_idx, connections[worker_idx]
87
+
88
+
89
+ # -------- OpenAI /v1/chat/completions Pydantic Models ---------- #
90
+ class OAChatMessage(BaseModel):
91
+ role: str
92
+ content: str
93
+
94
+
95
+ class OAChatResponseFormat(BaseModel):
96
+ type: Literal["json_schema", "json_object"]
97
+ json_schema: Optional[dict] = None
98
+
99
+ model_config = ConfigDict(frozen=True)
100
+
101
+ def __hash__(self):
102
+ return hash(
103
+ (
104
+ self.type,
105
+ (
106
+ json.dumps(self.json_schema, sort_keys=True)
107
+ if self.json_schema
108
+ else None
109
+ ),
110
+ )
111
+ )
112
+
113
+
114
+ class OAChatCompletionRequest(BaseModel):
115
+ model: str
116
+ messages: list[OAChatMessage]
117
+ temperature: float | None = 0.7
118
+ top_p: float | None = 1.0
119
+ presence_penalty: float | None = 0.0
120
+ frequency_penalty: float | None = 0.0
121
+ max_tokens: int | None = 1024
122
+ n: int | None = 1
123
+ stop: str | list[str] | None = None
124
+ stream: bool = False # not supported
125
+ extra_body: dict | None = None
126
+ response_format: dict | None = None
127
+ # supported by vLLM:
128
+ # guided_decoding, include_stop_str_in_output, skip_special_tokens, spaces_between_special_tokens
129
+
130
+
131
+ class OAChatChoice(BaseModel):
132
+ index: int
133
+ message: OAChatMessage
134
+ finish_reason: str | None = "stop"
135
+
136
+
137
+ class OAChatCompletionResponse(BaseModel):
138
+ id: str
139
+ object: str = "chat.completion"
140
+ created: int
141
+ model: str
142
+ choices: list[OAChatChoice]
143
+
144
+
145
+ # -------- OpenAI /v1/completions Pydantic Models ---------- #
146
+ class OACompletionRequest(BaseModel):
147
+ model: str
148
+ prompt: str | list[str]
149
+ temperature: float | None = 0.7
150
+ top_p: float | None = 1.0
151
+ presence_penalty: float | None = 0.0
152
+ frequency_penalty: float | None = 0.0
153
+ max_tokens: int | None = 1024
154
+ n: int = 1
155
+ stop: str | list[str] | None = None
156
+ stream: bool = False # not supported
157
+ extra_body: dict | None = None
158
+
159
+
160
+ class OACompletionChoice(BaseModel):
161
+ index: int
162
+ text: str
163
+ logprobs: dict | None = None
164
+ finish_reason: str | None = "length"
165
+
166
+
167
+ class OACompletionResponse(BaseModel):
168
+ id: str
169
+ object: str = "completion"
170
+ created: int
171
+ model: str
172
+ choices: list[OACompletionChoice]
173
+
174
+
175
+ # ---------------------------------------------------------------------- #
176
+
177
+
178
+ def send_and_recv(conn: MPConnection, payload: dict):
179
+ """Helper to send a payload and receive a response over a pipe."""
180
+ # Use the global pipe_lock
181
+ with pipe_lock:
182
+ conn.send(payload)
183
+ return conn.recv()
184
+
185
+
186
+ async def async_send_and_recv(conn: MPConnection, payload: dict, timeout: float = 30.0):
187
+ """Async helper to send a payload and receive a response with timeout."""
188
+ loop = asyncio.get_running_loop()
189
+
190
+ # Send the payload in a thread to avoid blocking
191
+ async with asyncio.timeout(timeout):
192
+ try:
193
+ # Use the global pipe_lock in the executor
194
+ await loop.run_in_executor(None, lambda: pipe_lock.acquire())
195
+ try:
196
+ await loop.run_in_executor(None, conn.send, payload)
197
+
198
+ # Poll for response with timeout
199
+ start_time = asyncio.get_event_loop().time()
200
+ while asyncio.get_event_loop().time() - start_time < timeout:
201
+ if await loop.run_in_executor(None, conn.poll, 0.1):
202
+ result = await loop.run_in_executor(None, conn.recv)
203
+ return result
204
+ await asyncio.sleep(0.05) # Small sleep to avoid busy waiting
205
+
206
+ raise asyncio.TimeoutError(
207
+ f"Worker did not respond within {timeout} seconds"
208
+ )
209
+ finally:
210
+ await loop.run_in_executor(None, lambda: pipe_lock.release())
211
+ except asyncio.TimeoutError:
212
+ logger.error(f"Timeout waiting for worker response after {timeout}s")
213
+ raise
214
+ except Exception as e:
215
+ logger.error(f"Error in async_send_and_recv: {e}", exc_info=True)
216
+ raise
217
+
218
+
219
+ class WeightSyncWorkerExtension:
220
+ """
221
+ A vLLM worker extension that enables weight synchronization between a client and multiple server workers.
222
+
223
+ This worker uses a `StatelessProcessGroup` to establish communication and a `PyNcclCommunicator` to handle
224
+ efficient GPU-based communication using NCCL. The primary purpose of this class is to receive updated model weights
225
+ from a client process and distribute them to all worker processes participating in model inference.
226
+ """
227
+
228
+ # The following attributes are initialized when `init_communicator` method is called.
229
+ pynccl_comm = None # Communicator for weight updates
230
+ client_rank = None # Source rank for broadcasting updated weights
231
+
232
+ def init_communicator(self, host: str, port: int, world_size: int) -> None:
233
+ """
234
+ Initializes the weight update communicator using a stateless process group.
235
+
236
+ This method creates a `StatelessProcessGroup` that allows external training processes to
237
+ communicate with vLLM workers without interfering with the global torch distributed group.
238
+
239
+ Args:
240
+ host (`str`):
241
+ Hostname or IP address of the master node.
242
+ port (`int`):
243
+ Port number to be used for communication.
244
+ world_size (`int`):
245
+ Total number of participating processes in the update group.
246
+ """
247
+ if self.pynccl_comm is not None:
248
+ raise RuntimeError(
249
+ "Weight update group already initialized. Call close_communicator first."
250
+ )
251
+
252
+ # Get the rank of the current worker in the global world group.
253
+ rank = get_world_group().rank
254
+
255
+ # Log device information for debugging
256
+ logger.info(f"[WORKER] Initializing communicator: rank={rank}, device={self.device}, world_size={world_size}") # type: ignore
257
+
258
+ # Create a stateless process group to manage communication between training processes and vLLM workers.
259
+ pg = StatelessProcessGroup.create(
260
+ host=host, port=port, rank=rank, world_size=world_size
261
+ )
262
+
263
+ # Initialize the NCCL-based communicator for weight synchronization.
264
+ self.pynccl_comm = PyNcclCommunicator(pg, device=self.device) # type: ignore
265
+
266
+ # The client process that sends updated weights has the highest rank (world_size - 1).
267
+ self.client_rank = world_size - 1
268
+
269
+ def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None:
270
+ """
271
+ Receives updated weights from the client process and updates the named parameter in the model.
272
+
273
+ Args:
274
+ name (`str`):
275
+ Name of the weight tensor being updated.
276
+ dtype (`torch.dtype`):
277
+ Data type of the weight tensor (e.g., `torch.float32`).
278
+ shape (`Sequence[int]`):
279
+ Shape of the weight tensor.
280
+ """
281
+ import logging
282
+
283
+ logger = logging.getLogger(__name__)
284
+ logger.debug(
285
+ f"[WORKER] Received weight update request for {name}, dtype={dtype}, shape={shape}"
286
+ )
287
+
288
+ if self.pynccl_comm is None:
289
+ raise RuntimeError(
290
+ "Communicator not initialized. Call `init_communicator` first."
291
+ )
292
+
293
+ dtype = getattr(torch, dtype.split(".")[-1])
294
+
295
+ # Allocate memory for the incoming weight tensor on the correct device.
296
+ weight = torch.empty(shape, dtype=dtype, device=self.device) # type: ignore
297
+
298
+ logger.debug(f"[WORKER] Starting NCCL broadcast for {name}")
299
+ # Use NCCL to broadcast the updated weights from the client (src) to all workers.
300
+ self.pynccl_comm.broadcast(weight, src=self.client_rank) # type: ignore
301
+ logger.debug(f"[WORKER] NCCL broadcast complete, waiting at barrier for {name}")
302
+ self.pynccl_comm.group.barrier()
303
+ logger.debug(f"[WORKER] Barrier passed, loading weights for {name}")
304
+
305
+ # Load the received weights into the model.
306
+ self.model_runner.model.load_weights(weights=[(name, weight)]) # type: ignore
307
+ logger.debug(f"[WORKER] Weight update complete for {name}")
308
+
309
+ def close_communicator(self) -> None:
310
+ """
311
+ Closes the communicator when weight synchronization is no longer needed.
312
+
313
+ This method deletes the NCCL communicator to release associated resources.
314
+ """
315
+
316
+ if self.pynccl_comm is not None:
317
+ del self.pynccl_comm
318
+ self.pynccl_comm = None # Ensure attribute is reset to None
319
+ self.client_rank = None # Ensure attribute is reset to None
320
+
321
+
322
+ @dataclass
323
+ class ScriptArguments:
324
+ r"""
325
+ Arguments for the script.
326
+
327
+ Args:
328
+ model (`str`):
329
+ Model name or path to load the model from.
330
+ revision (`str` or `None`, *optional*, defaults to `None`):
331
+ Revision to use for the model. If not specified, the default branch will be used.
332
+ tensor_parallel_size (`int`, *optional*, defaults to `1`):
333
+ Number of tensor parallel workers to use.
334
+ data_parallel_size (`int`, *optional*, defaults to `1`):
335
+ Number of data parallel workers to use.
336
+ host (`str`, *optional*, defaults to `"0.0.0.0"`):
337
+ Host address to run the server on.
338
+ port (`int`, *optional*, defaults to `8000`):
339
+ Port to run the server on.
340
+ gpu_memory_utilization (`float`, *optional*, defaults to `0.95`):
341
+ Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
342
+ device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
343
+ improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
344
+ during initialization.
345
+ dtype (`str`, *optional*, defaults to `"auto"`):
346
+ Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
347
+ based on the model configuration. Find the supported values in the vLLM documentation.
348
+ max_model_len (`int` or `None`, *optional*, defaults to `None`):
349
+ If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
350
+ `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
351
+ context size, which might be much larger than the KV cache, leading to inefficiencies.
352
+ enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`):
353
+ Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support
354
+ this feature.
355
+ enforce_eager (`bool` or `None`, *optional*, defaults to `None`):
356
+ Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the
357
+ model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid.
358
+ kv_cache_dtype (`str`, *optional*, defaults to `"auto"`):
359
+ Data type to use for KV cache. If set to `"auto"`, the dtype will default to the model data type.
360
+ log_level (`str`, *optional*, defaults to `"info"`):
361
+ Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`,
362
+ `"trace"`.
363
+ max_batch_size (int):
364
+ Maximum number of requests to process in one LLM call from the active pool.
365
+ batch_request_timeout_seconds (int):
366
+ Timeout in seconds for a single request waiting for its turn and completion.
367
+ token_chunk_size (int):
368
+ Number of tokens to generate per iteration per request in token-chunk dynamic batching.
369
+ """
370
+
371
+ model: str = field(metadata={"help": "Model name or path to load the model from."})
372
+ revision: Optional[str] = field(
373
+ default=None,
374
+ metadata={
375
+ "help": "Revision to use for the model. If not specified, the default branch will be used."
376
+ },
377
+ )
378
+ tensor_parallel_size: int = field(
379
+ default=1,
380
+ metadata={"help": "Number of tensor parallel workers to use."},
381
+ )
382
+ data_parallel_size: int = field(
383
+ default=1,
384
+ metadata={"help": "Number of data parallel workers to use."},
385
+ )
386
+ host: str = field(
387
+ default="0.0.0.0",
388
+ metadata={"help": "Host address to run the server on."},
389
+ )
390
+ port: int = field(
391
+ default=8000,
392
+ metadata={"help": "Port to run the server on."},
393
+ )
394
+ gpu_memory_utilization: float = field(
395
+ default=0.95,
396
+ metadata={
397
+ "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
398
+ "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
399
+ "size and thus improve the model's throughput. However, if the value is too high, it may cause "
400
+ "out-of-memory (OOM) errors during initialization."
401
+ },
402
+ )
403
+ dtype: str = field(
404
+ default="auto",
405
+ metadata={
406
+ "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
407
+ "determined based on the model configuration. Find the supported values in the vLLM documentation."
408
+ },
409
+ )
410
+ max_model_len: Optional[int] = field(
411
+ default=8192,
412
+ metadata={
413
+ "help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
414
+ "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
415
+ "context size, which might be much larger than the KV cache, leading to inefficiencies."
416
+ },
417
+ )
418
+ enable_prefix_caching: Optional[bool] = field(
419
+ default=True,
420
+ metadata={
421
+ "help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
422
+ "hardware support this feature."
423
+ },
424
+ )
425
+ enforce_eager: Optional[bool] = field(
426
+ default=None,
427
+ metadata={
428
+ "help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always "
429
+ "execute the model in eager mode. If `False` (default behavior), we will use CUDA graph and eager "
430
+ "execution in hybrid."
431
+ },
432
+ )
433
+ kv_cache_dtype: str = field(
434
+ default="auto",
435
+ metadata={
436
+ "help": "Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type."
437
+ },
438
+ )
439
+ log_level: str = field(
440
+ default="debug",
441
+ metadata={
442
+ "help": "Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', "
443
+ "'trace'."
444
+ },
445
+ )
446
+ max_batch_size: int = field(
447
+ default=32,
448
+ metadata={
449
+ "help": "Maximum number of requests to process in one LLM call from the active pool."
450
+ },
451
+ )
452
+ batch_request_timeout_seconds: int = field(
453
+ default=300,
454
+ metadata={
455
+ "help": "Timeout in seconds for a single request waiting for its turn and completion."
456
+ },
457
+ )
458
+ token_chunk_size: int = field(
459
+ default=64,
460
+ metadata={
461
+ "help": "Number of tokens to generate per iteration in token-chunk dynamic batching."
462
+ },
463
+ )
464
+
465
+
466
+ # Global/module-level variables for token-chunk dynamic batching
467
+ _SAMPLING_PARAM_NAMES: Optional[frozenset[str]] = None
468
+
469
+
470
+ @dataclass(frozen=True)
471
+ class PoolSignature:
472
+ model_name: str
473
+ request_type: Literal["chat", "completion"]
474
+ # Excludes max_tokens and stream
475
+ sampling_params_tuple: tuple[tuple[str, AnyType], ...]
476
+ extra_body_params_tuple: tuple[tuple[str, AnyType], ...]
477
+
478
+
479
+ @dataclass
480
+ class PooledRequestState:
481
+ original_request: AnyType # OAChatCompletionRequest or OACompletionRequest
482
+ completion_event: asyncio.Event
483
+ result_container: list
484
+ request_id: str
485
+ request_type: Literal["chat", "completion"]
486
+ pool_signature: PoolSignature # Store the signature for quick checks
487
+ effective_max_tokens: int
488
+ accumulated_content: str = ""
489
+ generated_token_count: int = 0
490
+ original_chat_messages: Optional[list[OAChatMessage]] = None
491
+ original_prompt: Optional[str | list[str]] = None
492
+ error: Optional[Exception] = None
493
+ finish_reason: Optional[str] = None
494
+ completed_and_signaled: bool = False
495
+ timed_out: bool = False
496
+
497
+ @property
498
+ def is_complete(self) -> bool:
499
+ """Single source of truth for whether this request is complete and should not be processed further."""
500
+ # Already finalized
501
+ if self.completed_and_signaled:
502
+ return True
503
+
504
+ # Error state
505
+ if self.error is not None:
506
+ return True
507
+
508
+ # Reached token limit
509
+ if self.generated_token_count >= self.effective_max_tokens:
510
+ return True
511
+
512
+ # Not enough room for meaningful generation (less than 1 token)
513
+ tokens_remaining = self.effective_max_tokens - self.generated_token_count
514
+ if tokens_remaining < 1:
515
+ return True
516
+
517
+ # vLLM indicated completion - but ignore "length" as that's just the chunk limit
518
+ if self.finish_reason is not None and self.finish_reason != "length":
519
+ return True
520
+
521
+ return False
522
+
523
+
524
+ pending_requests_by_signature: defaultdict[PoolSignature, list[PooledRequestState]] = (
525
+ defaultdict(list)
526
+ )
527
+ active_pool_signature: Optional[PoolSignature] = None
528
+ active_pool_requests: list[PooledRequestState] = []
529
+
530
+
531
+ def _get_sampling_param_names() -> frozenset[str]:
532
+ global _SAMPLING_PARAM_NAMES
533
+ if _SAMPLING_PARAM_NAMES is None:
534
+ _SAMPLING_PARAM_NAMES = frozenset(
535
+ inspect.signature(SamplingParams).parameters.keys()
536
+ | frozenset({"response_format"})
537
+ )
538
+ return _SAMPLING_PARAM_NAMES
539
+
540
+
541
+ def create_pool_signature(
542
+ model_name: str,
543
+ request_type: Literal["chat", "completion"],
544
+ raw_request_params: dict[
545
+ str, AnyType
546
+ ], # Contains all original request fields like temp, top_p, etc.
547
+ extra_body: Optional[dict[str, AnyType]],
548
+ ) -> PoolSignature:
549
+ valid_sampling_keys = _get_sampling_param_names()
550
+
551
+ sig_sampling_items = []
552
+ key_openai_to_vllm_map = {
553
+ "temperature": "temperature",
554
+ "top_p": "top_p",
555
+ "n": "n",
556
+ "presence_penalty": "presence_penalty",
557
+ "frequency_penalty": "frequency_penalty",
558
+ "stop": "stop",
559
+ "seed": "seed",
560
+ "ignore_eos": "ignore_eos",
561
+ "min_tokens": "min_tokens",
562
+ "response_format": "response_format",
563
+ }
564
+
565
+ # Use defaults from Pydantic models if not provided in request
566
+ param_defaults_for_sig = {
567
+ "temperature": OAChatCompletionRequest.model_fields["temperature"].default,
568
+ "top_p": OAChatCompletionRequest.model_fields["top_p"].default,
569
+ "presence_penalty": OAChatCompletionRequest.model_fields[
570
+ "presence_penalty"
571
+ ].default,
572
+ "frequency_penalty": OAChatCompletionRequest.model_fields[
573
+ "frequency_penalty"
574
+ ].default,
575
+ "n": OAChatCompletionRequest.model_fields["n"].default,
576
+ "stop": OAChatCompletionRequest.model_fields["stop"].default,
577
+ "response_format": OAChatCompletionRequest.model_fields[
578
+ "response_format"
579
+ ].default,
580
+ # stop: None, seed: None, ignore_eos: False, min_tokens: 0
581
+ }
582
+
583
+ for oa_key, vllm_key in key_openai_to_vllm_map.items():
584
+ if vllm_key in valid_sampling_keys:
585
+ value = raw_request_params.get(
586
+ oa_key, param_defaults_for_sig.get(oa_key)
587
+ ) # Use default if not in request
588
+ # For 'stop', ensure it's a tuple if it's a list for hashability
589
+ if vllm_key == "stop" and isinstance(value, list):
590
+ value = tuple(value)
591
+ if vllm_key == "response_format" and isinstance(value, dict):
592
+ value = json.dumps(value, sort_keys=True)
593
+ if (
594
+ value is not None
595
+ ): # Only add if value is meaningfully set or defaulted for signature
596
+ sig_sampling_items.append((vllm_key, value))
597
+
598
+ # Sort for stable signature
599
+ sig_sampling_items.sort(key=lambda item: item[0])
600
+
601
+ filtered_extra_body_items = []
602
+ if extra_body:
603
+ for k, v in sorted(extra_body.items()):
604
+ # Only include extra_body items that are NOT already part of vLLM SamplingParams
605
+ # to avoid them influencing signature if they are just alternative ways to pass standard params.
606
+ # This primarily targets things like 'guided_decoding_regex'.
607
+ if k not in valid_sampling_keys:
608
+ filtered_extra_body_items.append((k, v))
609
+
610
+ return PoolSignature(
611
+ model_name=model_name,
612
+ request_type=request_type,
613
+ sampling_params_tuple=tuple(sig_sampling_items),
614
+ extra_body_params_tuple=tuple(filtered_extra_body_items),
615
+ )
616
+
617
+
618
+ def llm_worker(
619
+ script_args: ScriptArguments,
620
+ data_parallel_rank: int,
621
+ master_port: int,
622
+ connection: MPConnection,
623
+ ) -> None:
624
+ # Set required environment variables for DP to work with vLLM
625
+ os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
626
+ os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
627
+ os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
628
+ os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
629
+
630
+ llm = LLM(
631
+ model=script_args.model,
632
+ revision=script_args.revision,
633
+ tensor_parallel_size=script_args.tensor_parallel_size,
634
+ gpu_memory_utilization=script_args.gpu_memory_utilization,
635
+ enforce_eager=script_args.enforce_eager,
636
+ dtype=script_args.dtype,
637
+ enable_prefix_caching=script_args.enable_prefix_caching,
638
+ max_model_len=script_args.max_model_len,
639
+ worker_extension_cls="arbor.server.services.inference.vllm_serve.WeightSyncWorkerExtension",
640
+ )
641
+
642
+ # Send ready signal to parent process
643
+ connection.send({"status": "ready"})
644
+
645
+ while True:
646
+ # Wait for commands from the parent process
647
+ try:
648
+ command = connection.recv()
649
+ except KeyboardInterrupt:
650
+ llm.collective_rpc(method="close_communicator")
651
+ break
652
+
653
+ # Handle commands
654
+ if command["type"] in ["call", "fire_and_forget"]:
655
+ method_name = command["method"]
656
+ args, kwargs = command.get("args", ()), command.get("kwargs", {})
657
+
658
+ # Add debugging
659
+ logger.debug(
660
+ f"[WORKER {data_parallel_rank}] Received command: {method_name}"
661
+ )
662
+
663
+ try:
664
+ method = getattr(llm, method_name)
665
+ logger.debug(
666
+ f"[WORKER {data_parallel_rank}] Calling {method_name} with kwargs keys: {list(kwargs.keys()) if kwargs else 'none'}"
667
+ )
668
+
669
+ # Call the method
670
+ result = method(*args, **kwargs)
671
+
672
+ logger.debug(
673
+ f"[WORKER {data_parallel_rank}] {method_name} completed, result type: {type(result)}"
674
+ )
675
+
676
+ if command["type"] == "call":
677
+ # Send result back
678
+ logger.debug(f"[WORKER {data_parallel_rank}] Sending result back")
679
+ connection.send(result)
680
+ logger.debug(f"[WORKER {data_parallel_rank}] Result sent")
681
+ except Exception as e:
682
+ logger.error(
683
+ f"[WORKER {data_parallel_rank}] Error in {method_name}: {e}",
684
+ exc_info=True,
685
+ )
686
+ if command["type"] == "call":
687
+ # Send error back as a special result
688
+ connection.send(
689
+ {"error": str(e), "traceback": traceback.format_exc()}
690
+ )
691
+ elif command["type"] == "shutdown":
692
+ logger.info(f"[WORKER {data_parallel_rank}] Received shutdown command")
693
+ break
694
+
695
+
696
+ def chunk_list(lst: list, n: int) -> list[list]:
697
+ """
698
+ Split list `lst` into `n` evenly distributed sublists.
699
+
700
+ Example:
701
+ >>> chunk_list([1, 2, 3, 4, 5, 6], 2)
702
+ [[1, 2, 3], [4, 5, 6]]
703
+ >>> chunk_list([1, 2, 3, 4, 5, 6], 4)
704
+ [[1, 2], [3, 4], [5], [6]]
705
+ >>> chunk_list([1, 2, 3, 4, 5, 6], 8)
706
+ [[1], [2], [3], [4], [5], [6], [], []]
707
+ """
708
+ k, r = divmod(len(lst), n)
709
+ return [lst[i * k + min(i, r) : (i + 1) * k + min(i + 1, r)] for i in range(n)]
710
+
711
+
712
+ async def batch_processing_loop(
713
+ script_args: ScriptArguments,
714
+ connections: list[AnyType],
715
+ queue: asyncio.Queue, # This queue now receives PooledRequestState
716
+ logger_instance: logging.Logger,
717
+ ):
718
+ global pending_requests_by_signature, active_pool_signature, active_pool_requests
719
+ global active_generation_count, generation_count_lock
720
+
721
+ if not connections:
722
+ logger_instance.error(
723
+ "Batch Processor: No LLM workers available. Shutting down loop."
724
+ )
725
+ # We cannot process anything if connections are not there from the start.
726
+ # New requests added to queue will eventually timeout or error when picked by lifespan shutdown.
727
+ return
728
+
729
+ while True:
730
+ try:
731
+ # 1. Ingest new requests (non-blocking or short timeout)
732
+ # This part tries to fill pending_requests_by_signature from the main request_queue
733
+ try:
734
+ while True: # Loop to drain current items in asyncio.Queue
735
+ pooled_req_state: PooledRequestState = queue.get_nowait()
736
+ pending_requests_by_signature[
737
+ pooled_req_state.pool_signature
738
+ ].append(pooled_req_state)
739
+ logger_instance.debug(
740
+ f"[LIFECYCLE] Request {pooled_req_state.request_id} enqueued to pending pool. "
741
+ f"max_tokens={pooled_req_state.effective_max_tokens}, "
742
+ f"signature={pooled_req_state.pool_signature}"
743
+ )
744
+ queue.task_done() # Signal that this item from main queue is taken
745
+ except asyncio.QueueEmpty:
746
+ pass # No new requests in the main queue right now
747
+
748
+ # 2. Activate a new pool if current is empty and pending requests exist
749
+ if not active_pool_requests and pending_requests_by_signature:
750
+ # Simple strategy: pick the first signature that has pending requests
751
+ # More sophisticated strategies (e.g. largest batch) could be implemented here
752
+ # items() gives a view, convert to list to pop
753
+ available_signatures = list(pending_requests_by_signature.keys())
754
+ if available_signatures:
755
+ active_pool_signature = available_signatures[0] # Pick one
756
+ active_pool_requests = pending_requests_by_signature.pop(
757
+ active_pool_signature
758
+ )
759
+ logger_instance.info(
760
+ f"[LIFECYCLE] Activated new pool with signature: {active_pool_signature}. Size: {len(active_pool_requests)}"
761
+ )
762
+ for req in active_pool_requests:
763
+ logger_instance.debug(
764
+ f"[LIFECYCLE] Request {req.request_id} moved to active pool"
765
+ )
766
+
767
+ # 3. Merge new matching requests into the active pool
768
+ if (
769
+ active_pool_signature
770
+ and active_pool_signature in pending_requests_by_signature
771
+ ):
772
+ newly_matching_requests = pending_requests_by_signature.pop(
773
+ active_pool_signature
774
+ )
775
+ active_pool_requests.extend(newly_matching_requests)
776
+ logger_instance.info(
777
+ f"Merged {len(newly_matching_requests)} requests into active pool. New size: {len(active_pool_requests)}"
778
+ )
779
+
780
+ # 4. Process active pool chunk
781
+ if active_pool_requests:
782
+ # Take a sub-batch from the active pool
783
+ # If active_pool_requests is not empty, active_pool_signature must be set.
784
+ assert (
785
+ active_pool_signature is not None
786
+ ), "active_pool_signature cannot be None if active_pool_requests is populated"
787
+
788
+ # Filter out already-completed requests before selecting sub-batch
789
+ available_requests = [
790
+ req for req in active_pool_requests if not req.is_complete
791
+ ]
792
+
793
+ # Log the state of requests in the pool
794
+ logger_instance.debug(
795
+ f"Active pool has {len(active_pool_requests)} total requests, {len(available_requests)} available for processing"
796
+ )
797
+ for req in active_pool_requests:
798
+ logger_instance.debug(
799
+ f" Request {req.request_id}: tokens={req.generated_token_count}/{req.effective_max_tokens}, "
800
+ f"is_complete={req.is_complete}, finish_reason={req.finish_reason}"
801
+ )
802
+
803
+ if not available_requests:
804
+ # All requests in active pool are already completed, clear the pool
805
+ logger_instance.info(
806
+ f"All requests in active pool {active_pool_signature} are already completed. Clearing pool."
807
+ )
808
+ active_pool_requests.clear()
809
+ active_pool_signature = None
810
+ continue
811
+
812
+ sub_batch_to_process: list[PooledRequestState] = []
813
+ sub_batch_size = min(
814
+ len(available_requests), script_args.max_batch_size
815
+ )
816
+ sub_batch_to_process = available_requests[:sub_batch_size]
817
+
818
+ logger_instance.debug(
819
+ f"[BATCH_PROCESSOR] Processing sub-batch of {len(sub_batch_to_process)} for sig: {active_pool_signature}"
820
+ )
821
+
822
+ # Track active generation
823
+ async with generation_count_lock:
824
+ active_generation_count += 1
825
+ logger_instance.debug(
826
+ f"[BATCH_PROCESSOR] Active generation count increased to {active_generation_count}"
827
+ )
828
+
829
+ try:
830
+ # Prepare inputs for LLM
831
+ # All requests in sub_batch_to_process share active_pool_signature
832
+ # So, sampling params (except max_tokens) are the same.
833
+
834
+ # Construct SamplingParams from the active_pool_signature
835
+ # The signature stores param tuples. Convert back to dict for SamplingParams.
836
+ sig_sampling_dict = dict(
837
+ active_pool_signature.sampling_params_tuple
838
+ )
839
+ sig_extra_body_dict = dict(
840
+ active_pool_signature.extra_body_params_tuple
841
+ )
842
+
843
+ # Override 'n' to 1 for chunked generation as per design.
844
+ # Log if original 'n' was different.
845
+ original_n = sig_sampling_dict.get("n", 1)
846
+ if original_n != 1:
847
+ logger_instance.warning(
848
+ f"Pool {active_pool_signature}: Original 'n={original_n}' overridden to n=1 for chunked generation."
849
+ )
850
+
851
+ # Calculate the minimum tokens remaining across all requests in the batch
852
+ min_tokens_remaining = min(
853
+ req.effective_max_tokens - req.generated_token_count
854
+ for req in sub_batch_to_process
855
+ )
856
+
857
+ # Log token calculations
858
+ logger_instance.debug(
859
+ f"[CHUNK_CALC] Calculating chunk size for {len(sub_batch_to_process)} requests:"
860
+ )
861
+ for req in sub_batch_to_process:
862
+ tokens_left = (
863
+ req.effective_max_tokens - req.generated_token_count
864
+ )
865
+ logger_instance.debug(
866
+ f"[CHUNK_CALC] Request {req.request_id}: {req.generated_token_count}/{req.effective_max_tokens} tokens, {tokens_left} remaining"
867
+ )
868
+
869
+ # Limit chunk size to available room
870
+ chunk_size = min(script_args.token_chunk_size, min_tokens_remaining)
871
+ logger_instance.debug(
872
+ f"[CHUNK_CALC] Final chunk size: {chunk_size} (configured: {script_args.token_chunk_size}, min_remaining: {min_tokens_remaining})"
873
+ )
874
+
875
+ # CRITICAL: Ensure chunk_size is at least 1 to avoid vLLM issues
876
+ if chunk_size <= 0:
877
+ logger_instance.error(
878
+ f"Invalid chunk size {chunk_size} calculated. Min remaining: {min_tokens_remaining}"
879
+ )
880
+ # Mark all requests as complete if we can't generate any more tokens
881
+ for req_state in sub_batch_to_process:
882
+ if not req_state.is_complete:
883
+ req_state.finish_reason = "length"
884
+ logger_instance.info(
885
+ f"Request {req_state.request_id} marked complete due to no room for generation"
886
+ )
887
+ continue # Skip this iteration
888
+
889
+ logger_instance.debug(
890
+ f"Chunk size for batch: {chunk_size} (min remaining: {min_tokens_remaining}, configured: {script_args.token_chunk_size})"
891
+ )
892
+
893
+ # Create a new dict for **kwargs, excluding 'n' as it's set explicitly
894
+ kwargs_for_sampling_params = {
895
+ k: v for k, v in sig_sampling_dict.items() if k != "n"
896
+ }
897
+
898
+ guided_decoding = None
899
+ if "guided_decoding_regex" in sig_extra_body_dict:
900
+ guided_decoding = GuidedDecodingParams(
901
+ backend="outlines",
902
+ regex=sig_extra_body_dict["guided_decoding_regex"],
903
+ )
904
+ elif "response_format" in kwargs_for_sampling_params:
905
+ response_format_json = json.loads(
906
+ kwargs_for_sampling_params["response_format"]
907
+ )
908
+ response_format_type = response_format_json["type"]
909
+ logger_instance.info(
910
+ f"Response format param: {response_format_type}"
911
+ )
912
+ if response_format_type == "json_schema":
913
+ json_schema = response_format_json["json_schema"]
914
+ guided_decoding = GuidedDecodingParams(json=json_schema)
915
+ logger_instance.info(
916
+ f"Going with json_schema {json_schema}"
917
+ )
918
+ elif response_format_type == "json_object":
919
+ logger_instance.info(f"Going with json_object")
920
+ guided_decoding = GuidedDecodingParams(json_object=True)
921
+ else:
922
+ logger_instance.info(
923
+ f"Response format param provided that isn't supported: {response_format_type}"
924
+ )
925
+ # remove the response_format key because we can't pass it to sampling params
926
+ del kwargs_for_sampling_params["response_format"]
927
+ else:
928
+ # guided_decoding should be none if we don't want to use any guided decoding
929
+ pass
930
+
931
+ vllm_sampling_params = SamplingParams(
932
+ **kwargs_for_sampling_params,
933
+ n=1, # Generate one sequence continuation per request in the chunk
934
+ max_tokens=chunk_size, # Use calculated chunk size
935
+ # Ensure guided_decoding is correctly set up if present in extra_body
936
+ guided_decoding=guided_decoding,
937
+ # Remove any params from extra_body that might also be in SamplingParams if they were not filtered by create_pool_signature
938
+ **{
939
+ k: v
940
+ for k, v in sig_extra_body_dict.items()
941
+ if k in _get_sampling_param_names()
942
+ and k != "guided_decoding_regex"
943
+ },
944
+ )
945
+
946
+ # --- Bucket chat requests by first chunk vs continuing ---
947
+ first_chunk_inputs = []
948
+ first_chunk_states = []
949
+ continue_chunk_states = []
950
+ prompts_for_vllm = []
951
+ is_chat_pool = active_pool_signature.request_type == "chat"
952
+
953
+ for req_state in sub_batch_to_process:
954
+ if is_chat_pool:
955
+ current_messages = []
956
+ if req_state.original_chat_messages:
957
+ current_messages.extend(
958
+ [
959
+ m.model_dump()
960
+ for m in req_state.original_chat_messages
961
+ ]
962
+ )
963
+
964
+ # For continuing generation, we need to ensure there's an assistant message to continue
965
+ if req_state.generated_token_count == 0:
966
+ # First chunk - ensure we have a valid message sequence ending with user
967
+ if not current_messages:
968
+ logger_instance.error(
969
+ f"Request {req_state.request_id} has no messages"
970
+ )
971
+ req_state.error = ValueError("No messages provided")
972
+ continue
973
+ if current_messages[-1]["role"] != "user":
974
+ logger_instance.error(
975
+ f"Request {req_state.request_id} last message is not from user for first chunk"
976
+ )
977
+ req_state.error = ValueError(
978
+ "Last message must be from user for first chunk"
979
+ )
980
+ continue
981
+ first_chunk_inputs.append(current_messages)
982
+ first_chunk_states.append(req_state)
983
+ else:
984
+ # Continuing chunk - add accumulated content as assistant message
985
+ if req_state.accumulated_content:
986
+ current_messages.append(
987
+ {
988
+ "role": "assistant",
989
+ "content": req_state.accumulated_content,
990
+ }
991
+ )
992
+ else:
993
+ # This should not happen - we should have content if we're continuing
994
+ logger_instance.error(
995
+ f"Request {req_state.request_id} has no accumulated content for continuation"
996
+ )
997
+ req_state.error = ValueError(
998
+ "No content to continue"
999
+ )
1000
+ continue
1001
+ continue_chunk_states.append(req_state)
1002
+ else:
1003
+ if isinstance(req_state.original_prompt, str):
1004
+ current_prompt = req_state.original_prompt
1005
+ elif isinstance(req_state.original_prompt, list):
1006
+ current_prompt = (
1007
+ req_state.original_prompt[0]
1008
+ if req_state.original_prompt
1009
+ else ""
1010
+ )
1011
+ else:
1012
+ current_prompt = str(req_state.original_prompt or "")
1013
+ prompts_for_vllm.append(
1014
+ current_prompt + req_state.accumulated_content
1015
+ )
1016
+
1017
+ # Only process first-chunk chat requests in this tick, then continuing if no first-chunk left
1018
+ llm_results = []
1019
+ processed_states = []
1020
+ if is_chat_pool:
1021
+ loop = asyncio.get_running_loop()
1022
+ if first_chunk_inputs:
1023
+ # Filter out any already-completed requests from first_chunk_states
1024
+ active_first_states = []
1025
+ active_first_inputs = []
1026
+ for i, req_state in enumerate(first_chunk_states):
1027
+ if req_state.is_complete:
1028
+ logger_instance.debug(
1029
+ f"Skipping already-completed request {req_state.request_id} in first chunk processing"
1030
+ )
1031
+ continue
1032
+ active_first_states.append(req_state)
1033
+ active_first_inputs.append(first_chunk_inputs[i])
1034
+
1035
+ if not active_first_states:
1036
+ logger_instance.debug(
1037
+ "All first chunk requests are already completed, skipping LLM call"
1038
+ )
1039
+ processed_states = []
1040
+ llm_results = []
1041
+ else:
1042
+ flags = dict(
1043
+ add_generation_prompt=True,
1044
+ continue_final_message=False,
1045
+ )
1046
+ payload = {
1047
+ "type": "call",
1048
+ "method": "chat",
1049
+ "kwargs": {
1050
+ "messages": active_first_inputs,
1051
+ "sampling_params": vllm_sampling_params,
1052
+ **flags,
1053
+ },
1054
+ }
1055
+ logger_instance.debug(
1056
+ f"Sending first-chunk chat request to LLM with {len(active_first_inputs)} messages"
1057
+ )
1058
+
1059
+ worker_idx = -1 # Initialize to avoid unbound variable
1060
+ try:
1061
+ worker_idx, worker_conn = (
1062
+ await get_next_worker_connection(connections)
1063
+ )
1064
+ logger_instance.debug(
1065
+ f"Using worker {worker_idx} for first-chunk chat request"
1066
+ )
1067
+ llm_results = await async_send_and_recv(
1068
+ worker_conn, payload, timeout=60.0
1069
+ )
1070
+ logger_instance.debug(
1071
+ f"Received {len(llm_results)} results from LLM for first-chunk chat"
1072
+ )
1073
+ except asyncio.TimeoutError:
1074
+ logger_instance.error(
1075
+ f"Worker {worker_idx} timeout for first-chunk chat after 60s"
1076
+ )
1077
+ for req_state in active_first_states:
1078
+ req_state.error = TimeoutError(
1079
+ "Worker timeout during generation"
1080
+ )
1081
+ llm_results = []
1082
+ except Exception as e:
1083
+ logger_instance.error(
1084
+ f"Error calling LLM for first-chunk chat: {e}",
1085
+ exc_info=True,
1086
+ )
1087
+ for req_state in active_first_states:
1088
+ req_state.error = e
1089
+ llm_results = []
1090
+ processed_states = active_first_states
1091
+ elif continue_chunk_states:
1092
+ # No first-chunk requests, process continuing requests
1093
+ continue_chunk_inputs = []
1094
+ # Filter out any already-completed requests
1095
+ active_continue_states = []
1096
+ for req_state in continue_chunk_states:
1097
+ if req_state.is_complete:
1098
+ logger_instance.debug(
1099
+ f"Skipping already-completed request {req_state.request_id} in continue chunk processing"
1100
+ )
1101
+ continue
1102
+ active_continue_states.append(req_state)
1103
+ current_messages = []
1104
+ if req_state.original_chat_messages:
1105
+ current_messages.extend(
1106
+ [
1107
+ m.model_dump()
1108
+ for m in req_state.original_chat_messages
1109
+ ]
1110
+ )
1111
+
1112
+ # Must have accumulated content to continue
1113
+ if not req_state.accumulated_content:
1114
+ logger_instance.error(
1115
+ f"Request {req_state.request_id} has no accumulated content for continuation"
1116
+ )
1117
+ req_state.error = ValueError(
1118
+ "No content to continue generation"
1119
+ )
1120
+ active_continue_states.remove(
1121
+ req_state
1122
+ ) # Remove from active list
1123
+ continue
1124
+
1125
+ # Add the accumulated content as the assistant message to continue
1126
+ current_messages.append(
1127
+ {
1128
+ "role": "assistant",
1129
+ "content": req_state.accumulated_content,
1130
+ }
1131
+ )
1132
+ continue_chunk_inputs.append(current_messages)
1133
+
1134
+ if not active_continue_states:
1135
+ logger_instance.debug(
1136
+ "All continue chunk requests are already completed, skipping LLM call"
1137
+ )
1138
+ processed_states = []
1139
+ llm_results = []
1140
+ else:
1141
+ flags = dict(
1142
+ add_generation_prompt=False,
1143
+ continue_final_message=True,
1144
+ )
1145
+ payload = {
1146
+ "type": "call",
1147
+ "method": "chat",
1148
+ "kwargs": {
1149
+ "messages": continue_chunk_inputs,
1150
+ "sampling_params": vllm_sampling_params,
1151
+ **flags,
1152
+ },
1153
+ }
1154
+ logger_instance.debug(
1155
+ f"Sending continue-chunk chat request to LLM with {len(continue_chunk_inputs)} messages"
1156
+ )
1157
+
1158
+ worker_idx = -1 # Initialize to avoid unbound variable
1159
+ try:
1160
+ worker_idx, worker_conn = (
1161
+ await get_next_worker_connection(connections)
1162
+ )
1163
+ logger_instance.debug(
1164
+ f"Using worker {worker_idx} for continue-chunk chat request"
1165
+ )
1166
+ llm_results = await async_send_and_recv(
1167
+ worker_conn, payload, timeout=60.0
1168
+ )
1169
+ logger_instance.debug(
1170
+ f"Received {len(llm_results)} results from LLM for continue-chunk chat"
1171
+ )
1172
+ except asyncio.TimeoutError:
1173
+ logger_instance.error(
1174
+ f"Worker {worker_idx} timeout for continue-chunk chat after 60s"
1175
+ )
1176
+ for req_state in active_continue_states:
1177
+ req_state.error = TimeoutError(
1178
+ "Worker timeout during generation"
1179
+ )
1180
+ llm_results = []
1181
+ except Exception as e:
1182
+ logger_instance.error(
1183
+ f"Error calling LLM for continue-chunk chat: {e}",
1184
+ exc_info=True,
1185
+ )
1186
+ for req_state in active_continue_states:
1187
+ req_state.error = e
1188
+ llm_results = []
1189
+ processed_states = active_continue_states
1190
+ else:
1191
+ # No requests to process in this iteration
1192
+ logger_instance.debug(
1193
+ "No chat requests to process in this iteration"
1194
+ )
1195
+ processed_states = []
1196
+ llm_results = []
1197
+ else:
1198
+ # completion – unchanged
1199
+ loop = asyncio.get_running_loop()
1200
+ payload = {
1201
+ "type": "call",
1202
+ "method": "generate",
1203
+ "kwargs": {
1204
+ "prompts": prompts_for_vllm,
1205
+ "sampling_params": vllm_sampling_params,
1206
+ },
1207
+ }
1208
+ logger_instance.debug(
1209
+ f"Sending completion request to LLM with {len(prompts_for_vllm)} prompts"
1210
+ )
1211
+ worker_idx = -1 # Initialize to avoid unbound variable
1212
+ try:
1213
+ worker_idx, worker_conn = await get_next_worker_connection(
1214
+ connections
1215
+ )
1216
+ logger_instance.debug(
1217
+ f"Using worker {worker_idx} for completion request"
1218
+ )
1219
+ llm_results = await async_send_and_recv(
1220
+ worker_conn, payload, timeout=60.0
1221
+ )
1222
+ logger_instance.debug(
1223
+ f"Received {len(llm_results)} results from LLM for completion"
1224
+ )
1225
+ except asyncio.TimeoutError:
1226
+ logger_instance.error(
1227
+ f"Worker {worker_idx} timeout for completion after 60s"
1228
+ )
1229
+ for req_state in sub_batch_to_process:
1230
+ req_state.error = TimeoutError(
1231
+ "Worker timeout during generation"
1232
+ )
1233
+ llm_results = []
1234
+ except Exception as e:
1235
+ logger_instance.error(
1236
+ f"Error calling LLM for completion: {e}", exc_info=True
1237
+ )
1238
+ for req_state in sub_batch_to_process:
1239
+ req_state.error = e
1240
+ llm_results = []
1241
+ processed_states = sub_batch_to_process
1242
+
1243
+ # Now, update state for each request in the processed_states
1244
+ temp_failed_requests_in_sub_batch: list[PooledRequestState] = []
1245
+
1246
+ if is_chat_pool:
1247
+ if processed_states and (
1248
+ len(llm_results) != len(processed_states)
1249
+ ):
1250
+ logger_instance.error(
1251
+ f"LLM result count mismatch. Expected {len(processed_states)}, got {len(llm_results)} for sig {active_pool_signature}. Marking affected requests in sub-batch as error."
1252
+ )
1253
+ for req_state in processed_states:
1254
+ if not req_state.completed_and_signaled:
1255
+ req_state.error = RuntimeError(
1256
+ "LLM result mismatch in batch processing."
1257
+ )
1258
+ req_state.finish_reason = "error"
1259
+ temp_failed_requests_in_sub_batch.append(req_state)
1260
+ else:
1261
+ real_idx = 0
1262
+ for req_state in processed_states:
1263
+ if req_state.completed_and_signaled:
1264
+ continue
1265
+ request_output = llm_results[real_idx]
1266
+ real_idx += 1
1267
+ if (
1268
+ not request_output.outputs
1269
+ or len(request_output.outputs) == 0
1270
+ ):
1271
+ logger_instance.warning(
1272
+ f"Request {req_state.request_id} (idx {real_idx-1}) received no output from vLLM in chunk."
1273
+ )
1274
+ # This might happen if vLLM can't generate any tokens (e.g., due to constraints)
1275
+ # Mark as complete rather than error
1276
+ req_state.finish_reason = (
1277
+ "stop" # vLLM couldn't generate
1278
+ )
1279
+ logger_instance.info(
1280
+ f"Request {req_state.request_id} marked complete due to empty vLLM output"
1281
+ )
1282
+ continue
1283
+ completion_output = request_output.outputs[0]
1284
+ new_text_chunk = completion_output.text
1285
+ req_state.accumulated_content += new_text_chunk
1286
+ new_token_count = len(completion_output.token_ids)
1287
+ req_state.generated_token_count += new_token_count
1288
+
1289
+ # Store vLLM's finish reason but we'll interpret it carefully
1290
+ vllm_finish_reason = completion_output.finish_reason
1291
+ logger_instance.debug(
1292
+ f"[VLLM_RESPONSE] Request {req_state.request_id}: vLLM returned {new_token_count} tokens, finish_reason={vllm_finish_reason}"
1293
+ )
1294
+
1295
+ # Only update our finish_reason if it's meaningful
1296
+ if vllm_finish_reason == "length":
1297
+ # vLLM hit the chunk limit - only set our finish_reason if we're at our actual limit
1298
+ if (
1299
+ req_state.generated_token_count
1300
+ >= req_state.effective_max_tokens
1301
+ ):
1302
+ req_state.finish_reason = "length"
1303
+ logger_instance.debug(
1304
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='length' - hit actual limit"
1305
+ )
1306
+ else:
1307
+ # Don't set finish_reason - we can continue generating
1308
+ logger_instance.debug(
1309
+ f"[FINISH_REASON] Request {req_state.request_id}: Ignoring vLLM's finish_reason='length' - only at chunk limit"
1310
+ )
1311
+ elif vllm_finish_reason is not None:
1312
+ # Other finish reasons (stop, eos_token, etc.) are real completions
1313
+ req_state.finish_reason = vllm_finish_reason
1314
+ logger_instance.debug(
1315
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='{vllm_finish_reason}' from vLLM"
1316
+ )
1317
+
1318
+ # Log detailed state for debugging
1319
+ logger_instance.debug(
1320
+ f"Request {req_state.request_id} chunk result: "
1321
+ f"new_tokens={new_token_count}, total_tokens={req_state.generated_token_count}, "
1322
+ f"finish_reason={req_state.finish_reason}, chunk_text_len={len(new_text_chunk)}"
1323
+ )
1324
+
1325
+ # Check if generation has stopped
1326
+ if new_token_count < chunk_size:
1327
+ # Incomplete chunk indicates generation should stop
1328
+ logger_instance.info(
1329
+ f"Request {req_state.request_id} generated incomplete chunk. "
1330
+ f"Generated {new_token_count}/{chunk_size} tokens in chunk. "
1331
+ f"Marking as complete to prevent doom loop."
1332
+ )
1333
+ # Set finish reason if not already set by vLLM
1334
+ if req_state.finish_reason is None:
1335
+ req_state.finish_reason = (
1336
+ "stop" # Generation naturally stopped
1337
+ )
1338
+ logger_instance.debug(
1339
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='stop' due to incomplete chunk"
1340
+ )
1341
+
1342
+ # Log current state
1343
+ logger_instance.debug(
1344
+ f"Request {req_state.request_id} chunk processed. Tokens in chunk: {new_token_count}, total: {req_state.generated_token_count}, is_complete: {req_state.is_complete}"
1345
+ )
1346
+ else:
1347
+ # completion – unchanged
1348
+ if len(llm_results) != len(processed_states):
1349
+ logger_instance.error(
1350
+ f"LLM result count mismatch. Expected {len(processed_states)}, got {len(llm_results)} for sig {active_pool_signature}. Marking affected requests in sub-batch as error."
1351
+ )
1352
+ for req_state in processed_states:
1353
+ if not req_state.completed_and_signaled:
1354
+ req_state.error = RuntimeError(
1355
+ "LLM result mismatch in batch processing."
1356
+ )
1357
+ req_state.finish_reason = "error"
1358
+ temp_failed_requests_in_sub_batch.append(req_state)
1359
+ else:
1360
+ real_idx = 0
1361
+ for req_state in processed_states:
1362
+ if req_state.completed_and_signaled:
1363
+ continue
1364
+ request_output = llm_results[real_idx]
1365
+ real_idx += 1
1366
+ if (
1367
+ not request_output.outputs
1368
+ or len(request_output.outputs) == 0
1369
+ ):
1370
+ logger_instance.warning(
1371
+ f"Request {req_state.request_id} (idx {real_idx-1}) received no output from vLLM in chunk."
1372
+ )
1373
+ # This might happen if vLLM can't generate any tokens (e.g., due to constraints)
1374
+ # Mark as complete rather than error
1375
+ req_state.finish_reason = (
1376
+ "stop" # vLLM couldn't generate
1377
+ )
1378
+ logger_instance.info(
1379
+ f"Request {req_state.request_id} marked complete due to empty vLLM output"
1380
+ )
1381
+ continue
1382
+ completion_output = request_output.outputs[0]
1383
+ new_text_chunk = completion_output.text
1384
+ req_state.accumulated_content += new_text_chunk
1385
+ new_token_count = len(completion_output.token_ids)
1386
+ req_state.generated_token_count += new_token_count
1387
+
1388
+ # Store vLLM's finish reason but we'll interpret it carefully
1389
+ vllm_finish_reason = completion_output.finish_reason
1390
+ logger_instance.debug(
1391
+ f"[VLLM_RESPONSE] Request {req_state.request_id}: vLLM returned {new_token_count} tokens, finish_reason={vllm_finish_reason}"
1392
+ )
1393
+
1394
+ # Only update our finish_reason if it's meaningful
1395
+ if vllm_finish_reason == "length":
1396
+ # vLLM hit the chunk limit - only set our finish_reason if we're at our actual limit
1397
+ if (
1398
+ req_state.generated_token_count
1399
+ >= req_state.effective_max_tokens
1400
+ ):
1401
+ req_state.finish_reason = "length"
1402
+ logger_instance.debug(
1403
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='length' - hit actual limit"
1404
+ )
1405
+ else:
1406
+ # Don't set finish_reason - we can continue generating
1407
+ logger_instance.debug(
1408
+ f"[FINISH_REASON] Request {req_state.request_id}: Ignoring vLLM's finish_reason='length' - only at chunk limit"
1409
+ )
1410
+ elif vllm_finish_reason is not None:
1411
+ # Other finish reasons (stop, eos_token, etc.) are real completions
1412
+ req_state.finish_reason = vllm_finish_reason
1413
+ logger_instance.debug(
1414
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='{vllm_finish_reason}' from vLLM"
1415
+ )
1416
+
1417
+ # Log detailed state for debugging
1418
+ logger_instance.debug(
1419
+ f"Request {req_state.request_id} chunk result: "
1420
+ f"new_tokens={new_token_count}, total_tokens={req_state.generated_token_count}, "
1421
+ f"finish_reason={req_state.finish_reason}, chunk_text_len={len(new_text_chunk)}"
1422
+ )
1423
+
1424
+ # Check if generation has stopped
1425
+ if new_token_count < chunk_size:
1426
+ # Incomplete chunk indicates generation should stop
1427
+ logger_instance.info(
1428
+ f"Request {req_state.request_id} generated incomplete chunk. "
1429
+ f"Generated {new_token_count}/{chunk_size} tokens in chunk. "
1430
+ f"Marking as complete to prevent doom loop."
1431
+ )
1432
+ # Set finish reason if not already set by vLLM
1433
+ if req_state.finish_reason is None:
1434
+ req_state.finish_reason = (
1435
+ "stop" # Generation naturally stopped
1436
+ )
1437
+ logger_instance.debug(
1438
+ f"[FINISH_REASON] Request {req_state.request_id}: Setting finish_reason='stop' due to incomplete chunk"
1439
+ )
1440
+
1441
+ # Log current state
1442
+ logger_instance.debug(
1443
+ f"Request {req_state.request_id} chunk processed. Tokens in chunk: {new_token_count}, total: {req_state.generated_token_count}, is_complete: {req_state.is_complete}"
1444
+ )
1445
+
1446
+ # Now, handle all finished or errored requests from this sub_batch
1447
+ # These need to be removed from active_pool_requests and have their events set.
1448
+ completed_in_sub_batch: list[PooledRequestState] = []
1449
+ remaining_in_sub_batch: list[PooledRequestState] = []
1450
+
1451
+ # Iterate over sub_batch_to_process to decide their fate
1452
+ updated_active_pool = []
1453
+
1454
+ # Create a set of request_ids from sub_batch_to_process for quick lookup
1455
+ sub_batch_ids = {r.request_id for r in sub_batch_to_process}
1456
+
1457
+ for req_state in active_pool_requests: # Iterate main active pool
1458
+ if req_state.request_id not in sub_batch_ids:
1459
+ updated_active_pool.append(
1460
+ req_state
1461
+ ) # Keep if not in current sub-batch
1462
+ continue
1463
+
1464
+ # req_state is from the current sub_batch. Check its status.
1465
+ logger_instance.debug(
1466
+ f"[COMPLETION_CHECK] Checking request {req_state.request_id}: "
1467
+ f"generated={req_state.generated_token_count}/{req_state.effective_max_tokens}, "
1468
+ f"finish_reason={req_state.finish_reason}, is_complete={req_state.is_complete}"
1469
+ )
1470
+
1471
+ if (
1472
+ req_state.is_complete
1473
+ and not req_state.completed_and_signaled
1474
+ ):
1475
+ # Request is complete but not yet signaled
1476
+ completed_in_sub_batch.append(req_state)
1477
+ logger_instance.info(
1478
+ f"[LIFECYCLE] Request {req_state.request_id} is complete and will be finalized. "
1479
+ f"Generated {req_state.generated_token_count} tokens, finish_reason={req_state.finish_reason}"
1480
+ )
1481
+ elif not req_state.is_complete:
1482
+ # Request is not complete, keep for next chunk
1483
+ updated_active_pool.append(req_state)
1484
+ logger_instance.debug(
1485
+ f"[LIFECYCLE] Request {req_state.request_id} is not complete, keeping for next chunk"
1486
+ )
1487
+ # If already signaled (completed_and_signaled is True), don't re-add or re-signal
1488
+
1489
+ active_pool_requests = updated_active_pool
1490
+ if not active_pool_requests:
1491
+ # Store signature before setting to None for logging
1492
+ deactivated_signature = active_pool_signature
1493
+ active_pool_signature = None # Deactivate pool if empty
1494
+ logger_instance.info(
1495
+ f"Deactivated pool {deactivated_signature} as it is now empty."
1496
+ )
1497
+
1498
+ for req_state in completed_in_sub_batch:
1499
+ if req_state.completed_and_signaled:
1500
+ continue # Already handled
1501
+
1502
+ response_content: (
1503
+ OAChatCompletionResponse
1504
+ | OACompletionResponse
1505
+ | JSONResponse
1506
+ )
1507
+ if req_state.error:
1508
+ logger_instance.error(
1509
+ f"Request {req_state.request_id} failed with error: {req_state.error}"
1510
+ )
1511
+ response_content = JSONResponse(
1512
+ status_code=500,
1513
+ content={
1514
+ "error": f"Processing error: {str(req_state.error)}",
1515
+ "request_id": req_state.request_id,
1516
+ },
1517
+ )
1518
+ elif req_state.request_type == "chat":
1519
+ final_choices = [
1520
+ OAChatChoice(
1521
+ index=0,
1522
+ message=OAChatMessage(
1523
+ role="assistant",
1524
+ content=req_state.accumulated_content,
1525
+ ),
1526
+ finish_reason=req_state.finish_reason,
1527
+ )
1528
+ ]
1529
+ response_content = OAChatCompletionResponse(
1530
+ id=f"chatcmpl-{uuid4().hex}", # Use original request_id if available? For now, new UUID.
1531
+ created=int(datetime.now(tz=timezone.utc).timestamp()),
1532
+ model=req_state.original_request.model,
1533
+ choices=final_choices,
1534
+ )
1535
+ else: # Completion
1536
+ final_choices = [
1537
+ OACompletionChoice(
1538
+ index=0,
1539
+ text=req_state.accumulated_content,
1540
+ finish_reason=req_state.finish_reason,
1541
+ )
1542
+ ]
1543
+ response_content = OACompletionResponse(
1544
+ id=f"cmpl-{uuid4().hex}",
1545
+ created=int(datetime.now(tz=timezone.utc).timestamp()),
1546
+ model=req_state.original_request.model,
1547
+ choices=final_choices,
1548
+ )
1549
+
1550
+ req_state.result_container[0] = response_content
1551
+ req_state.completion_event.set()
1552
+ req_state.completed_and_signaled = True
1553
+
1554
+ finally:
1555
+ # Always decrement active generation count
1556
+ async with generation_count_lock:
1557
+ active_generation_count -= 1
1558
+ logger_instance.debug(
1559
+ f"[BATCH_PROCESSOR] Active generation count decreased to {active_generation_count}"
1560
+ )
1561
+
1562
+ else: # No active pool
1563
+ await asyncio.sleep(
1564
+ 0.01
1565
+ ) # Small sleep if no active pool and pending queue was empty
1566
+
1567
+ except asyncio.CancelledError:
1568
+ logger_instance.info("Batch processing loop cancelled.")
1569
+ all_requests_to_cancel = list(active_pool_requests)
1570
+ active_pool_requests.clear()
1571
+ active_pool_signature = None
1572
+ for sig_list in pending_requests_by_signature.values():
1573
+ all_requests_to_cancel.extend(sig_list)
1574
+ pending_requests_by_signature.clear()
1575
+
1576
+ for req_state in all_requests_to_cancel:
1577
+ if not req_state.completed_and_signaled:
1578
+ req_state.result_container[0] = JSONResponse(
1579
+ status_code=503,
1580
+ content={"error": "Server shutting down, request cancelled."},
1581
+ )
1582
+ req_state.completion_event.set()
1583
+ req_state.completed_and_signaled = True
1584
+ break
1585
+ except Exception as e:
1586
+ logger_instance.error(
1587
+ f"Critical error in batch processing loop: {e}", exc_info=True
1588
+ )
1589
+ all_requests_to_fail = list(active_pool_requests)
1590
+ active_pool_requests.clear()
1591
+ active_pool_signature = None
1592
+ for sig_list in pending_requests_by_signature.values():
1593
+ all_requests_to_fail.extend(sig_list)
1594
+ pending_requests_by_signature.clear()
1595
+
1596
+ for req_state in all_requests_to_fail:
1597
+ if not req_state.completed_and_signaled:
1598
+ req_state.result_container[0] = JSONResponse(
1599
+ status_code=500,
1600
+ content={"error": f"Critical batch processor error: {str(e)}"},
1601
+ )
1602
+ req_state.completion_event.set()
1603
+ req_state.completed_and_signaled = True
1604
+ await asyncio.sleep(1) # Pause before retrying loop
1605
+
1606
+
1607
+ def main(script_args: ScriptArguments):
1608
+ global request_queue, batch_processor_task # Allow lifespan to assign to these
1609
+
1610
+ # Spawn dp workers, and setup pipes for communication
1611
+ master_port = get_open_port()
1612
+ connections: list[AnyType] = (
1613
+ []
1614
+ ) # Use Any type to avoid PipeConnection vs Connection mismatch
1615
+ processes = []
1616
+ for data_parallel_rank in range(script_args.data_parallel_size):
1617
+ # Use duplex=True to allow bidirectional communication
1618
+ # This is needed for "call" type commands that expect responses
1619
+ parent_connection, child_connection = Pipe(duplex=True)
1620
+ process = Process(
1621
+ target=llm_worker,
1622
+ args=(script_args, data_parallel_rank, master_port, child_connection),
1623
+ )
1624
+ process.start()
1625
+ connections.append(parent_connection)
1626
+ processes.append(process)
1627
+
1628
+ @asynccontextmanager
1629
+ async def lifespan(app: FastAPI):
1630
+ nonlocal processes # Capture from outer scope
1631
+ global request_queue, batch_processor_task # Defined at module level
1632
+
1633
+ logger.info(
1634
+ f"Lifespan: Waiting for {script_args.data_parallel_size} LLM worker(s) to be ready..."
1635
+ )
1636
+ ready_connections = set()
1637
+
1638
+ # Timeout for waiting for workers to get ready (e.g., 5 minutes)
1639
+ timeout_seconds = 300
1640
+ start_wait_time = time.time()
1641
+
1642
+ while len(ready_connections) < script_args.data_parallel_size:
1643
+ if time.time() - start_wait_time > timeout_seconds:
1644
+ logger.error(
1645
+ f"Lifespan: Timeout waiting for all LLM workers. Expected {script_args.data_parallel_size}, got {len(ready_connections)} ready."
1646
+ )
1647
+ raise RuntimeError("LLM workers failed to initialize in time")
1648
+
1649
+ for i, connection in enumerate(connections):
1650
+ if connection not in ready_connections:
1651
+ # Use poll() with a short timeout to avoid blocking indefinitely if a worker is stuck
1652
+ if connection.poll(
1653
+ 0.1
1654
+ ): # Check if data is available, with a 0.1s timeout
1655
+ try:
1656
+ msg = connection.recv()
1657
+ logger.info(
1658
+ f"Lifespan: Received message from worker {i}: {msg}"
1659
+ )
1660
+ if isinstance(msg, dict) and msg.get("status") == "ready":
1661
+ logger.info(f"Lifespan: LLM worker {i} reported ready.")
1662
+ ready_connections.add(connection)
1663
+ else:
1664
+ logger.warning(
1665
+ f"Lifespan: Received unexpected message from worker {i}: {msg}"
1666
+ )
1667
+ except Exception as e:
1668
+ logger.error(
1669
+ f"Lifespan: Error receiving message from worker {i}: {e}"
1670
+ )
1671
+
1672
+ if len(ready_connections) < script_args.data_parallel_size:
1673
+ time.sleep(
1674
+ 0.5
1675
+ ) # Brief sleep to avoid busy-waiting if not all workers are ready yet
1676
+
1677
+ if len(ready_connections) == script_args.data_parallel_size:
1678
+ logger.info(
1679
+ f"Lifespan: All {script_args.data_parallel_size} LLM worker(s) are ready. Proceeding to yield."
1680
+ )
1681
+ # Initialize request queue and start batch processor task
1682
+ request_queue = asyncio.Queue()
1683
+ logger.info(
1684
+ "Lifespan: Initialized request queue for batched chat completions."
1685
+ )
1686
+ batch_processor_task = asyncio.create_task(
1687
+ batch_processing_loop(script_args, connections, request_queue, logger)
1688
+ )
1689
+ logger.info("Lifespan: Started batch processing task for chat completions.")
1690
+ else:
1691
+ logger.error(
1692
+ f"Lifespan: Not all LLM workers became ready. Expected {script_args.data_parallel_size}, got {len(ready_connections)}. Uvicorn might not function correctly. Batch processor NOT started."
1693
+ )
1694
+
1695
+ yield
1696
+ logger.info("Lifespan: Uvicorn server is shutting down. Cleaning up resources.")
1697
+
1698
+ if batch_processor_task is not None and not batch_processor_task.done():
1699
+ logger.info("Lifespan: Cancelling batch processor task...")
1700
+ batch_processor_task.cancel()
1701
+ try:
1702
+ await batch_processor_task
1703
+ logger.info("Lifespan: Batch processor task finished.")
1704
+ except asyncio.CancelledError:
1705
+ logger.info("Lifespan: Batch processor task was cancelled as expected.")
1706
+ except Exception as e:
1707
+ logger.error(
1708
+ f"Lifespan: Exception during batch processor task shutdown: {e}",
1709
+ exc_info=True,
1710
+ )
1711
+
1712
+ # Wait for processes to terminate
1713
+ for process in processes:
1714
+ process.join(timeout=10) # Wait for 10 seconds for the process to terminate
1715
+ if process.is_alive():
1716
+ logger.warning(
1717
+ f"Process {process} is still alive after 10 seconds, attempting to terminate..."
1718
+ )
1719
+ process.terminate()
1720
+ process.join() # ensure process termination after calling terminate()
1721
+
1722
+ app = FastAPI(lifespan=lifespan)
1723
+
1724
+ # Define the endpoints for the model server
1725
+ @app.get("/health/")
1726
+ async def health():
1727
+ """
1728
+ Health check endpoint to verify that the server is running.
1729
+ """
1730
+ return {"status": "ok"}
1731
+
1732
+ @app.get("/get_world_size/")
1733
+ async def get_world_size():
1734
+ """
1735
+ Retrieves the world size of the LLM engine, which is `tensor_parallel_size * data_parallel_size`.
1736
+
1737
+ Returns:
1738
+ `dict`:
1739
+ A dictionary containing the world size.
1740
+
1741
+ Example response:
1742
+ ```json
1743
+ {"world_size": 8}
1744
+ ```
1745
+ """
1746
+ return {
1747
+ "world_size": script_args.tensor_parallel_size
1748
+ * script_args.data_parallel_size
1749
+ }
1750
+
1751
+ # -------- OpenAI-chat ---------- #
1752
+ @app.post("/v1/chat/completions", response_model=OAChatCompletionResponse)
1753
+ async def openai_chat(req: OAChatCompletionRequest):
1754
+ global request_queue
1755
+
1756
+ if request_queue is None:
1757
+ logger.error("/v1/chat/completions: Request queue not initialized.")
1758
+ return JSONResponse(
1759
+ status_code=503,
1760
+ content={
1761
+ "error": "Server not ready, batch processing queue not initialized."
1762
+ },
1763
+ )
1764
+
1765
+ request_id = f"chatcmpl-{uuid4().hex}"
1766
+ logger.debug(f"Received chat request {request_id}, model: {req.model}")
1767
+
1768
+ # Create signature for pooling
1769
+ # The OAChatCompletionRequest fields are: model, messages, temperature, top_p, max_tokens, stream, extra_body
1770
+ # We need to pass the relevant ones to create_pool_signature
1771
+ raw_params_for_sig = {
1772
+ "temperature": req.temperature,
1773
+ "top_p": req.top_p,
1774
+ "presence_penalty": req.presence_penalty,
1775
+ "frequency_penalty": req.frequency_penalty,
1776
+ # "n" is not in OAChatCompletionRequest, defaults to 1 for chat in OpenAI spec
1777
+ "n": 1,
1778
+ "response_format": req.response_format,
1779
+ }
1780
+ # Add other SamplingParams-mappable fields if they were part of OAChatCompletionRequest
1781
+ # For example, if we added 'stop', 'presence_penalty' etc. to OAChatCompletionRequest.
1782
+ # For now, the above are the main ones.
1783
+
1784
+ default_max = OAChatCompletionRequest.model_fields["max_tokens"].default
1785
+ effective_max_tokens = req.max_tokens or default_max
1786
+
1787
+ pool_sig = create_pool_signature(
1788
+ model_name=req.model,
1789
+ request_type="chat",
1790
+ raw_request_params=raw_params_for_sig,
1791
+ extra_body=req.extra_body,
1792
+ )
1793
+
1794
+ completion_event = asyncio.Event()
1795
+ result_container = [None]
1796
+
1797
+ pooled_state = PooledRequestState(
1798
+ original_request=req,
1799
+ completion_event=completion_event,
1800
+ result_container=result_container,
1801
+ request_id=request_id,
1802
+ request_type="chat",
1803
+ pool_signature=pool_sig,
1804
+ effective_max_tokens=effective_max_tokens,
1805
+ original_chat_messages=req.messages, # Store original messages
1806
+ )
1807
+
1808
+ logger.info(
1809
+ f"[LIFECYCLE] Created chat request {request_id}: max_tokens={effective_max_tokens}, "
1810
+ f"messages={len(req.messages)}, model={req.model}"
1811
+ )
1812
+
1813
+ try:
1814
+ await request_queue.put(pooled_state)
1815
+ logger.debug(f"[LIFECYCLE] Request {request_id} successfully queued")
1816
+ except Exception as e:
1817
+ logger.error(f"Enqueueing error for {request_id}: {e}", exc_info=True)
1818
+ return JSONResponse(
1819
+ status_code=500,
1820
+ content={"error": "Internal server error while queueing request."},
1821
+ )
1822
+
1823
+ try:
1824
+ await asyncio.wait_for(
1825
+ completion_event.wait(),
1826
+ timeout=script_args.batch_request_timeout_seconds,
1827
+ )
1828
+ except asyncio.TimeoutError:
1829
+ logger.error(f"Timeout for chat request {request_id} (model {req.model}).")
1830
+ pooled_state.timed_out = True
1831
+ pooled_state.completed_and_signaled = True
1832
+ pooled_state.completion_event.set()
1833
+ return JSONResponse(
1834
+ status_code=504, content={"error": "Request timed out."}
1835
+ )
1836
+ except Exception as e:
1837
+ logger.error(
1838
+ f"Error waiting for completion event for {request_id}: {e}",
1839
+ exc_info=True,
1840
+ )
1841
+ return JSONResponse(
1842
+ status_code=500,
1843
+ content={
1844
+ "error": "Internal server error while waiting for completion."
1845
+ },
1846
+ )
1847
+
1848
+ response_data = result_container[0]
1849
+ if (
1850
+ response_data is None
1851
+ ): # Should ideally be set to an error by processor if timeout internally
1852
+ logger.error(
1853
+ f"No result for {request_id} (model {req.model}) after event set. Internal error."
1854
+ )
1855
+ return JSONResponse(
1856
+ status_code=500,
1857
+ content={"error": "Internal error: No result from processor."},
1858
+ )
1859
+
1860
+ if isinstance(response_data, JSONResponse):
1861
+ return response_data
1862
+
1863
+ if isinstance(response_data, OAChatCompletionResponse):
1864
+ # Must return JSONResponse for FastAPI
1865
+ return JSONResponse(response_data.model_dump())
1866
+ else:
1867
+ logger.error(
1868
+ f"Unexpected result type {type(response_data)} for {request_id}."
1869
+ )
1870
+ return JSONResponse(
1871
+ status_code=500,
1872
+ content={"error": "Internal error: Unexpected result format."},
1873
+ )
1874
+
1875
+ @app.get("/v1/models")
1876
+ async def list_models():
1877
+ return {
1878
+ "data": [{"id": script_args.model, "object": "model", "owned_by": "vllm"}]
1879
+ }
1880
+
1881
+ @app.post("/v1/completions", response_model=OACompletionResponse)
1882
+ async def openai_completions(req: OACompletionRequest):
1883
+ global request_queue
1884
+
1885
+ if request_queue is None:
1886
+ logger.error("/v1/completions: Request queue not initialized.")
1887
+ return JSONResponse(
1888
+ status_code=503,
1889
+ content={
1890
+ "error": "Server not ready, batch processing queue not initialized."
1891
+ },
1892
+ )
1893
+
1894
+ request_id = f"cmpl-{uuid4().hex}"
1895
+ logger.debug(f"Received completion request {request_id}, model: {req.model}")
1896
+
1897
+ # OACompletionRequest fields: model, prompt, temperature, top_p, max_tokens, n, extra_body
1898
+ raw_params_for_sig = {
1899
+ "temperature": req.temperature,
1900
+ "top_p": req.top_p,
1901
+ "presence_penalty": req.presence_penalty,
1902
+ "frequency_penalty": req.frequency_penalty,
1903
+ "n": req.n, # Pass 'n' from the request
1904
+ }
1905
+ # Add other SamplingParams-mappable fields from OACompletionRequest if they exist
1906
+ # e.g., req.stop, req.presence_penalty etc. if we add them to OACompletionRequest model
1907
+ # For now, the above are the main ones.
1908
+ # We need to get ALL fields of OACompletionRequest that are also valid for SamplingParams
1909
+ # This is safer:
1910
+ valid_sp_keys = _get_sampling_param_names()
1911
+ for field_name, field_value in req.model_dump().items():
1912
+ if field_name in valid_sp_keys and field_name not in raw_params_for_sig:
1913
+ raw_params_for_sig[field_name] = field_value
1914
+
1915
+ default_max = OACompletionRequest.model_fields["max_tokens"].default
1916
+ effective_max_tokens = req.max_tokens or default_max
1917
+
1918
+ pool_sig = create_pool_signature(
1919
+ model_name=req.model,
1920
+ request_type="completion",
1921
+ raw_request_params=raw_params_for_sig,
1922
+ extra_body=req.extra_body,
1923
+ )
1924
+
1925
+ completion_event = asyncio.Event()
1926
+ result_container = [None]
1927
+
1928
+ # Check for list prompts for completion, which is problematic for current chunking.
1929
+ # vLLM's generate can take list of prompts, but our chunking logic (appending to prompt) assumes single string.
1930
+ if isinstance(req.prompt, list):
1931
+ if len(req.prompt) > 1:
1932
+ logger.warning(
1933
+ f"Request {request_id} for completion has a list of prompts. Only the first prompt will be used for chunked generation."
1934
+ )
1935
+ current_prompt = req.prompt[0] if req.prompt else ""
1936
+ elif not req.prompt: # empty list (simplified condition)
1937
+ current_prompt = ""
1938
+ else: # list with one element
1939
+ current_prompt = req.prompt[0]
1940
+ else: # string
1941
+ current_prompt = req.prompt
1942
+
1943
+ pooled_state = PooledRequestState(
1944
+ original_request=req,
1945
+ completion_event=completion_event,
1946
+ result_container=result_container,
1947
+ request_id=request_id,
1948
+ request_type="completion",
1949
+ pool_signature=pool_sig,
1950
+ effective_max_tokens=effective_max_tokens,
1951
+ original_prompt=current_prompt, # Store single prompt for chunking
1952
+ )
1953
+
1954
+ try:
1955
+ await request_queue.put(pooled_state)
1956
+ except Exception as e:
1957
+ logger.error(
1958
+ f"Enqueueing error for completion {request_id}: {e}", exc_info=True
1959
+ )
1960
+ return JSONResponse(
1961
+ status_code=500,
1962
+ content={"error": "Internal server error while queueing request."},
1963
+ )
1964
+
1965
+ try:
1966
+ await asyncio.wait_for(
1967
+ completion_event.wait(),
1968
+ timeout=script_args.batch_request_timeout_seconds,
1969
+ )
1970
+ except asyncio.TimeoutError:
1971
+ logger.error(
1972
+ f"Timeout for completion request {request_id} (model {req.model})."
1973
+ )
1974
+ pooled_state.timed_out = True
1975
+ pooled_state.completed_and_signaled = True
1976
+ pooled_state.completion_event.set()
1977
+ return JSONResponse(
1978
+ status_code=504, content={"error": "Request timed out."}
1979
+ )
1980
+ except Exception as e:
1981
+ logger.error(
1982
+ f"Error waiting for completion event for {request_id}: {e}",
1983
+ exc_info=True,
1984
+ )
1985
+ return JSONResponse(
1986
+ status_code=500,
1987
+ content={
1988
+ "error": "Internal server error while waiting for completion."
1989
+ },
1990
+ )
1991
+
1992
+ response_data = result_container[0]
1993
+ if response_data is None:
1994
+ logger.error(
1995
+ f"No result for completion {request_id} (model {req.model}) after event set. Internal error."
1996
+ )
1997
+ return JSONResponse(
1998
+ status_code=500,
1999
+ content={"error": "Internal error: No result from processor."},
2000
+ )
2001
+
2002
+ if isinstance(response_data, JSONResponse):
2003
+ return response_data
2004
+
2005
+ if isinstance(response_data, OACompletionResponse):
2006
+ return JSONResponse(
2007
+ response_data.model_dump()
2008
+ ) # Must return JSONResponse for FastAPI
2009
+ else:
2010
+ logger.error(
2011
+ f"Unexpected result type {type(response_data)} for completion {request_id}."
2012
+ )
2013
+ return JSONResponse(
2014
+ status_code=500,
2015
+ content={"error": "Internal error: Unexpected result format."},
2016
+ )
2017
+
2018
+ class InitCommunicatorRequest(BaseModel):
2019
+ host: str
2020
+ port: int
2021
+ world_size: int
2022
+
2023
+ @app.post("/init_communicator/")
2024
+ async def init_communicator(request: InitCommunicatorRequest):
2025
+ """
2026
+ Initializes the communicator for synchronizing model weights between a client and multiple server
2027
+ workers.
2028
+
2029
+ Args:
2030
+ request (`InitCommunicatorRequest`):
2031
+ - `host` (`str`): Hostname or IP address of the master node.
2032
+ - `port` (`int`): Port number to be used for communication.
2033
+ - `world_size` (`int`): Total number of participating processes in the group.
2034
+ """
2035
+ logger.info(
2036
+ f"[INIT_COMMUNICATOR] Received request: host={request.host}, port={request.port}, world_size={request.world_size}"
2037
+ )
2038
+
2039
+ # Calculate actual world size based on vLLM configuration
2040
+ vllm_world_size = (
2041
+ script_args.tensor_parallel_size * script_args.data_parallel_size
2042
+ )
2043
+ expected_world_size = vllm_world_size + 1 # +1 for the client
2044
+
2045
+ logger.info(
2046
+ f"[INIT_COMMUNICATOR] vLLM world size: {vllm_world_size} (TP={script_args.tensor_parallel_size} x DP={script_args.data_parallel_size})"
2047
+ )
2048
+ logger.info(
2049
+ f"[INIT_COMMUNICATOR] Expected total world size: {expected_world_size}"
2050
+ )
2051
+
2052
+ if request.world_size != expected_world_size:
2053
+ logger.warning(
2054
+ f"[INIT_COMMUNICATOR] World size mismatch! Request: {request.world_size}, Expected: {expected_world_size}"
2055
+ )
2056
+
2057
+ # The function init_communicator is called this way: init_communicator(host, port, world_size)
2058
+ # So with collective_rpc we need to call it this way:
2059
+ # llm.collective_rpc(method="init_communicator", args=(host, port, world_size))
2060
+ kwargs = {
2061
+ "method": "init_communicator",
2062
+ "args": (request.host, request.port, expected_world_size),
2063
+ }
2064
+
2065
+ # Send to all workers synchronously to ensure they're ready
2066
+ successful_workers = []
2067
+ failed_workers = []
2068
+
2069
+ for i, connection in enumerate(connections):
2070
+ logger.debug(f"[INIT_COMMUNICATOR] Sending to worker {i}")
2071
+ try:
2072
+ connection.send(
2073
+ {
2074
+ "type": "fire_and_forget",
2075
+ "method": "collective_rpc",
2076
+ "kwargs": kwargs,
2077
+ }
2078
+ )
2079
+ successful_workers.append(i)
2080
+ except Exception as e:
2081
+ logger.error(f"[INIT_COMMUNICATOR] Failed to notify worker {i}: {e}")
2082
+ failed_workers.append((i, str(e)))
2083
+
2084
+ if failed_workers:
2085
+ error_msg = f"Failed to notify workers: {failed_workers}"
2086
+ logger.error(f"[INIT_COMMUNICATOR] {error_msg}")
2087
+ return JSONResponse(status_code=500, content={"error": error_msg})
2088
+
2089
+ logger.info(
2090
+ f"[INIT_COMMUNICATOR] Successfully notified {len(successful_workers)} workers"
2091
+ )
2092
+ return {
2093
+ "message": "Request received, initializing communicator",
2094
+ "workers_notified": len(successful_workers),
2095
+ }
2096
+
2097
+ class UpdateWeightsRequest(BaseModel):
2098
+ name: str
2099
+ dtype: str
2100
+ shape: list[int]
2101
+
2102
+ class BatchUpdateWeightsRequest(BaseModel):
2103
+ updates: list[UpdateWeightsRequest]
2104
+
2105
+ @app.post("/update_named_param/")
2106
+ async def update_named_param(request: UpdateWeightsRequest):
2107
+ """
2108
+ Updates the model weights with the provided tensor.
2109
+
2110
+ Once this endpoint is called, the client process should broadcast the updated weights to all server workers.
2111
+
2112
+ Args:
2113
+ request (`UpdateWeightsRequest`):
2114
+ - `name` (`str`): Name of the weight tensor being updated.
2115
+ - `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`).
2116
+ - `shape` (list of `int`): Shape of the weight
2117
+
2118
+ """
2119
+ # Acquire semaphore to limit concurrent weight updates
2120
+ async with weight_update_semaphore:
2121
+ logger.info(
2122
+ f"[UPDATE_PARAM] Received weight update for: {request.name}, dtype={request.dtype}, shape={request.shape}"
2123
+ )
2124
+
2125
+ # Wait for active generations to complete before updating weights
2126
+ # This prevents conflicts between generation and weight loading
2127
+ wait_start = time.time()
2128
+ if active_generation_count > 0:
2129
+ logger.info(
2130
+ f"[UPDATE_PARAM] Waiting for {active_generation_count} active generations to complete before weight update"
2131
+ )
2132
+ while active_generation_count > 0:
2133
+ if time.time() - wait_start > 30.0: # 30 second timeout
2134
+ logger.warning(
2135
+ f"[UPDATE_PARAM] Timeout waiting for {active_generation_count} active generations to complete"
2136
+ )
2137
+ break
2138
+ await asyncio.sleep(0.1)
2139
+ if active_generation_count == 0:
2140
+ logger.debug(
2141
+ f"[UPDATE_PARAM] All generations complete, proceeding with weight update"
2142
+ )
2143
+
2144
+ # CRITICAL: Notify workers IMMEDIATELY so they're ready for NCCL broadcast
2145
+ # This must happen before returning the HTTP response to maintain synchronization with trainer
2146
+ # dtype = torch.__getattribute__(request.dtype.split(".")[-1])
2147
+ kwargs = {
2148
+ "method": "update_named_param",
2149
+ "args": (request.name, request.dtype, tuple(request.shape)),
2150
+ }
2151
+
2152
+ # Send to all workers synchronously to ensure they're ready
2153
+ # Using fire_and_forget since we don't need the result
2154
+ for i, connection in enumerate(connections):
2155
+ logger.debug(f"[UPDATE_PARAM] Notifying worker {i} about weight update")
2156
+ try:
2157
+ connection.send(
2158
+ {
2159
+ "type": "fire_and_forget",
2160
+ "method": "collective_rpc",
2161
+ "kwargs": kwargs,
2162
+ }
2163
+ )
2164
+ except Exception as e:
2165
+ logger.error(f"[UPDATE_PARAM] Failed to notify worker {i}: {e}")
2166
+ return JSONResponse(
2167
+ status_code=500,
2168
+ content={"error": f"Failed to notify worker {i}: {str(e)}"},
2169
+ )
2170
+
2171
+ logger.debug(
2172
+ f"[UPDATE_PARAM] All workers notified, trainer should now broadcast via NCCL"
2173
+ )
2174
+ return {"message": "Weight update processed"}
2175
+
2176
+ @app.post("/batch_update_named_params/")
2177
+ async def batch_update_named_params(request: BatchUpdateWeightsRequest):
2178
+ """
2179
+ Updates multiple model weights in a batch. Processes updates sequentially
2180
+ to ensure proper synchronization with NCCL broadcasts from the client.
2181
+
2182
+ Args:
2183
+ request (`BatchUpdateWeightsRequest`):
2184
+ - `updates` (list of `UpdateWeightsRequest`): List of weight updates to process
2185
+ """
2186
+ logger.info(
2187
+ f"[BATCH_UPDATE] Received batch of {len(request.updates)} weight updates"
2188
+ )
2189
+
2190
+ # Process updates sequentially to maintain NCCL synchronization
2191
+ # The client will broadcast each parameter after we notify workers
2192
+ successful = []
2193
+ errors = []
2194
+
2195
+ for update in request.updates:
2196
+ try:
2197
+ # Acquire semaphore to limit concurrent updates across different batches
2198
+ async with weight_update_semaphore:
2199
+ logger.debug(
2200
+ f"[BATCH_UPDATE] Processing weight update for: {update.name}"
2201
+ )
2202
+
2203
+ # Wait for active generations if needed
2204
+ wait_start = time.time()
2205
+ while active_generation_count > 0:
2206
+ if time.time() - wait_start > 30.0:
2207
+ logger.warning(
2208
+ f"[BATCH_UPDATE] Timeout waiting for generations"
2209
+ )
2210
+ break
2211
+ await asyncio.sleep(0.1)
2212
+
2213
+ # Notify workers synchronously
2214
+ dtype = getattr(torch, update.dtype.split(".")[-1])
2215
+ kwargs = {
2216
+ "method": "update_named_param",
2217
+ "args": (update.name, dtype, tuple(update.shape)),
2218
+ }
2219
+
2220
+ for i, connection in enumerate(connections):
2221
+ try:
2222
+ connection.send(
2223
+ {
2224
+ "type": "fire_and_forget",
2225
+ "method": "collective_rpc",
2226
+ "kwargs": kwargs,
2227
+ }
2228
+ )
2229
+ except Exception as e:
2230
+ logger.error(
2231
+ f"[BATCH_UPDATE] Failed to notify worker {i} for {update.name}: {e}"
2232
+ )
2233
+ raise Exception(f"Failed to notify worker {i}")
2234
+
2235
+ successful.append(update.name)
2236
+ logger.debug(f"[BATCH_UPDATE] Workers notified for {update.name}")
2237
+
2238
+ except Exception as e:
2239
+ errors.append({"param": update.name, "error": str(e)})
2240
+ logger.error(f"[BATCH_UPDATE] Error processing {update.name}: {e}")
2241
+
2242
+ if errors:
2243
+ return JSONResponse(
2244
+ status_code=207, # Multi-Status
2245
+ content={
2246
+ "message": f"Batch update completed with {len(errors)} errors",
2247
+ "successful": successful,
2248
+ "errors": errors,
2249
+ },
2250
+ )
2251
+
2252
+ logger.info(
2253
+ f"[BATCH_UPDATE] Successfully processed {len(successful)} weight updates"
2254
+ )
2255
+ return {
2256
+ "message": f"Successfully updated {len(successful)} parameters",
2257
+ "successful": successful,
2258
+ }
2259
+
2260
+ @app.post("/reset_prefix_cache/")
2261
+ async def reset_prefix_cache():
2262
+ """
2263
+ Resets the prefix cache for the model.
2264
+ """
2265
+ # Send requests and collect results synchronously
2266
+ all_outputs = []
2267
+ for connection in connections:
2268
+ try:
2269
+ connection.send({"type": "call", "method": "reset_prefix_cache"})
2270
+ output = connection.recv()
2271
+ all_outputs.append(output)
2272
+ except Exception as e:
2273
+ logger.error(f"Failed to reset prefix cache on worker: {e}")
2274
+ all_outputs.append(False)
2275
+
2276
+ success = all(output for output in all_outputs)
2277
+ return {
2278
+ "message": "Request received, resetting prefix cache status: "
2279
+ + str(success)
2280
+ }
2281
+
2282
+ @app.post("/close_communicator/")
2283
+ async def close_communicator():
2284
+ """
2285
+ Closes the weight update group and cleans up associated resources.
2286
+ """
2287
+ kwargs = {"method": "close_communicator"}
2288
+
2289
+ # Send to all workers
2290
+ for connection in connections:
2291
+ try:
2292
+ connection.send(
2293
+ {
2294
+ "type": "fire_and_forget",
2295
+ "method": "collective_rpc",
2296
+ "kwargs": kwargs,
2297
+ }
2298
+ )
2299
+ except Exception as e:
2300
+ logger.warning(f"Failed to send close_communicator to worker: {e}")
2301
+ # Don't fail the request if we can't notify a worker during shutdown
2302
+
2303
+ return {"message": "Request received, closing communicator"}
2304
+
2305
+ # Start the server
2306
+ # Always use 1 Uvicorn worker. vLLM handles its own worker processes and scheduling.
2307
+ num_uvicorn_workers = 1
2308
+
2309
+ logger.info(
2310
+ f"Starting Uvicorn with {num_uvicorn_workers} worker(s). Data parallel size: {script_args.data_parallel_size}"
2311
+ )
2312
+ uvicorn.run(
2313
+ app,
2314
+ host=script_args.host,
2315
+ port=script_args.port,
2316
+ log_level=script_args.log_level,
2317
+ workers=num_uvicorn_workers,
2318
+ )
2319
+
2320
+
2321
+ def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
2322
+ if subparsers is not None:
2323
+ parser = subparsers.add_parser(
2324
+ "vllm-serve",
2325
+ help="Run the vLLM serve script",
2326
+ dataclass_types=ScriptArguments,
2327
+ )
2328
+ else:
2329
+ parser = TrlParser(ScriptArguments)
2330
+ return parser
2331
+
2332
+
2333
+ if __name__ == "__main__":
2334
+ parser = make_parser()
2335
+ (script_args,) = parser.parse_args_and_config()
2336
+ main(script_args)