vllm-judge 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vllm_judge/cli.py ADDED
@@ -0,0 +1,288 @@
1
+ """
2
+ Command-line interface for vLLM Judge.
3
+ """
4
+ import asyncio
5
+ import json
6
+ import sys
7
+ from typing import Optional
8
+ import click
9
+
10
+ from vllm_judge import Judge
11
+ from vllm_judge.models import JudgeConfig
12
+ from vllm_judge.api.server import start_server as start_api_server
13
+ from vllm_judge.api.client import JudgeClient
14
+ from vllm_judge.metrics import BUILTIN_METRICS
15
+
16
+
17
+ @click.group()
18
+ def cli():
19
+ """vLLM Judge - LLM-as-a-Judge evaluation tool."""
20
+ pass
21
+
22
+
23
+ @cli.command()
24
+ @click.option('--base-url', required=True, help='vLLM server URL')
25
+ @click.option('--model', help='Model name/path (auto-detected if not provided)')
26
+ @click.option('--host', default='0.0.0.0', help='API server host')
27
+ @click.option('--port', default=8080, help='API server port')
28
+ @click.option('--reload', is_flag=True, help='Enable auto-reload for development')
29
+ @click.option('--max-concurrent', default=50, help='Maximum concurrent requests')
30
+ @click.option('--timeout', default=30.0, help='Request timeout in seconds')
31
+ def serve(base_url: str, model: str, host: str, port: int, reload: bool, max_concurrent: int, timeout: float):
32
+ """Start the Judge API server."""
33
+ click.echo(f"Starting vLLM Judge API server...")
34
+ click.echo(f"Base URL: {base_url}")
35
+ click.echo(f"Model: {model}")
36
+ click.echo(f"Server: http://{host}:{port}")
37
+
38
+ start_api_server(
39
+ base_url=base_url,
40
+ model=model,
41
+ host=host,
42
+ port=port,
43
+ reload=reload,
44
+ max_concurrent=max_concurrent,
45
+ timeout=timeout
46
+ )
47
+
48
+
49
+ @cli.command()
50
+ @click.option('--api-url', help='Judge API URL (if using remote server)')
51
+ @click.option('--base-url', help='vLLM server URL (if using local)')
52
+ @click.option('--model', help='Model name (if using local)')
53
+ @click.option('--response', required=True, help='Text to evaluate')
54
+ @click.option('--criteria', help='Evaluation criteria')
55
+ @click.option('--metric', help='Pre-defined metric name')
56
+ @click.option('--scale', nargs=2, type=int, help='Numeric scale (min max)')
57
+ @click.option('--rubric', help='Evaluation rubric')
58
+ @click.option('--context', help='Additional context')
59
+ @click.option('--output', type=click.Choice(['json', 'text']), default='text', help='Output format')
60
+ def evaluate(
61
+ api_url: Optional[str],
62
+ base_url: Optional[str],
63
+ model: Optional[str],
64
+ response: str,
65
+ criteria: Optional[str],
66
+ metric: Optional[str],
67
+ scale: Optional[tuple],
68
+ rubric: Optional[str],
69
+ context: Optional[str],
70
+ output: str
71
+ ):
72
+ """Evaluate a single response."""
73
+ async def run_evaluation():
74
+ if api_url:
75
+ # Use API client
76
+ async with JudgeClient(api_url) as client:
77
+ result = await client.evaluate(
78
+ response=response,
79
+ criteria=criteria,
80
+ metric=metric,
81
+ scale=scale,
82
+ rubric=rubric,
83
+ context=context
84
+ )
85
+ else:
86
+ # Use local Judge
87
+ if not base_url:
88
+ click.echo("Error: Either --api-url or --base-url is required", err=True)
89
+ sys.exit(1)
90
+
91
+ judge = Judge.from_url(base_url, model=model)
92
+ async with judge:
93
+ result = await judge.evaluate(
94
+ response=response,
95
+ criteria=criteria,
96
+ metric=metric,
97
+ scale=scale,
98
+ rubric=rubric,
99
+ context=context
100
+ )
101
+
102
+ # Format output
103
+ if output == 'json':
104
+ click.echo(json.dumps(result.model_dump(), indent=2))
105
+ else:
106
+ click.echo(f"Decision: {result.decision}")
107
+ if result.score is not None:
108
+ click.echo(f"Score: {result.score}")
109
+ click.echo(f"Reasoning: {result.reasoning}")
110
+
111
+ asyncio.run(run_evaluation())
112
+
113
+
114
+ @cli.command()
115
+ @click.option('--api-url', help='Judge API URL (if using remote server)')
116
+ @click.option('--base-url', help='vLLM server URL (if using local)')
117
+ @click.option('--model', help='Model name (if using local)')
118
+ @click.option('--response-a', required=True, help='First response')
119
+ @click.option('--response-b', required=True, help='Second response')
120
+ @click.option('--criteria', required=True, help='Comparison criteria')
121
+ @click.option('--output', type=click.Choice(['json', 'text']), default='text', help='Output format')
122
+ def compare(
123
+ api_url: Optional[str],
124
+ base_url: Optional[str],
125
+ model: Optional[str],
126
+ response_a: str,
127
+ response_b: str,
128
+ criteria: str,
129
+ output: str
130
+ ):
131
+ """Compare two responses."""
132
+ async def run_comparison():
133
+ if api_url:
134
+ async with JudgeClient(api_url) as client:
135
+ result = await client.compare(
136
+ response_a=response_a,
137
+ response_b=response_b,
138
+ criteria=criteria
139
+ )
140
+ else:
141
+ if not base_url:
142
+ click.echo("Error: Either --api-url or --base-url is required", err=True)
143
+ sys.exit(1)
144
+
145
+ judge = Judge.from_url(base_url, model=model)
146
+ async with judge:
147
+ result = await judge.compare(
148
+ response_a=response_a,
149
+ response_b=response_b,
150
+ criteria=criteria
151
+ )
152
+
153
+ if output == 'json':
154
+ click.echo(json.dumps(result.model_dump(), indent=2))
155
+ else:
156
+ click.echo(f"Winner: {result.decision}")
157
+ click.echo(f"Reasoning: {result.reasoning}")
158
+
159
+ asyncio.run(run_comparison())
160
+
161
+
162
+ @cli.command()
163
+ @click.option('--api-url', required=True, help='Judge API URL')
164
+ def health(api_url: str):
165
+ """Check API health status."""
166
+ async def check_health():
167
+ async with JudgeClient(api_url) as client:
168
+ try:
169
+ health_data = await client.health_check()
170
+ click.echo(json.dumps(health_data, indent=2))
171
+ except Exception as e:
172
+ click.echo(f"Health check failed: {e}", err=True)
173
+ sys.exit(1)
174
+
175
+ asyncio.run(check_health())
176
+
177
+
178
+ @cli.command()
179
+ @click.option('--api-url', help='Judge API URL (if using remote server)')
180
+ @click.option('--filter', help='Filter metrics by name')
181
+ def list_metrics(api_url: Optional[str], filter: Optional[str]):
182
+ """List available metrics."""
183
+ async def list_all_metrics():
184
+ if api_url:
185
+ async with JudgeClient(api_url) as client:
186
+ metrics = await client.list_metrics()
187
+ for metric in metrics:
188
+ if filter and filter.lower() not in metric.name.lower():
189
+ continue
190
+ click.echo(f"\n{metric.name}:")
191
+ click.echo(f" Criteria: {metric.criteria}")
192
+ if metric.has_scale:
193
+ click.echo(f" Scale: {metric.scale}")
194
+ click.echo(f" Has rubric: {metric.has_rubric}")
195
+ click.echo(f" Examples: {metric.example_count}")
196
+ else:
197
+ # List built-in metrics
198
+ for name, metric in BUILTIN_METRICS.items():
199
+ if filter and filter.lower() not in name.lower():
200
+ continue
201
+ click.echo(f"\n{name}:")
202
+ click.echo(f" Criteria: {metric.criteria}")
203
+ if metric.scale:
204
+ click.echo(f" Scale: {metric.scale}")
205
+ click.echo(f" Has rubric: {'Yes' if metric.rubric else 'No'}")
206
+ click.echo(f" Examples: {len(metric.examples)}")
207
+
208
+ asyncio.run(list_all_metrics())
209
+
210
+
211
+ @cli.command()
212
+ @click.option('--api-url', help='Judge API URL')
213
+ @click.option('--file', required=True, type=click.File('r'), help='JSON file with batch data')
214
+ @click.option('--async', 'use_async', is_flag=True, help='Use async batch processing')
215
+ @click.option('--max-concurrent', type=int, help='Maximum concurrent requests')
216
+ @click.option('--output', type=click.File('w'), help='Output file (default: stdout)')
217
+ def batch(api_url: str, file, use_async: bool, max_concurrent: Optional[int], output):
218
+ """Run batch evaluation from JSON file."""
219
+ # Load batch data
220
+ try:
221
+ data = json.load(file)
222
+ if not isinstance(data, list):
223
+ click.echo("Error: Batch file must contain a JSON array", err=True)
224
+ sys.exit(1)
225
+ except json.JSONDecodeError as e:
226
+ click.echo(f"Error parsing JSON: {e}", err=True)
227
+ sys.exit(1)
228
+
229
+ async def run_batch():
230
+ async with JudgeClient(api_url) as client:
231
+ if use_async:
232
+ click.echo(f"Starting async batch evaluation of {len(data)} items...")
233
+ result = await client.async_batch_evaluate(
234
+ data=data,
235
+ max_concurrent=max_concurrent
236
+ )
237
+ else:
238
+ click.echo(f"Running batch evaluation of {len(data)} items...")
239
+ result = await client.batch_evaluate(
240
+ data=data,
241
+ max_concurrent=max_concurrent
242
+ )
243
+
244
+ # Format results
245
+ output_data = {
246
+ "total": result.total,
247
+ "successful": result.successful,
248
+ "failed": result.failed,
249
+ "success_rate": result.success_rate,
250
+ "duration_seconds": result.duration_seconds,
251
+ "results": []
252
+ }
253
+
254
+ for r in result.results:
255
+ if isinstance(r, Exception):
256
+ output_data["results"].append({"error": str(r)})
257
+ else:
258
+ output_data["results"].append({
259
+ "decision": r.decision,
260
+ "reasoning": r.reasoning,
261
+ "score": r.score,
262
+ "metadata": r.metadata
263
+ })
264
+
265
+ # Write output
266
+ output_file = output or sys.stdout
267
+ json.dump(output_data, output_file, indent=2)
268
+ if output:
269
+ click.echo(f"Results written to {output.name}")
270
+
271
+ # Summary
272
+ click.echo(f"\nSummary:")
273
+ click.echo(f" Total: {result.total}")
274
+ click.echo(f" Successful: {result.successful}")
275
+ click.echo(f" Failed: {result.failed}")
276
+ click.echo(f" Success rate: {result.success_rate:.1%}")
277
+ click.echo(f" Duration: {result.duration_seconds:.1f}s")
278
+
279
+ asyncio.run(run_batch())
280
+
281
+
282
+ def main():
283
+ """Main entry point."""
284
+ cli()
285
+
286
+
287
+ if __name__ == '__main__':
288
+ main()
vllm_judge/client.py ADDED
@@ -0,0 +1,262 @@
1
+ from typing import List, Dict, Any
2
+ import httpx
3
+ from tenacity import (
4
+ retry,
5
+ stop_after_attempt,
6
+ wait_exponential,
7
+ retry_if_exception_type,
8
+ # before_retry
9
+ )
10
+
11
+ from vllm_judge.models import JudgeConfig
12
+ from vllm_judge.exceptions import (
13
+ ConnectionError,
14
+ TimeoutError,
15
+ ParseError,
16
+ RetryExhaustedError
17
+ )
18
+
19
+ CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions"
20
+ COMPLETIONS_ENDPOINT = "/v1/completions"
21
+ MODELS_ENDPOINT = "/v1/models"
22
+
23
+ class VLLMClient:
24
+ """Async client for vLLM endpoints."""
25
+
26
+ def __init__(self, config: JudgeConfig):
27
+ """
28
+ Initialize vLLM client.
29
+
30
+ Args:
31
+ config: Judge configuration
32
+ """
33
+
34
+ if not config.model:
35
+ config.model = detect_model_sync(config.base_url)
36
+ self.config = config
37
+ self.session = httpx.AsyncClient(
38
+ base_url=config.base_url,
39
+ timeout=httpx.Timeout(config.timeout),
40
+ limits=httpx.Limits(
41
+ max_connections=100,
42
+ max_keepalive_connections=20
43
+ ),
44
+ headers={
45
+ "Authorization": f"Bearer {config.api_key}",
46
+ "Content-Type": "application/json"
47
+ }
48
+ )
49
+
50
+ async def __aenter__(self):
51
+ """Async context manager entry."""
52
+ return self
53
+
54
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
55
+ """Async context manager exit."""
56
+ await self.close()
57
+
58
+ async def close(self):
59
+ """Close the HTTP session."""
60
+ await self.session.aclose()
61
+
62
+ def _log_retry(self, retry_state):
63
+ """Log retry attempts."""
64
+ attempt = retry_state.attempt_number
65
+ if attempt > 1:
66
+ print(f"Retry attempt {attempt} after error: {retry_state.outcome.exception()}")
67
+
68
+ @retry(
69
+ stop=stop_after_attempt(3),
70
+ wait=wait_exponential(multiplier=1, min=1, max=10),
71
+ retry=retry_if_exception_type((httpx.HTTPError, ConnectionError, TimeoutError)),
72
+ # before=before_retry(lambda retry_state: retry_state.outcome and print(
73
+ # f"Retrying after error: {retry_state.outcome.exception()}"
74
+ # ))
75
+ )
76
+ async def _request_with_retry(self, endpoint: str, **kwargs) -> Dict[str, Any]:
77
+ """
78
+ Make HTTP request with retry logic.
79
+
80
+ Args:
81
+ endpoint: API endpoint
82
+ **kwargs: Request parameters
83
+
84
+ Returns:
85
+ Parsed JSON response
86
+
87
+ Raises:
88
+ ConnectionError: If unable to connect
89
+ TimeoutError: If request times out
90
+ RetryExhaustedError: If all retries fail
91
+ """
92
+ try:
93
+ response = await self.session.post(endpoint, **kwargs)
94
+ response.raise_for_status()
95
+ return response.json()
96
+ except httpx.ConnectError as e:
97
+ raise ConnectionError(f"Failed to connect to {self.config.base_url}: {e}")
98
+ except httpx.TimeoutException as e:
99
+ raise TimeoutError(f"Request timed out after {self.config.timeout}s: {e}")
100
+ except httpx.HTTPStatusError as e:
101
+ # Parse error message from response if available
102
+ try:
103
+ error_detail = e.response.json().get('detail', str(e))
104
+ except:
105
+ error_detail = str(e)
106
+ raise ConnectionError(f"HTTP {e.response.status_code}: {error_detail}")
107
+ except Exception as e:
108
+ raise ConnectionError(f"Unexpected error: {e}")
109
+
110
+ async def chat_completion(self, messages: List[Dict[str, str]]) -> str:
111
+ """
112
+ Use chat completions endpoint (handles templates automatically).
113
+
114
+ Args:
115
+ messages: List of chat messages
116
+
117
+ Returns:
118
+ Model response content
119
+
120
+ Raises:
121
+ ConnectionError: If request fails
122
+ ParseError: If response parsing fails
123
+ """
124
+ request_data = {
125
+ "model": self.config.model,
126
+ "messages": messages,
127
+ "temperature": self.config.temperature,
128
+ "max_tokens": self.config.max_tokens,
129
+ # "top_p": self.config.top_p,
130
+ }
131
+
132
+ # # Request JSON response format if supported
133
+ # if self.config.temperature < 0.2: # Only for low temperature
134
+ # request_data["response_format"] = {"type": "json_object"}
135
+
136
+ try:
137
+ response = await self._request_with_retry(
138
+ CHAT_COMPLETIONS_ENDPOINT,
139
+ json=request_data
140
+ )
141
+
142
+ # Extract content from response
143
+ if "choices" not in response or not response["choices"]:
144
+ raise ParseError("Invalid response format: missing choices")
145
+
146
+ content = response["choices"][0]["message"]["content"]
147
+ return content
148
+
149
+ except RetryExhaustedError:
150
+ raise
151
+ except Exception as e:
152
+ if isinstance(e, (ConnectionError, TimeoutError, ParseError)):
153
+ raise
154
+ raise ConnectionError(f"Chat completion failed: {e}")
155
+
156
+ async def completion(self, prompt: str) -> str:
157
+ """
158
+ Use completions endpoint for edge cases.
159
+
160
+ Args:
161
+ prompt: Text prompt
162
+
163
+ Returns:
164
+ Model response text
165
+
166
+ Raises:
167
+ ConnectionError: If request fails
168
+ ParseError: If response parsing fails
169
+ """
170
+ request_data = {
171
+ "model": self.config.model,
172
+ "prompt": prompt,
173
+ "temperature": self.config.temperature,
174
+ "max_tokens": self.config.max_tokens,
175
+ # "top_p": self.config.top_p,
176
+ }
177
+
178
+ try:
179
+ response = await self._request_with_retry(
180
+ COMPLETIONS_ENDPOINT,
181
+ json=request_data
182
+ )
183
+
184
+ # Extract text from response
185
+ if "choices" not in response or not response["choices"]:
186
+ raise ParseError("Invalid response format: missing choices")
187
+
188
+ text = response["choices"][0]["text"]
189
+ return text
190
+
191
+ except RetryExhaustedError:
192
+ raise
193
+ except Exception as e:
194
+ if isinstance(e, (ConnectionError, TimeoutError, ParseError)):
195
+ raise
196
+ raise ConnectionError(f"Completion failed: {e}")
197
+
198
+ async def list_models(self) -> List[str]:
199
+ """
200
+ List available models.
201
+
202
+ Returns:
203
+ List of model names
204
+
205
+ Raises:
206
+ ConnectionError: If request fails
207
+ """
208
+ try:
209
+ response = await self._request_with_retry(MODELS_ENDPOINT)
210
+ models = response.get("data", [])
211
+ return [model["id"] for model in models]
212
+ except Exception as e:
213
+ if isinstance(e, ConnectionError):
214
+ raise
215
+ raise ConnectionError(f"Failed to list models: {e}")
216
+
217
+ async def detect_model(self) -> str:
218
+ """
219
+ Auto-detect the first available model.
220
+
221
+ Returns:
222
+ Model name
223
+
224
+ Raises:
225
+ ConnectionError: If no models found
226
+ """
227
+ models = await self.list_models()
228
+ if not models:
229
+ raise ConnectionError("No models available on vLLM server")
230
+ return models[0]
231
+
232
+
233
+ def detect_model_sync(base_url: str, timeout: float = 30.0) -> str:
234
+ """
235
+ Synchronously detect the first available model.
236
+
237
+ Args:
238
+ base_url: vLLM server URL
239
+ timeout: Request timeout
240
+
241
+ Returns:
242
+ Model name
243
+
244
+ Raises:
245
+ ConnectionError: If no models found
246
+ """
247
+ url = f"{base_url}{MODELS_ENDPOINT}"
248
+ try:
249
+ with httpx.Client(timeout=timeout) as client:
250
+ response = client.get(url)
251
+ response.raise_for_status()
252
+ data = response.json().get("data", [])
253
+ models = [model["id"] for model in data]
254
+
255
+ if not models:
256
+ raise ConnectionError("No models available on vLLM server")
257
+
258
+ model = models[0]
259
+ return model
260
+
261
+ except httpx.HTTPError as e:
262
+ raise ConnectionError(f"Failed to detect model: {e}")
@@ -0,0 +1,42 @@
1
+ class VLLMJudgeError(Exception):
2
+ """Base exception for all vLLM Judge errors."""
3
+ pass
4
+
5
+
6
+ class ConfigurationError(VLLMJudgeError):
7
+ """Raised when configuration is invalid."""
8
+ pass
9
+
10
+
11
+ class ConnectionError(VLLMJudgeError):
12
+ """Raised when unable to connect to vLLM server."""
13
+ pass
14
+
15
+
16
+ class TimeoutError(VLLMJudgeError):
17
+ """Raised when request times out."""
18
+ pass
19
+
20
+
21
+ class ParseError(VLLMJudgeError):
22
+ """Raised when unable to parse LLM response."""
23
+ def __init__(self, message: str, raw_response: str = None):
24
+ super().__init__(message)
25
+ self.raw_response = raw_response
26
+
27
+
28
+ class MetricNotFoundError(VLLMJudgeError):
29
+ """Raised when requested metric is not found."""
30
+ pass
31
+
32
+
33
+ class InvalidInputError(VLLMJudgeError):
34
+ """Raised when input parameters are invalid."""
35
+ pass
36
+
37
+
38
+ class RetryExhaustedError(VLLMJudgeError):
39
+ """Raised when all retry attempts are exhausted."""
40
+ def __init__(self, message: str, last_error: Exception = None):
41
+ super().__init__(message)
42
+ self.last_error = last_error