chimeraforge 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chimeraforge/__init__.py +8 -0
- chimeraforge/bench/__init__.py +43 -0
- chimeraforge/bench/backends/__init__.py +33 -0
- chimeraforge/bench/backends/base.py +56 -0
- chimeraforge/bench/backends/ollama.py +117 -0
- chimeraforge/bench/backends/tgi.py +149 -0
- chimeraforge/bench/backends/vllm.py +123 -0
- chimeraforge/bench/metrics.py +192 -0
- chimeraforge/bench/profiles.py +55 -0
- chimeraforge/bench/prompts.py +32 -0
- chimeraforge/bench/runner.py +341 -0
- chimeraforge/cli.py +852 -0
- chimeraforge/compare/__init__.py +23 -0
- chimeraforge/compare/comparator.py +281 -0
- chimeraforge/eval/__init__.py +56 -0
- chimeraforge/eval/metrics.py +334 -0
- chimeraforge/eval/runner.py +217 -0
- chimeraforge/eval/tasks.py +190 -0
- chimeraforge/planner/__init__.py +29 -0
- chimeraforge/planner/constants.py +45 -0
- chimeraforge/planner/data/fitted_models.json +138 -0
- chimeraforge/planner/engine.py +181 -0
- chimeraforge/planner/formatter.py +161 -0
- chimeraforge/planner/hardware.py +68 -0
- chimeraforge/planner/models.py +407 -0
- chimeraforge/refit/__init__.py +41 -0
- chimeraforge/refit/fitter.py +480 -0
- chimeraforge/refit/validator.py +278 -0
- chimeraforge/report/__init__.py +54 -0
- chimeraforge/report/analysis.py +244 -0
- chimeraforge/report/generator.py +753 -0
- chimeraforge-0.2.0.dist-info/METADATA +726 -0
- chimeraforge-0.2.0.dist-info/RECORD +37 -0
- chimeraforge-0.2.0.dist-info/WHEEL +5 -0
- chimeraforge-0.2.0.dist-info/entry_points.txt +2 -0
- chimeraforge-0.2.0.dist-info/licenses/LICENSE +21 -0
- chimeraforge-0.2.0.dist-info/top_level.txt +1 -0
chimeraforge/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""ChimeraForge Benchmarking Engine — run real LLM inference benchmarks.
|
|
2
|
+
|
|
3
|
+
Public API:
|
|
4
|
+
run_benchmark Run a single benchmark configuration
|
|
5
|
+
run_quant_sweep Sweep across quantization levels
|
|
6
|
+
run_context_sweep Sweep across context lengths
|
|
7
|
+
save_results Persist results as JSON
|
|
8
|
+
BenchmarkResult Top-level result container
|
|
9
|
+
RunMetrics Per-run timing metrics
|
|
10
|
+
AggregateMetrics Statistical summary
|
|
11
|
+
get_backend Instantiate a backend adapter
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from chimeraforge.bench.backends import get_backend
|
|
15
|
+
from chimeraforge.bench.metrics import (
|
|
16
|
+
AggregateMetrics,
|
|
17
|
+
BenchmarkResult,
|
|
18
|
+
EnvironmentInfo,
|
|
19
|
+
RunMetrics,
|
|
20
|
+
StatSummary,
|
|
21
|
+
)
|
|
22
|
+
from chimeraforge.bench.profiles import WorkloadProfile, get_profile
|
|
23
|
+
from chimeraforge.bench.runner import (
|
|
24
|
+
run_benchmark,
|
|
25
|
+
run_context_sweep,
|
|
26
|
+
run_quant_sweep,
|
|
27
|
+
save_results,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"AggregateMetrics",
|
|
32
|
+
"BenchmarkResult",
|
|
33
|
+
"EnvironmentInfo",
|
|
34
|
+
"RunMetrics",
|
|
35
|
+
"StatSummary",
|
|
36
|
+
"WorkloadProfile",
|
|
37
|
+
"get_backend",
|
|
38
|
+
"get_profile",
|
|
39
|
+
"run_benchmark",
|
|
40
|
+
"run_context_sweep",
|
|
41
|
+
"run_quant_sweep",
|
|
42
|
+
"save_results",
|
|
43
|
+
]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Backend registry — maps backend names to adapter classes."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from chimeraforge.bench.backends.base import Backend
|
|
6
|
+
from chimeraforge.bench.backends.ollama import OllamaBackend
|
|
7
|
+
from chimeraforge.bench.backends.tgi import TGIBackend
|
|
8
|
+
from chimeraforge.bench.backends.vllm import VLLMBackend
|
|
9
|
+
|
|
10
|
+
BACKEND_REGISTRY: dict[str, type[Backend]] = {
|
|
11
|
+
"ollama": OllamaBackend,
|
|
12
|
+
"vllm": VLLMBackend,
|
|
13
|
+
"tgi": TGIBackend,
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_backend(name: str, **kwargs: object) -> Backend:
|
|
18
|
+
"""Instantiate a backend by name.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
name: Backend identifier ("ollama", "vllm", or "tgi").
|
|
22
|
+
**kwargs: Passed to the backend constructor (e.g. base_url).
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Configured Backend instance.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ValueError: If backend name is unknown.
|
|
29
|
+
"""
|
|
30
|
+
cls = BACKEND_REGISTRY.get(name)
|
|
31
|
+
if cls is None:
|
|
32
|
+
raise ValueError(f"Unknown backend: {name}. Available: {list(BACKEND_REGISTRY)}")
|
|
33
|
+
return cls(**kwargs)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Abstract backend interface for LLM serving backends."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
from chimeraforge.bench.metrics import RunMetrics
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Backend(ABC):
|
|
11
|
+
"""Abstract interface for LLM serving backends.
|
|
12
|
+
|
|
13
|
+
Each backend adapter translates the common generate() call into
|
|
14
|
+
the backend-specific HTTP API, extracts timing metrics from the
|
|
15
|
+
response, and returns a standardized RunMetrics object.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
name: str
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
async def health_check(self) -> tuple[bool, str]:
|
|
22
|
+
"""Check if the backend is reachable.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tuple of (ok, message). If not ok, message describes the error.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def check_model(self, model: str) -> tuple[bool, str]:
|
|
30
|
+
"""Check if a model is available on the backend.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Tuple of (ok, message). If not ok, message includes remediation.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
async def generate(
|
|
38
|
+
self,
|
|
39
|
+
model: str,
|
|
40
|
+
prompt: str,
|
|
41
|
+
options: dict | None = None,
|
|
42
|
+
) -> RunMetrics:
|
|
43
|
+
"""Run a single generation and return metrics.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model: Model name or tag.
|
|
47
|
+
prompt: Input prompt text.
|
|
48
|
+
options: Backend-specific generation options.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
RunMetrics with timing and token counts.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
async def get_version(self) -> str | None:
|
|
56
|
+
"""Return the backend version string, or None if unavailable."""
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Ollama backend adapter.
|
|
2
|
+
|
|
3
|
+
Implements the Backend interface against the Ollama REST API.
|
|
4
|
+
Uses stream=False to extract eval_count / eval_duration / prompt_eval_duration
|
|
5
|
+
from the final JSON response, matching the banterhearts measurement pattern.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from chimeraforge.bench.backends.base import Backend
|
|
13
|
+
from chimeraforge.bench.metrics import RunMetrics
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OllamaBackend(Backend):
|
|
17
|
+
"""Ollama serving backend (http://localhost:11434 by default)."""
|
|
18
|
+
|
|
19
|
+
name = "ollama"
|
|
20
|
+
|
|
21
|
+
def __init__(self, base_url: str = "http://localhost:11434") -> None:
|
|
22
|
+
self.base_url = base_url.rstrip("/")
|
|
23
|
+
self._client: httpx.AsyncClient | None = None
|
|
24
|
+
|
|
25
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
26
|
+
if self._client is None or self._client.is_closed:
|
|
27
|
+
self._client = httpx.AsyncClient(timeout=300)
|
|
28
|
+
return self._client
|
|
29
|
+
|
|
30
|
+
async def close(self) -> None:
|
|
31
|
+
"""Close the underlying HTTP client."""
|
|
32
|
+
if self._client and not self._client.is_closed:
|
|
33
|
+
await self._client.aclose()
|
|
34
|
+
|
|
35
|
+
async def health_check(self) -> tuple[bool, str]:
|
|
36
|
+
"""GET / -- Ollama returns 'Ollama is running'."""
|
|
37
|
+
try:
|
|
38
|
+
client = await self._get_client()
|
|
39
|
+
resp = await client.get(f"{self.base_url}/", timeout=10)
|
|
40
|
+
if resp.status_code == 200:
|
|
41
|
+
return True, "Ollama is running"
|
|
42
|
+
return False, f"Ollama returned status {resp.status_code}"
|
|
43
|
+
except httpx.ConnectError:
|
|
44
|
+
return False, f"Ollama not running at {self.base_url}"
|
|
45
|
+
except httpx.TimeoutException:
|
|
46
|
+
return False, f"Ollama timed out at {self.base_url}"
|
|
47
|
+
|
|
48
|
+
async def check_model(self, model: str) -> tuple[bool, str]:
|
|
49
|
+
"""POST /api/show to verify model availability."""
|
|
50
|
+
try:
|
|
51
|
+
client = await self._get_client()
|
|
52
|
+
resp = await client.post(
|
|
53
|
+
f"{self.base_url}/api/show",
|
|
54
|
+
json={"name": model},
|
|
55
|
+
timeout=30,
|
|
56
|
+
)
|
|
57
|
+
if resp.status_code == 200:
|
|
58
|
+
return True, ""
|
|
59
|
+
return False, f"Model not found. Run: ollama pull {model}"
|
|
60
|
+
except httpx.ConnectError:
|
|
61
|
+
return False, f"Ollama not running at {self.base_url}"
|
|
62
|
+
|
|
63
|
+
async def generate(
|
|
64
|
+
self,
|
|
65
|
+
model: str,
|
|
66
|
+
prompt: str,
|
|
67
|
+
options: dict | None = None,
|
|
68
|
+
) -> RunMetrics:
|
|
69
|
+
"""POST /api/generate with stream=False, extract timing metrics."""
|
|
70
|
+
payload: dict = {
|
|
71
|
+
"model": model,
|
|
72
|
+
"prompt": prompt,
|
|
73
|
+
"stream": False,
|
|
74
|
+
}
|
|
75
|
+
if options:
|
|
76
|
+
payload["options"] = options
|
|
77
|
+
|
|
78
|
+
client = await self._get_client()
|
|
79
|
+
resp = await client.post(
|
|
80
|
+
f"{self.base_url}/api/generate",
|
|
81
|
+
json=payload,
|
|
82
|
+
timeout=300,
|
|
83
|
+
)
|
|
84
|
+
resp.raise_for_status()
|
|
85
|
+
data = resp.json()
|
|
86
|
+
|
|
87
|
+
eval_count = data.get("eval_count", 0)
|
|
88
|
+
eval_duration_ns = data.get("eval_duration", 0)
|
|
89
|
+
prompt_eval_duration_ns = data.get("prompt_eval_duration", 0)
|
|
90
|
+
total_duration_ns = data.get("total_duration", 0)
|
|
91
|
+
|
|
92
|
+
eval_duration_ms = eval_duration_ns / 1e6
|
|
93
|
+
prompt_eval_duration_ms = prompt_eval_duration_ns / 1e6
|
|
94
|
+
total_duration_ms = total_duration_ns / 1e6
|
|
95
|
+
|
|
96
|
+
throughput = eval_count / (eval_duration_ns / 1e9) if eval_duration_ns > 0 else 0.0
|
|
97
|
+
ttft = prompt_eval_duration_ms
|
|
98
|
+
|
|
99
|
+
return RunMetrics(
|
|
100
|
+
tokens_generated=eval_count,
|
|
101
|
+
throughput_tps=throughput,
|
|
102
|
+
ttft_ms=ttft,
|
|
103
|
+
total_duration_ms=total_duration_ms,
|
|
104
|
+
prompt_eval_duration_ms=prompt_eval_duration_ms,
|
|
105
|
+
eval_duration_ms=eval_duration_ms,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
async def get_version(self) -> str | None:
|
|
109
|
+
"""GET /api/version."""
|
|
110
|
+
try:
|
|
111
|
+
client = await self._get_client()
|
|
112
|
+
resp = await client.get(f"{self.base_url}/api/version", timeout=10)
|
|
113
|
+
if resp.status_code == 200:
|
|
114
|
+
return resp.json().get("version")
|
|
115
|
+
except Exception:
|
|
116
|
+
pass
|
|
117
|
+
return None
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""TGI (Text Generation Inference) backend adapter.
|
|
2
|
+
|
|
3
|
+
Implements the Backend interface against the HuggingFace TGI HTTP API.
|
|
4
|
+
TGI exposes /generate with detailed timing in the response.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import time
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
from chimeraforge.bench.backends.base import Backend
|
|
14
|
+
from chimeraforge.bench.metrics import RunMetrics
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TGIBackend(Backend):
|
|
18
|
+
"""HuggingFace TGI serving backend (http://localhost:8080 by default)."""
|
|
19
|
+
|
|
20
|
+
name = "tgi"
|
|
21
|
+
|
|
22
|
+
def __init__(self, base_url: str = "http://localhost:8080") -> None:
|
|
23
|
+
self.base_url = base_url.rstrip("/")
|
|
24
|
+
self._client: httpx.AsyncClient | None = None
|
|
25
|
+
|
|
26
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
27
|
+
if self._client is None or self._client.is_closed:
|
|
28
|
+
self._client = httpx.AsyncClient(timeout=300)
|
|
29
|
+
return self._client
|
|
30
|
+
|
|
31
|
+
async def close(self) -> None:
|
|
32
|
+
"""Close the underlying HTTP client."""
|
|
33
|
+
if self._client and not self._client.is_closed:
|
|
34
|
+
await self._client.aclose()
|
|
35
|
+
|
|
36
|
+
async def health_check(self) -> tuple[bool, str]:
|
|
37
|
+
"""GET /health to check TGI availability."""
|
|
38
|
+
try:
|
|
39
|
+
client = await self._get_client()
|
|
40
|
+
resp = await client.get(f"{self.base_url}/health", timeout=10)
|
|
41
|
+
if resp.status_code == 200:
|
|
42
|
+
return True, "TGI is running"
|
|
43
|
+
return False, f"TGI returned status {resp.status_code}"
|
|
44
|
+
except httpx.ConnectError:
|
|
45
|
+
return False, f"TGI not running at {self.base_url}"
|
|
46
|
+
except httpx.TimeoutException:
|
|
47
|
+
return False, f"TGI timed out at {self.base_url}"
|
|
48
|
+
|
|
49
|
+
async def check_model(self, model: str) -> tuple[bool, str]:
|
|
50
|
+
"""GET /info to verify model is loaded.
|
|
51
|
+
|
|
52
|
+
TGI loads a single model at startup, so we verify the loaded
|
|
53
|
+
model_id matches exactly or that the model name appears as a
|
|
54
|
+
path component of the loaded model_id (e.g. "llama-3b" matches
|
|
55
|
+
"meta-llama/Llama-3.2-3B-Instruct").
|
|
56
|
+
"""
|
|
57
|
+
try:
|
|
58
|
+
client = await self._get_client()
|
|
59
|
+
resp = await client.get(f"{self.base_url}/info", timeout=30)
|
|
60
|
+
if resp.status_code != 200:
|
|
61
|
+
return False, f"Cannot get model info (status {resp.status_code})"
|
|
62
|
+
data = resp.json()
|
|
63
|
+
loaded = data.get("model_id", "")
|
|
64
|
+
# Exact match or model is a path component of loaded model_id
|
|
65
|
+
if model == loaded:
|
|
66
|
+
return True, ""
|
|
67
|
+
# Check if model name appears after a "/" in the loaded ID
|
|
68
|
+
loaded_parts = loaded.lower().split("/")
|
|
69
|
+
if model.lower() in loaded_parts:
|
|
70
|
+
return True, ""
|
|
71
|
+
return False, (
|
|
72
|
+
f"TGI has '{loaded}' loaded, not '{model}'. Restart TGI with the desired model."
|
|
73
|
+
)
|
|
74
|
+
except httpx.ConnectError:
|
|
75
|
+
return False, f"TGI not running at {self.base_url}"
|
|
76
|
+
|
|
77
|
+
async def generate(
|
|
78
|
+
self,
|
|
79
|
+
model: str,
|
|
80
|
+
prompt: str,
|
|
81
|
+
options: dict | None = None,
|
|
82
|
+
) -> RunMetrics:
|
|
83
|
+
"""POST /generate, extract timing from details."""
|
|
84
|
+
opts = options or {}
|
|
85
|
+
payload = {
|
|
86
|
+
"inputs": prompt,
|
|
87
|
+
"parameters": {
|
|
88
|
+
"max_new_tokens": opts.get("max_new_tokens", 256),
|
|
89
|
+
"temperature": opts.get("temperature", 0.7),
|
|
90
|
+
},
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
client = await self._get_client()
|
|
94
|
+
t0 = time.perf_counter()
|
|
95
|
+
resp = await client.post(
|
|
96
|
+
f"{self.base_url}/generate",
|
|
97
|
+
json=payload,
|
|
98
|
+
timeout=300,
|
|
99
|
+
)
|
|
100
|
+
total_s = time.perf_counter() - t0
|
|
101
|
+
resp.raise_for_status()
|
|
102
|
+
data = resp.json()
|
|
103
|
+
|
|
104
|
+
# TGI response structure
|
|
105
|
+
details = data.get("details", {})
|
|
106
|
+
generated_tokens = details.get("generated_tokens", 0)
|
|
107
|
+
total_duration_ms = total_s * 1000
|
|
108
|
+
|
|
109
|
+
# TGI timing fields vary by version; try multiple known field names
|
|
110
|
+
prefill_time = (
|
|
111
|
+
details.get("prefill_time") # TGI 2.x (seconds)
|
|
112
|
+
or details.get("prefill_duration_ns", 0) / 1e9 # hypothetical ns
|
|
113
|
+
)
|
|
114
|
+
decode_time = (
|
|
115
|
+
details.get("decode_time") # TGI 2.x (seconds)
|
|
116
|
+
or details.get("decode_duration_ns", 0) / 1e9
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if prefill_time and prefill_time > 0:
|
|
120
|
+
prompt_eval_duration_ms = prefill_time * 1000
|
|
121
|
+
eval_duration_ms = decode_time * 1000 if decode_time else total_duration_ms
|
|
122
|
+
else:
|
|
123
|
+
# Timing not available from TGI response
|
|
124
|
+
prompt_eval_duration_ms = 0.0
|
|
125
|
+
eval_duration_ms = total_duration_ms
|
|
126
|
+
|
|
127
|
+
throughput = generated_tokens / total_s if total_s > 0 else 0.0
|
|
128
|
+
ttft = prompt_eval_duration_ms if prompt_eval_duration_ms > 0 else -1.0
|
|
129
|
+
|
|
130
|
+
return RunMetrics(
|
|
131
|
+
tokens_generated=generated_tokens,
|
|
132
|
+
throughput_tps=throughput,
|
|
133
|
+
ttft_ms=ttft,
|
|
134
|
+
total_duration_ms=total_duration_ms,
|
|
135
|
+
prompt_eval_duration_ms=prompt_eval_duration_ms,
|
|
136
|
+
eval_duration_ms=eval_duration_ms,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
async def get_version(self) -> str | None:
|
|
140
|
+
"""GET /info and extract version."""
|
|
141
|
+
try:
|
|
142
|
+
client = await self._get_client()
|
|
143
|
+
resp = await client.get(f"{self.base_url}/info", timeout=10)
|
|
144
|
+
if resp.status_code == 200:
|
|
145
|
+
data = resp.json()
|
|
146
|
+
return data.get("version")
|
|
147
|
+
except Exception:
|
|
148
|
+
pass
|
|
149
|
+
return None
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""vLLM backend adapter.
|
|
2
|
+
|
|
3
|
+
Implements the Backend interface against the vLLM OpenAI-compatible API.
|
|
4
|
+
vLLM exposes /v1/completions with usage stats in the response.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import time
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
from chimeraforge.bench.backends.base import Backend
|
|
14
|
+
from chimeraforge.bench.metrics import RunMetrics
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class VLLMBackend(Backend):
|
|
18
|
+
"""vLLM serving backend (OpenAI-compatible, http://localhost:8000 by default)."""
|
|
19
|
+
|
|
20
|
+
name = "vllm"
|
|
21
|
+
|
|
22
|
+
def __init__(self, base_url: str = "http://localhost:8000") -> None:
|
|
23
|
+
self.base_url = base_url.rstrip("/")
|
|
24
|
+
self._client: httpx.AsyncClient | None = None
|
|
25
|
+
|
|
26
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
27
|
+
if self._client is None or self._client.is_closed:
|
|
28
|
+
self._client = httpx.AsyncClient(timeout=300)
|
|
29
|
+
return self._client
|
|
30
|
+
|
|
31
|
+
async def close(self) -> None:
|
|
32
|
+
"""Close the underlying HTTP client."""
|
|
33
|
+
if self._client and not self._client.is_closed:
|
|
34
|
+
await self._client.aclose()
|
|
35
|
+
|
|
36
|
+
async def health_check(self) -> tuple[bool, str]:
|
|
37
|
+
"""GET /health or /v1/models to check availability."""
|
|
38
|
+
try:
|
|
39
|
+
client = await self._get_client()
|
|
40
|
+
resp = await client.get(f"{self.base_url}/health", timeout=10)
|
|
41
|
+
if resp.status_code == 200:
|
|
42
|
+
return True, "vLLM is running"
|
|
43
|
+
# Fallback: try /v1/models
|
|
44
|
+
resp = await client.get(f"{self.base_url}/v1/models", timeout=10)
|
|
45
|
+
if resp.status_code == 200:
|
|
46
|
+
return True, "vLLM is running"
|
|
47
|
+
return False, f"vLLM returned status {resp.status_code}"
|
|
48
|
+
except httpx.ConnectError:
|
|
49
|
+
return False, f"vLLM not running at {self.base_url}"
|
|
50
|
+
except httpx.TimeoutException:
|
|
51
|
+
return False, f"vLLM timed out at {self.base_url}"
|
|
52
|
+
|
|
53
|
+
async def check_model(self, model: str) -> tuple[bool, str]:
|
|
54
|
+
"""GET /v1/models and check if model is listed."""
|
|
55
|
+
try:
|
|
56
|
+
client = await self._get_client()
|
|
57
|
+
resp = await client.get(f"{self.base_url}/v1/models", timeout=30)
|
|
58
|
+
if resp.status_code != 200:
|
|
59
|
+
return False, f"Cannot list models (status {resp.status_code})"
|
|
60
|
+
data = resp.json()
|
|
61
|
+
model_ids = [m["id"] for m in data.get("data", [])]
|
|
62
|
+
if model in model_ids:
|
|
63
|
+
return True, ""
|
|
64
|
+
return False, (
|
|
65
|
+
f"Model '{model}' not found. Available: {', '.join(model_ids) or 'none'}"
|
|
66
|
+
)
|
|
67
|
+
except httpx.ConnectError:
|
|
68
|
+
return False, f"vLLM not running at {self.base_url}"
|
|
69
|
+
|
|
70
|
+
async def generate(
|
|
71
|
+
self,
|
|
72
|
+
model: str,
|
|
73
|
+
prompt: str,
|
|
74
|
+
options: dict | None = None,
|
|
75
|
+
) -> RunMetrics:
|
|
76
|
+
"""POST /v1/completions, extract usage and timing."""
|
|
77
|
+
opts = options or {}
|
|
78
|
+
payload = {
|
|
79
|
+
"model": model,
|
|
80
|
+
"prompt": prompt,
|
|
81
|
+
"max_tokens": opts.get("max_tokens", 256),
|
|
82
|
+
"temperature": opts.get("temperature", 0.7),
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
client = await self._get_client()
|
|
86
|
+
t0 = time.perf_counter()
|
|
87
|
+
resp = await client.post(
|
|
88
|
+
f"{self.base_url}/v1/completions",
|
|
89
|
+
json=payload,
|
|
90
|
+
timeout=300,
|
|
91
|
+
)
|
|
92
|
+
total_s = time.perf_counter() - t0
|
|
93
|
+
resp.raise_for_status()
|
|
94
|
+
data = resp.json()
|
|
95
|
+
|
|
96
|
+
usage = data.get("usage", {})
|
|
97
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
98
|
+
total_duration_ms = total_s * 1000
|
|
99
|
+
|
|
100
|
+
# vLLM non-streaming API doesn't expose TTFT; throughput from wall clock
|
|
101
|
+
eval_duration_ms = total_duration_ms
|
|
102
|
+
throughput = completion_tokens / total_s if total_s > 0 else 0.0
|
|
103
|
+
|
|
104
|
+
return RunMetrics(
|
|
105
|
+
tokens_generated=completion_tokens,
|
|
106
|
+
throughput_tps=throughput,
|
|
107
|
+
ttft_ms=-1.0, # Not measurable without streaming
|
|
108
|
+
total_duration_ms=total_duration_ms,
|
|
109
|
+
prompt_eval_duration_ms=0.0,
|
|
110
|
+
eval_duration_ms=eval_duration_ms,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
async def get_version(self) -> str | None:
|
|
114
|
+
"""GET /version or extract from /v1/models metadata."""
|
|
115
|
+
try:
|
|
116
|
+
client = await self._get_client()
|
|
117
|
+
resp = await client.get(f"{self.base_url}/version", timeout=10)
|
|
118
|
+
if resp.status_code == 200:
|
|
119
|
+
data = resp.json()
|
|
120
|
+
return data.get("version", str(data))
|
|
121
|
+
except Exception:
|
|
122
|
+
pass
|
|
123
|
+
return None
|