llmasajudge 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmasajudge/__init__.py +1078 -312
- {llmasajudge-0.1.12.dist-info → llmasajudge-0.1.14.dist-info}/METADATA +1 -1
- llmasajudge-0.1.14.dist-info/RECORD +5 -0
- llmasajudge-0.1.12.dist-info/RECORD +0 -5
- {llmasajudge-0.1.12.dist-info → llmasajudge-0.1.14.dist-info}/WHEEL +0 -0
- {llmasajudge-0.1.12.dist-info → llmasajudge-0.1.14.dist-info}/top_level.txt +0 -0
llmasajudge/__init__.py
CHANGED
|
@@ -201,141 +201,726 @@
|
|
|
201
201
|
|
|
202
202
|
|
|
203
203
|
|
|
204
|
-
import os
|
|
205
|
-
import time
|
|
206
|
-
import random
|
|
207
|
-
import re
|
|
208
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
209
|
-
import litellm
|
|
210
|
-
from litellm import completion
|
|
211
|
-
from litellm.caching.caching import Cache
|
|
204
|
+
# import os
|
|
205
|
+
# import time
|
|
206
|
+
# import random
|
|
207
|
+
# import re
|
|
208
|
+
# from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
209
|
+
# import litellm
|
|
210
|
+
# from litellm import completion
|
|
211
|
+
# from litellm.caching.caching import Cache
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# __all__ = ["LLMAsAJudge", "OutputParsers"]
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# class UnlimitedDiskCache:
|
|
218
|
+
# """
|
|
219
|
+
# Drop-in replacement backend with 'unlimited' size for LiteLLM cache.
|
|
220
|
+
|
|
221
|
+
# This wraps diskcache.Cache with a very large size limit (2^62 bytes ~ 4.6 exabytes)
|
|
222
|
+
# to effectively disable automatic cache eviction, allowing the cache to grow
|
|
223
|
+
# without size constraints.
|
|
224
|
+
# """
|
|
225
|
+
|
|
226
|
+
# def __init__(self, directory, size_limit=None):
|
|
227
|
+
# """
|
|
228
|
+
# Initialize unlimited disk cache.
|
|
229
|
+
|
|
230
|
+
# Args:
|
|
231
|
+
# directory: Path to cache directory
|
|
232
|
+
# size_limit: Optional size limit in bytes. If None, uses 2^62 bytes (~4.6 exabytes)
|
|
233
|
+
# """
|
|
234
|
+
# import diskcache as dc
|
|
235
|
+
|
|
236
|
+
# # Set to very large cap so culling never triggers (effectively unlimited)
|
|
237
|
+
# cap = size_limit if size_limit is not None else (1 << 62)
|
|
238
|
+
# self._dc = dc.Cache(directory, size_limit=cap)
|
|
239
|
+
|
|
240
|
+
# # Sync API used by LiteLLM
|
|
241
|
+
# def get_cache(self, key, **kwargs):
|
|
242
|
+
# """Get value from cache by key."""
|
|
243
|
+
# return self._dc.get(key)
|
|
244
|
+
|
|
245
|
+
# def set_cache(self, key, value, ttl=None, **kwargs):
|
|
246
|
+
# """Set value in cache with optional TTL."""
|
|
247
|
+
# expire = None if ttl is None else float(ttl)
|
|
248
|
+
# self._dc.set(key, value, expire=expire)
|
|
249
|
+
|
|
250
|
+
# # Async API used by LiteLLM
|
|
251
|
+
# async def async_get_cache(self, key, **kwargs):
|
|
252
|
+
# """Async get value from cache by key."""
|
|
253
|
+
# return self.get_cache(key, **kwargs)
|
|
254
|
+
|
|
255
|
+
# async def async_set_cache(self, key, value, ttl=None, **kwargs):
|
|
256
|
+
# """Async set value in cache with optional TTL."""
|
|
257
|
+
# return self.set_cache(key, value, ttl=ttl, **kwargs)
|
|
258
|
+
|
|
259
|
+
# async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
|
260
|
+
# """
|
|
261
|
+
# Async batch set multiple cache entries.
|
|
262
|
+
|
|
263
|
+
# Args:
|
|
264
|
+
# cache_list: List of (key, value) tuples
|
|
265
|
+
# ttl: Optional time-to-live in seconds
|
|
266
|
+
# """
|
|
267
|
+
# for k, v in cache_list:
|
|
268
|
+
# self.set_cache(k, v, ttl=ttl)
|
|
269
|
+
|
|
270
|
+
# async def batch_cache_write(self, key, value, ttl=None, **kwargs):
|
|
271
|
+
# """Async batch write (single entry)."""
|
|
272
|
+
# self.set_cache(key, value, ttl=ttl)
|
|
273
|
+
|
|
274
|
+
# async def ping(self):
|
|
275
|
+
# """Async ping check."""
|
|
276
|
+
# return True
|
|
277
|
+
|
|
278
|
+
# async def delete_cache_keys(self, keys):
|
|
279
|
+
# """
|
|
280
|
+
# Async delete multiple cache keys.
|
|
281
|
+
|
|
282
|
+
# Args:
|
|
283
|
+
# keys: List of keys to delete
|
|
284
|
+
# """
|
|
285
|
+
# for k in keys:
|
|
286
|
+
# try:
|
|
287
|
+
# del self._dc[k]
|
|
288
|
+
# except KeyError:
|
|
289
|
+
# pass
|
|
290
|
+
# return True
|
|
291
|
+
|
|
292
|
+
# async def disconnect(self):
|
|
293
|
+
# """Async disconnect and close cache."""
|
|
294
|
+
# self._dc.close()
|
|
295
|
+
|
|
296
|
+
# def get_stats(self):
|
|
297
|
+
# """
|
|
298
|
+
# Get cache statistics.
|
|
299
|
+
|
|
300
|
+
# Returns:
|
|
301
|
+
# dict with size_limit, current_size, item_count, and percent_full
|
|
302
|
+
# """
|
|
303
|
+
# size_limit = self._dc.size_limit
|
|
304
|
+
# volume = self._dc.volume() # Current size in bytes
|
|
305
|
+
# count = len(self._dc) # Number of items
|
|
306
|
+
|
|
307
|
+
# return {
|
|
308
|
+
# "size_limit": size_limit,
|
|
309
|
+
# "current_size": volume,
|
|
310
|
+
# "item_count": count,
|
|
311
|
+
# "percent_full": (volume / size_limit) * 100 if size_limit > 0 else 0.0,
|
|
312
|
+
# }
|
|
313
|
+
|
|
314
|
+
# def print_stats(self):
|
|
315
|
+
# """Print human-readable cache statistics."""
|
|
316
|
+
# stats = self.get_stats()
|
|
317
|
+
|
|
318
|
+
# def human_size(bytes_val):
|
|
319
|
+
# """Convert bytes to human readable format."""
|
|
320
|
+
# for unit in ["B", "KB", "MB", "GB", "TB", "PB", "EB"]:
|
|
321
|
+
# if bytes_val < 1024.0:
|
|
322
|
+
# return f"{bytes_val:.2f} {unit}"
|
|
323
|
+
# bytes_val /= 1024.0
|
|
324
|
+
# return f"{bytes_val:.2f} EB"
|
|
325
|
+
|
|
326
|
+
# print("=" * 60)
|
|
327
|
+
# print("CACHE STATISTICS")
|
|
328
|
+
# print("=" * 60)
|
|
329
|
+
# print(f" Size limit: {human_size(stats['size_limit'])} ({stats['size_limit']:,} bytes)")
|
|
330
|
+
# print(f" Current size: {human_size(stats['current_size'])} ({stats['current_size']:,} bytes)")
|
|
331
|
+
# print(f" Items cached: {stats['item_count']}")
|
|
332
|
+
# print(f" % full: {stats['percent_full']:.6f}%")
|
|
333
|
+
# print("=" * 60)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
# class OutputParsers:
|
|
337
|
+
# """Stock output parsers for common judge output formats."""
|
|
212
338
|
|
|
339
|
+
# @staticmethod
|
|
340
|
+
# def right_wrong(s: str) -> Optional[bool]:
|
|
341
|
+
# """Parse 'right' or 'wrong' from the last 6 characters."""
|
|
342
|
+
# if not s:
|
|
343
|
+
# return None
|
|
344
|
+
# tail = s.strip()[-6:].lower()
|
|
345
|
+
# if "right" in tail:
|
|
346
|
+
# return True
|
|
347
|
+
# if "wrong" in tail:
|
|
348
|
+
# return False
|
|
349
|
+
# return None
|
|
213
350
|
|
|
214
|
-
|
|
351
|
+
# @staticmethod
|
|
352
|
+
# def pass_fail(s: str) -> Optional[bool]:
|
|
353
|
+
# """Parse 'pass' or 'fail' from the response."""
|
|
354
|
+
# if not s:
|
|
355
|
+
# return None
|
|
356
|
+
# text = s.strip().lower()
|
|
357
|
+
# if "pass" in text:
|
|
358
|
+
# return True
|
|
359
|
+
# if "fail" in text:
|
|
360
|
+
# return False
|
|
361
|
+
# return None
|
|
215
362
|
|
|
363
|
+
# @staticmethod
|
|
364
|
+
# def yes_no(s: str) -> Optional[bool]:
|
|
365
|
+
# """Parse 'yes' or 'no' from the response."""
|
|
366
|
+
# if not s:
|
|
367
|
+
# return None
|
|
368
|
+
# text = s.strip().lower()
|
|
369
|
+
# if "yes" in text:
|
|
370
|
+
# return True
|
|
371
|
+
# if "no" in text:
|
|
372
|
+
# return False
|
|
373
|
+
# return None
|
|
216
374
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
375
|
+
# @staticmethod
|
|
376
|
+
# def numeric_score(s: str) -> Optional[float]:
|
|
377
|
+
# """Extract first numeric value from the response."""
|
|
378
|
+
# if not s:
|
|
379
|
+
# return None
|
|
380
|
+
# match = re.search(r'[-+]?\d*\.?\d+', s.strip())
|
|
381
|
+
# if match:
|
|
382
|
+
# return float(match.group())
|
|
383
|
+
# return None
|
|
220
384
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
385
|
+
# @staticmethod
|
|
386
|
+
# def json_extract(key: str) -> Callable[[str], Any]:
|
|
387
|
+
# """Create a parser that extracts a specific key from JSON output."""
|
|
388
|
+
# import json
|
|
389
|
+
# def parser(s: str) -> Any:
|
|
390
|
+
# if not s:
|
|
391
|
+
# return None
|
|
392
|
+
# try:
|
|
393
|
+
# data = json.loads(s.strip())
|
|
394
|
+
# return data.get(key)
|
|
395
|
+
# except (json.JSONDecodeError, AttributeError):
|
|
396
|
+
# return None
|
|
397
|
+
# return parser
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
# class LLMAsAJudge:
|
|
401
|
+
# BASE_TEMPLATE = """\
|
|
402
|
+
# You are a judge. Read input, model_output, and ground_truth.
|
|
403
|
+
# {instruction}
|
|
404
|
+
# ##################
|
|
405
|
+
# {notes_section}### input:
|
|
406
|
+
# {input_block}
|
|
407
|
+
# ##################
|
|
408
|
+
# model's output:
|
|
409
|
+
# {model_output}
|
|
410
|
+
# ##################
|
|
411
|
+
# ground_truth answer:
|
|
412
|
+
# {ground_truth}
|
|
413
|
+
# ##################
|
|
414
|
+
# """
|
|
225
415
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
416
|
+
# PARSER_INSTRUCTIONS = {
|
|
417
|
+
# 'right/wrong': """\
|
|
418
|
+
# Return exactly one word: right or wrong.
|
|
419
|
+
# Rules:
|
|
420
|
+
# - Treat extra words or punctuation as irrelevant if the same final value is present.
|
|
421
|
+
# - Output must be exactly right or wrong. No JSON. No quotes. No extra text.""",
|
|
422
|
+
# 'yes/no': """\
|
|
423
|
+
# Return exactly one word: yes or no.
|
|
424
|
+
# Answer yes if the model output matches the ground truth, no otherwise.
|
|
425
|
+
# Rules:
|
|
426
|
+
# - Treat extra words or punctuation as irrelevant if the same final value is present.
|
|
427
|
+
# - Output must be exactly yes or no. No JSON. No quotes. No extra text.""",
|
|
428
|
+
# 'pass/fail': """\
|
|
429
|
+
# Return exactly one word: pass or fail.
|
|
430
|
+
# Answer pass if the model output matches the ground truth, fail otherwise.
|
|
431
|
+
# Rules:
|
|
432
|
+
# - Treat extra words or punctuation as irrelevant if the same final value is present.
|
|
433
|
+
# - Output must be exactly pass or fail. No JSON. No quotes. No extra text.""",
|
|
434
|
+
# 'numeric': """\
|
|
435
|
+
# Return a single numeric score from 0-10 indicating how well the model output matches the ground truth.
|
|
436
|
+
# - 10 = perfect match
|
|
437
|
+
# - 7-9 = close match with minor differences
|
|
438
|
+
# - 4-6 = partial match
|
|
439
|
+
# - 1-3 = poor match
|
|
440
|
+
# - 0 = completely wrong
|
|
441
|
+
# Output only the number. No explanation. No extra text.""",
|
|
442
|
+
# }
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
# # def __init__(
|
|
448
|
+
# # self,
|
|
449
|
+
# # models: Optional[List[str]] = None,
|
|
450
|
+
# # config: Optional[Dict[str, Dict[str, Any]]] = None, # one dict for providers and models
|
|
451
|
+
# # base_headers: Optional[Dict[str, str]] = None,
|
|
452
|
+
# # wandb_project: Optional[str] = None,
|
|
453
|
+
# # custom_template: Optional[str] = None,
|
|
454
|
+
# # use_fully_custom_prompt: bool = False,
|
|
455
|
+
# # notes: Optional[str] = None,
|
|
456
|
+
# # output_parser: Optional[str] = 'right/wrong',
|
|
457
|
+
# # fallback_comparison: bool = True,
|
|
458
|
+
# # default_temperature: float = 0.0,
|
|
459
|
+
# # verbose: bool = False,
|
|
460
|
+
# # num_retries: int = 2, # per-call retries before giving up on that model
|
|
461
|
+
# # backoff_base: float = 0.5, # seconds
|
|
462
|
+
# # backoff_max: float = 4.0, # seconds
|
|
463
|
+
# # custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
|
|
464
|
+
# # mode: str = "majority", # "single", "majority", "all"
|
|
465
|
+
# # ):
|
|
466
|
+
# # """
|
|
467
|
+
# # config keys can be a provider name ("wandb", "openai", "anthropic")
|
|
468
|
+
# # or a full model name ("openai/gpt-4o-mini", "wandb/deepseek-ai/DeepSeek-V3.1").
|
|
469
|
+
|
|
470
|
+
# # Values can include:
|
|
471
|
+
# # api_base: Optional[str]
|
|
472
|
+
# # headers: Dict[str, str]
|
|
473
|
+
# # temperature: float
|
|
474
|
+
|
|
475
|
+
# # Precedence:
|
|
476
|
+
# # base_headers < provider config < model config
|
|
477
|
+
|
|
478
|
+
# # Args:
|
|
479
|
+
# # models: List of litellm model strings (e.g., ["openai/gpt-4", "anthropic/claude-3"])
|
|
480
|
+
# # custom_template: Template with placeholders for input/output/ground_truth
|
|
481
|
+
# # use_fully_custom_prompt: If True, pass complete prompt to judge(prompt=...).
|
|
482
|
+
# # When True, input/output/ground_truth must NOT be passed to judge()
|
|
483
|
+
# # output_parser: Parser name ('right/wrong', 'yes/no', 'pass/fail', 'numeric')
|
|
484
|
+
# # or custom function with signature (str) -> Any
|
|
485
|
+
# # fallback_comparison: If True and parser returns None, falls back to string comparison
|
|
486
|
+
# # custom_generation_fns: List of custom inference functions with signature fn(prompt: str) -> str
|
|
487
|
+
# # These will be used in addition to litellm models for voting.
|
|
488
|
+
# # mode: Voting mode - "majority" (default), "single" (first judge only), or "all" (unanimous)
|
|
489
|
+
# # """
|
|
490
|
+
# # self.models = models or []
|
|
491
|
+
# # self.custom_generation_fns = custom_generation_fns or []
|
|
492
|
+
|
|
493
|
+
# # # Validate that at least one judge is provided
|
|
494
|
+
# # if not self.models and not self.custom_generation_fns:
|
|
495
|
+
# # raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
|
|
496
|
+
|
|
497
|
+
# # # Validate mode
|
|
498
|
+
# # if mode not in ("majority", "single", "all"):
|
|
499
|
+
# # raise ValueError("mode must be 'majority', 'single', or 'all'")
|
|
500
|
+
|
|
501
|
+
# # self.config = config or {}
|
|
502
|
+
# # self.base_headers = dict(base_headers or {})
|
|
503
|
+
# # self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
|
|
504
|
+
# # self.notes = notes or ""
|
|
505
|
+
# # self.use_fully_custom_prompt = use_fully_custom_prompt
|
|
506
|
+
# # self.mode = mode
|
|
507
|
+
|
|
508
|
+
# # # Resolve output parser
|
|
509
|
+
# # parser_name = None
|
|
510
|
+
# # if isinstance(output_parser, str):
|
|
511
|
+
# # parser_map = {
|
|
512
|
+
# # 'right/wrong': OutputParsers.right_wrong,
|
|
513
|
+
# # 'pass/fail': OutputParsers.pass_fail,
|
|
514
|
+
# # 'yes/no': OutputParsers.yes_no,
|
|
515
|
+
# # 'numeric': OutputParsers.numeric_score,
|
|
516
|
+
# # }
|
|
517
|
+
# # if output_parser not in parser_map:
|
|
518
|
+
# # raise ValueError(f"Unknown parser '{output_parser}'. Available: {list(parser_map.keys())}")
|
|
519
|
+
# # self.output_parser = parser_map[output_parser]
|
|
520
|
+
# # parser_name = output_parser
|
|
521
|
+
# # else:
|
|
522
|
+
# # self.output_parser = output_parser
|
|
523
|
+
|
|
524
|
+
# # # Set template based on mode
|
|
525
|
+
# # if use_fully_custom_prompt:
|
|
526
|
+
# # self.template = None # No template in fully custom mode
|
|
527
|
+
# # elif custom_template:
|
|
528
|
+
# # self.template = custom_template
|
|
529
|
+
# # elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
|
|
530
|
+
# # self.template = self.BASE_TEMPLATE.format(
|
|
531
|
+
# # instruction=self.PARSER_INSTRUCTIONS[parser_name],
|
|
532
|
+
# # notes_section="{notes_section}",
|
|
533
|
+
# # input_block="{input_block}",
|
|
534
|
+
# # model_output="{model_output}",
|
|
535
|
+
# # ground_truth="{ground_truth}",
|
|
536
|
+
# # )
|
|
537
|
+
# # else:
|
|
538
|
+
# # # Default to right/wrong for custom parsers
|
|
539
|
+
# # self.template = self.BASE_TEMPLATE.format(
|
|
540
|
+
# # instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
|
|
541
|
+
# # notes_section="{notes_section}",
|
|
542
|
+
# # input_block="{input_block}",
|
|
543
|
+
# # model_output="{model_output}",
|
|
544
|
+
# # ground_truth="{ground_truth}",
|
|
545
|
+
# # )
|
|
546
|
+
|
|
547
|
+
# # self.fallback_comparison = fallback_comparison
|
|
548
|
+
# # self.default_temperature = float(default_temperature)
|
|
549
|
+
# # self.verbose = verbose
|
|
550
|
+
# # self.num_retries = int(num_retries)
|
|
551
|
+
# # self.backoff_base = float(backoff_base)
|
|
552
|
+
# # self.backoff_max = float(backoff_max)
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
# def __init__(
|
|
560
|
+
# self,
|
|
561
|
+
# models: Optional[List[str]] = None,
|
|
562
|
+
# config: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
563
|
+
# base_headers: Optional[Dict[str, str]] = None,
|
|
564
|
+
# wandb_project: Optional[str] = None,
|
|
565
|
+
# custom_template: Optional[str] = None,
|
|
566
|
+
# use_fully_custom_prompt: bool = False,
|
|
567
|
+
# notes: Optional[str] = None,
|
|
568
|
+
# output_parser: Optional[str] = 'right/wrong',
|
|
569
|
+
# fallback_comparison: bool = True,
|
|
570
|
+
# default_temperature: float = 0.0,
|
|
571
|
+
# verbose: bool = False,
|
|
572
|
+
# num_retries: int = 2,
|
|
573
|
+
# backoff_base: float = 0.5,
|
|
574
|
+
# backoff_max: float = 4.0,
|
|
575
|
+
# custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
|
|
576
|
+
# mode: str = "majority",
|
|
577
|
+
# use_cache: bool = False,
|
|
578
|
+
# litellm_cache_dir: Optional[str] = None,
|
|
579
|
+
# cache_size_gb: Optional[float] = None,
|
|
580
|
+
# ):
|
|
581
|
+
# self.models = models or []
|
|
582
|
+
# self.custom_generation_fns = custom_generation_fns or []
|
|
583
|
+
|
|
584
|
+
# if not self.models and not self.custom_generation_fns:
|
|
585
|
+
# raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
|
|
586
|
+
|
|
587
|
+
# if mode not in ("majority", "single", "all"):
|
|
588
|
+
# raise ValueError("mode must be 'majority', 'single', or 'all'")
|
|
589
|
+
|
|
590
|
+
# self.config = config or {}
|
|
591
|
+
# self.base_headers = dict(base_headers or {})
|
|
592
|
+
# self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
|
|
593
|
+
# self.notes = notes or ""
|
|
594
|
+
# self.use_fully_custom_prompt = use_fully_custom_prompt
|
|
595
|
+
# self.mode = mode
|
|
596
|
+
# self.fallback_comparison = fallback_comparison
|
|
597
|
+
# self.default_temperature = float(default_temperature)
|
|
598
|
+
# self.verbose = verbose
|
|
599
|
+
# self.num_retries = int(num_retries)
|
|
600
|
+
# self.backoff_base = float(backoff_base)
|
|
601
|
+
# self.backoff_max = float(backoff_max)
|
|
602
|
+
|
|
603
|
+
# parser_name = None
|
|
604
|
+
# if isinstance(output_parser, str):
|
|
605
|
+
# parser_map = {
|
|
606
|
+
# 'right/wrong': OutputParsers.right_wrong,
|
|
607
|
+
# 'pass/fail': OutputParsers.pass_fail,
|
|
608
|
+
# 'yes/no': OutputParsers.yes_no,
|
|
609
|
+
# 'numeric': OutputParsers.numeric_score,
|
|
610
|
+
# }
|
|
611
|
+
# if output_parser not in parser_map:
|
|
612
|
+
# raise ValueError(f"Unknown parser '{output_parser}'")
|
|
613
|
+
# self.output_parser = parser_map[output_parser]
|
|
614
|
+
# parser_name = output_parser
|
|
615
|
+
# else:
|
|
616
|
+
# self.output_parser = output_parser
|
|
617
|
+
|
|
618
|
+
# if use_fully_custom_prompt:
|
|
619
|
+
# self.template = None
|
|
620
|
+
# elif custom_template:
|
|
621
|
+
# self.template = custom_template
|
|
622
|
+
# elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
|
|
623
|
+
# self.template = self.BASE_TEMPLATE.format(
|
|
624
|
+
# instruction=self.PARSER_INSTRUCTIONS[parser_name],
|
|
625
|
+
# notes_section="{notes_section}",
|
|
626
|
+
# input_block="{input_block}",
|
|
627
|
+
# model_output="{model_output}",
|
|
628
|
+
# ground_truth="{ground_truth}",
|
|
629
|
+
# )
|
|
630
|
+
# else:
|
|
631
|
+
# self.template = self.BASE_TEMPLATE.format(
|
|
632
|
+
# instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
|
|
633
|
+
# notes_section="{notes_section}",
|
|
634
|
+
# input_block="{input_block}",
|
|
635
|
+
# model_output="{model_output}",
|
|
636
|
+
# ground_truth="{ground_truth}",
|
|
637
|
+
# )
|
|
229
638
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
639
|
+
# # optional local cache setup
|
|
640
|
+
# # Enable cache if use_cache=True OR if litellm_cache_dir is explicitly provided (backward compatible)
|
|
641
|
+
# self.cache_enabled = use_cache or (litellm_cache_dir is not None)
|
|
642
|
+
# if self.cache_enabled:
|
|
643
|
+
# # Only set up cache if it hasn't been set up already
|
|
644
|
+
# if litellm.cache is None:
|
|
645
|
+
# # Set default cache directory if not specified
|
|
646
|
+
# if litellm_cache_dir is None:
|
|
647
|
+
# litellm_cache_dir = ".litellm_cache"
|
|
235
648
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
649
|
+
# # Convert GB to bytes if specified, otherwise unlimited
|
|
650
|
+
# size_limit_bytes = None if cache_size_gb is None else int(cache_size_gb * 1024 * 1024 * 1024)
|
|
651
|
+
# cache_backend = UnlimitedDiskCache(litellm_cache_dir, size_limit=size_limit_bytes)
|
|
652
|
+
# litellm.cache = Cache(disk_cache_dir=litellm_cache_dir)
|
|
653
|
+
# litellm.cache.cache = cache_backend
|
|
239
654
|
|
|
240
|
-
# Sync API used by LiteLLM
|
|
241
|
-
def get_cache(self, key, **kwargs):
|
|
242
|
-
"""Get value from cache by key."""
|
|
243
|
-
return self._dc.get(key)
|
|
244
655
|
|
|
245
|
-
def set_cache(self, key, value, ttl=None, **kwargs):
|
|
246
|
-
"""Set value in cache with optional TTL."""
|
|
247
|
-
expire = None if ttl is None else float(ttl)
|
|
248
|
-
self._dc.set(key, value, expire=expire)
|
|
249
656
|
|
|
250
|
-
# Async API used by LiteLLM
|
|
251
|
-
async def async_get_cache(self, key, **kwargs):
|
|
252
|
-
"""Async get value from cache by key."""
|
|
253
|
-
return self.get_cache(key, **kwargs)
|
|
254
657
|
|
|
255
|
-
async def async_set_cache(self, key, value, ttl=None, **kwargs):
|
|
256
|
-
"""Async set value in cache with optional TTL."""
|
|
257
|
-
return self.set_cache(key, value, ttl=ttl, **kwargs)
|
|
258
658
|
|
|
259
|
-
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
|
260
|
-
"""
|
|
261
|
-
Async batch set multiple cache entries.
|
|
262
659
|
|
|
263
|
-
Args:
|
|
264
|
-
cache_list: List of (key, value) tuples
|
|
265
|
-
ttl: Optional time-to-live in seconds
|
|
266
|
-
"""
|
|
267
|
-
for k, v in cache_list:
|
|
268
|
-
self.set_cache(k, v, ttl=ttl)
|
|
269
660
|
|
|
270
|
-
async def batch_cache_write(self, key, value, ttl=None, **kwargs):
|
|
271
|
-
"""Async batch write (single entry)."""
|
|
272
|
-
self.set_cache(key, value, ttl=ttl)
|
|
273
661
|
|
|
274
|
-
async def ping(self):
|
|
275
|
-
"""Async ping check."""
|
|
276
|
-
return True
|
|
277
662
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
663
|
+
# def _build_prompt(self, input: Any, model_output: Any, ground_truth: Any) -> str:
|
|
664
|
+
# notes_section = f"notes:\n{self.notes}\n" if self.notes else ""
|
|
665
|
+
# input_text = str(input) if input not in (None, "") else "[omitted input for brevity]"
|
|
666
|
+
# return self.template.format(
|
|
667
|
+
# notes_section=notes_section,
|
|
668
|
+
# input_block=input_text,
|
|
669
|
+
# model_output=str(model_output),
|
|
670
|
+
# ground_truth=str(ground_truth),
|
|
671
|
+
# )
|
|
281
672
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
673
|
+
# @staticmethod
|
|
674
|
+
# def _last6_right_wrong(s: str):
|
|
675
|
+
# if not s:
|
|
676
|
+
# return None
|
|
677
|
+
# tail = s.strip()[-6:].lower()
|
|
678
|
+
# if "right" in tail:
|
|
679
|
+
# return True
|
|
680
|
+
# if "wrong" in tail:
|
|
681
|
+
# return False
|
|
682
|
+
# return None
|
|
291
683
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
684
|
+
# def _resolve_per_model(self, model: str) -> Tuple[Optional[str], Dict[str, str], float]:
|
|
685
|
+
# provider = model.split("/", 1)[0] if "/" in model else model
|
|
686
|
+
|
|
687
|
+
# api_base: Optional[str] = None
|
|
688
|
+
# headers: Dict[str, str] = dict(self.base_headers)
|
|
689
|
+
# temperature: float = self.default_temperature
|
|
690
|
+
|
|
691
|
+
# # provider-level defaults
|
|
692
|
+
# pc = self.config.get(provider, {})
|
|
693
|
+
# if pc.get("api_base") is not None:
|
|
694
|
+
# api_base = pc["api_base"]
|
|
695
|
+
# headers.update(pc.get("headers", {}))
|
|
696
|
+
# if "temperature" in pc:
|
|
697
|
+
# temperature = float(pc["temperature"])
|
|
698
|
+
|
|
699
|
+
# # model-level overrides
|
|
700
|
+
# mc = self.config.get(model, {})
|
|
701
|
+
# if mc.get("api_base") is not None:
|
|
702
|
+
# api_base = mc["api_base"]
|
|
703
|
+
# headers.update(mc.get("headers", {}))
|
|
704
|
+
# if "temperature" in mc:
|
|
705
|
+
# temperature = float(mc["temperature"])
|
|
706
|
+
|
|
707
|
+
# # wandb defaults
|
|
708
|
+
# if provider == "wandb":
|
|
709
|
+
# if api_base is None:
|
|
710
|
+
# api_base = "https://api.inference.wandb.ai/v1"
|
|
711
|
+
# if "OpenAI-Project" not in headers:
|
|
712
|
+
# headers["OpenAI-Project"] = self.wandb_project or "wandb_fc/quickstart_playground"
|
|
713
|
+
|
|
714
|
+
# return api_base, headers, temperature
|
|
715
|
+
|
|
716
|
+
# def _attempt_completion(
|
|
717
|
+
# self,
|
|
718
|
+
# model: str,
|
|
719
|
+
# api_base: Optional[str],
|
|
720
|
+
# headers: Dict[str, str],
|
|
721
|
+
# prompt: str,
|
|
722
|
+
# temperature: float,
|
|
723
|
+
# max_tokens: int,
|
|
724
|
+
# ) -> str:
|
|
725
|
+
# attempts = self.num_retries + 1
|
|
726
|
+
# last_err = None
|
|
727
|
+
# for i in range(attempts):
|
|
728
|
+
# try:
|
|
729
|
+
# # resp = completion(
|
|
730
|
+
# # model=model,
|
|
731
|
+
# # api_base=api_base, # None uses provider default
|
|
732
|
+
# # messages=[{"role": "user", "content": prompt}],
|
|
733
|
+
# # temperature=temperature,
|
|
734
|
+
# # max_tokens=max_tokens,
|
|
735
|
+
# # extra_headers=headers,
|
|
736
|
+
# # )
|
|
737
|
+
|
|
738
|
+
# resp = completion(
|
|
739
|
+
# model=model,
|
|
740
|
+
# api_base=api_base,
|
|
741
|
+
# messages=[{"role": "user", "content": prompt}],
|
|
742
|
+
# temperature=temperature,
|
|
743
|
+
# max_tokens=max_tokens,
|
|
744
|
+
# extra_headers=headers,
|
|
745
|
+
# caching=self.cache_enabled
|
|
746
|
+
# )
|
|
747
|
+
# return (resp.choices[0].message.content or "").strip()
|
|
748
|
+
# except Exception as e:
|
|
749
|
+
# last_err = e
|
|
750
|
+
# if i == attempts - 1:
|
|
751
|
+
# break
|
|
752
|
+
# sleep_s = min(self.backoff_max, self.backoff_base * (2 ** i))
|
|
753
|
+
# jitter = sleep_s * (0.1 * (2 * random.random() - 1.0)) # ±10%
|
|
754
|
+
# if self.verbose:
|
|
755
|
+
# print(f"[retry {i+1}/{attempts-1}] {model} error: {e} — sleeping {max(0.0, sleep_s + jitter):.2f}s", flush=True)
|
|
756
|
+
# time.sleep(max(0.0, sleep_s + jitter))
|
|
757
|
+
# raise last_err # fail after retries
|
|
758
|
+
|
|
759
|
+
# def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any):
|
|
760
|
+
# api_base, headers, temperature = self._resolve_per_model(model)
|
|
761
|
+
|
|
762
|
+
# content = self._attempt_completion(
|
|
763
|
+
# model=model,
|
|
764
|
+
# api_base=api_base,
|
|
765
|
+
# headers=headers,
|
|
766
|
+
# prompt=prompt,
|
|
767
|
+
# temperature=temperature,
|
|
768
|
+
# max_tokens=max_tokens,
|
|
769
|
+
# )
|
|
295
770
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
771
|
+
# # Use the instance parser
|
|
772
|
+
# parsed = self.output_parser(content)
|
|
773
|
+
|
|
774
|
+
# # If parser returns None and fallback is enabled, do string comparison
|
|
775
|
+
# if parsed is None and self.fallback_comparison:
|
|
776
|
+
# return str(model_output).strip() == str(ground_truth).strip()
|
|
777
|
+
|
|
778
|
+
# return parsed
|
|
779
|
+
|
|
780
|
+
# def judge(
|
|
781
|
+
# self,
|
|
782
|
+
# input: Any = None,
|
|
783
|
+
# model_output: Any = None,
|
|
784
|
+
# ground_truth: Any = None,
|
|
785
|
+
# prompt: Optional[str] = None,
|
|
786
|
+
# max_tokens: int = 10000,
|
|
787
|
+
# ):
|
|
788
|
+
# # Validation for fully_custom_prompt mode
|
|
789
|
+
# if self.use_fully_custom_prompt:
|
|
790
|
+
# if prompt is None:
|
|
791
|
+
# raise ValueError(
|
|
792
|
+
# "When use_fully_custom_prompt=True, you must pass prompt to judge()."
|
|
793
|
+
# )
|
|
794
|
+
# if input is not None or model_output is not None or ground_truth is not None:
|
|
795
|
+
# raise ValueError(
|
|
796
|
+
# "When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth to judge(). "
|
|
797
|
+
# "Only pass the complete prompt."
|
|
798
|
+
# )
|
|
799
|
+
# elif prompt is not None:
|
|
800
|
+
# raise ValueError(
|
|
801
|
+
# "prompt parameter can only be used when use_fully_custom_prompt=True. "
|
|
802
|
+
# "Otherwise, use input/model_output/ground_truth."
|
|
803
|
+
# )
|
|
804
|
+
# else:
|
|
805
|
+
# prompt = self._build_prompt(input, model_output, ground_truth)
|
|
806
|
+
|
|
807
|
+
# votes = []
|
|
808
|
+
|
|
809
|
+
# # Vote with litellm models
|
|
810
|
+
# for m in self.models:
|
|
811
|
+
# res = self._ask_model(m, prompt, max_tokens, model_output, ground_truth)
|
|
812
|
+
# if self.verbose:
|
|
813
|
+
# print(f"Model {m} voted: {res}", flush=True)
|
|
814
|
+
# votes.append({"model": m, "correct": res})
|
|
815
|
+
|
|
816
|
+
# # Vote with custom generation functions
|
|
817
|
+
# for idx, custom_fn in enumerate(self.custom_generation_fns):
|
|
818
|
+
# try:
|
|
819
|
+
# content = custom_fn(prompt)
|
|
820
|
+
# parsed = self.output_parser(content)
|
|
821
|
+
|
|
822
|
+
# # If parser returns None and fallback is enabled, do string comparison
|
|
823
|
+
# if parsed is None and self.fallback_comparison:
|
|
824
|
+
# res = str(model_output).strip() == str(ground_truth).strip()
|
|
825
|
+
# else:
|
|
826
|
+
# res = parsed
|
|
827
|
+
|
|
828
|
+
# if self.verbose:
|
|
829
|
+
# print(f"Custom function {idx} voted: {res}", flush=True)
|
|
830
|
+
# votes.append({"model": f"custom_fn_{idx}", "correct": res})
|
|
831
|
+
# except Exception as e:
|
|
832
|
+
# if self.verbose:
|
|
833
|
+
# print(f"Custom function {idx} failed: {e}", flush=True)
|
|
834
|
+
# # If custom function fails and fallback is enabled, do string comparison
|
|
835
|
+
# if self.fallback_comparison:
|
|
836
|
+
# res = str(model_output).strip() == str(ground_truth).strip()
|
|
837
|
+
# votes.append({"model": f"custom_fn_{idx}", "correct": res})
|
|
838
|
+
# else:
|
|
839
|
+
# raise
|
|
840
|
+
|
|
841
|
+
# if self.mode == "single":
|
|
842
|
+
# final = votes[0]["correct"]
|
|
843
|
+
# elif self.mode == "majority":
|
|
844
|
+
# true_votes = sum(v["correct"] for v in votes)
|
|
845
|
+
# false_votes = len(votes) - true_votes
|
|
846
|
+
# final = True if true_votes >= false_votes else False
|
|
847
|
+
# elif self.mode == "all":
|
|
848
|
+
# final = all(v["correct"] for v in votes)
|
|
849
|
+
# else:
|
|
850
|
+
# raise ValueError("mode must be 'majority', 'single', or 'all'")
|
|
851
|
+
|
|
852
|
+
# return {"correct": final, "mode": self.mode, "votes": votes}
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
import os
|
|
859
|
+
import time
|
|
860
|
+
import random
|
|
861
|
+
import re
|
|
862
|
+
import json
|
|
863
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
864
|
+
from enum import Enum
|
|
865
|
+
import litellm
|
|
866
|
+
from litellm import completion
|
|
867
|
+
from litellm.caching.caching import Cache
|
|
299
868
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
869
|
+
|
|
870
|
+
__all__ = ["LLMAsAJudge", "OutputParsers"]
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
class ReturnType(Enum):
|
|
874
|
+
"""Enum for categorizing parser return types."""
|
|
875
|
+
BOOLEAN = "boolean"
|
|
876
|
+
SCALAR = "scalar"
|
|
877
|
+
MAP = "map"
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
class AggregationMode(Enum):
|
|
881
|
+
"""Enum for aggregation modes across multiple voters."""
|
|
882
|
+
# Boolean modes
|
|
883
|
+
MAJORITY = "majority"
|
|
884
|
+
SINGLE = "single"
|
|
885
|
+
ALL = "all"
|
|
886
|
+
# Scalar modes
|
|
887
|
+
AVERAGE = "average"
|
|
888
|
+
MIN = "min"
|
|
889
|
+
MAX = "max"
|
|
890
|
+
MEDIAN = "median"
|
|
891
|
+
|
|
892
|
+
|
|
893
|
+
# Valid aggregation modes per return type
|
|
894
|
+
VALID_MODES = {
|
|
895
|
+
ReturnType.BOOLEAN: {AggregationMode.MAJORITY, AggregationMode.SINGLE, AggregationMode.ALL},
|
|
896
|
+
ReturnType.SCALAR: {AggregationMode.AVERAGE, AggregationMode.MIN, AggregationMode.MAX, AggregationMode.MEDIAN, AggregationMode.SINGLE},
|
|
897
|
+
ReturnType.MAP: {AggregationMode.AVERAGE, AggregationMode.MIN, AggregationMode.MAX, AggregationMode.MEDIAN, AggregationMode.SINGLE},
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
# Default aggregation modes per return type
|
|
901
|
+
DEFAULT_MODES = {
|
|
902
|
+
ReturnType.BOOLEAN: AggregationMode.MAJORITY,
|
|
903
|
+
ReturnType.SCALAR: AggregationMode.AVERAGE,
|
|
904
|
+
ReturnType.MAP: AggregationMode.AVERAGE,
|
|
905
|
+
}
|
|
906
|
+
|
|
907
|
+
# String to enum mapping (for backward compat)
|
|
908
|
+
MODE_STR_MAP = {
|
|
909
|
+
'majority': AggregationMode.MAJORITY,
|
|
910
|
+
'single': AggregationMode.SINGLE,
|
|
911
|
+
'all': AggregationMode.ALL,
|
|
912
|
+
'average': AggregationMode.AVERAGE,
|
|
913
|
+
'min': AggregationMode.MIN,
|
|
914
|
+
'max': AggregationMode.MAX,
|
|
915
|
+
'median': AggregationMode.MEDIAN,
|
|
916
|
+
}
|
|
334
917
|
|
|
335
918
|
|
|
336
919
|
class OutputParsers:
|
|
337
920
|
"""Stock output parsers for common judge output formats."""
|
|
338
921
|
|
|
922
|
+
# ========== BOOLEAN PARSERS ==========
|
|
923
|
+
|
|
339
924
|
@staticmethod
|
|
340
925
|
def right_wrong(s: str) -> Optional[bool]:
|
|
341
926
|
"""Parse 'right' or 'wrong' from the last 6 characters."""
|
|
@@ -372,6 +957,8 @@ class OutputParsers:
|
|
|
372
957
|
return False
|
|
373
958
|
return None
|
|
374
959
|
|
|
960
|
+
# ========== SCALAR PARSERS ==========
|
|
961
|
+
|
|
375
962
|
@staticmethod
|
|
376
963
|
def numeric_score(s: str) -> Optional[float]:
|
|
377
964
|
"""Extract first numeric value from the response."""
|
|
@@ -382,20 +969,188 @@ class OutputParsers:
|
|
|
382
969
|
return float(match.group())
|
|
383
970
|
return None
|
|
384
971
|
|
|
972
|
+
@staticmethod
|
|
973
|
+
def score_0_to_10(s: str) -> Optional[float]:
|
|
974
|
+
"""Extract a score from 0-10, clamping to valid range."""
|
|
975
|
+
if not s:
|
|
976
|
+
return None
|
|
977
|
+
match = re.search(r'[-+]?\d*\.?\d+', s.strip())
|
|
978
|
+
if match:
|
|
979
|
+
return max(0.0, min(10.0, float(match.group())))
|
|
980
|
+
return None
|
|
981
|
+
|
|
982
|
+
@staticmethod
|
|
983
|
+
def score_0_to_100(s: str) -> Optional[float]:
|
|
984
|
+
"""Extract a score from 0-100, clamping to valid range."""
|
|
985
|
+
if not s:
|
|
986
|
+
return None
|
|
987
|
+
match = re.search(r'[-+]?\d*\.?\d+', s.strip())
|
|
988
|
+
if match:
|
|
989
|
+
return max(0.0, min(100.0, float(match.group())))
|
|
990
|
+
return None
|
|
991
|
+
|
|
992
|
+
@staticmethod
|
|
993
|
+
def score_1_to_5(s: str) -> Optional[float]:
|
|
994
|
+
"""Extract a score from 1-5, clamping to valid range."""
|
|
995
|
+
if not s:
|
|
996
|
+
return None
|
|
997
|
+
match = re.search(r'[-+]?\d*\.?\d+', s.strip())
|
|
998
|
+
if match:
|
|
999
|
+
return max(1.0, min(5.0, float(match.group())))
|
|
1000
|
+
return None
|
|
1001
|
+
|
|
1002
|
+
# ========== MAP PARSERS ==========
|
|
1003
|
+
|
|
385
1004
|
@staticmethod
|
|
386
1005
|
def json_extract(key: str) -> Callable[[str], Any]:
|
|
387
1006
|
"""Create a parser that extracts a specific key from JSON output."""
|
|
388
|
-
import json
|
|
389
1007
|
def parser(s: str) -> Any:
|
|
390
1008
|
if not s:
|
|
391
1009
|
return None
|
|
392
1010
|
try:
|
|
393
|
-
|
|
1011
|
+
s = s.strip()
|
|
1012
|
+
if "```json" in s:
|
|
1013
|
+
start = s.find("```json") + 7
|
|
1014
|
+
end = s.find("```", start)
|
|
1015
|
+
s = s[start:end].strip()
|
|
1016
|
+
elif "```" in s:
|
|
1017
|
+
start = s.find("```") + 3
|
|
1018
|
+
end = s.find("```", start)
|
|
1019
|
+
s = s[start:end].strip()
|
|
1020
|
+
data = json.loads(s)
|
|
394
1021
|
return data.get(key)
|
|
395
1022
|
except (json.JSONDecodeError, AttributeError):
|
|
396
1023
|
return None
|
|
397
1024
|
return parser
|
|
398
1025
|
|
|
1026
|
+
@staticmethod
|
|
1027
|
+
def json_map(keys: Optional[List[str]] = None) -> Callable[[str], Optional[Dict[str, float]]]:
|
|
1028
|
+
"""
|
|
1029
|
+
Create a parser that extracts multiple keys from JSON output as a map.
|
|
1030
|
+
Returns Dict[str, float] or None.
|
|
1031
|
+
|
|
1032
|
+
More robust version that handles:
|
|
1033
|
+
- Code blocks with ```json or ```
|
|
1034
|
+
- Extra text before/after JSON
|
|
1035
|
+
- Various JSON formatting issues
|
|
1036
|
+
"""
|
|
1037
|
+
def parser(s: str) -> Optional[Dict[str, float]]:
|
|
1038
|
+
if not s:
|
|
1039
|
+
return None
|
|
1040
|
+
try:
|
|
1041
|
+
s = s.strip()
|
|
1042
|
+
|
|
1043
|
+
# Handle markdown code blocks
|
|
1044
|
+
if "```json" in s.lower():
|
|
1045
|
+
start = s.lower().find("```json") + 7
|
|
1046
|
+
end = s.find("```", start)
|
|
1047
|
+
if end > start:
|
|
1048
|
+
s = s[start:end].strip()
|
|
1049
|
+
elif "```" in s:
|
|
1050
|
+
start = s.find("```") + 3
|
|
1051
|
+
end = s.find("```", start)
|
|
1052
|
+
if end > start:
|
|
1053
|
+
s = s[start:end].strip()
|
|
1054
|
+
|
|
1055
|
+
# Try to find JSON object if there's extra text
|
|
1056
|
+
# Look for first { and last }
|
|
1057
|
+
if '{' in s and '}' in s:
|
|
1058
|
+
start_brace = s.find('{')
|
|
1059
|
+
end_brace = s.rfind('}')
|
|
1060
|
+
if start_brace < end_brace:
|
|
1061
|
+
s = s[start_brace:end_brace + 1]
|
|
1062
|
+
|
|
1063
|
+
data = json.loads(s)
|
|
1064
|
+
if not isinstance(data, dict):
|
|
1065
|
+
return None
|
|
1066
|
+
|
|
1067
|
+
result = {}
|
|
1068
|
+
target_keys = keys if keys else list(data.keys())
|
|
1069
|
+
|
|
1070
|
+
for key in target_keys:
|
|
1071
|
+
if key in data:
|
|
1072
|
+
val = data[key]
|
|
1073
|
+
if isinstance(val, (int, float)):
|
|
1074
|
+
result[key] = float(val)
|
|
1075
|
+
elif isinstance(val, str):
|
|
1076
|
+
try:
|
|
1077
|
+
result[key] = float(val)
|
|
1078
|
+
except ValueError:
|
|
1079
|
+
pass
|
|
1080
|
+
|
|
1081
|
+
return result if result else None
|
|
1082
|
+
except (json.JSONDecodeError, AttributeError, ValueError):
|
|
1083
|
+
return None
|
|
1084
|
+
return parser
|
|
1085
|
+
|
|
1086
|
+
@staticmethod
|
|
1087
|
+
def multi_score_pattern(pattern: str = r'(\w+):\s*([\d.]+)') -> Callable[[str], Optional[Dict[str, float]]]:
|
|
1088
|
+
"""
|
|
1089
|
+
Create a parser that extracts key-value pairs using regex pattern.
|
|
1090
|
+
Default pattern matches "key: value" format.
|
|
1091
|
+
"""
|
|
1092
|
+
def parser(s: str) -> Optional[Dict[str, float]]:
|
|
1093
|
+
if not s:
|
|
1094
|
+
return None
|
|
1095
|
+
matches = re.findall(pattern, s)
|
|
1096
|
+
if not matches:
|
|
1097
|
+
return None
|
|
1098
|
+
result = {}
|
|
1099
|
+
for key, val in matches:
|
|
1100
|
+
try:
|
|
1101
|
+
result[key.lower()] = float(val)
|
|
1102
|
+
except ValueError:
|
|
1103
|
+
pass
|
|
1104
|
+
return result if result else None
|
|
1105
|
+
return parser
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def _infer_return_type(value: Any) -> Optional[ReturnType]:
|
|
1109
|
+
"""Infer the return type from a parsed value."""
|
|
1110
|
+
if value is None:
|
|
1111
|
+
return None
|
|
1112
|
+
if isinstance(value, bool):
|
|
1113
|
+
return ReturnType.BOOLEAN
|
|
1114
|
+
if isinstance(value, (int, float)):
|
|
1115
|
+
return ReturnType.SCALAR
|
|
1116
|
+
if isinstance(value, dict) and all(isinstance(v, (int, float)) for v in value.values()):
|
|
1117
|
+
return ReturnType.MAP
|
|
1118
|
+
return None
|
|
1119
|
+
|
|
1120
|
+
|
|
1121
|
+
def _validate_consistent_types(votes: List[Dict[str, Any]]) -> ReturnType:
|
|
1122
|
+
"""Validate that all votes have consistent return types."""
|
|
1123
|
+
return_types = set()
|
|
1124
|
+
for vote in votes:
|
|
1125
|
+
result = vote.get("result")
|
|
1126
|
+
if result is not None:
|
|
1127
|
+
rt = _infer_return_type(result)
|
|
1128
|
+
if rt is not None:
|
|
1129
|
+
return_types.add(rt)
|
|
1130
|
+
|
|
1131
|
+
if len(return_types) == 0:
|
|
1132
|
+
raise ValueError("All parsers returned None - cannot determine return type")
|
|
1133
|
+
if len(return_types) > 1:
|
|
1134
|
+
raise ValueError(
|
|
1135
|
+
f"Mixed return types detected: {[rt.value for rt in return_types]}. "
|
|
1136
|
+
"All judges must return the same type (boolean, scalar, or map)."
|
|
1137
|
+
)
|
|
1138
|
+
return return_types.pop()
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
def _validate_consistent_map_keys(votes: List[Dict[str, Any]]) -> set:
|
|
1142
|
+
"""Validate that all map votes have the same keys."""
|
|
1143
|
+
all_keys = None
|
|
1144
|
+
for vote in votes:
|
|
1145
|
+
result = vote.get("result")
|
|
1146
|
+
if isinstance(result, dict):
|
|
1147
|
+
keys = set(result.keys())
|
|
1148
|
+
if all_keys is None:
|
|
1149
|
+
all_keys = keys
|
|
1150
|
+
elif keys != all_keys:
|
|
1151
|
+
raise ValueError(f"Inconsistent map keys across voters. Expected {all_keys}, got {keys}")
|
|
1152
|
+
return all_keys or set()
|
|
1153
|
+
|
|
399
1154
|
|
|
400
1155
|
class LLMAsAJudge:
|
|
401
1156
|
BASE_TEMPLATE = """\
|
|
@@ -439,123 +1194,13 @@ Return a single numeric score from 0-10 indicating how well the model output mat
|
|
|
439
1194
|
- 1-3 = poor match
|
|
440
1195
|
- 0 = completely wrong
|
|
441
1196
|
Output only the number. No explanation. No extra text.""",
|
|
1197
|
+
'multi_score': """\
|
|
1198
|
+
Evaluate the model output against the ground truth on multiple dimensions.
|
|
1199
|
+
Return your scores in JSON format with numeric values (0-10 scale).
|
|
1200
|
+
Example: {"accuracy": 8, "helpfulness": 7, "relevance": 9}
|
|
1201
|
+
Output only valid JSON. No explanation. No extra text.""",
|
|
442
1202
|
}
|
|
443
1203
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
# def __init__(
|
|
448
|
-
# self,
|
|
449
|
-
# models: Optional[List[str]] = None,
|
|
450
|
-
# config: Optional[Dict[str, Dict[str, Any]]] = None, # one dict for providers and models
|
|
451
|
-
# base_headers: Optional[Dict[str, str]] = None,
|
|
452
|
-
# wandb_project: Optional[str] = None,
|
|
453
|
-
# custom_template: Optional[str] = None,
|
|
454
|
-
# use_fully_custom_prompt: bool = False,
|
|
455
|
-
# notes: Optional[str] = None,
|
|
456
|
-
# output_parser: Optional[str] = 'right/wrong',
|
|
457
|
-
# fallback_comparison: bool = True,
|
|
458
|
-
# default_temperature: float = 0.0,
|
|
459
|
-
# verbose: bool = False,
|
|
460
|
-
# num_retries: int = 2, # per-call retries before giving up on that model
|
|
461
|
-
# backoff_base: float = 0.5, # seconds
|
|
462
|
-
# backoff_max: float = 4.0, # seconds
|
|
463
|
-
# custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
|
|
464
|
-
# mode: str = "majority", # "single", "majority", "all"
|
|
465
|
-
# ):
|
|
466
|
-
# """
|
|
467
|
-
# config keys can be a provider name ("wandb", "openai", "anthropic")
|
|
468
|
-
# or a full model name ("openai/gpt-4o-mini", "wandb/deepseek-ai/DeepSeek-V3.1").
|
|
469
|
-
|
|
470
|
-
# Values can include:
|
|
471
|
-
# api_base: Optional[str]
|
|
472
|
-
# headers: Dict[str, str]
|
|
473
|
-
# temperature: float
|
|
474
|
-
|
|
475
|
-
# Precedence:
|
|
476
|
-
# base_headers < provider config < model config
|
|
477
|
-
|
|
478
|
-
# Args:
|
|
479
|
-
# models: List of litellm model strings (e.g., ["openai/gpt-4", "anthropic/claude-3"])
|
|
480
|
-
# custom_template: Template with placeholders for input/output/ground_truth
|
|
481
|
-
# use_fully_custom_prompt: If True, pass complete prompt to judge(prompt=...).
|
|
482
|
-
# When True, input/output/ground_truth must NOT be passed to judge()
|
|
483
|
-
# output_parser: Parser name ('right/wrong', 'yes/no', 'pass/fail', 'numeric')
|
|
484
|
-
# or custom function with signature (str) -> Any
|
|
485
|
-
# fallback_comparison: If True and parser returns None, falls back to string comparison
|
|
486
|
-
# custom_generation_fns: List of custom inference functions with signature fn(prompt: str) -> str
|
|
487
|
-
# These will be used in addition to litellm models for voting.
|
|
488
|
-
# mode: Voting mode - "majority" (default), "single" (first judge only), or "all" (unanimous)
|
|
489
|
-
# """
|
|
490
|
-
# self.models = models or []
|
|
491
|
-
# self.custom_generation_fns = custom_generation_fns or []
|
|
492
|
-
|
|
493
|
-
# # Validate that at least one judge is provided
|
|
494
|
-
# if not self.models and not self.custom_generation_fns:
|
|
495
|
-
# raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
|
|
496
|
-
|
|
497
|
-
# # Validate mode
|
|
498
|
-
# if mode not in ("majority", "single", "all"):
|
|
499
|
-
# raise ValueError("mode must be 'majority', 'single', or 'all'")
|
|
500
|
-
|
|
501
|
-
# self.config = config or {}
|
|
502
|
-
# self.base_headers = dict(base_headers or {})
|
|
503
|
-
# self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
|
|
504
|
-
# self.notes = notes or ""
|
|
505
|
-
# self.use_fully_custom_prompt = use_fully_custom_prompt
|
|
506
|
-
# self.mode = mode
|
|
507
|
-
|
|
508
|
-
# # Resolve output parser
|
|
509
|
-
# parser_name = None
|
|
510
|
-
# if isinstance(output_parser, str):
|
|
511
|
-
# parser_map = {
|
|
512
|
-
# 'right/wrong': OutputParsers.right_wrong,
|
|
513
|
-
# 'pass/fail': OutputParsers.pass_fail,
|
|
514
|
-
# 'yes/no': OutputParsers.yes_no,
|
|
515
|
-
# 'numeric': OutputParsers.numeric_score,
|
|
516
|
-
# }
|
|
517
|
-
# if output_parser not in parser_map:
|
|
518
|
-
# raise ValueError(f"Unknown parser '{output_parser}'. Available: {list(parser_map.keys())}")
|
|
519
|
-
# self.output_parser = parser_map[output_parser]
|
|
520
|
-
# parser_name = output_parser
|
|
521
|
-
# else:
|
|
522
|
-
# self.output_parser = output_parser
|
|
523
|
-
|
|
524
|
-
# # Set template based on mode
|
|
525
|
-
# if use_fully_custom_prompt:
|
|
526
|
-
# self.template = None # No template in fully custom mode
|
|
527
|
-
# elif custom_template:
|
|
528
|
-
# self.template = custom_template
|
|
529
|
-
# elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
|
|
530
|
-
# self.template = self.BASE_TEMPLATE.format(
|
|
531
|
-
# instruction=self.PARSER_INSTRUCTIONS[parser_name],
|
|
532
|
-
# notes_section="{notes_section}",
|
|
533
|
-
# input_block="{input_block}",
|
|
534
|
-
# model_output="{model_output}",
|
|
535
|
-
# ground_truth="{ground_truth}",
|
|
536
|
-
# )
|
|
537
|
-
# else:
|
|
538
|
-
# # Default to right/wrong for custom parsers
|
|
539
|
-
# self.template = self.BASE_TEMPLATE.format(
|
|
540
|
-
# instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
|
|
541
|
-
# notes_section="{notes_section}",
|
|
542
|
-
# input_block="{input_block}",
|
|
543
|
-
# model_output="{model_output}",
|
|
544
|
-
# ground_truth="{ground_truth}",
|
|
545
|
-
# )
|
|
546
|
-
|
|
547
|
-
# self.fallback_comparison = fallback_comparison
|
|
548
|
-
# self.default_temperature = float(default_temperature)
|
|
549
|
-
# self.verbose = verbose
|
|
550
|
-
# self.num_retries = int(num_retries)
|
|
551
|
-
# self.backoff_base = float(backoff_base)
|
|
552
|
-
# self.backoff_max = float(backoff_max)
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
1204
|
def __init__(
|
|
560
1205
|
self,
|
|
561
1206
|
models: Optional[List[str]] = None,
|
|
@@ -565,7 +1210,7 @@ Output only the number. No explanation. No extra text.""",
|
|
|
565
1210
|
custom_template: Optional[str] = None,
|
|
566
1211
|
use_fully_custom_prompt: bool = False,
|
|
567
1212
|
notes: Optional[str] = None,
|
|
568
|
-
output_parser: Optional[str] = 'right/wrong',
|
|
1213
|
+
output_parser: Optional[Union[str, Callable]] = 'right/wrong',
|
|
569
1214
|
fallback_comparison: bool = True,
|
|
570
1215
|
default_temperature: float = 0.0,
|
|
571
1216
|
verbose: bool = False,
|
|
@@ -573,9 +1218,9 @@ Output only the number. No explanation. No extra text.""",
|
|
|
573
1218
|
backoff_base: float = 0.5,
|
|
574
1219
|
backoff_max: float = 4.0,
|
|
575
1220
|
custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
|
|
576
|
-
mode: str =
|
|
1221
|
+
mode: Optional[str] = None,
|
|
577
1222
|
litellm_cache_dir: Optional[str] = None,
|
|
578
|
-
|
|
1223
|
+
return_type: Optional[str] = None,
|
|
579
1224
|
):
|
|
580
1225
|
self.models = models or []
|
|
581
1226
|
self.custom_generation_fns = custom_generation_fns or []
|
|
@@ -583,15 +1228,11 @@ Output only the number. No explanation. No extra text.""",
|
|
|
583
1228
|
if not self.models and not self.custom_generation_fns:
|
|
584
1229
|
raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
|
|
585
1230
|
|
|
586
|
-
if mode not in ("majority", "single", "all"):
|
|
587
|
-
raise ValueError("mode must be 'majority', 'single', or 'all'")
|
|
588
|
-
|
|
589
1231
|
self.config = config or {}
|
|
590
1232
|
self.base_headers = dict(base_headers or {})
|
|
591
1233
|
self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
|
|
592
1234
|
self.notes = notes or ""
|
|
593
1235
|
self.use_fully_custom_prompt = use_fully_custom_prompt
|
|
594
|
-
self.mode = mode
|
|
595
1236
|
self.fallback_comparison = fallback_comparison
|
|
596
1237
|
self.default_temperature = float(default_temperature)
|
|
597
1238
|
self.verbose = verbose
|
|
@@ -599,6 +1240,7 @@ Output only the number. No explanation. No extra text.""",
|
|
|
599
1240
|
self.backoff_base = float(backoff_base)
|
|
600
1241
|
self.backoff_max = float(backoff_max)
|
|
601
1242
|
|
|
1243
|
+
# Resolve output parser
|
|
602
1244
|
parser_name = None
|
|
603
1245
|
if isinstance(output_parser, str):
|
|
604
1246
|
parser_map = {
|
|
@@ -606,6 +1248,7 @@ Output only the number. No explanation. No extra text.""",
|
|
|
606
1248
|
'pass/fail': OutputParsers.pass_fail,
|
|
607
1249
|
'yes/no': OutputParsers.yes_no,
|
|
608
1250
|
'numeric': OutputParsers.numeric_score,
|
|
1251
|
+
'multi_score': OutputParsers.json_map(),
|
|
609
1252
|
}
|
|
610
1253
|
if output_parser not in parser_map:
|
|
611
1254
|
raise ValueError(f"Unknown parser '{output_parser}'")
|
|
@@ -614,6 +1257,42 @@ Output only the number. No explanation. No extra text.""",
|
|
|
614
1257
|
else:
|
|
615
1258
|
self.output_parser = output_parser
|
|
616
1259
|
|
|
1260
|
+
# Determine expected return type
|
|
1261
|
+
self._explicit_return_type: Optional[ReturnType] = None
|
|
1262
|
+
if return_type is not None:
|
|
1263
|
+
self._explicit_return_type = ReturnType(return_type)
|
|
1264
|
+
elif parser_name:
|
|
1265
|
+
if parser_name in ('right/wrong', 'pass/fail', 'yes/no'):
|
|
1266
|
+
self._explicit_return_type = ReturnType.BOOLEAN
|
|
1267
|
+
elif parser_name == 'numeric':
|
|
1268
|
+
self._explicit_return_type = ReturnType.SCALAR
|
|
1269
|
+
elif parser_name == 'multi_score':
|
|
1270
|
+
self._explicit_return_type = ReturnType.MAP
|
|
1271
|
+
|
|
1272
|
+
# Resolve aggregation mode
|
|
1273
|
+
if mode is not None:
|
|
1274
|
+
if mode not in MODE_STR_MAP:
|
|
1275
|
+
raise ValueError(f"Unknown mode '{mode}'. Available: {list(MODE_STR_MAP.keys())}")
|
|
1276
|
+
self._mode = MODE_STR_MAP[mode]
|
|
1277
|
+
else:
|
|
1278
|
+
if self._explicit_return_type:
|
|
1279
|
+
self._mode = DEFAULT_MODES[self._explicit_return_type]
|
|
1280
|
+
else:
|
|
1281
|
+
self._mode = AggregationMode.MAJORITY
|
|
1282
|
+
|
|
1283
|
+
# Validate mode against return type if known
|
|
1284
|
+
if self._explicit_return_type:
|
|
1285
|
+
valid_modes = VALID_MODES[self._explicit_return_type]
|
|
1286
|
+
if self._mode not in valid_modes:
|
|
1287
|
+
raise ValueError(
|
|
1288
|
+
f"Mode '{self._mode.value}' not valid for return type '{self._explicit_return_type.value}'. "
|
|
1289
|
+
f"Valid: {[m.value for m in valid_modes]}"
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
# For backward compat, expose mode as string
|
|
1293
|
+
self.mode = self._mode.value
|
|
1294
|
+
|
|
1295
|
+
# Set template
|
|
617
1296
|
if use_fully_custom_prompt:
|
|
618
1297
|
self.template = None
|
|
619
1298
|
elif custom_template:
|
|
@@ -635,43 +1314,23 @@ Output only the number. No explanation. No extra text.""",
|
|
|
635
1314
|
ground_truth="{ground_truth}",
|
|
636
1315
|
)
|
|
637
1316
|
|
|
638
|
-
#
|
|
1317
|
+
# Optional local cache setup
|
|
639
1318
|
self.cache_enabled = litellm_cache_dir is not None
|
|
640
1319
|
if self.cache_enabled:
|
|
641
|
-
|
|
642
|
-
size_limit_bytes = None if cache_size_gb is None else int(cache_size_gb * 1024 * 1024 * 1024)
|
|
643
|
-
cache_backend = UnlimitedDiskCache(litellm_cache_dir, size_limit=size_limit_bytes)
|
|
644
|
-
litellm.cache = Cache(disk_cache_dir=litellm_cache_dir)
|
|
645
|
-
litellm.cache.cache = cache_backend
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
1320
|
+
litellm.cache = Cache(type="disk", disk_cache_dir=litellm_cache_dir)
|
|
654
1321
|
|
|
655
1322
|
def _build_prompt(self, input: Any, model_output: Any, ground_truth: Any) -> str:
|
|
656
1323
|
notes_section = f"notes:\n{self.notes}\n" if self.notes else ""
|
|
657
1324
|
input_text = str(input) if input not in (None, "") else "[omitted input for brevity]"
|
|
658
|
-
return self.template.format(
|
|
659
|
-
notes_section=notes_section,
|
|
660
|
-
input_block=input_text,
|
|
661
|
-
model_output=str(model_output),
|
|
662
|
-
ground_truth=str(ground_truth),
|
|
663
|
-
)
|
|
664
1325
|
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
return False
|
|
674
|
-
return None
|
|
1326
|
+
# Use string replacement instead of .format() to avoid issues with { } in template
|
|
1327
|
+
# (e.g., JSON examples in multi_score instructions)
|
|
1328
|
+
prompt = self.template
|
|
1329
|
+
prompt = prompt.replace("{notes_section}", notes_section)
|
|
1330
|
+
prompt = prompt.replace("{input_block}", input_text)
|
|
1331
|
+
prompt = prompt.replace("{model_output}", str(model_output))
|
|
1332
|
+
prompt = prompt.replace("{ground_truth}", str(ground_truth))
|
|
1333
|
+
return prompt
|
|
675
1334
|
|
|
676
1335
|
def _resolve_per_model(self, model: str) -> Tuple[Optional[str], Dict[str, str], float]:
|
|
677
1336
|
provider = model.split("/", 1)[0] if "/" in model else model
|
|
@@ -680,7 +1339,6 @@ Output only the number. No explanation. No extra text.""",
|
|
|
680
1339
|
headers: Dict[str, str] = dict(self.base_headers)
|
|
681
1340
|
temperature: float = self.default_temperature
|
|
682
1341
|
|
|
683
|
-
# provider-level defaults
|
|
684
1342
|
pc = self.config.get(provider, {})
|
|
685
1343
|
if pc.get("api_base") is not None:
|
|
686
1344
|
api_base = pc["api_base"]
|
|
@@ -688,7 +1346,6 @@ Output only the number. No explanation. No extra text.""",
|
|
|
688
1346
|
if "temperature" in pc:
|
|
689
1347
|
temperature = float(pc["temperature"])
|
|
690
1348
|
|
|
691
|
-
# model-level overrides
|
|
692
1349
|
mc = self.config.get(model, {})
|
|
693
1350
|
if mc.get("api_base") is not None:
|
|
694
1351
|
api_base = mc["api_base"]
|
|
@@ -696,7 +1353,6 @@ Output only the number. No explanation. No extra text.""",
|
|
|
696
1353
|
if "temperature" in mc:
|
|
697
1354
|
temperature = float(mc["temperature"])
|
|
698
1355
|
|
|
699
|
-
# wandb defaults
|
|
700
1356
|
if provider == "wandb":
|
|
701
1357
|
if api_base is None:
|
|
702
1358
|
api_base = "https://api.inference.wandb.ai/v1"
|
|
@@ -718,15 +1374,6 @@ Output only the number. No explanation. No extra text.""",
|
|
|
718
1374
|
last_err = None
|
|
719
1375
|
for i in range(attempts):
|
|
720
1376
|
try:
|
|
721
|
-
# resp = completion(
|
|
722
|
-
# model=model,
|
|
723
|
-
# api_base=api_base, # None uses provider default
|
|
724
|
-
# messages=[{"role": "user", "content": prompt}],
|
|
725
|
-
# temperature=temperature,
|
|
726
|
-
# max_tokens=max_tokens,
|
|
727
|
-
# extra_headers=headers,
|
|
728
|
-
# )
|
|
729
|
-
|
|
730
1377
|
resp = completion(
|
|
731
1378
|
model=model,
|
|
732
1379
|
api_base=api_base,
|
|
@@ -735,20 +1382,20 @@ Output only the number. No explanation. No extra text.""",
|
|
|
735
1382
|
max_tokens=max_tokens,
|
|
736
1383
|
extra_headers=headers,
|
|
737
1384
|
caching=self.cache_enabled
|
|
738
|
-
)
|
|
1385
|
+
)
|
|
739
1386
|
return (resp.choices[0].message.content or "").strip()
|
|
740
1387
|
except Exception as e:
|
|
741
1388
|
last_err = e
|
|
742
1389
|
if i == attempts - 1:
|
|
743
1390
|
break
|
|
744
1391
|
sleep_s = min(self.backoff_max, self.backoff_base * (2 ** i))
|
|
745
|
-
jitter = sleep_s * (0.1 * (2 * random.random() - 1.0))
|
|
1392
|
+
jitter = sleep_s * (0.1 * (2 * random.random() - 1.0))
|
|
746
1393
|
if self.verbose:
|
|
747
1394
|
print(f"[retry {i+1}/{attempts-1}] {model} error: {e} — sleeping {max(0.0, sleep_s + jitter):.2f}s", flush=True)
|
|
748
1395
|
time.sleep(max(0.0, sleep_s + jitter))
|
|
749
|
-
raise last_err
|
|
1396
|
+
raise last_err
|
|
750
1397
|
|
|
751
|
-
def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any):
|
|
1398
|
+
def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any) -> Any:
|
|
752
1399
|
api_base, headers, temperature = self._resolve_per_model(model)
|
|
753
1400
|
|
|
754
1401
|
content = self._attempt_completion(
|
|
@@ -760,15 +1407,80 @@ Output only the number. No explanation. No extra text.""",
|
|
|
760
1407
|
max_tokens=max_tokens,
|
|
761
1408
|
)
|
|
762
1409
|
|
|
763
|
-
# Use the instance parser
|
|
764
1410
|
parsed = self.output_parser(content)
|
|
765
1411
|
|
|
766
|
-
# If parser returns None and fallback is enabled, do string comparison
|
|
767
1412
|
if parsed is None and self.fallback_comparison:
|
|
768
1413
|
return str(model_output).strip() == str(ground_truth).strip()
|
|
769
1414
|
|
|
770
1415
|
return parsed
|
|
771
1416
|
|
|
1417
|
+
def _aggregate_boolean(self, votes: List[Dict[str, Any]]) -> bool:
|
|
1418
|
+
results = [v["result"] for v in votes if v["result"] is not None]
|
|
1419
|
+
if not results:
|
|
1420
|
+
raise ValueError("No valid votes to aggregate")
|
|
1421
|
+
|
|
1422
|
+
if self._mode == AggregationMode.SINGLE:
|
|
1423
|
+
return bool(results[0])
|
|
1424
|
+
elif self._mode == AggregationMode.MAJORITY:
|
|
1425
|
+
return sum(1 for r in results if r) >= len(results) / 2
|
|
1426
|
+
elif self._mode == AggregationMode.ALL:
|
|
1427
|
+
return all(results)
|
|
1428
|
+
else:
|
|
1429
|
+
raise ValueError(f"Invalid mode for boolean: {self._mode}")
|
|
1430
|
+
|
|
1431
|
+
def _aggregate_scalar(self, votes: List[Dict[str, Any]]) -> float:
|
|
1432
|
+
results = [float(v["result"]) for v in votes if v["result"] is not None]
|
|
1433
|
+
if not results:
|
|
1434
|
+
raise ValueError("No valid votes to aggregate")
|
|
1435
|
+
|
|
1436
|
+
if self._mode == AggregationMode.SINGLE:
|
|
1437
|
+
return results[0]
|
|
1438
|
+
elif self._mode == AggregationMode.AVERAGE:
|
|
1439
|
+
return sum(results) / len(results)
|
|
1440
|
+
elif self._mode == AggregationMode.MIN:
|
|
1441
|
+
return min(results)
|
|
1442
|
+
elif self._mode == AggregationMode.MAX:
|
|
1443
|
+
return max(results)
|
|
1444
|
+
elif self._mode == AggregationMode.MEDIAN:
|
|
1445
|
+
s = sorted(results)
|
|
1446
|
+
n = len(s)
|
|
1447
|
+
mid = n // 2
|
|
1448
|
+
return (s[mid - 1] + s[mid]) / 2 if n % 2 == 0 else s[mid]
|
|
1449
|
+
else:
|
|
1450
|
+
raise ValueError(f"Invalid mode for scalar: {self._mode}")
|
|
1451
|
+
|
|
1452
|
+
def _aggregate_map(self, votes: List[Dict[str, Any]]) -> Dict[str, float]:
|
|
1453
|
+
valid = [v["result"] for v in votes if v["result"] is not None and isinstance(v["result"], dict)]
|
|
1454
|
+
if not valid:
|
|
1455
|
+
raise ValueError("No valid map votes to aggregate")
|
|
1456
|
+
|
|
1457
|
+
keys = set()
|
|
1458
|
+
for v in valid:
|
|
1459
|
+
keys.update(v.keys())
|
|
1460
|
+
|
|
1461
|
+
if self._mode == AggregationMode.SINGLE:
|
|
1462
|
+
return valid[0]
|
|
1463
|
+
|
|
1464
|
+
result = {}
|
|
1465
|
+
for key in keys:
|
|
1466
|
+
values = [v[key] for v in valid if key in v]
|
|
1467
|
+
if not values:
|
|
1468
|
+
continue
|
|
1469
|
+
|
|
1470
|
+
if self._mode == AggregationMode.AVERAGE:
|
|
1471
|
+
result[key] = sum(values) / len(values)
|
|
1472
|
+
elif self._mode == AggregationMode.MIN:
|
|
1473
|
+
result[key] = min(values)
|
|
1474
|
+
elif self._mode == AggregationMode.MAX:
|
|
1475
|
+
result[key] = max(values)
|
|
1476
|
+
elif self._mode == AggregationMode.MEDIAN:
|
|
1477
|
+
s = sorted(values)
|
|
1478
|
+
n = len(s)
|
|
1479
|
+
mid = n // 2
|
|
1480
|
+
result[key] = (s[mid - 1] + s[mid]) / 2 if n % 2 == 0 else s[mid]
|
|
1481
|
+
|
|
1482
|
+
return result
|
|
1483
|
+
|
|
772
1484
|
def judge(
|
|
773
1485
|
self,
|
|
774
1486
|
input: Any = None,
|
|
@@ -776,23 +1488,25 @@ Output only the number. No explanation. No extra text.""",
|
|
|
776
1488
|
ground_truth: Any = None,
|
|
777
1489
|
prompt: Optional[str] = None,
|
|
778
1490
|
max_tokens: int = 10000,
|
|
779
|
-
):
|
|
780
|
-
|
|
1491
|
+
) -> Dict[str, Any]:
|
|
1492
|
+
"""
|
|
1493
|
+
Run the judge evaluation.
|
|
1494
|
+
|
|
1495
|
+
Returns dict with:
|
|
1496
|
+
- 'correct': bool or None (bool for boolean type, None for scalar/map)
|
|
1497
|
+
- 'scores': float or Dict[str, float] or None (for scalar/map, None for boolean)
|
|
1498
|
+
- 'mode': aggregation mode string
|
|
1499
|
+
- 'votes': list of individual judge votes
|
|
1500
|
+
"""
|
|
781
1501
|
if self.use_fully_custom_prompt:
|
|
782
1502
|
if prompt is None:
|
|
783
|
-
raise ValueError(
|
|
784
|
-
"When use_fully_custom_prompt=True, you must pass prompt to judge()."
|
|
785
|
-
)
|
|
1503
|
+
raise ValueError("When use_fully_custom_prompt=True, you must pass prompt to judge().")
|
|
786
1504
|
if input is not None or model_output is not None or ground_truth is not None:
|
|
787
1505
|
raise ValueError(
|
|
788
|
-
"When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth
|
|
789
|
-
"Only pass the complete prompt."
|
|
1506
|
+
"When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth."
|
|
790
1507
|
)
|
|
791
1508
|
elif prompt is not None:
|
|
792
|
-
raise ValueError(
|
|
793
|
-
"prompt parameter can only be used when use_fully_custom_prompt=True. "
|
|
794
|
-
"Otherwise, use input/model_output/ground_truth."
|
|
795
|
-
)
|
|
1509
|
+
raise ValueError("prompt parameter can only be used when use_fully_custom_prompt=True.")
|
|
796
1510
|
else:
|
|
797
1511
|
prompt = self._build_prompt(input, model_output, ground_truth)
|
|
798
1512
|
|
|
@@ -800,18 +1514,27 @@ Output only the number. No explanation. No extra text.""",
|
|
|
800
1514
|
|
|
801
1515
|
# Vote with litellm models
|
|
802
1516
|
for m in self.models:
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
1517
|
+
try:
|
|
1518
|
+
res = self._ask_model(m, prompt, max_tokens, model_output, ground_truth)
|
|
1519
|
+
if self.verbose:
|
|
1520
|
+
print(f"Model {m} voted: {res}", flush=True)
|
|
1521
|
+
votes.append({"model": m, "result": res})
|
|
1522
|
+
except Exception as e:
|
|
1523
|
+
if self.verbose:
|
|
1524
|
+
print(f"Model {m} failed: {e}", flush=True)
|
|
1525
|
+
if self.fallback_comparison:
|
|
1526
|
+
res = str(model_output).strip() == str(ground_truth).strip()
|
|
1527
|
+
votes.append({"model": m, "result": res, "error": str(e)})
|
|
1528
|
+
else:
|
|
1529
|
+
raise
|
|
807
1530
|
|
|
808
1531
|
# Vote with custom generation functions
|
|
809
1532
|
for idx, custom_fn in enumerate(self.custom_generation_fns):
|
|
1533
|
+
fn_name = f"custom_fn_{idx}"
|
|
810
1534
|
try:
|
|
811
1535
|
content = custom_fn(prompt)
|
|
812
1536
|
parsed = self.output_parser(content)
|
|
813
1537
|
|
|
814
|
-
# If parser returns None and fallback is enabled, do string comparison
|
|
815
1538
|
if parsed is None and self.fallback_comparison:
|
|
816
1539
|
res = str(model_output).strip() == str(ground_truth).strip()
|
|
817
1540
|
else:
|
|
@@ -819,27 +1542,70 @@ Output only the number. No explanation. No extra text.""",
|
|
|
819
1542
|
|
|
820
1543
|
if self.verbose:
|
|
821
1544
|
print(f"Custom function {idx} voted: {res}", flush=True)
|
|
822
|
-
votes.append({"model":
|
|
1545
|
+
votes.append({"model": fn_name, "result": res})
|
|
823
1546
|
except Exception as e:
|
|
824
1547
|
if self.verbose:
|
|
825
1548
|
print(f"Custom function {idx} failed: {e}", flush=True)
|
|
826
|
-
# If custom function fails and fallback is enabled, do string comparison
|
|
827
1549
|
if self.fallback_comparison:
|
|
828
1550
|
res = str(model_output).strip() == str(ground_truth).strip()
|
|
829
|
-
votes.append({"model":
|
|
1551
|
+
votes.append({"model": fn_name, "result": res, "error": str(e)})
|
|
830
1552
|
else:
|
|
831
1553
|
raise
|
|
832
1554
|
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
false_votes = len(votes) - true_votes
|
|
838
|
-
final = True if true_votes >= false_votes else False
|
|
839
|
-
elif self.mode == "all":
|
|
840
|
-
final = all(v["correct"] for v in votes)
|
|
1555
|
+
# Determine return type
|
|
1556
|
+
return_type = self._explicit_return_type
|
|
1557
|
+
if return_type is None:
|
|
1558
|
+
return_type = _validate_consistent_types(votes)
|
|
841
1559
|
else:
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
1560
|
+
actual = _validate_consistent_types(votes)
|
|
1561
|
+
if actual != return_type:
|
|
1562
|
+
raise ValueError(f"Expected return type '{return_type.value}' but got '{actual.value}'")
|
|
1563
|
+
|
|
1564
|
+
# Auto-correct mode if needed
|
|
1565
|
+
if self._mode not in VALID_MODES[return_type]:
|
|
1566
|
+
self._mode = DEFAULT_MODES[return_type]
|
|
1567
|
+
self.mode = self._mode.value
|
|
1568
|
+
|
|
1569
|
+
# Validate map keys
|
|
1570
|
+
if return_type == ReturnType.MAP:
|
|
1571
|
+
_validate_consistent_map_keys(votes)
|
|
1572
|
+
|
|
1573
|
+
# Aggregate
|
|
1574
|
+
if return_type == ReturnType.BOOLEAN:
|
|
1575
|
+
final = self._aggregate_boolean(votes)
|
|
1576
|
+
elif return_type == ReturnType.SCALAR:
|
|
1577
|
+
final = self._aggregate_scalar(votes)
|
|
1578
|
+
elif return_type == ReturnType.MAP:
|
|
1579
|
+
final = self._aggregate_map(votes)
|
|
1580
|
+
else:
|
|
1581
|
+
raise ValueError(f"Unknown return type: {return_type}")
|
|
1582
|
+
|
|
1583
|
+
# Build backward-compatible response
|
|
1584
|
+
# Boolean: correct=bool, scores=None
|
|
1585
|
+
# Scalar: correct=score, scores=score (both fields for convenience)
|
|
1586
|
+
# Map: correct=None, scores=map
|
|
1587
|
+
if return_type == ReturnType.BOOLEAN:
|
|
1588
|
+
# Also put "correct" in each vote for backward compat
|
|
1589
|
+
for v in votes:
|
|
1590
|
+
v["correct"] = v["result"]
|
|
1591
|
+
return {
|
|
1592
|
+
"correct": final,
|
|
1593
|
+
"scores": None,
|
|
1594
|
+
"mode": self.mode,
|
|
1595
|
+
"votes": votes,
|
|
1596
|
+
}
|
|
1597
|
+
elif return_type == ReturnType.SCALAR:
|
|
1598
|
+
# For scalar, put score in both correct and scores for convenience
|
|
1599
|
+
return {
|
|
1600
|
+
"correct": final,
|
|
1601
|
+
"scores": final,
|
|
1602
|
+
"mode": self.mode,
|
|
1603
|
+
"votes": votes,
|
|
1604
|
+
}
|
|
1605
|
+
else: # MAP
|
|
1606
|
+
return {
|
|
1607
|
+
"correct": None,
|
|
1608
|
+
"scores": final,
|
|
1609
|
+
"mode": self.mode,
|
|
1610
|
+
"votes": votes,
|
|
1611
|
+
}
|