autoinference-utils 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autoinference_utils-0.1.0/.gitignore +245 -0
- autoinference_utils-0.1.0/PKG-INFO +5 -0
- autoinference_utils-0.1.0/README.md +13 -0
- autoinference_utils-0.1.0/autoinference_utils/__init__.py +0 -0
- autoinference_utils-0.1.0/autoinference_utils/endpoint.py +506 -0
- autoinference_utils-0.1.0/pyproject.toml +9 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[codz]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
.Python
|
|
11
|
+
build/
|
|
12
|
+
develop-eggs/
|
|
13
|
+
dist/
|
|
14
|
+
downloads/
|
|
15
|
+
eggs/
|
|
16
|
+
.eggs/
|
|
17
|
+
lib/
|
|
18
|
+
lib64/
|
|
19
|
+
parts/
|
|
20
|
+
sdist/
|
|
21
|
+
var/
|
|
22
|
+
wheels/
|
|
23
|
+
share/python-wheels/
|
|
24
|
+
*.egg-info/
|
|
25
|
+
.installed.cfg
|
|
26
|
+
*.egg
|
|
27
|
+
MANIFEST
|
|
28
|
+
|
|
29
|
+
# PyInstaller
|
|
30
|
+
# Usually these files are written by a python script from a template
|
|
31
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
32
|
+
*.manifest
|
|
33
|
+
*.spec
|
|
34
|
+
|
|
35
|
+
# Installer logs
|
|
36
|
+
pip-log.txt
|
|
37
|
+
pip-delete-this-directory.txt
|
|
38
|
+
|
|
39
|
+
# Unit test / coverage reports
|
|
40
|
+
htmlcov/
|
|
41
|
+
.tox/
|
|
42
|
+
.nox/
|
|
43
|
+
.coverage
|
|
44
|
+
.coverage.*
|
|
45
|
+
.cache
|
|
46
|
+
nosetests.xml
|
|
47
|
+
coverage.xml
|
|
48
|
+
*.cover
|
|
49
|
+
*.py.cover
|
|
50
|
+
.hypothesis/
|
|
51
|
+
.pytest_cache/
|
|
52
|
+
cover/
|
|
53
|
+
|
|
54
|
+
# Translations
|
|
55
|
+
*.mo
|
|
56
|
+
*.pot
|
|
57
|
+
|
|
58
|
+
# Django stuff:
|
|
59
|
+
*.log
|
|
60
|
+
local_settings.py
|
|
61
|
+
db.sqlite3
|
|
62
|
+
db.sqlite3-journal
|
|
63
|
+
|
|
64
|
+
# Flask stuff:
|
|
65
|
+
instance/
|
|
66
|
+
.webassets-cache
|
|
67
|
+
|
|
68
|
+
# Scrapy stuff:
|
|
69
|
+
.scrapy
|
|
70
|
+
|
|
71
|
+
# Sphinx documentation
|
|
72
|
+
docs/_build/
|
|
73
|
+
|
|
74
|
+
# PyBuilder
|
|
75
|
+
.pybuilder/
|
|
76
|
+
target/
|
|
77
|
+
|
|
78
|
+
# Jupyter Notebook
|
|
79
|
+
.ipynb_checkpoints
|
|
80
|
+
|
|
81
|
+
# IPython
|
|
82
|
+
profile_default/
|
|
83
|
+
ipython_config.py
|
|
84
|
+
|
|
85
|
+
# pyenv
|
|
86
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
87
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
88
|
+
# .python-version
|
|
89
|
+
|
|
90
|
+
# pipenv
|
|
91
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
92
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
93
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
94
|
+
# install all needed dependencies.
|
|
95
|
+
# Pipfile.lock
|
|
96
|
+
|
|
97
|
+
# UV
|
|
98
|
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
|
99
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
100
|
+
# commonly ignored for libraries.
|
|
101
|
+
# uv.lock
|
|
102
|
+
|
|
103
|
+
# poetry
|
|
104
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
105
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
106
|
+
# commonly ignored for libraries.
|
|
107
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
108
|
+
# poetry.lock
|
|
109
|
+
# poetry.toml
|
|
110
|
+
|
|
111
|
+
# pdm
|
|
112
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
113
|
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
|
114
|
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
|
115
|
+
# pdm.lock
|
|
116
|
+
# pdm.toml
|
|
117
|
+
.pdm-python
|
|
118
|
+
.pdm-build/
|
|
119
|
+
|
|
120
|
+
# pixi
|
|
121
|
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
|
122
|
+
# pixi.lock
|
|
123
|
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
|
124
|
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
|
125
|
+
.pixi
|
|
126
|
+
|
|
127
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
128
|
+
__pypackages__/
|
|
129
|
+
|
|
130
|
+
# Celery stuff
|
|
131
|
+
celerybeat-schedule
|
|
132
|
+
celerybeat.pid
|
|
133
|
+
|
|
134
|
+
# Redis
|
|
135
|
+
*.rdb
|
|
136
|
+
*.aof
|
|
137
|
+
*.pid
|
|
138
|
+
|
|
139
|
+
# RabbitMQ
|
|
140
|
+
mnesia/
|
|
141
|
+
rabbitmq/
|
|
142
|
+
rabbitmq-data/
|
|
143
|
+
|
|
144
|
+
# ActiveMQ
|
|
145
|
+
activemq-data/
|
|
146
|
+
|
|
147
|
+
# SageMath parsed files
|
|
148
|
+
*.sage.py
|
|
149
|
+
|
|
150
|
+
# Environments
|
|
151
|
+
.env
|
|
152
|
+
.envrc
|
|
153
|
+
.venv
|
|
154
|
+
env/
|
|
155
|
+
venv/
|
|
156
|
+
ENV/
|
|
157
|
+
env.bak/
|
|
158
|
+
venv.bak/
|
|
159
|
+
|
|
160
|
+
# Spyder project settings
|
|
161
|
+
.spyderproject
|
|
162
|
+
.spyproject
|
|
163
|
+
|
|
164
|
+
# Rope project settings
|
|
165
|
+
.ropeproject
|
|
166
|
+
|
|
167
|
+
# mkdocs documentation
|
|
168
|
+
/site
|
|
169
|
+
|
|
170
|
+
# mypy
|
|
171
|
+
.mypy_cache/
|
|
172
|
+
.dmypy.json
|
|
173
|
+
dmypy.json
|
|
174
|
+
|
|
175
|
+
# Pyre type checker
|
|
176
|
+
.pyre/
|
|
177
|
+
|
|
178
|
+
# pytype static type analyzer
|
|
179
|
+
.pytype/
|
|
180
|
+
|
|
181
|
+
# Cython debug symbols
|
|
182
|
+
cython_debug/
|
|
183
|
+
|
|
184
|
+
# PyCharm
|
|
185
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
186
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
187
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
188
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
189
|
+
# .idea/
|
|
190
|
+
|
|
191
|
+
# Abstra
|
|
192
|
+
# Abstra is an AI-powered process automation framework.
|
|
193
|
+
# Ignore directories containing user credentials, local state, and settings.
|
|
194
|
+
# Learn more at https://abstra.io/docs
|
|
195
|
+
.abstra/
|
|
196
|
+
|
|
197
|
+
# Visual Studio Code
|
|
198
|
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
|
199
|
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
|
200
|
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
|
201
|
+
# you could uncomment the following to ignore the entire vscode folder
|
|
202
|
+
# .vscode/
|
|
203
|
+
|
|
204
|
+
# Ruff stuff:
|
|
205
|
+
.ruff_cache/
|
|
206
|
+
|
|
207
|
+
# PyPI configuration file
|
|
208
|
+
.pypirc
|
|
209
|
+
|
|
210
|
+
# Marimo
|
|
211
|
+
marimo/_static/
|
|
212
|
+
marimo/_lsp/
|
|
213
|
+
__marimo__/
|
|
214
|
+
|
|
215
|
+
# Streamlit
|
|
216
|
+
.streamlit/secrets.toml
|
|
217
|
+
# General
|
|
218
|
+
.DS_Store
|
|
219
|
+
__MACOSX/
|
|
220
|
+
.AppleDouble
|
|
221
|
+
.LSOverride
|
|
222
|
+
Icon[
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
# Thumbnails
|
|
226
|
+
._*
|
|
227
|
+
|
|
228
|
+
# Files that might appear in the root of a volume
|
|
229
|
+
.DocumentRevisions-V100
|
|
230
|
+
.fseventsd
|
|
231
|
+
.Spotlight-V100
|
|
232
|
+
.TemporaryItems
|
|
233
|
+
.Trashes
|
|
234
|
+
.VolumeIcon.icns
|
|
235
|
+
.com.apple.timemachine.donotpresent
|
|
236
|
+
|
|
237
|
+
# Directories potentially created on remote AFP share
|
|
238
|
+
.AppleDB
|
|
239
|
+
.AppleDesktop
|
|
240
|
+
Network Trash Folder
|
|
241
|
+
Temporary Items
|
|
242
|
+
.apdisk
|
|
243
|
+
|
|
244
|
+
benchmark_results/
|
|
245
|
+
docs/
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# autoinference-utils
|
|
2
|
+
|
|
3
|
+
Shared endpoint abstractions (`SGLangEndpoint`, `VLLMEndpoint`) for autoinference deployments.
|
|
4
|
+
|
|
5
|
+
## Publishing
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
cd autoinference_utils
|
|
9
|
+
uv build
|
|
10
|
+
uv publish # requires PyPI token via UV_PUBLISH_TOKEN or --token
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
Bump `version` in `pyproject.toml` before each release.
|
|
File without changes
|
|
@@ -0,0 +1,506 @@
|
|
|
1
|
+
"""Composable endpoint abstractions for inference server subprocess deployments.
|
|
2
|
+
|
|
3
|
+
Adapted from the timmy branch's utils/endpoint.py. Key differences from the
|
|
4
|
+
old SGLangFlash inheritance pattern: these are plain Python objects, not Modal
|
|
5
|
+
classes. Deployments compose them rather than inheriting from them.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import shlex
|
|
12
|
+
import subprocess
|
|
13
|
+
import threading
|
|
14
|
+
import time
|
|
15
|
+
import urllib.error
|
|
16
|
+
import urllib.request
|
|
17
|
+
from abc import ABC
|
|
18
|
+
from typing import Any, Callable, Literal, Mapping, Optional, Sequence
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Endpoint(ABC):
|
|
22
|
+
"""A thing with a URL that you can start and stop."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, base_url: str):
|
|
25
|
+
self.base_url = base_url.rstrip("/")
|
|
26
|
+
|
|
27
|
+
def __enter__(self):
|
|
28
|
+
self.start()
|
|
29
|
+
return self
|
|
30
|
+
|
|
31
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
32
|
+
self.stop()
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
def start(self):
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
def stop(self):
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SGLangEndpoint(Endpoint):
|
|
43
|
+
"""Manages an SGLang server subprocess on the local machine."""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
*,
|
|
48
|
+
model_path: str,
|
|
49
|
+
worker_port: int = 8000,
|
|
50
|
+
tp: Optional[int] = None,
|
|
51
|
+
ep: Optional[int] = None,
|
|
52
|
+
dp: Optional[int] = None,
|
|
53
|
+
speculative_model_path: Optional[str] = None,
|
|
54
|
+
load_format: Optional[str] = None,
|
|
55
|
+
nnodes: int = 1,
|
|
56
|
+
node_rank: int = 0,
|
|
57
|
+
dist_init_host: Optional[str] = None,
|
|
58
|
+
dist_init_port: int = 1234,
|
|
59
|
+
disaggregation_mode: Optional[Literal["prefill", "decode"]] = None,
|
|
60
|
+
prefill_bootstrap_port: int = 8998,
|
|
61
|
+
launcher_module: str = "sglang.launch_server",
|
|
62
|
+
extra_server_args: Optional[dict[str, str]] = None,
|
|
63
|
+
health_timeout: float = 20 * 60,
|
|
64
|
+
health_poll_interval: float = 5.0,
|
|
65
|
+
health_request_timeout: float = 5.0,
|
|
66
|
+
):
|
|
67
|
+
super().__init__(base_url=f"http://localhost:{worker_port}")
|
|
68
|
+
self.worker_port = worker_port
|
|
69
|
+
self.model_path = model_path
|
|
70
|
+
self.tp = tp
|
|
71
|
+
self.ep = ep
|
|
72
|
+
self.dp = dp
|
|
73
|
+
self.speculative_model_path = speculative_model_path
|
|
74
|
+
self.load_format = load_format
|
|
75
|
+
self.nnodes = nnodes
|
|
76
|
+
self.node_rank = node_rank
|
|
77
|
+
self.dist_init_host = dist_init_host
|
|
78
|
+
self.dist_init_port = dist_init_port
|
|
79
|
+
self.disaggregation_mode = disaggregation_mode
|
|
80
|
+
self.prefill_bootstrap_port = prefill_bootstrap_port
|
|
81
|
+
self.launcher_module = launcher_module
|
|
82
|
+
self.extra_server_args = dict(extra_server_args) if extra_server_args else {}
|
|
83
|
+
self.health_timeout = health_timeout
|
|
84
|
+
self.health_poll_interval = health_poll_interval
|
|
85
|
+
self.health_request_timeout = health_request_timeout
|
|
86
|
+
self._proc: Optional[subprocess.Popen] = None
|
|
87
|
+
|
|
88
|
+
if self.disaggregation_mode not in (None, "prefill", "decode"):
|
|
89
|
+
raise ValueError("disaggregation_mode must be None, 'prefill', or 'decode'")
|
|
90
|
+
|
|
91
|
+
DEFAULT_OPERATIONAL_ARGS: dict[str, str] = {
|
|
92
|
+
"--enable-metrics": "",
|
|
93
|
+
"--decode-log-interval": "1",
|
|
94
|
+
"--enable-cache-report": "",
|
|
95
|
+
"--model-loader-extra-config": '{"enable_multithread_load":true,"num_threads":64}',
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
def _build_cmd(self) -> list[str]:
|
|
99
|
+
cmd = [
|
|
100
|
+
"python", "-m", self.launcher_module,
|
|
101
|
+
"--host", "0.0.0.0",
|
|
102
|
+
"--port", str(self.worker_port),
|
|
103
|
+
"--model-path", self.model_path,
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
if self.speculative_model_path is not None:
|
|
107
|
+
cmd.extend(["--speculative-draft-model-path", self.speculative_model_path])
|
|
108
|
+
|
|
109
|
+
if self.load_format is not None:
|
|
110
|
+
cmd.extend(["--load-format", self.load_format])
|
|
111
|
+
|
|
112
|
+
if self.tp is not None:
|
|
113
|
+
cmd.extend(["--tp", str(self.tp)])
|
|
114
|
+
if self.ep is not None:
|
|
115
|
+
cmd.extend(["--ep", str(self.ep)])
|
|
116
|
+
if self.dp is not None:
|
|
117
|
+
cmd.extend(["--dp", str(self.dp), "--enable-dp-attention"])
|
|
118
|
+
|
|
119
|
+
if self.disaggregation_mode is not None:
|
|
120
|
+
cmd.extend(["--disaggregation-mode", self.disaggregation_mode])
|
|
121
|
+
if self.disaggregation_mode == "prefill":
|
|
122
|
+
cmd.extend(["--disaggregation-bootstrap-port", str(self.prefill_bootstrap_port)])
|
|
123
|
+
|
|
124
|
+
if self.nnodes > 1:
|
|
125
|
+
if self.dist_init_host is None:
|
|
126
|
+
raise ValueError("dist_init_host is required when nnodes > 1")
|
|
127
|
+
cmd.extend([
|
|
128
|
+
"--nnodes", str(self.nnodes),
|
|
129
|
+
"--node-rank", str(self.node_rank),
|
|
130
|
+
"--dist-init-addr", f"{self.dist_init_host}:{self.dist_init_port}",
|
|
131
|
+
])
|
|
132
|
+
|
|
133
|
+
merged = {**self.DEFAULT_OPERATIONAL_ARGS, **self.extra_server_args}
|
|
134
|
+
for key, value in merged.items():
|
|
135
|
+
if value == "":
|
|
136
|
+
cmd.append(key)
|
|
137
|
+
else:
|
|
138
|
+
cmd.extend([key, *value.split()])
|
|
139
|
+
|
|
140
|
+
return cmd
|
|
141
|
+
|
|
142
|
+
def health_check(self) -> str | None:
|
|
143
|
+
url = f"http://127.0.0.1:{self.worker_port}/health"
|
|
144
|
+
return _health_check(url, request_timeout=self.health_request_timeout, process=self._proc)
|
|
145
|
+
|
|
146
|
+
def start(self):
|
|
147
|
+
cmd = self._build_cmd()
|
|
148
|
+
print(f"[endpoint] starting: {shlex.join(cmd)}")
|
|
149
|
+
self._proc = subprocess.Popen(cmd)
|
|
150
|
+
wait_ready(
|
|
151
|
+
self._proc,
|
|
152
|
+
port=self.worker_port,
|
|
153
|
+
timeout=self.health_timeout,
|
|
154
|
+
poll_interval=self.health_poll_interval,
|
|
155
|
+
request_timeout=self.health_request_timeout,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def stop(self):
|
|
159
|
+
terminate_process(self._proc)
|
|
160
|
+
self._proc = None
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class VLLMEndpoint(Endpoint):
|
|
164
|
+
"""Manages a vLLM server subprocess on the local machine."""
|
|
165
|
+
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
*,
|
|
169
|
+
model: str,
|
|
170
|
+
worker_port: int = 8000,
|
|
171
|
+
extra_server_args: Optional[dict[str, str]] = None,
|
|
172
|
+
health_timeout: float = 20 * 60,
|
|
173
|
+
health_poll_interval: float = 5.0,
|
|
174
|
+
health_request_timeout: float = 5.0,
|
|
175
|
+
):
|
|
176
|
+
super().__init__(base_url=f"http://localhost:{worker_port}")
|
|
177
|
+
self.model = model
|
|
178
|
+
self.worker_port = worker_port
|
|
179
|
+
self.extra_server_args = dict(extra_server_args) if extra_server_args else {}
|
|
180
|
+
self.health_timeout = health_timeout
|
|
181
|
+
self.health_poll_interval = health_poll_interval
|
|
182
|
+
self.health_request_timeout = health_request_timeout
|
|
183
|
+
self._proc: Optional[subprocess.Popen] = None
|
|
184
|
+
|
|
185
|
+
def _build_cmd(self) -> list[str]:
|
|
186
|
+
cmd = [
|
|
187
|
+
"python", "-m", "vllm.entrypoints.openai.api_server",
|
|
188
|
+
"--host", "0.0.0.0",
|
|
189
|
+
"--port", str(self.worker_port),
|
|
190
|
+
"--model", self.model,
|
|
191
|
+
]
|
|
192
|
+
for key, value in self.extra_server_args.items():
|
|
193
|
+
if value == "":
|
|
194
|
+
cmd.append(key)
|
|
195
|
+
else:
|
|
196
|
+
cmd.extend([key, *value.split()])
|
|
197
|
+
return cmd
|
|
198
|
+
|
|
199
|
+
def start(self):
|
|
200
|
+
cmd = self._build_cmd()
|
|
201
|
+
print(f"[vllm] starting: {shlex.join(cmd)}")
|
|
202
|
+
self._proc = subprocess.Popen(cmd)
|
|
203
|
+
wait_ready(
|
|
204
|
+
self._proc,
|
|
205
|
+
port=self.worker_port,
|
|
206
|
+
timeout=self.health_timeout,
|
|
207
|
+
poll_interval=self.health_poll_interval,
|
|
208
|
+
request_timeout=self.health_request_timeout,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def stop(self):
|
|
212
|
+
terminate_process(self._proc)
|
|
213
|
+
self._proc = None
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class RouterEndpoint(Endpoint):
|
|
217
|
+
"""Manages an SGLang router process for PD disaggregation."""
|
|
218
|
+
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
*,
|
|
222
|
+
pd_config: Sequence[tuple[str, str]],
|
|
223
|
+
worker_port: int = 8000,
|
|
224
|
+
router_port: int = 9000,
|
|
225
|
+
prefill_bootstrap_port: int = 8998,
|
|
226
|
+
api_key: Optional[str] = None,
|
|
227
|
+
health_timeout: float = 10 * 60,
|
|
228
|
+
health_poll_interval: float = 5.0,
|
|
229
|
+
):
|
|
230
|
+
super().__init__(base_url=f"http://localhost:{router_port}")
|
|
231
|
+
self.pd_config = list(pd_config)
|
|
232
|
+
self.worker_port = worker_port
|
|
233
|
+
self.router_port = router_port
|
|
234
|
+
self.prefill_bootstrap_port = prefill_bootstrap_port
|
|
235
|
+
self.api_key = api_key
|
|
236
|
+
self.health_timeout = health_timeout
|
|
237
|
+
self.health_poll_interval = health_poll_interval
|
|
238
|
+
self._proc: Optional[subprocess.Popen] = None
|
|
239
|
+
|
|
240
|
+
def _build_cmd(self) -> list[str]:
|
|
241
|
+
cmd = [
|
|
242
|
+
"python", "-m", "sglang_router.launch_router",
|
|
243
|
+
"--host", "0.0.0.0",
|
|
244
|
+
"--port", str(self.router_port),
|
|
245
|
+
"--prefill-policy", "cache_aware",
|
|
246
|
+
"--decode-policy", "round_robin",
|
|
247
|
+
"--max-concurrent-requests", "128",
|
|
248
|
+
"--rate-limit-tokens-per-second", "0",
|
|
249
|
+
"--queue-size", "0",
|
|
250
|
+
"--health-check-timeout-secs", "600",
|
|
251
|
+
"--log-level", "info",
|
|
252
|
+
"--disable-circuit-breaker",
|
|
253
|
+
"--request-timeout-secs", "3600",
|
|
254
|
+
]
|
|
255
|
+
|
|
256
|
+
if self.api_key is not None:
|
|
257
|
+
cmd.extend(["--api-key", self.api_key])
|
|
258
|
+
|
|
259
|
+
if any(role in ("prefill", "decode") for role, _ in self.pd_config):
|
|
260
|
+
cmd.append("--pd-disaggregation")
|
|
261
|
+
|
|
262
|
+
for role, node_ip in self.pd_config:
|
|
263
|
+
node_url = f"http://{node_ip}:{self.worker_port}"
|
|
264
|
+
if role == "prefill":
|
|
265
|
+
cmd.extend(["--prefill", node_url, str(self.prefill_bootstrap_port)])
|
|
266
|
+
elif role == "decode":
|
|
267
|
+
cmd.extend(["--decode", node_url])
|
|
268
|
+
elif role == "worker":
|
|
269
|
+
cmd.extend(["--worker-urls", node_url])
|
|
270
|
+
else:
|
|
271
|
+
raise ValueError(f"invalid pd_config role: {role}")
|
|
272
|
+
|
|
273
|
+
return cmd
|
|
274
|
+
|
|
275
|
+
def start(self):
|
|
276
|
+
for _, node_ip in self.pd_config:
|
|
277
|
+
_wait_ready_url(
|
|
278
|
+
f"http://{node_ip}:{self.worker_port}/health",
|
|
279
|
+
timeout=self.health_timeout,
|
|
280
|
+
poll_interval=self.health_poll_interval,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
cmd = self._build_cmd()
|
|
284
|
+
print(f"[router] starting: {shlex.join(cmd)}")
|
|
285
|
+
self._proc = subprocess.Popen(cmd)
|
|
286
|
+
_wait_ready_url(
|
|
287
|
+
f"http://localhost:{self.router_port}/health",
|
|
288
|
+
timeout=self.health_timeout,
|
|
289
|
+
poll_interval=self.health_poll_interval,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def stop(self):
|
|
293
|
+
terminate_process(self._proc)
|
|
294
|
+
self._proc = None
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# ---------------------------------------------------------------------------
|
|
298
|
+
# Shared helpers
|
|
299
|
+
# ---------------------------------------------------------------------------
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def wait_ready(
|
|
303
|
+
process: subprocess.Popen,
|
|
304
|
+
*,
|
|
305
|
+
port: int,
|
|
306
|
+
timeout: float,
|
|
307
|
+
health_path: str = "/health",
|
|
308
|
+
poll_interval: float = 5.0,
|
|
309
|
+
request_timeout: float = 5.0,
|
|
310
|
+
) -> None:
|
|
311
|
+
"""Poll SGLang's HTTP health endpoint until ready, raising if the process dies."""
|
|
312
|
+
deadline = time.time() + timeout
|
|
313
|
+
url = f"http://127.0.0.1:{port}{health_path}"
|
|
314
|
+
last_error = "no response yet"
|
|
315
|
+
|
|
316
|
+
while time.time() < deadline:
|
|
317
|
+
try:
|
|
318
|
+
error = _health_check(url, request_timeout=request_timeout, process=process)
|
|
319
|
+
except subprocess.CalledProcessError:
|
|
320
|
+
raise
|
|
321
|
+
if error is None:
|
|
322
|
+
return
|
|
323
|
+
last_error = error
|
|
324
|
+
time.sleep(poll_interval)
|
|
325
|
+
|
|
326
|
+
raise TimeoutError(
|
|
327
|
+
f"SGLang health check timed out after {timeout}s for {url}. "
|
|
328
|
+
f"Last error: {last_error}"
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def warmup_chat_completions(
|
|
333
|
+
*,
|
|
334
|
+
port: int,
|
|
335
|
+
payload: Mapping[str, Any],
|
|
336
|
+
headers: Mapping[str, str] | None = None,
|
|
337
|
+
successful_requests: int = 3,
|
|
338
|
+
request_timeout: float = 30.0,
|
|
339
|
+
max_attempts_per_request: int = 2,
|
|
340
|
+
retry_delay: float = 1.0,
|
|
341
|
+
) -> None:
|
|
342
|
+
"""Warm the OpenAI chat completions endpoint with strict retries."""
|
|
343
|
+
url = f"http://127.0.0.1:{port}/v1/chat/completions"
|
|
344
|
+
request_headers = {"Content-Type": "application/json"}
|
|
345
|
+
if headers:
|
|
346
|
+
request_headers.update(headers)
|
|
347
|
+
|
|
348
|
+
for request_idx in range(successful_requests):
|
|
349
|
+
for attempt in range(max_attempts_per_request):
|
|
350
|
+
try:
|
|
351
|
+
_post_json(url, payload=payload, headers=request_headers, timeout=request_timeout)
|
|
352
|
+
break
|
|
353
|
+
except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError, OSError) as exc:
|
|
354
|
+
if attempt + 1 == max_attempts_per_request:
|
|
355
|
+
detail = _format_http_error(exc) if isinstance(exc, urllib.error.HTTPError) else f"{type(exc).__name__}: {exc}"
|
|
356
|
+
raise RuntimeError(
|
|
357
|
+
f"warmup request {request_idx + 1}/{successful_requests}: {detail}"
|
|
358
|
+
) from exc
|
|
359
|
+
|
|
360
|
+
time.sleep(retry_delay)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def start_heartbeat_thread(
|
|
364
|
+
health_check_fn: Callable[[], str | None],
|
|
365
|
+
*,
|
|
366
|
+
on_failure: Callable[[], None],
|
|
367
|
+
poll_interval: float = 5.0,
|
|
368
|
+
max_consecutive_failures: int = 3,
|
|
369
|
+
) -> threading.Thread:
|
|
370
|
+
def _loop():
|
|
371
|
+
consecutive_failures = 0
|
|
372
|
+
while True:
|
|
373
|
+
time.sleep(poll_interval)
|
|
374
|
+
try:
|
|
375
|
+
error = health_check_fn()
|
|
376
|
+
except subprocess.CalledProcessError as exc:
|
|
377
|
+
print(f"[heartbeat] server process exited with code {exc.returncode}")
|
|
378
|
+
on_failure()
|
|
379
|
+
return
|
|
380
|
+
|
|
381
|
+
if error is None:
|
|
382
|
+
consecutive_failures = 0
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
consecutive_failures += 1
|
|
386
|
+
print(f"[heartbeat] {error} ({consecutive_failures}/{max_consecutive_failures})")
|
|
387
|
+
if consecutive_failures >= max_consecutive_failures:
|
|
388
|
+
print("[heartbeat] sustained health-check failure, invoking on_failure")
|
|
389
|
+
on_failure()
|
|
390
|
+
return
|
|
391
|
+
|
|
392
|
+
t = threading.Thread(target=_loop, daemon=True)
|
|
393
|
+
t.start()
|
|
394
|
+
return t
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def terminate_process(
|
|
398
|
+
process: subprocess.Popen | None,
|
|
399
|
+
*,
|
|
400
|
+
terminate_timeout: float = 10.0,
|
|
401
|
+
) -> None:
|
|
402
|
+
"""Terminate a subprocess cleanly, then kill it if needed."""
|
|
403
|
+
if process is None or process.poll() is not None:
|
|
404
|
+
return
|
|
405
|
+
process.terminate()
|
|
406
|
+
try:
|
|
407
|
+
process.wait(timeout=terminate_timeout)
|
|
408
|
+
except subprocess.TimeoutExpired:
|
|
409
|
+
process.kill()
|
|
410
|
+
process.wait()
|
|
411
|
+
|
|
412
|
+
# ---------------------------------------------------------------------------
|
|
413
|
+
# SGLang snapshot helpers
|
|
414
|
+
# ---------------------------------------------------------------------------
|
|
415
|
+
def sgl_sleep(port: int) -> None:
|
|
416
|
+
_post_json(
|
|
417
|
+
f"http://127.0.0.1:{port}/release_memory_occupation",
|
|
418
|
+
payload={},
|
|
419
|
+
headers={"Content-Type": "application/json"},
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def sgl_wake(port: int) -> None:
|
|
424
|
+
_post_json(
|
|
425
|
+
f"http://127.0.0.1:{port}/resume_memory_occupation",
|
|
426
|
+
payload={},
|
|
427
|
+
headers={"Content-Type": "application/json"},
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
# ---------------------------------------------------------------------------
|
|
431
|
+
# Internal
|
|
432
|
+
# ---------------------------------------------------------------------------
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _health_check(
|
|
436
|
+
url: str,
|
|
437
|
+
*,
|
|
438
|
+
request_timeout: float = 5.0,
|
|
439
|
+
process: subprocess.Popen | None = None,
|
|
440
|
+
) -> str | None:
|
|
441
|
+
"""Single health-check attempt.
|
|
442
|
+
|
|
443
|
+
Returns None on success, or an error description string on failure.
|
|
444
|
+
Raises subprocess.CalledProcessError if *process* has exited.
|
|
445
|
+
"""
|
|
446
|
+
if process is not None:
|
|
447
|
+
_raise_if_exited(process)
|
|
448
|
+
try:
|
|
449
|
+
status = _get_status(url, timeout=request_timeout)
|
|
450
|
+
if 200 <= status < 300:
|
|
451
|
+
return None
|
|
452
|
+
return f"health check returned status {status}"
|
|
453
|
+
except urllib.error.HTTPError as exc:
|
|
454
|
+
return _format_http_error(exc)
|
|
455
|
+
except (urllib.error.URLError, TimeoutError, OSError) as exc:
|
|
456
|
+
return f"{type(exc).__name__}: {exc}"
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def _raise_if_exited(process: subprocess.Popen) -> None:
|
|
460
|
+
if (rc := process.poll()) is not None:
|
|
461
|
+
raise subprocess.CalledProcessError(rc, cmd=process.args)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def _wait_ready_url(
|
|
465
|
+
url: str,
|
|
466
|
+
*,
|
|
467
|
+
timeout: float,
|
|
468
|
+
poll_interval: float = 5.0,
|
|
469
|
+
request_timeout: float = 5.0,
|
|
470
|
+
) -> None:
|
|
471
|
+
"""Poll a URL until it returns 2xx (no subprocess to watch)."""
|
|
472
|
+
deadline = time.time() + timeout
|
|
473
|
+
last_error = "no response yet"
|
|
474
|
+
while time.time() < deadline:
|
|
475
|
+
error = _health_check(url, request_timeout=request_timeout)
|
|
476
|
+
if error is None:
|
|
477
|
+
return
|
|
478
|
+
last_error = error
|
|
479
|
+
time.sleep(poll_interval)
|
|
480
|
+
raise TimeoutError(f"Timed out after {timeout}s waiting for {url}. Last error: {last_error}")
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _get_status(url: str, *, timeout: float) -> int:
|
|
484
|
+
req = urllib.request.Request(url, method="GET")
|
|
485
|
+
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
486
|
+
return resp.getcode()
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def _post_json(
|
|
490
|
+
url: str,
|
|
491
|
+
*,
|
|
492
|
+
payload: Mapping[str, Any],
|
|
493
|
+
headers: Mapping[str, str],
|
|
494
|
+
timeout: float | None = None,
|
|
495
|
+
) -> int:
|
|
496
|
+
body = json.dumps(payload).encode("utf-8")
|
|
497
|
+
req = urllib.request.Request(url, data=body, headers=dict(headers), method="POST")
|
|
498
|
+
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
499
|
+
return resp.getcode()
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def _format_http_error(exc: urllib.error.HTTPError) -> str:
|
|
503
|
+
body = exc.read().decode("utf-8", errors="replace").strip()
|
|
504
|
+
if body:
|
|
505
|
+
return f"status {exc.code}: {body[:500]}"
|
|
506
|
+
return f"status {exc.code}"
|