vllm-router 0.1.10__cp39-cp39-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vllm_router/__init__.py +9 -0
- vllm_router/launch_router.py +109 -0
- vllm_router/mini_lb.py +395 -0
- vllm_router/router.py +148 -0
- vllm_router/router_args.py +592 -0
- vllm_router/version.py +1 -0
- vllm_router-0.1.10.dist-info/METADATA +284 -0
- vllm_router-0.1.10.dist-info/RECORD +14 -0
- vllm_router-0.1.10.dist-info/WHEEL +5 -0
- vllm_router-0.1.10.dist-info/entry_points.txt +2 -0
- vllm_router-0.1.10.dist-info/top_level.txt +1 -0
- vllm_router.libs/libcrypto-3eda328c.so.1.1.1k +0 -0
- vllm_router.libs/libssl-f60bf0e2.so.1.1.1k +0 -0
- vllm_router_rs.cpython-39-aarch64-linux-gnu.so +0 -0
vllm_router/__init__.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
import sys
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
import setproctitle
|
|
7
|
+
from vllm_router.mini_lb import MiniLoadBalancer
|
|
8
|
+
from vllm_router.router_args import RouterArgs
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger("router")
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from vllm_router.router import Router
|
|
14
|
+
except ImportError:
|
|
15
|
+
Router = None
|
|
16
|
+
logger.warning(
|
|
17
|
+
"Rust Router is not installed, only python MiniLB (debugging only) is available"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|
22
|
+
"""
|
|
23
|
+
Launch the VLLM router with the configuration from parsed arguments.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
args: Namespace object containing router configuration
|
|
27
|
+
Can be either raw argparse.Namespace or converted RouterArgs
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Router instance if successful, None if failed
|
|
31
|
+
"""
|
|
32
|
+
setproctitle.setproctitle("vllm::router")
|
|
33
|
+
try:
|
|
34
|
+
# Convert to RouterArgs if needed
|
|
35
|
+
if not isinstance(args, RouterArgs):
|
|
36
|
+
router_args = RouterArgs.from_cli_args(args)
|
|
37
|
+
else:
|
|
38
|
+
router_args = args
|
|
39
|
+
|
|
40
|
+
if router_args.mini_lb:
|
|
41
|
+
mini_lb = MiniLoadBalancer(router_args)
|
|
42
|
+
mini_lb.start()
|
|
43
|
+
else:
|
|
44
|
+
if Router is None:
|
|
45
|
+
raise RuntimeError("Rust Router is not installed")
|
|
46
|
+
router_args._validate_router_args()
|
|
47
|
+
router = Router.from_args(router_args)
|
|
48
|
+
router.start()
|
|
49
|
+
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.error(f"Error starting router: {e}")
|
|
52
|
+
raise e
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class CustomHelpFormatter(
|
|
56
|
+
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
|
|
57
|
+
):
|
|
58
|
+
"""Custom formatter that preserves both description formatting and shows defaults"""
|
|
59
|
+
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def parse_router_args(args: List[str]) -> RouterArgs:
|
|
64
|
+
"""Parse command line arguments and return RouterArgs instance."""
|
|
65
|
+
parser = argparse.ArgumentParser(
|
|
66
|
+
description="""VLLM Router - High-performance request distribution across worker nodes
|
|
67
|
+
|
|
68
|
+
Usage:
|
|
69
|
+
This launcher enables starting a router with individual worker instances. It is useful for
|
|
70
|
+
multi-node setups or when you want to start workers and router separately.
|
|
71
|
+
|
|
72
|
+
Examples:
|
|
73
|
+
# Regular mode
|
|
74
|
+
vllm-router --worker-urls http://worker1:8000 http://worker2:8000
|
|
75
|
+
|
|
76
|
+
# PD disaggregated mode with same policy for both
|
|
77
|
+
vllm-router --pd-disaggregation \\
|
|
78
|
+
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 \\
|
|
79
|
+
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
|
80
|
+
--policy cache_aware
|
|
81
|
+
|
|
82
|
+
# PD mode with optional bootstrap ports
|
|
83
|
+
vllm-router --pd-disaggregation \\
|
|
84
|
+
--prefill http://prefill1:8000 9000 \\ # With bootstrap port
|
|
85
|
+
--prefill http://prefill2:8000 none \\ # Explicitly no bootstrap port
|
|
86
|
+
--prefill http://prefill3:8000 \\ # Defaults to no bootstrap port
|
|
87
|
+
--decode http://decode1:8001 --decode http://decode2:8001
|
|
88
|
+
|
|
89
|
+
# PD mode with different policies for prefill and decode
|
|
90
|
+
vllm-router --pd-disaggregation \\
|
|
91
|
+
--prefill http://prefill1:8000 --prefill http://prefill2:8000 \\
|
|
92
|
+
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
|
93
|
+
--prefill-policy cache_aware --decode-policy power_of_two
|
|
94
|
+
|
|
95
|
+
""",
|
|
96
|
+
formatter_class=CustomHelpFormatter,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
RouterArgs.add_cli_args(parser, use_router_prefix=False)
|
|
100
|
+
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def main() -> None:
|
|
104
|
+
router_args = parse_router_args(sys.argv[1:])
|
|
105
|
+
launch_router(router_args)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if __name__ == "__main__":
|
|
109
|
+
main()
|
vllm_router/mini_lb.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Minimal HTTP load balancer for prefill and decode servers for testing.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import ipaddress
|
|
7
|
+
import logging
|
|
8
|
+
import random
|
|
9
|
+
import urllib
|
|
10
|
+
from http import HTTPStatus
|
|
11
|
+
from itertools import chain
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
import aiohttp
|
|
15
|
+
import orjson
|
|
16
|
+
import uvicorn
|
|
17
|
+
from fastapi import FastAPI, HTTPException
|
|
18
|
+
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
|
19
|
+
from vllm_router.router_args import RouterArgs
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
|
24
|
+
1024 * 64
|
|
25
|
+
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def maybe_wrap_ipv6_address(address: str) -> str:
|
|
29
|
+
try:
|
|
30
|
+
ipaddress.IPv6Address(address)
|
|
31
|
+
return f"[{address}]"
|
|
32
|
+
except ValueError:
|
|
33
|
+
return address
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MiniLoadBalancer:
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
router_args: RouterArgs,
|
|
40
|
+
):
|
|
41
|
+
self._validate_router_args(router_args)
|
|
42
|
+
|
|
43
|
+
self.host = router_args.host
|
|
44
|
+
self.port = router_args.port
|
|
45
|
+
self.timeout = router_args.request_timeout_secs
|
|
46
|
+
self.prefill_urls = [url[0] for url in router_args.prefill_urls]
|
|
47
|
+
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
|
|
48
|
+
self.decode_urls = router_args.decode_urls
|
|
49
|
+
|
|
50
|
+
def _validate_router_args(self, router_args: RouterArgs):
|
|
51
|
+
logger.warning(
|
|
52
|
+
"\x1b[33mMiniLB is only for debugging purposes, it only supports random policy!\033[0m"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# NOTE: too many arguments unsupported, just validate some important ones
|
|
56
|
+
if router_args.policy != "random":
|
|
57
|
+
logger.warning("[MiniLB] Overriding policy to random")
|
|
58
|
+
router_args.policy = "random"
|
|
59
|
+
|
|
60
|
+
if not router_args.pd_disaggregation:
|
|
61
|
+
raise ValueError("MiniLB only supports PD disaggregation mode")
|
|
62
|
+
|
|
63
|
+
if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"MiniLB requires at least one prefill and one decode server"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def start(self):
|
|
69
|
+
global lb
|
|
70
|
+
lb = self
|
|
71
|
+
uvicorn.run(app, host=self.host, port=self.port)
|
|
72
|
+
|
|
73
|
+
def select_pair(self):
|
|
74
|
+
assert len(self.prefill_urls) > 0, "No prefill servers available"
|
|
75
|
+
assert len(self.decode_urls) > 0, "No decode servers available"
|
|
76
|
+
pidx = random.randint(0, len(self.prefill_urls) - 1)
|
|
77
|
+
didx = random.randint(0, len(self.decode_urls) - 1)
|
|
78
|
+
return (
|
|
79
|
+
self.prefill_urls[pidx],
|
|
80
|
+
self.prefill_bootstrap_ports[pidx],
|
|
81
|
+
self.decode_urls[didx],
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
async def generate(
|
|
85
|
+
self, modified_request, prefill_server, decode_server, endpoint
|
|
86
|
+
) -> ORJSONResponse:
|
|
87
|
+
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
|
88
|
+
|
|
89
|
+
async with aiohttp.ClientSession(
|
|
90
|
+
timeout=aiohttp.ClientTimeout(
|
|
91
|
+
total=self.timeout
|
|
92
|
+
) # Add timeout for request reliability
|
|
93
|
+
) as session:
|
|
94
|
+
tasks = [
|
|
95
|
+
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
|
96
|
+
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
# Wait for both responses to complete. Prefill should end first.
|
|
100
|
+
prefill_response, decode_response = await asyncio.gather(*tasks)
|
|
101
|
+
|
|
102
|
+
if "return_logprob" in modified_request:
|
|
103
|
+
|
|
104
|
+
prefill_json = await prefill_response.json()
|
|
105
|
+
ret_json = await decode_response.json()
|
|
106
|
+
|
|
107
|
+
# merge `meta_info.input_token_logprobs` from prefill to decode
|
|
108
|
+
if "meta_info" in ret_json:
|
|
109
|
+
if "input_token_logprobs" in ret_json["meta_info"]:
|
|
110
|
+
ret_json["meta_info"]["input_token_logprobs"] = (
|
|
111
|
+
prefill_json["meta_info"]["input_token_logprobs"]
|
|
112
|
+
+ ret_json["meta_info"]["input_token_logprobs"]
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
ret_json = await decode_response.json()
|
|
116
|
+
|
|
117
|
+
return ORJSONResponse(
|
|
118
|
+
content=ret_json,
|
|
119
|
+
status_code=decode_response.status,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
async def generate_stream(
|
|
123
|
+
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
|
124
|
+
):
|
|
125
|
+
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
|
126
|
+
|
|
127
|
+
async def stream_results():
|
|
128
|
+
async with aiohttp.ClientSession(
|
|
129
|
+
timeout=aiohttp.ClientTimeout(
|
|
130
|
+
total=self.timeout
|
|
131
|
+
) # Add timeout for request reliability
|
|
132
|
+
) as session:
|
|
133
|
+
# Create the tasks for both prefill and decode requests
|
|
134
|
+
tasks = [
|
|
135
|
+
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
|
136
|
+
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
|
137
|
+
]
|
|
138
|
+
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
|
139
|
+
prefill_response, decode_response = await asyncio.gather(*tasks)
|
|
140
|
+
|
|
141
|
+
if modified_request.get("return_logprob", False):
|
|
142
|
+
prefill_chunks = []
|
|
143
|
+
async for chunk in prefill_response.content:
|
|
144
|
+
prefill_chunks.append(chunk)
|
|
145
|
+
|
|
146
|
+
first_prefill_chunk = (
|
|
147
|
+
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
|
148
|
+
)
|
|
149
|
+
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
|
150
|
+
|
|
151
|
+
async for chunk in decode_response.content:
|
|
152
|
+
# Note: This is inefficient
|
|
153
|
+
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
|
154
|
+
decoded_chunk = chunk.decode("utf-8")
|
|
155
|
+
if (
|
|
156
|
+
decoded_chunk
|
|
157
|
+
and decoded_chunk.startswith("data:")
|
|
158
|
+
and "[DONE]" not in decoded_chunk
|
|
159
|
+
):
|
|
160
|
+
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
|
161
|
+
ret_json["meta_info"]["input_token_logprobs"] = (
|
|
162
|
+
first_prefill_chunk_json["meta_info"][
|
|
163
|
+
"input_token_logprobs"
|
|
164
|
+
]
|
|
165
|
+
+ ret_json["meta_info"]["input_token_logprobs"]
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
|
169
|
+
else:
|
|
170
|
+
yield chunk
|
|
171
|
+
else:
|
|
172
|
+
async for chunk in decode_response.content.iter_chunked(
|
|
173
|
+
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
|
174
|
+
):
|
|
175
|
+
yield chunk
|
|
176
|
+
|
|
177
|
+
return StreamingResponse(
|
|
178
|
+
stream_results(),
|
|
179
|
+
media_type="text/event-stream",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
app = FastAPI()
|
|
184
|
+
lb: Optional[MiniLoadBalancer] = None
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@app.get("/health")
|
|
188
|
+
async def health_check():
|
|
189
|
+
return Response(status_code=200)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@app.get("/health_generate")
|
|
193
|
+
async def health_generate():
|
|
194
|
+
async with aiohttp.ClientSession() as session:
|
|
195
|
+
# Create the tasks
|
|
196
|
+
tasks = []
|
|
197
|
+
for server in chain(lb.prefill_urls, lb.decode_urls):
|
|
198
|
+
tasks.append(session.get(f"{server}/health_generate"))
|
|
199
|
+
for i, response in enumerate(asyncio.as_completed(tasks)):
|
|
200
|
+
await response
|
|
201
|
+
return Response(status_code=200)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@app.post("/flush_cache")
|
|
205
|
+
async def flush_cache():
|
|
206
|
+
async with aiohttp.ClientSession() as session:
|
|
207
|
+
# Create the tasks
|
|
208
|
+
tasks = []
|
|
209
|
+
for server in chain(lb.prefill_urls, lb.decode_urls):
|
|
210
|
+
tasks.append(session.post(f"{server}/flush_cache"))
|
|
211
|
+
for i, response in enumerate(asyncio.as_completed(tasks)):
|
|
212
|
+
await response
|
|
213
|
+
return Response(status_code=200)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@app.get("/get_server_info")
|
|
217
|
+
async def get_server_info():
|
|
218
|
+
prefill_infos = []
|
|
219
|
+
decode_infos = []
|
|
220
|
+
all_internal_states = []
|
|
221
|
+
|
|
222
|
+
async with aiohttp.ClientSession() as session:
|
|
223
|
+
for server in lb.prefill_urls:
|
|
224
|
+
server_info = await session.get(f"{server}/get_server_info")
|
|
225
|
+
prefill_infos.append(await server_info.json())
|
|
226
|
+
for server in lb.decode_urls:
|
|
227
|
+
server_info = await session.get(f"{server}/get_server_info")
|
|
228
|
+
info_json = await server_info.json()
|
|
229
|
+
decode_infos.append(info_json)
|
|
230
|
+
# Extract internal_states from decode servers
|
|
231
|
+
if "internal_states" in info_json:
|
|
232
|
+
all_internal_states.extend(info_json["internal_states"])
|
|
233
|
+
|
|
234
|
+
# Return format expected by bench_one_batch_server.py
|
|
235
|
+
if all_internal_states:
|
|
236
|
+
return {
|
|
237
|
+
"internal_states": all_internal_states,
|
|
238
|
+
"prefill": prefill_infos,
|
|
239
|
+
"decode": decode_infos,
|
|
240
|
+
}
|
|
241
|
+
else:
|
|
242
|
+
# Fallback with dummy data if no internal states found
|
|
243
|
+
return {
|
|
244
|
+
"internal_states": [
|
|
245
|
+
{
|
|
246
|
+
"last_gen_throughput": 0.0,
|
|
247
|
+
"avg_spec_accept_length": None,
|
|
248
|
+
}
|
|
249
|
+
],
|
|
250
|
+
"prefill": prefill_infos,
|
|
251
|
+
"decode": decode_infos,
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@app.get("/get_model_info")
|
|
256
|
+
async def get_model_info():
|
|
257
|
+
if not lb or not lb.prefill_urls:
|
|
258
|
+
raise HTTPException(
|
|
259
|
+
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
|
260
|
+
detail="There is no server registered",
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
target_server_url = lb.prefill_urls[0]
|
|
264
|
+
endpoint_url = f"{target_server_url}/get_model_info"
|
|
265
|
+
|
|
266
|
+
async with aiohttp.ClientSession() as session:
|
|
267
|
+
try:
|
|
268
|
+
async with session.get(endpoint_url) as response:
|
|
269
|
+
if response.status != 200:
|
|
270
|
+
error_text = await response.text()
|
|
271
|
+
raise HTTPException(
|
|
272
|
+
status_code=HTTPStatus.BAD_GATEWAY,
|
|
273
|
+
detail=(
|
|
274
|
+
f"Failed to get model info from {target_server_url}"
|
|
275
|
+
f"Status: {response.status}, Response: {error_text}"
|
|
276
|
+
),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
model_info_json = await response.json()
|
|
280
|
+
return ORJSONResponse(content=model_info_json)
|
|
281
|
+
|
|
282
|
+
except aiohttp.ClientError:
|
|
283
|
+
raise HTTPException(
|
|
284
|
+
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
|
285
|
+
detail="Failed to get model info from backend",
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
@app.post("/generate")
|
|
290
|
+
async def handle_generate_request(request_data: dict):
|
|
291
|
+
prefill_server, bootstrap_port, decode_server = lb.select_pair()
|
|
292
|
+
|
|
293
|
+
# Parse and transform prefill_server for bootstrap data
|
|
294
|
+
parsed_url = urllib.parse.urlparse(prefill_server)
|
|
295
|
+
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
|
296
|
+
modified_request = request_data.copy()
|
|
297
|
+
|
|
298
|
+
batch_size = _get_request_batch_size(modified_request)
|
|
299
|
+
if batch_size is not None:
|
|
300
|
+
modified_request.update(
|
|
301
|
+
{
|
|
302
|
+
"bootstrap_host": [hostname] * batch_size,
|
|
303
|
+
"bootstrap_port": [bootstrap_port] * batch_size,
|
|
304
|
+
"bootstrap_room": [
|
|
305
|
+
_generate_bootstrap_room() for _ in range(batch_size)
|
|
306
|
+
],
|
|
307
|
+
}
|
|
308
|
+
)
|
|
309
|
+
else:
|
|
310
|
+
modified_request.update(
|
|
311
|
+
{
|
|
312
|
+
"bootstrap_host": hostname,
|
|
313
|
+
"bootstrap_port": bootstrap_port,
|
|
314
|
+
"bootstrap_room": _generate_bootstrap_room(),
|
|
315
|
+
}
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if request_data.get("stream", False):
|
|
319
|
+
return await lb.generate_stream(
|
|
320
|
+
modified_request, prefill_server, decode_server, "generate"
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
return await lb.generate(
|
|
324
|
+
modified_request, prefill_server, decode_server, "generate"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
|
329
|
+
prefill_server, bootstrap_port, decode_server = lb.select_pair()
|
|
330
|
+
|
|
331
|
+
# Parse and transform prefill_server for bootstrap data
|
|
332
|
+
parsed_url = urllib.parse.urlparse(prefill_server)
|
|
333
|
+
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
|
334
|
+
modified_request = request_data.copy()
|
|
335
|
+
modified_request.update(
|
|
336
|
+
{
|
|
337
|
+
"bootstrap_host": hostname,
|
|
338
|
+
"bootstrap_port": bootstrap_port,
|
|
339
|
+
"bootstrap_room": _generate_bootstrap_room(),
|
|
340
|
+
}
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if request_data.get("stream", False):
|
|
344
|
+
return await lb.generate_stream(
|
|
345
|
+
modified_request,
|
|
346
|
+
prefill_server,
|
|
347
|
+
decode_server,
|
|
348
|
+
endpoint=endpoint_name,
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
return await lb.generate(
|
|
352
|
+
modified_request,
|
|
353
|
+
prefill_server,
|
|
354
|
+
decode_server,
|
|
355
|
+
endpoint=endpoint_name,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
@app.post("/v1/chat/completions")
|
|
360
|
+
async def handle_chat_completion_request(request_data: dict):
|
|
361
|
+
return await _forward_to_backend(request_data, "v1/chat/completions")
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@app.post("/v1/completions")
|
|
365
|
+
async def handle_completion_request(request_data: dict):
|
|
366
|
+
return await _forward_to_backend(request_data, "v1/completions")
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def _generate_bootstrap_room():
|
|
370
|
+
return random.randint(0, 2**63 - 1)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
# We may utilize `GenerateReqInput`'s logic later
|
|
374
|
+
def _get_request_batch_size(request):
|
|
375
|
+
if (text := request.get("text")) is not None:
|
|
376
|
+
return None if isinstance(text, str) else len(text)
|
|
377
|
+
if (input_ids := request.get("input_ids")) is not None:
|
|
378
|
+
return None if isinstance(input_ids[0], int) else len(input_ids)
|
|
379
|
+
return None
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
@app.get("/v1/models")
|
|
383
|
+
async def get_models():
|
|
384
|
+
prefill_server = lb.prefill_urls[0] # Get the first prefill server
|
|
385
|
+
async with aiohttp.ClientSession() as session:
|
|
386
|
+
try:
|
|
387
|
+
response = await session.get(f"{prefill_server}/v1/models")
|
|
388
|
+
if response.status != 200:
|
|
389
|
+
raise HTTPException(
|
|
390
|
+
status_code=response.status,
|
|
391
|
+
detail=f"Prefill server error: Status {response.status}",
|
|
392
|
+
)
|
|
393
|
+
return ORJSONResponse(content=await response.json())
|
|
394
|
+
except Exception as e:
|
|
395
|
+
raise HTTPException(status_code=500, detail=str(e))
|
vllm_router/router.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from vllm_router.router_args import RouterArgs
|
|
4
|
+
from vllm_router_rs import PolicyType
|
|
5
|
+
from vllm_router_rs import Router as _Router
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
|
|
9
|
+
"""Convert policy string to PolicyType enum."""
|
|
10
|
+
if policy_str is None:
|
|
11
|
+
return None
|
|
12
|
+
policy_map = {
|
|
13
|
+
"random": PolicyType.Random,
|
|
14
|
+
"round_robin": PolicyType.RoundRobin,
|
|
15
|
+
"cache_aware": PolicyType.CacheAware,
|
|
16
|
+
"power_of_two": PolicyType.PowerOfTwo,
|
|
17
|
+
}
|
|
18
|
+
return policy_map[policy_str]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Router:
|
|
22
|
+
"""
|
|
23
|
+
A high-performance router for distributing requests across worker nodes.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
|
|
27
|
+
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
|
|
28
|
+
policy: Load balancing policy to use. Options:
|
|
29
|
+
- PolicyType.Random: Randomly select workers
|
|
30
|
+
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
|
31
|
+
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
|
|
32
|
+
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
|
|
33
|
+
host: Host address to bind the router server. Default: '127.0.0.1'
|
|
34
|
+
port: Port number to bind the router server. Default: 3001
|
|
35
|
+
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
|
|
36
|
+
worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10
|
|
37
|
+
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
|
|
38
|
+
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
|
|
39
|
+
tree. Default: 0.5
|
|
40
|
+
balance_abs_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
|
|
41
|
+
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32
|
|
42
|
+
balance_rel_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
|
|
43
|
+
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
|
|
44
|
+
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
|
|
45
|
+
routing. Default: 60
|
|
46
|
+
max_payload_size: Maximum payload size in bytes. Default: 256MB
|
|
47
|
+
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
|
48
|
+
intra_node_data_parallel_size: Data parallel size for DP-aware routing (automatically enabled when > 1). Default: 1
|
|
49
|
+
enable_igw: Enable IGW (Inference-Gateway) mode for multi-model support. When enabled,
|
|
50
|
+
the router can manage multiple models simultaneously with per-model load balancing
|
|
51
|
+
policies. Default: False
|
|
52
|
+
api_key: API key for authorization with workers. Required when using data parallel routing (intra_node_data_parallel_size > 1).
|
|
53
|
+
Default: None
|
|
54
|
+
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
|
|
55
|
+
log_level: Logging level. Options: 'debug', 'info', 'warning', 'error', 'critical'.
|
|
56
|
+
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
|
|
57
|
+
automatically discover worker pods based on the selector. Default: False
|
|
58
|
+
selector: Dictionary mapping of label keys to values for Kubernetes pod selection.
|
|
59
|
+
Example: {"app": "vllm-worker"}. Default: {}
|
|
60
|
+
service_discovery_port: Port to use for service discovery. The router will generate
|
|
61
|
+
worker URLs using this port. Default: 80
|
|
62
|
+
service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
|
|
63
|
+
watches pods across all namespaces (requires cluster-wide permissions). Default: None
|
|
64
|
+
prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
|
|
65
|
+
for prefill servers (PD mode only). Default: {}
|
|
66
|
+
decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
|
|
67
|
+
for decode servers (PD mode only). Default: {}
|
|
68
|
+
prometheus_port: Port to expose Prometheus metrics. Default: None
|
|
69
|
+
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
|
|
70
|
+
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
|
71
|
+
vllm_pd_disaggregation: Enable vLLM PD (Prefill-Decode) disaggregated mode. Default: False
|
|
72
|
+
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
|
73
|
+
decode_urls: List of URLs for decode servers (PD mode only)
|
|
74
|
+
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
|
|
75
|
+
If not specified, uses the main policy. Default: None
|
|
76
|
+
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
|
|
77
|
+
If not specified, uses the main policy. Default: None
|
|
78
|
+
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
|
|
79
|
+
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
|
|
80
|
+
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
|
|
81
|
+
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
|
|
82
|
+
Default: 'vllm.ai/bootstrap-port'
|
|
83
|
+
request_timeout_secs: Request timeout in seconds. Default: 600
|
|
84
|
+
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 256
|
|
85
|
+
queue_size: Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately). Default: 100
|
|
86
|
+
queue_timeout_secs: Maximum time (in seconds) a request can wait in queue before timing out. Default: 60
|
|
87
|
+
rate_limit_tokens_per_second: Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests. Default: None
|
|
88
|
+
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
|
|
89
|
+
health_failure_threshold: Number of consecutive health check failures before marking worker unhealthy. Default: 3
|
|
90
|
+
health_success_threshold: Number of consecutive health check successes before marking worker healthy. Default: 2
|
|
91
|
+
health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5
|
|
92
|
+
health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60
|
|
93
|
+
health_check_endpoint: Health check endpoint path. Default: '/health'
|
|
94
|
+
model_path: Model path for loading tokenizer (HuggingFace model ID or local path). Default: None
|
|
95
|
+
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(self, router: Optional[_Router] = None, **kwargs):
|
|
99
|
+
"""Initialize Router either from a _Router instance or keyword arguments.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
router: Optional _Router instance. If provided, kwargs are ignored.
|
|
103
|
+
**kwargs: Keyword arguments to pass to _Router constructor if router is None.
|
|
104
|
+
"""
|
|
105
|
+
if router is not None:
|
|
106
|
+
self._router = router
|
|
107
|
+
else:
|
|
108
|
+
# Create _Router from kwargs
|
|
109
|
+
self._router = _Router(**kwargs)
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def from_args(args: RouterArgs) -> "Router":
|
|
113
|
+
"""Create a router from a RouterArgs instance."""
|
|
114
|
+
|
|
115
|
+
args_dict = vars(args)
|
|
116
|
+
# Convert RouterArgs to _Router parameters
|
|
117
|
+
args_dict["worker_urls"] = (
|
|
118
|
+
[]
|
|
119
|
+
if args_dict["service_discovery"]
|
|
120
|
+
or args_dict["pd_disaggregation"]
|
|
121
|
+
or args_dict["vllm_pd_disaggregation"]
|
|
122
|
+
else args_dict["worker_urls"]
|
|
123
|
+
)
|
|
124
|
+
args_dict["policy"] = policy_from_str(args_dict["policy"])
|
|
125
|
+
args_dict["prefill_urls"] = (
|
|
126
|
+
args_dict["prefill_urls"]
|
|
127
|
+
if args_dict["pd_disaggregation"] or args_dict["vllm_pd_disaggregation"]
|
|
128
|
+
else None
|
|
129
|
+
)
|
|
130
|
+
args_dict["decode_urls"] = (
|
|
131
|
+
args_dict["decode_urls"]
|
|
132
|
+
if args_dict["pd_disaggregation"] or args_dict["vllm_pd_disaggregation"]
|
|
133
|
+
else None
|
|
134
|
+
)
|
|
135
|
+
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
|
|
136
|
+
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
|
|
137
|
+
|
|
138
|
+
# remove mini_lb parameter
|
|
139
|
+
args_dict.pop("mini_lb")
|
|
140
|
+
|
|
141
|
+
return Router(router=_Router(**args_dict))
|
|
142
|
+
|
|
143
|
+
def start(self) -> None:
|
|
144
|
+
"""Start the router server.
|
|
145
|
+
|
|
146
|
+
This method blocks until the server is shut down.
|
|
147
|
+
"""
|
|
148
|
+
self._router.start()
|