vllm-judge 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vllm_judge/__init__.py +120 -0
- vllm_judge/api/__init__.py +39 -0
- vllm_judge/api/client.py +354 -0
- vllm_judge/api/models.py +157 -0
- vllm_judge/api/server.py +564 -0
- vllm_judge/batch.py +147 -0
- vllm_judge/cli.py +288 -0
- vllm_judge/client.py +262 -0
- vllm_judge/exceptions.py +42 -0
- vllm_judge/judge.py +421 -0
- vllm_judge/metrics.py +417 -0
- vllm_judge/models.py +185 -0
- vllm_judge/prompts.py +175 -0
- vllm_judge/templating.py +206 -0
- vllm_judge-0.1.0.dist-info/METADATA +124 -0
- vllm_judge-0.1.0.dist-info/RECORD +19 -0
- vllm_judge-0.1.0.dist-info/WHEEL +5 -0
- vllm_judge-0.1.0.dist-info/entry_points.txt +2 -0
- vllm_judge-0.1.0.dist-info/top_level.txt +1 -0
vllm_judge/api/server.py
ADDED
@@ -0,0 +1,564 @@
|
|
1
|
+
"""
|
2
|
+
FastAPI server for vLLM Judge API.
|
3
|
+
"""
|
4
|
+
import asyncio
|
5
|
+
import time
|
6
|
+
import uuid
|
7
|
+
from datetime import datetime
|
8
|
+
from typing import Dict, List, Optional, Any
|
9
|
+
from contextlib import asynccontextmanager
|
10
|
+
|
11
|
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, WebSocketDisconnect
|
12
|
+
from fastapi.responses import JSONResponse
|
13
|
+
import uvicorn
|
14
|
+
|
15
|
+
from vllm_judge.judge import Judge
|
16
|
+
from vllm_judge.models import EvaluationResult, JudgeConfig
|
17
|
+
from vllm_judge.metrics import BUILTIN_METRICS
|
18
|
+
from vllm_judge.exceptions import VLLMJudgeError
|
19
|
+
from vllm_judge.api.models import (
|
20
|
+
EvaluateRequest,
|
21
|
+
BatchEvaluateRequest,
|
22
|
+
AsyncBatchRequest,
|
23
|
+
EvaluationResponse,
|
24
|
+
BatchResponse,
|
25
|
+
AsyncBatchResponse,
|
26
|
+
JobStatusResponse,
|
27
|
+
MetricInfo,
|
28
|
+
HealthResponse,
|
29
|
+
ErrorResponse
|
30
|
+
)
|
31
|
+
from vllm_judge.templating import TemplateProcessor
|
32
|
+
from vllm_judge.models import TemplateEngine
|
33
|
+
from vllm_judge import __version__
|
34
|
+
|
35
|
+
|
36
|
+
# Global state
|
37
|
+
judge: Optional[Judge] = None
|
38
|
+
app_start_time: float = 0
|
39
|
+
total_evaluations: int = 0
|
40
|
+
active_connections: int = 0
|
41
|
+
jobs: Dict[str, Dict[str, Any]] = {} # job_id -> job info
|
42
|
+
|
43
|
+
|
44
|
+
@asynccontextmanager
|
45
|
+
async def lifespan(app: FastAPI):
|
46
|
+
"""Manage application lifecycle."""
|
47
|
+
global app_start_time
|
48
|
+
app_start_time = time.time()
|
49
|
+
yield
|
50
|
+
# Cleanup
|
51
|
+
if judge:
|
52
|
+
await judge.close()
|
53
|
+
|
54
|
+
|
55
|
+
app = FastAPI(
|
56
|
+
title="vLLM Judge API",
|
57
|
+
description="LLM-as-a-Judge evaluation service for vLLM hosted models",
|
58
|
+
version=__version__,
|
59
|
+
lifespan=lifespan
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
@app.exception_handler(VLLMJudgeError)
|
64
|
+
async def vllm_judge_exception_handler(request, exc: VLLMJudgeError):
|
65
|
+
"""Handle vLLM Judge specific exceptions."""
|
66
|
+
return JSONResponse(
|
67
|
+
status_code=400,
|
68
|
+
content=ErrorResponse(
|
69
|
+
error=exc.__class__.__name__,
|
70
|
+
detail=str(exc),
|
71
|
+
code="VLLM_JUDGE_ERROR"
|
72
|
+
).model_dump()
|
73
|
+
)
|
74
|
+
|
75
|
+
|
76
|
+
@app.get("/health", response_model=HealthResponse)
|
77
|
+
async def health_check():
|
78
|
+
"""Health check endpoint."""
|
79
|
+
if not judge:
|
80
|
+
raise HTTPException(status_code=503, detail="Judge not initialized")
|
81
|
+
|
82
|
+
uptime = time.time() - app_start_time
|
83
|
+
|
84
|
+
return HealthResponse(
|
85
|
+
status="healthy",
|
86
|
+
version=__version__,
|
87
|
+
model=judge.config.model,
|
88
|
+
base_url=judge.config.base_url,
|
89
|
+
uptime_seconds=uptime,
|
90
|
+
total_evaluations=total_evaluations,
|
91
|
+
active_connections=active_connections,
|
92
|
+
metrics_available=len(judge.list_metrics())
|
93
|
+
)
|
94
|
+
|
95
|
+
|
96
|
+
@app.post("/evaluate", response_model=EvaluationResponse)
|
97
|
+
async def evaluate(request: EvaluateRequest):
|
98
|
+
"""Single evaluation endpoint."""
|
99
|
+
global total_evaluations
|
100
|
+
|
101
|
+
if not judge:
|
102
|
+
raise HTTPException(status_code=503, detail="Judge not initialized")
|
103
|
+
|
104
|
+
start_time = time.time()
|
105
|
+
|
106
|
+
try:
|
107
|
+
# Convert scale list to tuple if provided
|
108
|
+
scale = tuple(request.scale) if request.scale else None
|
109
|
+
|
110
|
+
# Perform evaluation with template support
|
111
|
+
result = await judge.evaluate(
|
112
|
+
response=request.response,
|
113
|
+
criteria=request.criteria,
|
114
|
+
rubric=request.rubric,
|
115
|
+
scale=scale,
|
116
|
+
metric=request.metric,
|
117
|
+
context=request.context,
|
118
|
+
system_prompt=request.system_prompt,
|
119
|
+
examples=request.examples,
|
120
|
+
template_vars=request.template_vars,
|
121
|
+
template_engine=request.template_engine
|
122
|
+
)
|
123
|
+
|
124
|
+
# Convert to response model
|
125
|
+
duration_ms = int((time.time() - start_time) * 1000)
|
126
|
+
total_evaluations += 1
|
127
|
+
|
128
|
+
return EvaluationResponse(
|
129
|
+
decision=result.decision,
|
130
|
+
reasoning=result.reasoning,
|
131
|
+
score=result.score,
|
132
|
+
metadata=result.metadata,
|
133
|
+
evaluation_id=str(uuid.uuid4()),
|
134
|
+
timestamp=datetime.utcnow(),
|
135
|
+
duration_ms=duration_ms
|
136
|
+
)
|
137
|
+
|
138
|
+
except VLLMJudgeError as e:
|
139
|
+
raise HTTPException(status_code=400, detail=str(e))
|
140
|
+
except Exception as e:
|
141
|
+
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
142
|
+
|
143
|
+
|
144
|
+
@app.post("/batch", response_model=BatchResponse)
|
145
|
+
async def batch_evaluate(request: BatchEvaluateRequest):
|
146
|
+
"""Synchronous batch evaluation endpoint."""
|
147
|
+
global total_evaluations
|
148
|
+
|
149
|
+
if not judge:
|
150
|
+
raise HTTPException(status_code=503, detail="Judge not initialized")
|
151
|
+
|
152
|
+
# Apply defaults if provided
|
153
|
+
if request.default_criteria or request.default_metric:
|
154
|
+
for item in request.data:
|
155
|
+
if request.default_criteria and "criteria" not in item:
|
156
|
+
item["criteria"] = request.default_criteria
|
157
|
+
if request.default_metric and "metric" not in item:
|
158
|
+
item["metric"] = request.default_metric
|
159
|
+
|
160
|
+
try:
|
161
|
+
# Perform batch evaluation
|
162
|
+
batch_result = await judge.batch_evaluate(
|
163
|
+
data=request.data,
|
164
|
+
max_concurrent=request.max_concurrent
|
165
|
+
)
|
166
|
+
|
167
|
+
# Convert results
|
168
|
+
results = []
|
169
|
+
for i, r in enumerate(batch_result.results):
|
170
|
+
if isinstance(r, EvaluationResult):
|
171
|
+
results.append(EvaluationResponse(
|
172
|
+
decision=r.decision,
|
173
|
+
reasoning=r.reasoning,
|
174
|
+
score=r.score,
|
175
|
+
metadata=r.metadata,
|
176
|
+
evaluation_id=str(uuid.uuid4()),
|
177
|
+
timestamp=datetime.utcnow()
|
178
|
+
))
|
179
|
+
else:
|
180
|
+
# Error case
|
181
|
+
results.append({
|
182
|
+
"error": str(r),
|
183
|
+
"index": i
|
184
|
+
})
|
185
|
+
|
186
|
+
total_evaluations += batch_result.successful
|
187
|
+
|
188
|
+
return BatchResponse(
|
189
|
+
total=batch_result.total,
|
190
|
+
successful=batch_result.successful,
|
191
|
+
failed=batch_result.failed,
|
192
|
+
success_rate=batch_result.success_rate,
|
193
|
+
duration_seconds=batch_result.duration_seconds,
|
194
|
+
results=results
|
195
|
+
)
|
196
|
+
|
197
|
+
except Exception as e:
|
198
|
+
raise HTTPException(status_code=500, detail=f"Batch evaluation failed: {str(e)}")
|
199
|
+
|
200
|
+
|
201
|
+
@app.post("/batch/async", response_model=AsyncBatchResponse)
|
202
|
+
async def async_batch_evaluate(
|
203
|
+
request: AsyncBatchRequest,
|
204
|
+
background_tasks: BackgroundTasks
|
205
|
+
):
|
206
|
+
"""Asynchronous batch evaluation endpoint."""
|
207
|
+
if not judge:
|
208
|
+
raise HTTPException(status_code=503, detail="Judge not initialized")
|
209
|
+
|
210
|
+
# Create job
|
211
|
+
job_id = str(uuid.uuid4())
|
212
|
+
job_info = {
|
213
|
+
"id": job_id,
|
214
|
+
"status": "pending",
|
215
|
+
"data": request.data,
|
216
|
+
"total": len(request.data),
|
217
|
+
"completed": 0,
|
218
|
+
"created_at": datetime.utcnow(),
|
219
|
+
"callback_url": request.callback_url,
|
220
|
+
"max_concurrent": request.max_concurrent
|
221
|
+
}
|
222
|
+
jobs[job_id] = job_info
|
223
|
+
|
224
|
+
# Estimate duration (rough estimate: 0.5s per evaluation)
|
225
|
+
estimated_duration = len(request.data) * 0.5 / (request.max_concurrent or judge.config.max_concurrent)
|
226
|
+
|
227
|
+
# Start background task
|
228
|
+
background_tasks.add_task(
|
229
|
+
run_async_batch,
|
230
|
+
job_id,
|
231
|
+
request.data,
|
232
|
+
request.max_concurrent,
|
233
|
+
request.callback_url
|
234
|
+
)
|
235
|
+
|
236
|
+
return AsyncBatchResponse(
|
237
|
+
job_id=job_id,
|
238
|
+
status="pending",
|
239
|
+
total_items=len(request.data),
|
240
|
+
created_at=job_info["created_at"],
|
241
|
+
estimated_duration_seconds=estimated_duration
|
242
|
+
)
|
243
|
+
|
244
|
+
|
245
|
+
async def run_async_batch(
|
246
|
+
job_id: str,
|
247
|
+
data: List[Dict[str, Any]],
|
248
|
+
max_concurrent: Optional[int],
|
249
|
+
callback_url: Optional[str]
|
250
|
+
):
|
251
|
+
"""Run batch evaluation in background."""
|
252
|
+
global total_evaluations
|
253
|
+
|
254
|
+
job = jobs[job_id]
|
255
|
+
job["status"] = "running"
|
256
|
+
job["started_at"] = datetime.utcnow()
|
257
|
+
|
258
|
+
try:
|
259
|
+
# Progress callback
|
260
|
+
def update_progress(completed: int, total: int):
|
261
|
+
job["completed"] = completed
|
262
|
+
|
263
|
+
# Run evaluation
|
264
|
+
batch_result = await judge.batch_evaluate(
|
265
|
+
data=data,
|
266
|
+
max_concurrent=max_concurrent,
|
267
|
+
progress_callback=update_progress
|
268
|
+
)
|
269
|
+
|
270
|
+
# Update job
|
271
|
+
job["status"] = "completed"
|
272
|
+
job["completed_at"] = datetime.utcnow()
|
273
|
+
job["result"] = batch_result
|
274
|
+
total_evaluations += batch_result.successful
|
275
|
+
|
276
|
+
# Send callback if provided
|
277
|
+
if callback_url:
|
278
|
+
# TODO: Implement callback POST request
|
279
|
+
pass
|
280
|
+
|
281
|
+
except Exception as e:
|
282
|
+
job["status"] = "failed"
|
283
|
+
job["error"] = str(e)
|
284
|
+
job["completed_at"] = datetime.utcnow()
|
285
|
+
|
286
|
+
|
287
|
+
@app.get("/jobs/{job_id}", response_model=JobStatusResponse)
|
288
|
+
async def get_job_status(job_id: str):
|
289
|
+
"""Get status of async job."""
|
290
|
+
if job_id not in jobs:
|
291
|
+
raise HTTPException(status_code=404, detail="Job not found")
|
292
|
+
|
293
|
+
job = jobs[job_id]
|
294
|
+
|
295
|
+
return JobStatusResponse(
|
296
|
+
job_id=job_id,
|
297
|
+
status=job["status"],
|
298
|
+
progress={"completed": job.get("completed", 0), "total": job["total"]},
|
299
|
+
created_at=job["created_at"],
|
300
|
+
started_at=job.get("started_at"),
|
301
|
+
completed_at=job.get("completed_at"),
|
302
|
+
result_url=f"/jobs/{job_id}/result" if job["status"] == "completed" else None,
|
303
|
+
error=job.get("error")
|
304
|
+
)
|
305
|
+
|
306
|
+
|
307
|
+
@app.get("/jobs/{job_id}/result")
|
308
|
+
async def get_job_result(job_id: str):
|
309
|
+
"""Get result of completed async job."""
|
310
|
+
if job_id not in jobs:
|
311
|
+
raise HTTPException(status_code=404, detail="Job not found")
|
312
|
+
|
313
|
+
job = jobs[job_id]
|
314
|
+
|
315
|
+
if job["status"] != "completed":
|
316
|
+
raise HTTPException(
|
317
|
+
status_code=400,
|
318
|
+
detail=f"Job is {job['status']}, not completed"
|
319
|
+
)
|
320
|
+
|
321
|
+
if "result" not in job:
|
322
|
+
raise HTTPException(status_code=500, detail="Job result not found")
|
323
|
+
|
324
|
+
batch_result = job["result"]
|
325
|
+
|
326
|
+
# Convert to response format
|
327
|
+
results = []
|
328
|
+
for r in batch_result.results:
|
329
|
+
if isinstance(r, EvaluationResult):
|
330
|
+
results.append({
|
331
|
+
"decision": r.decision,
|
332
|
+
"reasoning": r.reasoning,
|
333
|
+
"score": r.score,
|
334
|
+
"metadata": r.metadata
|
335
|
+
})
|
336
|
+
else:
|
337
|
+
results.append({"error": str(r)})
|
338
|
+
|
339
|
+
return {
|
340
|
+
"job_id": job_id,
|
341
|
+
"total": batch_result.total,
|
342
|
+
"successful": batch_result.successful,
|
343
|
+
"failed": batch_result.failed,
|
344
|
+
"success_rate": batch_result.success_rate,
|
345
|
+
"duration_seconds": batch_result.duration_seconds,
|
346
|
+
"results": results
|
347
|
+
}
|
348
|
+
|
349
|
+
|
350
|
+
@app.get("/metrics", response_model=List[MetricInfo])
|
351
|
+
async def list_metrics():
|
352
|
+
"""List all available metrics."""
|
353
|
+
if not judge:
|
354
|
+
raise HTTPException(status_code=503, detail="Judge not initialized")
|
355
|
+
|
356
|
+
metrics_info = []
|
357
|
+
|
358
|
+
# Get all metrics (user-registered + built-in)
|
359
|
+
all_metrics = {**judge.metrics, **BUILTIN_METRICS}
|
360
|
+
|
361
|
+
for name, metric in all_metrics.items():
|
362
|
+
info = MetricInfo(
|
363
|
+
name=name,
|
364
|
+
criteria=metric.criteria,
|
365
|
+
has_scale=metric.scale is not None,
|
366
|
+
scale=metric.scale,
|
367
|
+
has_rubric=metric.rubric is not None,
|
368
|
+
rubric_type=type(metric.rubric).__name__ if metric.rubric else None,
|
369
|
+
has_examples=bool(metric.examples),
|
370
|
+
example_count=len(metric.examples) if metric.examples else 0,
|
371
|
+
has_system_prompt=metric.system_prompt is not None,
|
372
|
+
has_template_vars=bool(metric.template_vars),
|
373
|
+
template_vars=metric.template_vars if metric.template_vars else None,
|
374
|
+
required_vars=metric.required_vars if hasattr(metric, 'required_vars') else None,
|
375
|
+
template_engine=metric.template_engine.value if hasattr(metric, 'template_engine') else None
|
376
|
+
)
|
377
|
+
metrics_info.append(info)
|
378
|
+
|
379
|
+
return metrics_info
|
380
|
+
|
381
|
+
|
382
|
+
@app.get("/metrics/{metric_name}")
|
383
|
+
async def get_metric_details(metric_name: str):
|
384
|
+
"""Get detailed information about a specific metric."""
|
385
|
+
if not judge:
|
386
|
+
raise HTTPException(status_code=503, detail="Judge not initialized")
|
387
|
+
|
388
|
+
try:
|
389
|
+
metric = judge.get_metric(metric_name)
|
390
|
+
except Exception:
|
391
|
+
raise HTTPException(status_code=404, detail=f"Metric '{metric_name}' not found")
|
392
|
+
|
393
|
+
return {
|
394
|
+
"name": metric_name,
|
395
|
+
"criteria": metric.criteria,
|
396
|
+
"scale": metric.scale,
|
397
|
+
"rubric": metric.rubric,
|
398
|
+
"examples": metric.examples,
|
399
|
+
"system_prompt": metric.system_prompt,
|
400
|
+
"template_vars": getattr(metric, 'template_vars', None),
|
401
|
+
"required_vars": getattr(metric, 'required_vars', None),
|
402
|
+
"template_engine": getattr(metric, 'template_engine', None)
|
403
|
+
}
|
404
|
+
|
405
|
+
|
406
|
+
@app.websocket("/ws/evaluate")
|
407
|
+
async def websocket_evaluate(websocket: WebSocket):
|
408
|
+
"""WebSocket endpoint for real-time evaluations."""
|
409
|
+
global active_connections
|
410
|
+
|
411
|
+
await websocket.accept()
|
412
|
+
active_connections += 1
|
413
|
+
|
414
|
+
try:
|
415
|
+
while True:
|
416
|
+
# Receive evaluation request
|
417
|
+
data = await websocket.receive_json()
|
418
|
+
|
419
|
+
try:
|
420
|
+
# Perform evaluation
|
421
|
+
request = EvaluateRequest(**data)
|
422
|
+
scale = tuple(request.scale) if request.scale else None
|
423
|
+
|
424
|
+
result = await judge.evaluate(
|
425
|
+
response=request.response,
|
426
|
+
criteria=request.criteria,
|
427
|
+
rubric=request.rubric,
|
428
|
+
scale=scale,
|
429
|
+
metric=request.metric,
|
430
|
+
context=request.context,
|
431
|
+
system_prompt=request.system_prompt,
|
432
|
+
examples=request.examples,
|
433
|
+
template_vars=request.template_vars,
|
434
|
+
template_engine=request.template_engine
|
435
|
+
)
|
436
|
+
|
437
|
+
# Send result
|
438
|
+
await websocket.send_json({
|
439
|
+
"status": "success",
|
440
|
+
"result": {
|
441
|
+
"decision": result.decision,
|
442
|
+
"reasoning": result.reasoning,
|
443
|
+
"score": result.score,
|
444
|
+
"metadata": result.metadata
|
445
|
+
}
|
446
|
+
})
|
447
|
+
|
448
|
+
except Exception as e:
|
449
|
+
await websocket.send_json({
|
450
|
+
"status": "error",
|
451
|
+
"error": str(e)
|
452
|
+
})
|
453
|
+
|
454
|
+
except WebSocketDisconnect:
|
455
|
+
active_connections -= 1
|
456
|
+
|
457
|
+
|
458
|
+
@app.post("/validate/template")
|
459
|
+
async def validate_template(request: Dict[str, Any]):
|
460
|
+
"""Validate template variables for a given template."""
|
461
|
+
template = request.get("template", "")
|
462
|
+
template_vars = request.get("template_vars", {})
|
463
|
+
engine = request.get("template_engine", "format")
|
464
|
+
|
465
|
+
try:
|
466
|
+
# Get required variables
|
467
|
+
required_vars = TemplateProcessor.get_required_vars(
|
468
|
+
template,
|
469
|
+
TemplateEngine(engine)
|
470
|
+
)
|
471
|
+
|
472
|
+
# Check which are missing
|
473
|
+
provided_vars = set(template_vars.keys())
|
474
|
+
missing_vars = required_vars - provided_vars
|
475
|
+
|
476
|
+
# Try to apply template
|
477
|
+
try:
|
478
|
+
result = TemplateProcessor.apply_template(
|
479
|
+
template,
|
480
|
+
template_vars,
|
481
|
+
TemplateEngine(engine),
|
482
|
+
strict=True
|
483
|
+
)
|
484
|
+
|
485
|
+
return {
|
486
|
+
"valid": True,
|
487
|
+
"required_vars": list(required_vars),
|
488
|
+
"provided_vars": list(provided_vars),
|
489
|
+
"missing_vars": list(missing_vars),
|
490
|
+
"result": result
|
491
|
+
}
|
492
|
+
except Exception as e:
|
493
|
+
return {
|
494
|
+
"valid": False,
|
495
|
+
"required_vars": list(required_vars),
|
496
|
+
"provided_vars": list(provided_vars),
|
497
|
+
"missing_vars": list(missing_vars),
|
498
|
+
"error": str(e)
|
499
|
+
}
|
500
|
+
|
501
|
+
except Exception as e:
|
502
|
+
raise HTTPException(status_code=400, detail=f"Validation error: {str(e)}")
|
503
|
+
|
504
|
+
|
505
|
+
@app.post("/metrics/register")
|
506
|
+
async def register_metric(metric_data: Dict[str, Any]):
|
507
|
+
"""Register a new metric dynamically."""
|
508
|
+
if not judge:
|
509
|
+
raise HTTPException(status_code=503, detail="Judge not initialized")
|
510
|
+
|
511
|
+
try:
|
512
|
+
# Create metric from data
|
513
|
+
from vllm_judge.models import Metric
|
514
|
+
|
515
|
+
metric = Metric(
|
516
|
+
name=metric_data["name"],
|
517
|
+
criteria=metric_data["criteria"],
|
518
|
+
rubric=metric_data.get("rubric"),
|
519
|
+
scale=tuple(metric_data["scale"]) if metric_data.get("scale") else None,
|
520
|
+
examples=metric_data.get("examples", []),
|
521
|
+
system_prompt=metric_data.get("system_prompt"),
|
522
|
+
template_vars=metric_data.get("template_vars", {}),
|
523
|
+
required_vars=metric_data.get("required_vars", []),
|
524
|
+
template_engine=metric_data.get("template_engine", "format")
|
525
|
+
)
|
526
|
+
|
527
|
+
# Register with judge
|
528
|
+
judge.register_metric(metric)
|
529
|
+
|
530
|
+
return {"message": f"Metric '{metric.name}' registered successfully"}
|
531
|
+
|
532
|
+
except Exception as e:
|
533
|
+
raise HTTPException(status_code=400, detail=f"Failed to register metric: {str(e)}")
|
534
|
+
|
535
|
+
|
536
|
+
def create_app(config: JudgeConfig) -> FastAPI:
|
537
|
+
"""Create FastAPI app with initialized Judge."""
|
538
|
+
global judge
|
539
|
+
judge = Judge(config)
|
540
|
+
return app
|
541
|
+
|
542
|
+
|
543
|
+
def start_server(
|
544
|
+
base_url: str,
|
545
|
+
model: Optional[str] = None,
|
546
|
+
host: str = "0.0.0.0",
|
547
|
+
port: int = 8080,
|
548
|
+
reload: bool = False,
|
549
|
+
**kwargs
|
550
|
+
):
|
551
|
+
"""Start the API server."""
|
552
|
+
global judge
|
553
|
+
|
554
|
+
# Initialize judge
|
555
|
+
config = JudgeConfig.from_url(base_url, model=model, **kwargs)
|
556
|
+
judge = Judge(config)
|
557
|
+
|
558
|
+
# Run server
|
559
|
+
uvicorn.run(
|
560
|
+
"vllm_judge.api.server:app",
|
561
|
+
host=host,
|
562
|
+
port=port,
|
563
|
+
reload=reload
|
564
|
+
)
|
vllm_judge/batch.py
ADDED
@@ -0,0 +1,147 @@
|
|
1
|
+
import asyncio
|
2
|
+
import time
|
3
|
+
from typing import List, Dict, Any, Callable, Optional, Union
|
4
|
+
from vllm_judge.models import EvaluationResult, BatchResult
|
5
|
+
from vllm_judge.exceptions import VLLMJudgeError
|
6
|
+
|
7
|
+
|
8
|
+
class BatchProcessor:
|
9
|
+
"""High-concurrency batch processing for evaluations."""
|
10
|
+
|
11
|
+
def __init__(self, judge, max_concurrent: int = 50):
|
12
|
+
"""
|
13
|
+
Initialize batch processor.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
judge: Judge instance
|
17
|
+
max_concurrent: Maximum concurrent requests
|
18
|
+
"""
|
19
|
+
self.judge = judge
|
20
|
+
self.semaphore = asyncio.Semaphore(max_concurrent)
|
21
|
+
self.progress_lock = asyncio.Lock()
|
22
|
+
self.completed = 0
|
23
|
+
|
24
|
+
async def process(
|
25
|
+
self,
|
26
|
+
data: List[Dict[str, Any]],
|
27
|
+
progress_callback: Optional[Callable[[int, int], None]] = None,
|
28
|
+
**default_kwargs
|
29
|
+
) -> BatchResult:
|
30
|
+
"""
|
31
|
+
Process batch of evaluations.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
data: List of evaluation inputs
|
35
|
+
progress_callback: Optional callback for progress updates
|
36
|
+
**default_kwargs: Default parameters for all evaluations
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
BatchResult with all results
|
40
|
+
"""
|
41
|
+
start_time = time.time()
|
42
|
+
self.completed = 0
|
43
|
+
total = len(data)
|
44
|
+
|
45
|
+
# Create tasks
|
46
|
+
tasks = []
|
47
|
+
for i, item in enumerate(data):
|
48
|
+
# Merge default kwargs with item-specific kwargs
|
49
|
+
eval_kwargs = {**default_kwargs, **item}
|
50
|
+
|
51
|
+
task = self._process_item(
|
52
|
+
eval_kwargs,
|
53
|
+
i,
|
54
|
+
total,
|
55
|
+
progress_callback
|
56
|
+
)
|
57
|
+
tasks.append(task)
|
58
|
+
|
59
|
+
# Process all tasks
|
60
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
61
|
+
|
62
|
+
# Calculate statistics
|
63
|
+
successful = sum(1 for r in results if isinstance(r, EvaluationResult))
|
64
|
+
failed = total - successful
|
65
|
+
duration = time.time() - start_time
|
66
|
+
|
67
|
+
return BatchResult(
|
68
|
+
results=results,
|
69
|
+
total=total,
|
70
|
+
successful=successful,
|
71
|
+
failed=failed,
|
72
|
+
duration_seconds=duration
|
73
|
+
)
|
74
|
+
|
75
|
+
async def _process_item(
|
76
|
+
self,
|
77
|
+
eval_kwargs: Dict[str, Any],
|
78
|
+
index: int,
|
79
|
+
total: int,
|
80
|
+
progress_callback: Optional[Callable]
|
81
|
+
) -> Union[EvaluationResult, Exception]:
|
82
|
+
"""Process single item with concurrency control."""
|
83
|
+
async with self.semaphore:
|
84
|
+
try:
|
85
|
+
# Extract response from kwargs
|
86
|
+
response = eval_kwargs.pop('response', None)
|
87
|
+
if not response:
|
88
|
+
raise ValueError(f"Item {index} missing 'response' field")
|
89
|
+
|
90
|
+
# Perform evaluation
|
91
|
+
result = await self.judge.evaluate(response=response, **eval_kwargs)
|
92
|
+
|
93
|
+
# Update progress
|
94
|
+
async with self.progress_lock:
|
95
|
+
self.completed += 1
|
96
|
+
if progress_callback:
|
97
|
+
progress_callback(self.completed, total)
|
98
|
+
|
99
|
+
# Add index to metadata
|
100
|
+
result.metadata['batch_index'] = index
|
101
|
+
return result
|
102
|
+
|
103
|
+
except Exception as e:
|
104
|
+
# Update progress even for failures
|
105
|
+
async with self.progress_lock:
|
106
|
+
self.completed += 1
|
107
|
+
if progress_callback:
|
108
|
+
progress_callback(self.completed, total)
|
109
|
+
|
110
|
+
# Return exception with context
|
111
|
+
error = VLLMJudgeError(f"Item {index} failed: {str(e)}")
|
112
|
+
error.batch_index = index
|
113
|
+
error.original_error = e
|
114
|
+
return error
|
115
|
+
|
116
|
+
async def process_streaming(
|
117
|
+
self,
|
118
|
+
data: List[Dict[str, Any]],
|
119
|
+
callback: Callable[[int, Union[EvaluationResult, Exception]], None],
|
120
|
+
**default_kwargs
|
121
|
+
):
|
122
|
+
"""
|
123
|
+
Process batch with streaming results.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
data: List of evaluation inputs
|
127
|
+
callback: Called with (index, result) as results complete
|
128
|
+
**default_kwargs: Default parameters for all evaluations
|
129
|
+
"""
|
130
|
+
async def process_and_callback(item, index):
|
131
|
+
result = await self._process_item(
|
132
|
+
{**default_kwargs, **item},
|
133
|
+
index,
|
134
|
+
len(data),
|
135
|
+
None
|
136
|
+
)
|
137
|
+
callback(index, result)
|
138
|
+
return result
|
139
|
+
|
140
|
+
tasks = [
|
141
|
+
process_and_callback(item, i)
|
142
|
+
for i, item in enumerate(data)
|
143
|
+
]
|
144
|
+
|
145
|
+
# Process tasks as they complete
|
146
|
+
for coro in asyncio.as_completed(tasks):
|
147
|
+
await coro
|