vllm-judge 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,564 @@
1
+ """
2
+ FastAPI server for vLLM Judge API.
3
+ """
4
+ import asyncio
5
+ import time
6
+ import uuid
7
+ from datetime import datetime
8
+ from typing import Dict, List, Optional, Any
9
+ from contextlib import asynccontextmanager
10
+
11
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, WebSocketDisconnect
12
+ from fastapi.responses import JSONResponse
13
+ import uvicorn
14
+
15
+ from vllm_judge.judge import Judge
16
+ from vllm_judge.models import EvaluationResult, JudgeConfig
17
+ from vllm_judge.metrics import BUILTIN_METRICS
18
+ from vllm_judge.exceptions import VLLMJudgeError
19
+ from vllm_judge.api.models import (
20
+ EvaluateRequest,
21
+ BatchEvaluateRequest,
22
+ AsyncBatchRequest,
23
+ EvaluationResponse,
24
+ BatchResponse,
25
+ AsyncBatchResponse,
26
+ JobStatusResponse,
27
+ MetricInfo,
28
+ HealthResponse,
29
+ ErrorResponse
30
+ )
31
+ from vllm_judge.templating import TemplateProcessor
32
+ from vllm_judge.models import TemplateEngine
33
+ from vllm_judge import __version__
34
+
35
+
36
+ # Global state
37
+ judge: Optional[Judge] = None
38
+ app_start_time: float = 0
39
+ total_evaluations: int = 0
40
+ active_connections: int = 0
41
+ jobs: Dict[str, Dict[str, Any]] = {} # job_id -> job info
42
+
43
+
44
+ @asynccontextmanager
45
+ async def lifespan(app: FastAPI):
46
+ """Manage application lifecycle."""
47
+ global app_start_time
48
+ app_start_time = time.time()
49
+ yield
50
+ # Cleanup
51
+ if judge:
52
+ await judge.close()
53
+
54
+
55
+ app = FastAPI(
56
+ title="vLLM Judge API",
57
+ description="LLM-as-a-Judge evaluation service for vLLM hosted models",
58
+ version=__version__,
59
+ lifespan=lifespan
60
+ )
61
+
62
+
63
+ @app.exception_handler(VLLMJudgeError)
64
+ async def vllm_judge_exception_handler(request, exc: VLLMJudgeError):
65
+ """Handle vLLM Judge specific exceptions."""
66
+ return JSONResponse(
67
+ status_code=400,
68
+ content=ErrorResponse(
69
+ error=exc.__class__.__name__,
70
+ detail=str(exc),
71
+ code="VLLM_JUDGE_ERROR"
72
+ ).model_dump()
73
+ )
74
+
75
+
76
+ @app.get("/health", response_model=HealthResponse)
77
+ async def health_check():
78
+ """Health check endpoint."""
79
+ if not judge:
80
+ raise HTTPException(status_code=503, detail="Judge not initialized")
81
+
82
+ uptime = time.time() - app_start_time
83
+
84
+ return HealthResponse(
85
+ status="healthy",
86
+ version=__version__,
87
+ model=judge.config.model,
88
+ base_url=judge.config.base_url,
89
+ uptime_seconds=uptime,
90
+ total_evaluations=total_evaluations,
91
+ active_connections=active_connections,
92
+ metrics_available=len(judge.list_metrics())
93
+ )
94
+
95
+
96
+ @app.post("/evaluate", response_model=EvaluationResponse)
97
+ async def evaluate(request: EvaluateRequest):
98
+ """Single evaluation endpoint."""
99
+ global total_evaluations
100
+
101
+ if not judge:
102
+ raise HTTPException(status_code=503, detail="Judge not initialized")
103
+
104
+ start_time = time.time()
105
+
106
+ try:
107
+ # Convert scale list to tuple if provided
108
+ scale = tuple(request.scale) if request.scale else None
109
+
110
+ # Perform evaluation with template support
111
+ result = await judge.evaluate(
112
+ response=request.response,
113
+ criteria=request.criteria,
114
+ rubric=request.rubric,
115
+ scale=scale,
116
+ metric=request.metric,
117
+ context=request.context,
118
+ system_prompt=request.system_prompt,
119
+ examples=request.examples,
120
+ template_vars=request.template_vars,
121
+ template_engine=request.template_engine
122
+ )
123
+
124
+ # Convert to response model
125
+ duration_ms = int((time.time() - start_time) * 1000)
126
+ total_evaluations += 1
127
+
128
+ return EvaluationResponse(
129
+ decision=result.decision,
130
+ reasoning=result.reasoning,
131
+ score=result.score,
132
+ metadata=result.metadata,
133
+ evaluation_id=str(uuid.uuid4()),
134
+ timestamp=datetime.utcnow(),
135
+ duration_ms=duration_ms
136
+ )
137
+
138
+ except VLLMJudgeError as e:
139
+ raise HTTPException(status_code=400, detail=str(e))
140
+ except Exception as e:
141
+ raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
142
+
143
+
144
+ @app.post("/batch", response_model=BatchResponse)
145
+ async def batch_evaluate(request: BatchEvaluateRequest):
146
+ """Synchronous batch evaluation endpoint."""
147
+ global total_evaluations
148
+
149
+ if not judge:
150
+ raise HTTPException(status_code=503, detail="Judge not initialized")
151
+
152
+ # Apply defaults if provided
153
+ if request.default_criteria or request.default_metric:
154
+ for item in request.data:
155
+ if request.default_criteria and "criteria" not in item:
156
+ item["criteria"] = request.default_criteria
157
+ if request.default_metric and "metric" not in item:
158
+ item["metric"] = request.default_metric
159
+
160
+ try:
161
+ # Perform batch evaluation
162
+ batch_result = await judge.batch_evaluate(
163
+ data=request.data,
164
+ max_concurrent=request.max_concurrent
165
+ )
166
+
167
+ # Convert results
168
+ results = []
169
+ for i, r in enumerate(batch_result.results):
170
+ if isinstance(r, EvaluationResult):
171
+ results.append(EvaluationResponse(
172
+ decision=r.decision,
173
+ reasoning=r.reasoning,
174
+ score=r.score,
175
+ metadata=r.metadata,
176
+ evaluation_id=str(uuid.uuid4()),
177
+ timestamp=datetime.utcnow()
178
+ ))
179
+ else:
180
+ # Error case
181
+ results.append({
182
+ "error": str(r),
183
+ "index": i
184
+ })
185
+
186
+ total_evaluations += batch_result.successful
187
+
188
+ return BatchResponse(
189
+ total=batch_result.total,
190
+ successful=batch_result.successful,
191
+ failed=batch_result.failed,
192
+ success_rate=batch_result.success_rate,
193
+ duration_seconds=batch_result.duration_seconds,
194
+ results=results
195
+ )
196
+
197
+ except Exception as e:
198
+ raise HTTPException(status_code=500, detail=f"Batch evaluation failed: {str(e)}")
199
+
200
+
201
+ @app.post("/batch/async", response_model=AsyncBatchResponse)
202
+ async def async_batch_evaluate(
203
+ request: AsyncBatchRequest,
204
+ background_tasks: BackgroundTasks
205
+ ):
206
+ """Asynchronous batch evaluation endpoint."""
207
+ if not judge:
208
+ raise HTTPException(status_code=503, detail="Judge not initialized")
209
+
210
+ # Create job
211
+ job_id = str(uuid.uuid4())
212
+ job_info = {
213
+ "id": job_id,
214
+ "status": "pending",
215
+ "data": request.data,
216
+ "total": len(request.data),
217
+ "completed": 0,
218
+ "created_at": datetime.utcnow(),
219
+ "callback_url": request.callback_url,
220
+ "max_concurrent": request.max_concurrent
221
+ }
222
+ jobs[job_id] = job_info
223
+
224
+ # Estimate duration (rough estimate: 0.5s per evaluation)
225
+ estimated_duration = len(request.data) * 0.5 / (request.max_concurrent or judge.config.max_concurrent)
226
+
227
+ # Start background task
228
+ background_tasks.add_task(
229
+ run_async_batch,
230
+ job_id,
231
+ request.data,
232
+ request.max_concurrent,
233
+ request.callback_url
234
+ )
235
+
236
+ return AsyncBatchResponse(
237
+ job_id=job_id,
238
+ status="pending",
239
+ total_items=len(request.data),
240
+ created_at=job_info["created_at"],
241
+ estimated_duration_seconds=estimated_duration
242
+ )
243
+
244
+
245
+ async def run_async_batch(
246
+ job_id: str,
247
+ data: List[Dict[str, Any]],
248
+ max_concurrent: Optional[int],
249
+ callback_url: Optional[str]
250
+ ):
251
+ """Run batch evaluation in background."""
252
+ global total_evaluations
253
+
254
+ job = jobs[job_id]
255
+ job["status"] = "running"
256
+ job["started_at"] = datetime.utcnow()
257
+
258
+ try:
259
+ # Progress callback
260
+ def update_progress(completed: int, total: int):
261
+ job["completed"] = completed
262
+
263
+ # Run evaluation
264
+ batch_result = await judge.batch_evaluate(
265
+ data=data,
266
+ max_concurrent=max_concurrent,
267
+ progress_callback=update_progress
268
+ )
269
+
270
+ # Update job
271
+ job["status"] = "completed"
272
+ job["completed_at"] = datetime.utcnow()
273
+ job["result"] = batch_result
274
+ total_evaluations += batch_result.successful
275
+
276
+ # Send callback if provided
277
+ if callback_url:
278
+ # TODO: Implement callback POST request
279
+ pass
280
+
281
+ except Exception as e:
282
+ job["status"] = "failed"
283
+ job["error"] = str(e)
284
+ job["completed_at"] = datetime.utcnow()
285
+
286
+
287
+ @app.get("/jobs/{job_id}", response_model=JobStatusResponse)
288
+ async def get_job_status(job_id: str):
289
+ """Get status of async job."""
290
+ if job_id not in jobs:
291
+ raise HTTPException(status_code=404, detail="Job not found")
292
+
293
+ job = jobs[job_id]
294
+
295
+ return JobStatusResponse(
296
+ job_id=job_id,
297
+ status=job["status"],
298
+ progress={"completed": job.get("completed", 0), "total": job["total"]},
299
+ created_at=job["created_at"],
300
+ started_at=job.get("started_at"),
301
+ completed_at=job.get("completed_at"),
302
+ result_url=f"/jobs/{job_id}/result" if job["status"] == "completed" else None,
303
+ error=job.get("error")
304
+ )
305
+
306
+
307
+ @app.get("/jobs/{job_id}/result")
308
+ async def get_job_result(job_id: str):
309
+ """Get result of completed async job."""
310
+ if job_id not in jobs:
311
+ raise HTTPException(status_code=404, detail="Job not found")
312
+
313
+ job = jobs[job_id]
314
+
315
+ if job["status"] != "completed":
316
+ raise HTTPException(
317
+ status_code=400,
318
+ detail=f"Job is {job['status']}, not completed"
319
+ )
320
+
321
+ if "result" not in job:
322
+ raise HTTPException(status_code=500, detail="Job result not found")
323
+
324
+ batch_result = job["result"]
325
+
326
+ # Convert to response format
327
+ results = []
328
+ for r in batch_result.results:
329
+ if isinstance(r, EvaluationResult):
330
+ results.append({
331
+ "decision": r.decision,
332
+ "reasoning": r.reasoning,
333
+ "score": r.score,
334
+ "metadata": r.metadata
335
+ })
336
+ else:
337
+ results.append({"error": str(r)})
338
+
339
+ return {
340
+ "job_id": job_id,
341
+ "total": batch_result.total,
342
+ "successful": batch_result.successful,
343
+ "failed": batch_result.failed,
344
+ "success_rate": batch_result.success_rate,
345
+ "duration_seconds": batch_result.duration_seconds,
346
+ "results": results
347
+ }
348
+
349
+
350
+ @app.get("/metrics", response_model=List[MetricInfo])
351
+ async def list_metrics():
352
+ """List all available metrics."""
353
+ if not judge:
354
+ raise HTTPException(status_code=503, detail="Judge not initialized")
355
+
356
+ metrics_info = []
357
+
358
+ # Get all metrics (user-registered + built-in)
359
+ all_metrics = {**judge.metrics, **BUILTIN_METRICS}
360
+
361
+ for name, metric in all_metrics.items():
362
+ info = MetricInfo(
363
+ name=name,
364
+ criteria=metric.criteria,
365
+ has_scale=metric.scale is not None,
366
+ scale=metric.scale,
367
+ has_rubric=metric.rubric is not None,
368
+ rubric_type=type(metric.rubric).__name__ if metric.rubric else None,
369
+ has_examples=bool(metric.examples),
370
+ example_count=len(metric.examples) if metric.examples else 0,
371
+ has_system_prompt=metric.system_prompt is not None,
372
+ has_template_vars=bool(metric.template_vars),
373
+ template_vars=metric.template_vars if metric.template_vars else None,
374
+ required_vars=metric.required_vars if hasattr(metric, 'required_vars') else None,
375
+ template_engine=metric.template_engine.value if hasattr(metric, 'template_engine') else None
376
+ )
377
+ metrics_info.append(info)
378
+
379
+ return metrics_info
380
+
381
+
382
+ @app.get("/metrics/{metric_name}")
383
+ async def get_metric_details(metric_name: str):
384
+ """Get detailed information about a specific metric."""
385
+ if not judge:
386
+ raise HTTPException(status_code=503, detail="Judge not initialized")
387
+
388
+ try:
389
+ metric = judge.get_metric(metric_name)
390
+ except Exception:
391
+ raise HTTPException(status_code=404, detail=f"Metric '{metric_name}' not found")
392
+
393
+ return {
394
+ "name": metric_name,
395
+ "criteria": metric.criteria,
396
+ "scale": metric.scale,
397
+ "rubric": metric.rubric,
398
+ "examples": metric.examples,
399
+ "system_prompt": metric.system_prompt,
400
+ "template_vars": getattr(metric, 'template_vars', None),
401
+ "required_vars": getattr(metric, 'required_vars', None),
402
+ "template_engine": getattr(metric, 'template_engine', None)
403
+ }
404
+
405
+
406
+ @app.websocket("/ws/evaluate")
407
+ async def websocket_evaluate(websocket: WebSocket):
408
+ """WebSocket endpoint for real-time evaluations."""
409
+ global active_connections
410
+
411
+ await websocket.accept()
412
+ active_connections += 1
413
+
414
+ try:
415
+ while True:
416
+ # Receive evaluation request
417
+ data = await websocket.receive_json()
418
+
419
+ try:
420
+ # Perform evaluation
421
+ request = EvaluateRequest(**data)
422
+ scale = tuple(request.scale) if request.scale else None
423
+
424
+ result = await judge.evaluate(
425
+ response=request.response,
426
+ criteria=request.criteria,
427
+ rubric=request.rubric,
428
+ scale=scale,
429
+ metric=request.metric,
430
+ context=request.context,
431
+ system_prompt=request.system_prompt,
432
+ examples=request.examples,
433
+ template_vars=request.template_vars,
434
+ template_engine=request.template_engine
435
+ )
436
+
437
+ # Send result
438
+ await websocket.send_json({
439
+ "status": "success",
440
+ "result": {
441
+ "decision": result.decision,
442
+ "reasoning": result.reasoning,
443
+ "score": result.score,
444
+ "metadata": result.metadata
445
+ }
446
+ })
447
+
448
+ except Exception as e:
449
+ await websocket.send_json({
450
+ "status": "error",
451
+ "error": str(e)
452
+ })
453
+
454
+ except WebSocketDisconnect:
455
+ active_connections -= 1
456
+
457
+
458
+ @app.post("/validate/template")
459
+ async def validate_template(request: Dict[str, Any]):
460
+ """Validate template variables for a given template."""
461
+ template = request.get("template", "")
462
+ template_vars = request.get("template_vars", {})
463
+ engine = request.get("template_engine", "format")
464
+
465
+ try:
466
+ # Get required variables
467
+ required_vars = TemplateProcessor.get_required_vars(
468
+ template,
469
+ TemplateEngine(engine)
470
+ )
471
+
472
+ # Check which are missing
473
+ provided_vars = set(template_vars.keys())
474
+ missing_vars = required_vars - provided_vars
475
+
476
+ # Try to apply template
477
+ try:
478
+ result = TemplateProcessor.apply_template(
479
+ template,
480
+ template_vars,
481
+ TemplateEngine(engine),
482
+ strict=True
483
+ )
484
+
485
+ return {
486
+ "valid": True,
487
+ "required_vars": list(required_vars),
488
+ "provided_vars": list(provided_vars),
489
+ "missing_vars": list(missing_vars),
490
+ "result": result
491
+ }
492
+ except Exception as e:
493
+ return {
494
+ "valid": False,
495
+ "required_vars": list(required_vars),
496
+ "provided_vars": list(provided_vars),
497
+ "missing_vars": list(missing_vars),
498
+ "error": str(e)
499
+ }
500
+
501
+ except Exception as e:
502
+ raise HTTPException(status_code=400, detail=f"Validation error: {str(e)}")
503
+
504
+
505
+ @app.post("/metrics/register")
506
+ async def register_metric(metric_data: Dict[str, Any]):
507
+ """Register a new metric dynamically."""
508
+ if not judge:
509
+ raise HTTPException(status_code=503, detail="Judge not initialized")
510
+
511
+ try:
512
+ # Create metric from data
513
+ from vllm_judge.models import Metric
514
+
515
+ metric = Metric(
516
+ name=metric_data["name"],
517
+ criteria=metric_data["criteria"],
518
+ rubric=metric_data.get("rubric"),
519
+ scale=tuple(metric_data["scale"]) if metric_data.get("scale") else None,
520
+ examples=metric_data.get("examples", []),
521
+ system_prompt=metric_data.get("system_prompt"),
522
+ template_vars=metric_data.get("template_vars", {}),
523
+ required_vars=metric_data.get("required_vars", []),
524
+ template_engine=metric_data.get("template_engine", "format")
525
+ )
526
+
527
+ # Register with judge
528
+ judge.register_metric(metric)
529
+
530
+ return {"message": f"Metric '{metric.name}' registered successfully"}
531
+
532
+ except Exception as e:
533
+ raise HTTPException(status_code=400, detail=f"Failed to register metric: {str(e)}")
534
+
535
+
536
+ def create_app(config: JudgeConfig) -> FastAPI:
537
+ """Create FastAPI app with initialized Judge."""
538
+ global judge
539
+ judge = Judge(config)
540
+ return app
541
+
542
+
543
+ def start_server(
544
+ base_url: str,
545
+ model: Optional[str] = None,
546
+ host: str = "0.0.0.0",
547
+ port: int = 8080,
548
+ reload: bool = False,
549
+ **kwargs
550
+ ):
551
+ """Start the API server."""
552
+ global judge
553
+
554
+ # Initialize judge
555
+ config = JudgeConfig.from_url(base_url, model=model, **kwargs)
556
+ judge = Judge(config)
557
+
558
+ # Run server
559
+ uvicorn.run(
560
+ "vllm_judge.api.server:app",
561
+ host=host,
562
+ port=port,
563
+ reload=reload
564
+ )
vllm_judge/batch.py ADDED
@@ -0,0 +1,147 @@
1
+ import asyncio
2
+ import time
3
+ from typing import List, Dict, Any, Callable, Optional, Union
4
+ from vllm_judge.models import EvaluationResult, BatchResult
5
+ from vllm_judge.exceptions import VLLMJudgeError
6
+
7
+
8
+ class BatchProcessor:
9
+ """High-concurrency batch processing for evaluations."""
10
+
11
+ def __init__(self, judge, max_concurrent: int = 50):
12
+ """
13
+ Initialize batch processor.
14
+
15
+ Args:
16
+ judge: Judge instance
17
+ max_concurrent: Maximum concurrent requests
18
+ """
19
+ self.judge = judge
20
+ self.semaphore = asyncio.Semaphore(max_concurrent)
21
+ self.progress_lock = asyncio.Lock()
22
+ self.completed = 0
23
+
24
+ async def process(
25
+ self,
26
+ data: List[Dict[str, Any]],
27
+ progress_callback: Optional[Callable[[int, int], None]] = None,
28
+ **default_kwargs
29
+ ) -> BatchResult:
30
+ """
31
+ Process batch of evaluations.
32
+
33
+ Args:
34
+ data: List of evaluation inputs
35
+ progress_callback: Optional callback for progress updates
36
+ **default_kwargs: Default parameters for all evaluations
37
+
38
+ Returns:
39
+ BatchResult with all results
40
+ """
41
+ start_time = time.time()
42
+ self.completed = 0
43
+ total = len(data)
44
+
45
+ # Create tasks
46
+ tasks = []
47
+ for i, item in enumerate(data):
48
+ # Merge default kwargs with item-specific kwargs
49
+ eval_kwargs = {**default_kwargs, **item}
50
+
51
+ task = self._process_item(
52
+ eval_kwargs,
53
+ i,
54
+ total,
55
+ progress_callback
56
+ )
57
+ tasks.append(task)
58
+
59
+ # Process all tasks
60
+ results = await asyncio.gather(*tasks, return_exceptions=True)
61
+
62
+ # Calculate statistics
63
+ successful = sum(1 for r in results if isinstance(r, EvaluationResult))
64
+ failed = total - successful
65
+ duration = time.time() - start_time
66
+
67
+ return BatchResult(
68
+ results=results,
69
+ total=total,
70
+ successful=successful,
71
+ failed=failed,
72
+ duration_seconds=duration
73
+ )
74
+
75
+ async def _process_item(
76
+ self,
77
+ eval_kwargs: Dict[str, Any],
78
+ index: int,
79
+ total: int,
80
+ progress_callback: Optional[Callable]
81
+ ) -> Union[EvaluationResult, Exception]:
82
+ """Process single item with concurrency control."""
83
+ async with self.semaphore:
84
+ try:
85
+ # Extract response from kwargs
86
+ response = eval_kwargs.pop('response', None)
87
+ if not response:
88
+ raise ValueError(f"Item {index} missing 'response' field")
89
+
90
+ # Perform evaluation
91
+ result = await self.judge.evaluate(response=response, **eval_kwargs)
92
+
93
+ # Update progress
94
+ async with self.progress_lock:
95
+ self.completed += 1
96
+ if progress_callback:
97
+ progress_callback(self.completed, total)
98
+
99
+ # Add index to metadata
100
+ result.metadata['batch_index'] = index
101
+ return result
102
+
103
+ except Exception as e:
104
+ # Update progress even for failures
105
+ async with self.progress_lock:
106
+ self.completed += 1
107
+ if progress_callback:
108
+ progress_callback(self.completed, total)
109
+
110
+ # Return exception with context
111
+ error = VLLMJudgeError(f"Item {index} failed: {str(e)}")
112
+ error.batch_index = index
113
+ error.original_error = e
114
+ return error
115
+
116
+ async def process_streaming(
117
+ self,
118
+ data: List[Dict[str, Any]],
119
+ callback: Callable[[int, Union[EvaluationResult, Exception]], None],
120
+ **default_kwargs
121
+ ):
122
+ """
123
+ Process batch with streaming results.
124
+
125
+ Args:
126
+ data: List of evaluation inputs
127
+ callback: Called with (index, result) as results complete
128
+ **default_kwargs: Default parameters for all evaluations
129
+ """
130
+ async def process_and_callback(item, index):
131
+ result = await self._process_item(
132
+ {**default_kwargs, **item},
133
+ index,
134
+ len(data),
135
+ None
136
+ )
137
+ callback(index, result)
138
+ return result
139
+
140
+ tasks = [
141
+ process_and_callback(item, i)
142
+ for i, item in enumerate(data)
143
+ ]
144
+
145
+ # Process tasks as they complete
146
+ for coro in asyncio.as_completed(tasks):
147
+ await coro