llmasajudge 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmasajudge/__init__.py +1079 -321
- {llmasajudge-0.1.13.dist-info → llmasajudge-0.1.14.dist-info}/METADATA +1 -1
- llmasajudge-0.1.14.dist-info/RECORD +5 -0
- llmasajudge-0.1.13.dist-info/RECORD +0 -5
- {llmasajudge-0.1.13.dist-info → llmasajudge-0.1.14.dist-info}/WHEEL +0 -0
- {llmasajudge-0.1.13.dist-info → llmasajudge-0.1.14.dist-info}/top_level.txt +0 -0
llmasajudge/__init__.py
CHANGED
|
@@ -201,141 +201,726 @@
|
|
|
201
201
|
|
|
202
202
|
|
|
203
203
|
|
|
204
|
-
import os
|
|
205
|
-
import time
|
|
206
|
-
import random
|
|
207
|
-
import re
|
|
208
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
209
|
-
import litellm
|
|
210
|
-
from litellm import completion
|
|
211
|
-
from litellm.caching.caching import Cache
|
|
204
|
+
# import os
|
|
205
|
+
# import time
|
|
206
|
+
# import random
|
|
207
|
+
# import re
|
|
208
|
+
# from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
209
|
+
# import litellm
|
|
210
|
+
# from litellm import completion
|
|
211
|
+
# from litellm.caching.caching import Cache
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# __all__ = ["LLMAsAJudge", "OutputParsers"]
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# class UnlimitedDiskCache:
|
|
218
|
+
# """
|
|
219
|
+
# Drop-in replacement backend with 'unlimited' size for LiteLLM cache.
|
|
220
|
+
|
|
221
|
+
# This wraps diskcache.Cache with a very large size limit (2^62 bytes ~ 4.6 exabytes)
|
|
222
|
+
# to effectively disable automatic cache eviction, allowing the cache to grow
|
|
223
|
+
# without size constraints.
|
|
224
|
+
# """
|
|
225
|
+
|
|
226
|
+
# def __init__(self, directory, size_limit=None):
|
|
227
|
+
# """
|
|
228
|
+
# Initialize unlimited disk cache.
|
|
229
|
+
|
|
230
|
+
# Args:
|
|
231
|
+
# directory: Path to cache directory
|
|
232
|
+
# size_limit: Optional size limit in bytes. If None, uses 2^62 bytes (~4.6 exabytes)
|
|
233
|
+
# """
|
|
234
|
+
# import diskcache as dc
|
|
235
|
+
|
|
236
|
+
# # Set to very large cap so culling never triggers (effectively unlimited)
|
|
237
|
+
# cap = size_limit if size_limit is not None else (1 << 62)
|
|
238
|
+
# self._dc = dc.Cache(directory, size_limit=cap)
|
|
239
|
+
|
|
240
|
+
# # Sync API used by LiteLLM
|
|
241
|
+
# def get_cache(self, key, **kwargs):
|
|
242
|
+
# """Get value from cache by key."""
|
|
243
|
+
# return self._dc.get(key)
|
|
244
|
+
|
|
245
|
+
# def set_cache(self, key, value, ttl=None, **kwargs):
|
|
246
|
+
# """Set value in cache with optional TTL."""
|
|
247
|
+
# expire = None if ttl is None else float(ttl)
|
|
248
|
+
# self._dc.set(key, value, expire=expire)
|
|
249
|
+
|
|
250
|
+
# # Async API used by LiteLLM
|
|
251
|
+
# async def async_get_cache(self, key, **kwargs):
|
|
252
|
+
# """Async get value from cache by key."""
|
|
253
|
+
# return self.get_cache(key, **kwargs)
|
|
254
|
+
|
|
255
|
+
# async def async_set_cache(self, key, value, ttl=None, **kwargs):
|
|
256
|
+
# """Async set value in cache with optional TTL."""
|
|
257
|
+
# return self.set_cache(key, value, ttl=ttl, **kwargs)
|
|
258
|
+
|
|
259
|
+
# async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
|
260
|
+
# """
|
|
261
|
+
# Async batch set multiple cache entries.
|
|
262
|
+
|
|
263
|
+
# Args:
|
|
264
|
+
# cache_list: List of (key, value) tuples
|
|
265
|
+
# ttl: Optional time-to-live in seconds
|
|
266
|
+
# """
|
|
267
|
+
# for k, v in cache_list:
|
|
268
|
+
# self.set_cache(k, v, ttl=ttl)
|
|
269
|
+
|
|
270
|
+
# async def batch_cache_write(self, key, value, ttl=None, **kwargs):
|
|
271
|
+
# """Async batch write (single entry)."""
|
|
272
|
+
# self.set_cache(key, value, ttl=ttl)
|
|
273
|
+
|
|
274
|
+
# async def ping(self):
|
|
275
|
+
# """Async ping check."""
|
|
276
|
+
# return True
|
|
277
|
+
|
|
278
|
+
# async def delete_cache_keys(self, keys):
|
|
279
|
+
# """
|
|
280
|
+
# Async delete multiple cache keys.
|
|
281
|
+
|
|
282
|
+
# Args:
|
|
283
|
+
# keys: List of keys to delete
|
|
284
|
+
# """
|
|
285
|
+
# for k in keys:
|
|
286
|
+
# try:
|
|
287
|
+
# del self._dc[k]
|
|
288
|
+
# except KeyError:
|
|
289
|
+
# pass
|
|
290
|
+
# return True
|
|
291
|
+
|
|
292
|
+
# async def disconnect(self):
|
|
293
|
+
# """Async disconnect and close cache."""
|
|
294
|
+
# self._dc.close()
|
|
295
|
+
|
|
296
|
+
# def get_stats(self):
|
|
297
|
+
# """
|
|
298
|
+
# Get cache statistics.
|
|
299
|
+
|
|
300
|
+
# Returns:
|
|
301
|
+
# dict with size_limit, current_size, item_count, and percent_full
|
|
302
|
+
# """
|
|
303
|
+
# size_limit = self._dc.size_limit
|
|
304
|
+
# volume = self._dc.volume() # Current size in bytes
|
|
305
|
+
# count = len(self._dc) # Number of items
|
|
306
|
+
|
|
307
|
+
# return {
|
|
308
|
+
# "size_limit": size_limit,
|
|
309
|
+
# "current_size": volume,
|
|
310
|
+
# "item_count": count,
|
|
311
|
+
# "percent_full": (volume / size_limit) * 100 if size_limit > 0 else 0.0,
|
|
312
|
+
# }
|
|
313
|
+
|
|
314
|
+
# def print_stats(self):
|
|
315
|
+
# """Print human-readable cache statistics."""
|
|
316
|
+
# stats = self.get_stats()
|
|
317
|
+
|
|
318
|
+
# def human_size(bytes_val):
|
|
319
|
+
# """Convert bytes to human readable format."""
|
|
320
|
+
# for unit in ["B", "KB", "MB", "GB", "TB", "PB", "EB"]:
|
|
321
|
+
# if bytes_val < 1024.0:
|
|
322
|
+
# return f"{bytes_val:.2f} {unit}"
|
|
323
|
+
# bytes_val /= 1024.0
|
|
324
|
+
# return f"{bytes_val:.2f} EB"
|
|
325
|
+
|
|
326
|
+
# print("=" * 60)
|
|
327
|
+
# print("CACHE STATISTICS")
|
|
328
|
+
# print("=" * 60)
|
|
329
|
+
# print(f" Size limit: {human_size(stats['size_limit'])} ({stats['size_limit']:,} bytes)")
|
|
330
|
+
# print(f" Current size: {human_size(stats['current_size'])} ({stats['current_size']:,} bytes)")
|
|
331
|
+
# print(f" Items cached: {stats['item_count']}")
|
|
332
|
+
# print(f" % full: {stats['percent_full']:.6f}%")
|
|
333
|
+
# print("=" * 60)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
# class OutputParsers:
|
|
337
|
+
# """Stock output parsers for common judge output formats."""
|
|
212
338
|
|
|
339
|
+
# @staticmethod
|
|
340
|
+
# def right_wrong(s: str) -> Optional[bool]:
|
|
341
|
+
# """Parse 'right' or 'wrong' from the last 6 characters."""
|
|
342
|
+
# if not s:
|
|
343
|
+
# return None
|
|
344
|
+
# tail = s.strip()[-6:].lower()
|
|
345
|
+
# if "right" in tail:
|
|
346
|
+
# return True
|
|
347
|
+
# if "wrong" in tail:
|
|
348
|
+
# return False
|
|
349
|
+
# return None
|
|
213
350
|
|
|
214
|
-
|
|
351
|
+
# @staticmethod
|
|
352
|
+
# def pass_fail(s: str) -> Optional[bool]:
|
|
353
|
+
# """Parse 'pass' or 'fail' from the response."""
|
|
354
|
+
# if not s:
|
|
355
|
+
# return None
|
|
356
|
+
# text = s.strip().lower()
|
|
357
|
+
# if "pass" in text:
|
|
358
|
+
# return True
|
|
359
|
+
# if "fail" in text:
|
|
360
|
+
# return False
|
|
361
|
+
# return None
|
|
215
362
|
|
|
363
|
+
# @staticmethod
|
|
364
|
+
# def yes_no(s: str) -> Optional[bool]:
|
|
365
|
+
# """Parse 'yes' or 'no' from the response."""
|
|
366
|
+
# if not s:
|
|
367
|
+
# return None
|
|
368
|
+
# text = s.strip().lower()
|
|
369
|
+
# if "yes" in text:
|
|
370
|
+
# return True
|
|
371
|
+
# if "no" in text:
|
|
372
|
+
# return False
|
|
373
|
+
# return None
|
|
216
374
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
375
|
+
# @staticmethod
|
|
376
|
+
# def numeric_score(s: str) -> Optional[float]:
|
|
377
|
+
# """Extract first numeric value from the response."""
|
|
378
|
+
# if not s:
|
|
379
|
+
# return None
|
|
380
|
+
# match = re.search(r'[-+]?\d*\.?\d+', s.strip())
|
|
381
|
+
# if match:
|
|
382
|
+
# return float(match.group())
|
|
383
|
+
# return None
|
|
220
384
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
385
|
+
# @staticmethod
|
|
386
|
+
# def json_extract(key: str) -> Callable[[str], Any]:
|
|
387
|
+
# """Create a parser that extracts a specific key from JSON output."""
|
|
388
|
+
# import json
|
|
389
|
+
# def parser(s: str) -> Any:
|
|
390
|
+
# if not s:
|
|
391
|
+
# return None
|
|
392
|
+
# try:
|
|
393
|
+
# data = json.loads(s.strip())
|
|
394
|
+
# return data.get(key)
|
|
395
|
+
# except (json.JSONDecodeError, AttributeError):
|
|
396
|
+
# return None
|
|
397
|
+
# return parser
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
# class LLMAsAJudge:
|
|
401
|
+
# BASE_TEMPLATE = """\
|
|
402
|
+
# You are a judge. Read input, model_output, and ground_truth.
|
|
403
|
+
# {instruction}
|
|
404
|
+
# ##################
|
|
405
|
+
# {notes_section}### input:
|
|
406
|
+
# {input_block}
|
|
407
|
+
# ##################
|
|
408
|
+
# model's output:
|
|
409
|
+
# {model_output}
|
|
410
|
+
# ##################
|
|
411
|
+
# ground_truth answer:
|
|
412
|
+
# {ground_truth}
|
|
413
|
+
# ##################
|
|
414
|
+
# """
|
|
225
415
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
416
|
+
# PARSER_INSTRUCTIONS = {
|
|
417
|
+
# 'right/wrong': """\
|
|
418
|
+
# Return exactly one word: right or wrong.
|
|
419
|
+
# Rules:
|
|
420
|
+
# - Treat extra words or punctuation as irrelevant if the same final value is present.
|
|
421
|
+
# - Output must be exactly right or wrong. No JSON. No quotes. No extra text.""",
|
|
422
|
+
# 'yes/no': """\
|
|
423
|
+
# Return exactly one word: yes or no.
|
|
424
|
+
# Answer yes if the model output matches the ground truth, no otherwise.
|
|
425
|
+
# Rules:
|
|
426
|
+
# - Treat extra words or punctuation as irrelevant if the same final value is present.
|
|
427
|
+
# - Output must be exactly yes or no. No JSON. No quotes. No extra text.""",
|
|
428
|
+
# 'pass/fail': """\
|
|
429
|
+
# Return exactly one word: pass or fail.
|
|
430
|
+
# Answer pass if the model output matches the ground truth, fail otherwise.
|
|
431
|
+
# Rules:
|
|
432
|
+
# - Treat extra words or punctuation as irrelevant if the same final value is present.
|
|
433
|
+
# - Output must be exactly pass or fail. No JSON. No quotes. No extra text.""",
|
|
434
|
+
# 'numeric': """\
|
|
435
|
+
# Return a single numeric score from 0-10 indicating how well the model output matches the ground truth.
|
|
436
|
+
# - 10 = perfect match
|
|
437
|
+
# - 7-9 = close match with minor differences
|
|
438
|
+
# - 4-6 = partial match
|
|
439
|
+
# - 1-3 = poor match
|
|
440
|
+
# - 0 = completely wrong
|
|
441
|
+
# Output only the number. No explanation. No extra text.""",
|
|
442
|
+
# }
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
# # def __init__(
|
|
448
|
+
# # self,
|
|
449
|
+
# # models: Optional[List[str]] = None,
|
|
450
|
+
# # config: Optional[Dict[str, Dict[str, Any]]] = None, # one dict for providers and models
|
|
451
|
+
# # base_headers: Optional[Dict[str, str]] = None,
|
|
452
|
+
# # wandb_project: Optional[str] = None,
|
|
453
|
+
# # custom_template: Optional[str] = None,
|
|
454
|
+
# # use_fully_custom_prompt: bool = False,
|
|
455
|
+
# # notes: Optional[str] = None,
|
|
456
|
+
# # output_parser: Optional[str] = 'right/wrong',
|
|
457
|
+
# # fallback_comparison: bool = True,
|
|
458
|
+
# # default_temperature: float = 0.0,
|
|
459
|
+
# # verbose: bool = False,
|
|
460
|
+
# # num_retries: int = 2, # per-call retries before giving up on that model
|
|
461
|
+
# # backoff_base: float = 0.5, # seconds
|
|
462
|
+
# # backoff_max: float = 4.0, # seconds
|
|
463
|
+
# # custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
|
|
464
|
+
# # mode: str = "majority", # "single", "majority", "all"
|
|
465
|
+
# # ):
|
|
466
|
+
# # """
|
|
467
|
+
# # config keys can be a provider name ("wandb", "openai", "anthropic")
|
|
468
|
+
# # or a full model name ("openai/gpt-4o-mini", "wandb/deepseek-ai/DeepSeek-V3.1").
|
|
469
|
+
|
|
470
|
+
# # Values can include:
|
|
471
|
+
# # api_base: Optional[str]
|
|
472
|
+
# # headers: Dict[str, str]
|
|
473
|
+
# # temperature: float
|
|
474
|
+
|
|
475
|
+
# # Precedence:
|
|
476
|
+
# # base_headers < provider config < model config
|
|
477
|
+
|
|
478
|
+
# # Args:
|
|
479
|
+
# # models: List of litellm model strings (e.g., ["openai/gpt-4", "anthropic/claude-3"])
|
|
480
|
+
# # custom_template: Template with placeholders for input/output/ground_truth
|
|
481
|
+
# # use_fully_custom_prompt: If True, pass complete prompt to judge(prompt=...).
|
|
482
|
+
# # When True, input/output/ground_truth must NOT be passed to judge()
|
|
483
|
+
# # output_parser: Parser name ('right/wrong', 'yes/no', 'pass/fail', 'numeric')
|
|
484
|
+
# # or custom function with signature (str) -> Any
|
|
485
|
+
# # fallback_comparison: If True and parser returns None, falls back to string comparison
|
|
486
|
+
# # custom_generation_fns: List of custom inference functions with signature fn(prompt: str) -> str
|
|
487
|
+
# # These will be used in addition to litellm models for voting.
|
|
488
|
+
# # mode: Voting mode - "majority" (default), "single" (first judge only), or "all" (unanimous)
|
|
489
|
+
# # """
|
|
490
|
+
# # self.models = models or []
|
|
491
|
+
# # self.custom_generation_fns = custom_generation_fns or []
|
|
492
|
+
|
|
493
|
+
# # # Validate that at least one judge is provided
|
|
494
|
+
# # if not self.models and not self.custom_generation_fns:
|
|
495
|
+
# # raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
|
|
496
|
+
|
|
497
|
+
# # # Validate mode
|
|
498
|
+
# # if mode not in ("majority", "single", "all"):
|
|
499
|
+
# # raise ValueError("mode must be 'majority', 'single', or 'all'")
|
|
500
|
+
|
|
501
|
+
# # self.config = config or {}
|
|
502
|
+
# # self.base_headers = dict(base_headers or {})
|
|
503
|
+
# # self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
|
|
504
|
+
# # self.notes = notes or ""
|
|
505
|
+
# # self.use_fully_custom_prompt = use_fully_custom_prompt
|
|
506
|
+
# # self.mode = mode
|
|
507
|
+
|
|
508
|
+
# # # Resolve output parser
|
|
509
|
+
# # parser_name = None
|
|
510
|
+
# # if isinstance(output_parser, str):
|
|
511
|
+
# # parser_map = {
|
|
512
|
+
# # 'right/wrong': OutputParsers.right_wrong,
|
|
513
|
+
# # 'pass/fail': OutputParsers.pass_fail,
|
|
514
|
+
# # 'yes/no': OutputParsers.yes_no,
|
|
515
|
+
# # 'numeric': OutputParsers.numeric_score,
|
|
516
|
+
# # }
|
|
517
|
+
# # if output_parser not in parser_map:
|
|
518
|
+
# # raise ValueError(f"Unknown parser '{output_parser}'. Available: {list(parser_map.keys())}")
|
|
519
|
+
# # self.output_parser = parser_map[output_parser]
|
|
520
|
+
# # parser_name = output_parser
|
|
521
|
+
# # else:
|
|
522
|
+
# # self.output_parser = output_parser
|
|
523
|
+
|
|
524
|
+
# # # Set template based on mode
|
|
525
|
+
# # if use_fully_custom_prompt:
|
|
526
|
+
# # self.template = None # No template in fully custom mode
|
|
527
|
+
# # elif custom_template:
|
|
528
|
+
# # self.template = custom_template
|
|
529
|
+
# # elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
|
|
530
|
+
# # self.template = self.BASE_TEMPLATE.format(
|
|
531
|
+
# # instruction=self.PARSER_INSTRUCTIONS[parser_name],
|
|
532
|
+
# # notes_section="{notes_section}",
|
|
533
|
+
# # input_block="{input_block}",
|
|
534
|
+
# # model_output="{model_output}",
|
|
535
|
+
# # ground_truth="{ground_truth}",
|
|
536
|
+
# # )
|
|
537
|
+
# # else:
|
|
538
|
+
# # # Default to right/wrong for custom parsers
|
|
539
|
+
# # self.template = self.BASE_TEMPLATE.format(
|
|
540
|
+
# # instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
|
|
541
|
+
# # notes_section="{notes_section}",
|
|
542
|
+
# # input_block="{input_block}",
|
|
543
|
+
# # model_output="{model_output}",
|
|
544
|
+
# # ground_truth="{ground_truth}",
|
|
545
|
+
# # )
|
|
546
|
+
|
|
547
|
+
# # self.fallback_comparison = fallback_comparison
|
|
548
|
+
# # self.default_temperature = float(default_temperature)
|
|
549
|
+
# # self.verbose = verbose
|
|
550
|
+
# # self.num_retries = int(num_retries)
|
|
551
|
+
# # self.backoff_base = float(backoff_base)
|
|
552
|
+
# # self.backoff_max = float(backoff_max)
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
# def __init__(
|
|
560
|
+
# self,
|
|
561
|
+
# models: Optional[List[str]] = None,
|
|
562
|
+
# config: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
563
|
+
# base_headers: Optional[Dict[str, str]] = None,
|
|
564
|
+
# wandb_project: Optional[str] = None,
|
|
565
|
+
# custom_template: Optional[str] = None,
|
|
566
|
+
# use_fully_custom_prompt: bool = False,
|
|
567
|
+
# notes: Optional[str] = None,
|
|
568
|
+
# output_parser: Optional[str] = 'right/wrong',
|
|
569
|
+
# fallback_comparison: bool = True,
|
|
570
|
+
# default_temperature: float = 0.0,
|
|
571
|
+
# verbose: bool = False,
|
|
572
|
+
# num_retries: int = 2,
|
|
573
|
+
# backoff_base: float = 0.5,
|
|
574
|
+
# backoff_max: float = 4.0,
|
|
575
|
+
# custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
|
|
576
|
+
# mode: str = "majority",
|
|
577
|
+
# use_cache: bool = False,
|
|
578
|
+
# litellm_cache_dir: Optional[str] = None,
|
|
579
|
+
# cache_size_gb: Optional[float] = None,
|
|
580
|
+
# ):
|
|
581
|
+
# self.models = models or []
|
|
582
|
+
# self.custom_generation_fns = custom_generation_fns or []
|
|
583
|
+
|
|
584
|
+
# if not self.models and not self.custom_generation_fns:
|
|
585
|
+
# raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
|
|
586
|
+
|
|
587
|
+
# if mode not in ("majority", "single", "all"):
|
|
588
|
+
# raise ValueError("mode must be 'majority', 'single', or 'all'")
|
|
589
|
+
|
|
590
|
+
# self.config = config or {}
|
|
591
|
+
# self.base_headers = dict(base_headers or {})
|
|
592
|
+
# self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
|
|
593
|
+
# self.notes = notes or ""
|
|
594
|
+
# self.use_fully_custom_prompt = use_fully_custom_prompt
|
|
595
|
+
# self.mode = mode
|
|
596
|
+
# self.fallback_comparison = fallback_comparison
|
|
597
|
+
# self.default_temperature = float(default_temperature)
|
|
598
|
+
# self.verbose = verbose
|
|
599
|
+
# self.num_retries = int(num_retries)
|
|
600
|
+
# self.backoff_base = float(backoff_base)
|
|
601
|
+
# self.backoff_max = float(backoff_max)
|
|
602
|
+
|
|
603
|
+
# parser_name = None
|
|
604
|
+
# if isinstance(output_parser, str):
|
|
605
|
+
# parser_map = {
|
|
606
|
+
# 'right/wrong': OutputParsers.right_wrong,
|
|
607
|
+
# 'pass/fail': OutputParsers.pass_fail,
|
|
608
|
+
# 'yes/no': OutputParsers.yes_no,
|
|
609
|
+
# 'numeric': OutputParsers.numeric_score,
|
|
610
|
+
# }
|
|
611
|
+
# if output_parser not in parser_map:
|
|
612
|
+
# raise ValueError(f"Unknown parser '{output_parser}'")
|
|
613
|
+
# self.output_parser = parser_map[output_parser]
|
|
614
|
+
# parser_name = output_parser
|
|
615
|
+
# else:
|
|
616
|
+
# self.output_parser = output_parser
|
|
617
|
+
|
|
618
|
+
# if use_fully_custom_prompt:
|
|
619
|
+
# self.template = None
|
|
620
|
+
# elif custom_template:
|
|
621
|
+
# self.template = custom_template
|
|
622
|
+
# elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
|
|
623
|
+
# self.template = self.BASE_TEMPLATE.format(
|
|
624
|
+
# instruction=self.PARSER_INSTRUCTIONS[parser_name],
|
|
625
|
+
# notes_section="{notes_section}",
|
|
626
|
+
# input_block="{input_block}",
|
|
627
|
+
# model_output="{model_output}",
|
|
628
|
+
# ground_truth="{ground_truth}",
|
|
629
|
+
# )
|
|
630
|
+
# else:
|
|
631
|
+
# self.template = self.BASE_TEMPLATE.format(
|
|
632
|
+
# instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
|
|
633
|
+
# notes_section="{notes_section}",
|
|
634
|
+
# input_block="{input_block}",
|
|
635
|
+
# model_output="{model_output}",
|
|
636
|
+
# ground_truth="{ground_truth}",
|
|
637
|
+
# )
|
|
229
638
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
639
|
+
# # optional local cache setup
|
|
640
|
+
# # Enable cache if use_cache=True OR if litellm_cache_dir is explicitly provided (backward compatible)
|
|
641
|
+
# self.cache_enabled = use_cache or (litellm_cache_dir is not None)
|
|
642
|
+
# if self.cache_enabled:
|
|
643
|
+
# # Only set up cache if it hasn't been set up already
|
|
644
|
+
# if litellm.cache is None:
|
|
645
|
+
# # Set default cache directory if not specified
|
|
646
|
+
# if litellm_cache_dir is None:
|
|
647
|
+
# litellm_cache_dir = ".litellm_cache"
|
|
235
648
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
649
|
+
# # Convert GB to bytes if specified, otherwise unlimited
|
|
650
|
+
# size_limit_bytes = None if cache_size_gb is None else int(cache_size_gb * 1024 * 1024 * 1024)
|
|
651
|
+
# cache_backend = UnlimitedDiskCache(litellm_cache_dir, size_limit=size_limit_bytes)
|
|
652
|
+
# litellm.cache = Cache(disk_cache_dir=litellm_cache_dir)
|
|
653
|
+
# litellm.cache.cache = cache_backend
|
|
239
654
|
|
|
240
|
-
# Sync API used by LiteLLM
|
|
241
|
-
def get_cache(self, key, **kwargs):
|
|
242
|
-
"""Get value from cache by key."""
|
|
243
|
-
return self._dc.get(key)
|
|
244
655
|
|
|
245
|
-
def set_cache(self, key, value, ttl=None, **kwargs):
|
|
246
|
-
"""Set value in cache with optional TTL."""
|
|
247
|
-
expire = None if ttl is None else float(ttl)
|
|
248
|
-
self._dc.set(key, value, expire=expire)
|
|
249
656
|
|
|
250
|
-
# Async API used by LiteLLM
|
|
251
|
-
async def async_get_cache(self, key, **kwargs):
|
|
252
|
-
"""Async get value from cache by key."""
|
|
253
|
-
return self.get_cache(key, **kwargs)
|
|
254
657
|
|
|
255
|
-
async def async_set_cache(self, key, value, ttl=None, **kwargs):
|
|
256
|
-
"""Async set value in cache with optional TTL."""
|
|
257
|
-
return self.set_cache(key, value, ttl=ttl, **kwargs)
|
|
258
658
|
|
|
259
|
-
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
|
260
|
-
"""
|
|
261
|
-
Async batch set multiple cache entries.
|
|
262
659
|
|
|
263
|
-
Args:
|
|
264
|
-
cache_list: List of (key, value) tuples
|
|
265
|
-
ttl: Optional time-to-live in seconds
|
|
266
|
-
"""
|
|
267
|
-
for k, v in cache_list:
|
|
268
|
-
self.set_cache(k, v, ttl=ttl)
|
|
269
660
|
|
|
270
|
-
async def batch_cache_write(self, key, value, ttl=None, **kwargs):
|
|
271
|
-
"""Async batch write (single entry)."""
|
|
272
|
-
self.set_cache(key, value, ttl=ttl)
|
|
273
661
|
|
|
274
|
-
async def ping(self):
|
|
275
|
-
"""Async ping check."""
|
|
276
|
-
return True
|
|
277
662
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
663
|
+
# def _build_prompt(self, input: Any, model_output: Any, ground_truth: Any) -> str:
|
|
664
|
+
# notes_section = f"notes:\n{self.notes}\n" if self.notes else ""
|
|
665
|
+
# input_text = str(input) if input not in (None, "") else "[omitted input for brevity]"
|
|
666
|
+
# return self.template.format(
|
|
667
|
+
# notes_section=notes_section,
|
|
668
|
+
# input_block=input_text,
|
|
669
|
+
# model_output=str(model_output),
|
|
670
|
+
# ground_truth=str(ground_truth),
|
|
671
|
+
# )
|
|
281
672
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
673
|
+
# @staticmethod
|
|
674
|
+
# def _last6_right_wrong(s: str):
|
|
675
|
+
# if not s:
|
|
676
|
+
# return None
|
|
677
|
+
# tail = s.strip()[-6:].lower()
|
|
678
|
+
# if "right" in tail:
|
|
679
|
+
# return True
|
|
680
|
+
# if "wrong" in tail:
|
|
681
|
+
# return False
|
|
682
|
+
# return None
|
|
291
683
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
684
|
+
# def _resolve_per_model(self, model: str) -> Tuple[Optional[str], Dict[str, str], float]:
|
|
685
|
+
# provider = model.split("/", 1)[0] if "/" in model else model
|
|
686
|
+
|
|
687
|
+
# api_base: Optional[str] = None
|
|
688
|
+
# headers: Dict[str, str] = dict(self.base_headers)
|
|
689
|
+
# temperature: float = self.default_temperature
|
|
690
|
+
|
|
691
|
+
# # provider-level defaults
|
|
692
|
+
# pc = self.config.get(provider, {})
|
|
693
|
+
# if pc.get("api_base") is not None:
|
|
694
|
+
# api_base = pc["api_base"]
|
|
695
|
+
# headers.update(pc.get("headers", {}))
|
|
696
|
+
# if "temperature" in pc:
|
|
697
|
+
# temperature = float(pc["temperature"])
|
|
698
|
+
|
|
699
|
+
# # model-level overrides
|
|
700
|
+
# mc = self.config.get(model, {})
|
|
701
|
+
# if mc.get("api_base") is not None:
|
|
702
|
+
# api_base = mc["api_base"]
|
|
703
|
+
# headers.update(mc.get("headers", {}))
|
|
704
|
+
# if "temperature" in mc:
|
|
705
|
+
# temperature = float(mc["temperature"])
|
|
706
|
+
|
|
707
|
+
# # wandb defaults
|
|
708
|
+
# if provider == "wandb":
|
|
709
|
+
# if api_base is None:
|
|
710
|
+
# api_base = "https://api.inference.wandb.ai/v1"
|
|
711
|
+
# if "OpenAI-Project" not in headers:
|
|
712
|
+
# headers["OpenAI-Project"] = self.wandb_project or "wandb_fc/quickstart_playground"
|
|
713
|
+
|
|
714
|
+
# return api_base, headers, temperature
|
|
715
|
+
|
|
716
|
+
# def _attempt_completion(
|
|
717
|
+
# self,
|
|
718
|
+
# model: str,
|
|
719
|
+
# api_base: Optional[str],
|
|
720
|
+
# headers: Dict[str, str],
|
|
721
|
+
# prompt: str,
|
|
722
|
+
# temperature: float,
|
|
723
|
+
# max_tokens: int,
|
|
724
|
+
# ) -> str:
|
|
725
|
+
# attempts = self.num_retries + 1
|
|
726
|
+
# last_err = None
|
|
727
|
+
# for i in range(attempts):
|
|
728
|
+
# try:
|
|
729
|
+
# # resp = completion(
|
|
730
|
+
# # model=model,
|
|
731
|
+
# # api_base=api_base, # None uses provider default
|
|
732
|
+
# # messages=[{"role": "user", "content": prompt}],
|
|
733
|
+
# # temperature=temperature,
|
|
734
|
+
# # max_tokens=max_tokens,
|
|
735
|
+
# # extra_headers=headers,
|
|
736
|
+
# # )
|
|
737
|
+
|
|
738
|
+
# resp = completion(
|
|
739
|
+
# model=model,
|
|
740
|
+
# api_base=api_base,
|
|
741
|
+
# messages=[{"role": "user", "content": prompt}],
|
|
742
|
+
# temperature=temperature,
|
|
743
|
+
# max_tokens=max_tokens,
|
|
744
|
+
# extra_headers=headers,
|
|
745
|
+
# caching=self.cache_enabled
|
|
746
|
+
# )
|
|
747
|
+
# return (resp.choices[0].message.content or "").strip()
|
|
748
|
+
# except Exception as e:
|
|
749
|
+
# last_err = e
|
|
750
|
+
# if i == attempts - 1:
|
|
751
|
+
# break
|
|
752
|
+
# sleep_s = min(self.backoff_max, self.backoff_base * (2 ** i))
|
|
753
|
+
# jitter = sleep_s * (0.1 * (2 * random.random() - 1.0)) # ±10%
|
|
754
|
+
# if self.verbose:
|
|
755
|
+
# print(f"[retry {i+1}/{attempts-1}] {model} error: {e} — sleeping {max(0.0, sleep_s + jitter):.2f}s", flush=True)
|
|
756
|
+
# time.sleep(max(0.0, sleep_s + jitter))
|
|
757
|
+
# raise last_err # fail after retries
|
|
758
|
+
|
|
759
|
+
# def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any):
|
|
760
|
+
# api_base, headers, temperature = self._resolve_per_model(model)
|
|
761
|
+
|
|
762
|
+
# content = self._attempt_completion(
|
|
763
|
+
# model=model,
|
|
764
|
+
# api_base=api_base,
|
|
765
|
+
# headers=headers,
|
|
766
|
+
# prompt=prompt,
|
|
767
|
+
# temperature=temperature,
|
|
768
|
+
# max_tokens=max_tokens,
|
|
769
|
+
# )
|
|
295
770
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
771
|
+
# # Use the instance parser
|
|
772
|
+
# parsed = self.output_parser(content)
|
|
773
|
+
|
|
774
|
+
# # If parser returns None and fallback is enabled, do string comparison
|
|
775
|
+
# if parsed is None and self.fallback_comparison:
|
|
776
|
+
# return str(model_output).strip() == str(ground_truth).strip()
|
|
777
|
+
|
|
778
|
+
# return parsed
|
|
779
|
+
|
|
780
|
+
# def judge(
|
|
781
|
+
# self,
|
|
782
|
+
# input: Any = None,
|
|
783
|
+
# model_output: Any = None,
|
|
784
|
+
# ground_truth: Any = None,
|
|
785
|
+
# prompt: Optional[str] = None,
|
|
786
|
+
# max_tokens: int = 10000,
|
|
787
|
+
# ):
|
|
788
|
+
# # Validation for fully_custom_prompt mode
|
|
789
|
+
# if self.use_fully_custom_prompt:
|
|
790
|
+
# if prompt is None:
|
|
791
|
+
# raise ValueError(
|
|
792
|
+
# "When use_fully_custom_prompt=True, you must pass prompt to judge()."
|
|
793
|
+
# )
|
|
794
|
+
# if input is not None or model_output is not None or ground_truth is not None:
|
|
795
|
+
# raise ValueError(
|
|
796
|
+
# "When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth to judge(). "
|
|
797
|
+
# "Only pass the complete prompt."
|
|
798
|
+
# )
|
|
799
|
+
# elif prompt is not None:
|
|
800
|
+
# raise ValueError(
|
|
801
|
+
# "prompt parameter can only be used when use_fully_custom_prompt=True. "
|
|
802
|
+
# "Otherwise, use input/model_output/ground_truth."
|
|
803
|
+
# )
|
|
804
|
+
# else:
|
|
805
|
+
# prompt = self._build_prompt(input, model_output, ground_truth)
|
|
806
|
+
|
|
807
|
+
# votes = []
|
|
808
|
+
|
|
809
|
+
# # Vote with litellm models
|
|
810
|
+
# for m in self.models:
|
|
811
|
+
# res = self._ask_model(m, prompt, max_tokens, model_output, ground_truth)
|
|
812
|
+
# if self.verbose:
|
|
813
|
+
# print(f"Model {m} voted: {res}", flush=True)
|
|
814
|
+
# votes.append({"model": m, "correct": res})
|
|
815
|
+
|
|
816
|
+
# # Vote with custom generation functions
|
|
817
|
+
# for idx, custom_fn in enumerate(self.custom_generation_fns):
|
|
818
|
+
# try:
|
|
819
|
+
# content = custom_fn(prompt)
|
|
820
|
+
# parsed = self.output_parser(content)
|
|
821
|
+
|
|
822
|
+
# # If parser returns None and fallback is enabled, do string comparison
|
|
823
|
+
# if parsed is None and self.fallback_comparison:
|
|
824
|
+
# res = str(model_output).strip() == str(ground_truth).strip()
|
|
825
|
+
# else:
|
|
826
|
+
# res = parsed
|
|
827
|
+
|
|
828
|
+
# if self.verbose:
|
|
829
|
+
# print(f"Custom function {idx} voted: {res}", flush=True)
|
|
830
|
+
# votes.append({"model": f"custom_fn_{idx}", "correct": res})
|
|
831
|
+
# except Exception as e:
|
|
832
|
+
# if self.verbose:
|
|
833
|
+
# print(f"Custom function {idx} failed: {e}", flush=True)
|
|
834
|
+
# # If custom function fails and fallback is enabled, do string comparison
|
|
835
|
+
# if self.fallback_comparison:
|
|
836
|
+
# res = str(model_output).strip() == str(ground_truth).strip()
|
|
837
|
+
# votes.append({"model": f"custom_fn_{idx}", "correct": res})
|
|
838
|
+
# else:
|
|
839
|
+
# raise
|
|
840
|
+
|
|
841
|
+
# if self.mode == "single":
|
|
842
|
+
# final = votes[0]["correct"]
|
|
843
|
+
# elif self.mode == "majority":
|
|
844
|
+
# true_votes = sum(v["correct"] for v in votes)
|
|
845
|
+
# false_votes = len(votes) - true_votes
|
|
846
|
+
# final = True if true_votes >= false_votes else False
|
|
847
|
+
# elif self.mode == "all":
|
|
848
|
+
# final = all(v["correct"] for v in votes)
|
|
849
|
+
# else:
|
|
850
|
+
# raise ValueError("mode must be 'majority', 'single', or 'all'")
|
|
851
|
+
|
|
852
|
+
# return {"correct": final, "mode": self.mode, "votes": votes}
|
|
299
853
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
import os
|
|
859
|
+
import time
|
|
860
|
+
import random
|
|
861
|
+
import re
|
|
862
|
+
import json
|
|
863
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
864
|
+
from enum import Enum
|
|
865
|
+
import litellm
|
|
866
|
+
from litellm import completion
|
|
867
|
+
from litellm.caching.caching import Cache
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
__all__ = ["LLMAsAJudge", "OutputParsers"]
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
class ReturnType(Enum):
|
|
874
|
+
"""Enum for categorizing parser return types."""
|
|
875
|
+
BOOLEAN = "boolean"
|
|
876
|
+
SCALAR = "scalar"
|
|
877
|
+
MAP = "map"
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
class AggregationMode(Enum):
|
|
881
|
+
"""Enum for aggregation modes across multiple voters."""
|
|
882
|
+
# Boolean modes
|
|
883
|
+
MAJORITY = "majority"
|
|
884
|
+
SINGLE = "single"
|
|
885
|
+
ALL = "all"
|
|
886
|
+
# Scalar modes
|
|
887
|
+
AVERAGE = "average"
|
|
888
|
+
MIN = "min"
|
|
889
|
+
MAX = "max"
|
|
890
|
+
MEDIAN = "median"
|
|
891
|
+
|
|
892
|
+
|
|
893
|
+
# Valid aggregation modes per return type
|
|
894
|
+
VALID_MODES = {
|
|
895
|
+
ReturnType.BOOLEAN: {AggregationMode.MAJORITY, AggregationMode.SINGLE, AggregationMode.ALL},
|
|
896
|
+
ReturnType.SCALAR: {AggregationMode.AVERAGE, AggregationMode.MIN, AggregationMode.MAX, AggregationMode.MEDIAN, AggregationMode.SINGLE},
|
|
897
|
+
ReturnType.MAP: {AggregationMode.AVERAGE, AggregationMode.MIN, AggregationMode.MAX, AggregationMode.MEDIAN, AggregationMode.SINGLE},
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
# Default aggregation modes per return type
|
|
901
|
+
DEFAULT_MODES = {
|
|
902
|
+
ReturnType.BOOLEAN: AggregationMode.MAJORITY,
|
|
903
|
+
ReturnType.SCALAR: AggregationMode.AVERAGE,
|
|
904
|
+
ReturnType.MAP: AggregationMode.AVERAGE,
|
|
905
|
+
}
|
|
906
|
+
|
|
907
|
+
# String to enum mapping (for backward compat)
|
|
908
|
+
MODE_STR_MAP = {
|
|
909
|
+
'majority': AggregationMode.MAJORITY,
|
|
910
|
+
'single': AggregationMode.SINGLE,
|
|
911
|
+
'all': AggregationMode.ALL,
|
|
912
|
+
'average': AggregationMode.AVERAGE,
|
|
913
|
+
'min': AggregationMode.MIN,
|
|
914
|
+
'max': AggregationMode.MAX,
|
|
915
|
+
'median': AggregationMode.MEDIAN,
|
|
916
|
+
}
|
|
334
917
|
|
|
335
918
|
|
|
336
919
|
class OutputParsers:
|
|
337
920
|
"""Stock output parsers for common judge output formats."""
|
|
338
921
|
|
|
922
|
+
# ========== BOOLEAN PARSERS ==========
|
|
923
|
+
|
|
339
924
|
@staticmethod
|
|
340
925
|
def right_wrong(s: str) -> Optional[bool]:
|
|
341
926
|
"""Parse 'right' or 'wrong' from the last 6 characters."""
|
|
@@ -372,6 +957,8 @@ class OutputParsers:
|
|
|
372
957
|
return False
|
|
373
958
|
return None
|
|
374
959
|
|
|
960
|
+
# ========== SCALAR PARSERS ==========
|
|
961
|
+
|
|
375
962
|
@staticmethod
|
|
376
963
|
def numeric_score(s: str) -> Optional[float]:
|
|
377
964
|
"""Extract first numeric value from the response."""
|
|
@@ -382,20 +969,188 @@ class OutputParsers:
|
|
|
382
969
|
return float(match.group())
|
|
383
970
|
return None
|
|
384
971
|
|
|
972
|
+
@staticmethod
|
|
973
|
+
def score_0_to_10(s: str) -> Optional[float]:
|
|
974
|
+
"""Extract a score from 0-10, clamping to valid range."""
|
|
975
|
+
if not s:
|
|
976
|
+
return None
|
|
977
|
+
match = re.search(r'[-+]?\d*\.?\d+', s.strip())
|
|
978
|
+
if match:
|
|
979
|
+
return max(0.0, min(10.0, float(match.group())))
|
|
980
|
+
return None
|
|
981
|
+
|
|
982
|
+
@staticmethod
|
|
983
|
+
def score_0_to_100(s: str) -> Optional[float]:
|
|
984
|
+
"""Extract a score from 0-100, clamping to valid range."""
|
|
985
|
+
if not s:
|
|
986
|
+
return None
|
|
987
|
+
match = re.search(r'[-+]?\d*\.?\d+', s.strip())
|
|
988
|
+
if match:
|
|
989
|
+
return max(0.0, min(100.0, float(match.group())))
|
|
990
|
+
return None
|
|
991
|
+
|
|
992
|
+
@staticmethod
|
|
993
|
+
def score_1_to_5(s: str) -> Optional[float]:
|
|
994
|
+
"""Extract a score from 1-5, clamping to valid range."""
|
|
995
|
+
if not s:
|
|
996
|
+
return None
|
|
997
|
+
match = re.search(r'[-+]?\d*\.?\d+', s.strip())
|
|
998
|
+
if match:
|
|
999
|
+
return max(1.0, min(5.0, float(match.group())))
|
|
1000
|
+
return None
|
|
1001
|
+
|
|
1002
|
+
# ========== MAP PARSERS ==========
|
|
1003
|
+
|
|
385
1004
|
@staticmethod
|
|
386
1005
|
def json_extract(key: str) -> Callable[[str], Any]:
|
|
387
1006
|
"""Create a parser that extracts a specific key from JSON output."""
|
|
388
|
-
import json
|
|
389
1007
|
def parser(s: str) -> Any:
|
|
390
1008
|
if not s:
|
|
391
1009
|
return None
|
|
392
1010
|
try:
|
|
393
|
-
|
|
1011
|
+
s = s.strip()
|
|
1012
|
+
if "```json" in s:
|
|
1013
|
+
start = s.find("```json") + 7
|
|
1014
|
+
end = s.find("```", start)
|
|
1015
|
+
s = s[start:end].strip()
|
|
1016
|
+
elif "```" in s:
|
|
1017
|
+
start = s.find("```") + 3
|
|
1018
|
+
end = s.find("```", start)
|
|
1019
|
+
s = s[start:end].strip()
|
|
1020
|
+
data = json.loads(s)
|
|
394
1021
|
return data.get(key)
|
|
395
1022
|
except (json.JSONDecodeError, AttributeError):
|
|
396
1023
|
return None
|
|
397
1024
|
return parser
|
|
398
1025
|
|
|
1026
|
+
@staticmethod
|
|
1027
|
+
def json_map(keys: Optional[List[str]] = None) -> Callable[[str], Optional[Dict[str, float]]]:
|
|
1028
|
+
"""
|
|
1029
|
+
Create a parser that extracts multiple keys from JSON output as a map.
|
|
1030
|
+
Returns Dict[str, float] or None.
|
|
1031
|
+
|
|
1032
|
+
More robust version that handles:
|
|
1033
|
+
- Code blocks with ```json or ```
|
|
1034
|
+
- Extra text before/after JSON
|
|
1035
|
+
- Various JSON formatting issues
|
|
1036
|
+
"""
|
|
1037
|
+
def parser(s: str) -> Optional[Dict[str, float]]:
|
|
1038
|
+
if not s:
|
|
1039
|
+
return None
|
|
1040
|
+
try:
|
|
1041
|
+
s = s.strip()
|
|
1042
|
+
|
|
1043
|
+
# Handle markdown code blocks
|
|
1044
|
+
if "```json" in s.lower():
|
|
1045
|
+
start = s.lower().find("```json") + 7
|
|
1046
|
+
end = s.find("```", start)
|
|
1047
|
+
if end > start:
|
|
1048
|
+
s = s[start:end].strip()
|
|
1049
|
+
elif "```" in s:
|
|
1050
|
+
start = s.find("```") + 3
|
|
1051
|
+
end = s.find("```", start)
|
|
1052
|
+
if end > start:
|
|
1053
|
+
s = s[start:end].strip()
|
|
1054
|
+
|
|
1055
|
+
# Try to find JSON object if there's extra text
|
|
1056
|
+
# Look for first { and last }
|
|
1057
|
+
if '{' in s and '}' in s:
|
|
1058
|
+
start_brace = s.find('{')
|
|
1059
|
+
end_brace = s.rfind('}')
|
|
1060
|
+
if start_brace < end_brace:
|
|
1061
|
+
s = s[start_brace:end_brace + 1]
|
|
1062
|
+
|
|
1063
|
+
data = json.loads(s)
|
|
1064
|
+
if not isinstance(data, dict):
|
|
1065
|
+
return None
|
|
1066
|
+
|
|
1067
|
+
result = {}
|
|
1068
|
+
target_keys = keys if keys else list(data.keys())
|
|
1069
|
+
|
|
1070
|
+
for key in target_keys:
|
|
1071
|
+
if key in data:
|
|
1072
|
+
val = data[key]
|
|
1073
|
+
if isinstance(val, (int, float)):
|
|
1074
|
+
result[key] = float(val)
|
|
1075
|
+
elif isinstance(val, str):
|
|
1076
|
+
try:
|
|
1077
|
+
result[key] = float(val)
|
|
1078
|
+
except ValueError:
|
|
1079
|
+
pass
|
|
1080
|
+
|
|
1081
|
+
return result if result else None
|
|
1082
|
+
except (json.JSONDecodeError, AttributeError, ValueError):
|
|
1083
|
+
return None
|
|
1084
|
+
return parser
|
|
1085
|
+
|
|
1086
|
+
@staticmethod
|
|
1087
|
+
def multi_score_pattern(pattern: str = r'(\w+):\s*([\d.]+)') -> Callable[[str], Optional[Dict[str, float]]]:
|
|
1088
|
+
"""
|
|
1089
|
+
Create a parser that extracts key-value pairs using regex pattern.
|
|
1090
|
+
Default pattern matches "key: value" format.
|
|
1091
|
+
"""
|
|
1092
|
+
def parser(s: str) -> Optional[Dict[str, float]]:
|
|
1093
|
+
if not s:
|
|
1094
|
+
return None
|
|
1095
|
+
matches = re.findall(pattern, s)
|
|
1096
|
+
if not matches:
|
|
1097
|
+
return None
|
|
1098
|
+
result = {}
|
|
1099
|
+
for key, val in matches:
|
|
1100
|
+
try:
|
|
1101
|
+
result[key.lower()] = float(val)
|
|
1102
|
+
except ValueError:
|
|
1103
|
+
pass
|
|
1104
|
+
return result if result else None
|
|
1105
|
+
return parser
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def _infer_return_type(value: Any) -> Optional[ReturnType]:
|
|
1109
|
+
"""Infer the return type from a parsed value."""
|
|
1110
|
+
if value is None:
|
|
1111
|
+
return None
|
|
1112
|
+
if isinstance(value, bool):
|
|
1113
|
+
return ReturnType.BOOLEAN
|
|
1114
|
+
if isinstance(value, (int, float)):
|
|
1115
|
+
return ReturnType.SCALAR
|
|
1116
|
+
if isinstance(value, dict) and all(isinstance(v, (int, float)) for v in value.values()):
|
|
1117
|
+
return ReturnType.MAP
|
|
1118
|
+
return None
|
|
1119
|
+
|
|
1120
|
+
|
|
1121
|
+
def _validate_consistent_types(votes: List[Dict[str, Any]]) -> ReturnType:
|
|
1122
|
+
"""Validate that all votes have consistent return types."""
|
|
1123
|
+
return_types = set()
|
|
1124
|
+
for vote in votes:
|
|
1125
|
+
result = vote.get("result")
|
|
1126
|
+
if result is not None:
|
|
1127
|
+
rt = _infer_return_type(result)
|
|
1128
|
+
if rt is not None:
|
|
1129
|
+
return_types.add(rt)
|
|
1130
|
+
|
|
1131
|
+
if len(return_types) == 0:
|
|
1132
|
+
raise ValueError("All parsers returned None - cannot determine return type")
|
|
1133
|
+
if len(return_types) > 1:
|
|
1134
|
+
raise ValueError(
|
|
1135
|
+
f"Mixed return types detected: {[rt.value for rt in return_types]}. "
|
|
1136
|
+
"All judges must return the same type (boolean, scalar, or map)."
|
|
1137
|
+
)
|
|
1138
|
+
return return_types.pop()
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
def _validate_consistent_map_keys(votes: List[Dict[str, Any]]) -> set:
|
|
1142
|
+
"""Validate that all map votes have the same keys."""
|
|
1143
|
+
all_keys = None
|
|
1144
|
+
for vote in votes:
|
|
1145
|
+
result = vote.get("result")
|
|
1146
|
+
if isinstance(result, dict):
|
|
1147
|
+
keys = set(result.keys())
|
|
1148
|
+
if all_keys is None:
|
|
1149
|
+
all_keys = keys
|
|
1150
|
+
elif keys != all_keys:
|
|
1151
|
+
raise ValueError(f"Inconsistent map keys across voters. Expected {all_keys}, got {keys}")
|
|
1152
|
+
return all_keys or set()
|
|
1153
|
+
|
|
399
1154
|
|
|
400
1155
|
class LLMAsAJudge:
|
|
401
1156
|
BASE_TEMPLATE = """\
|
|
@@ -439,123 +1194,13 @@ Return a single numeric score from 0-10 indicating how well the model output mat
|
|
|
439
1194
|
- 1-3 = poor match
|
|
440
1195
|
- 0 = completely wrong
|
|
441
1196
|
Output only the number. No explanation. No extra text.""",
|
|
1197
|
+
'multi_score': """\
|
|
1198
|
+
Evaluate the model output against the ground truth on multiple dimensions.
|
|
1199
|
+
Return your scores in JSON format with numeric values (0-10 scale).
|
|
1200
|
+
Example: {"accuracy": 8, "helpfulness": 7, "relevance": 9}
|
|
1201
|
+
Output only valid JSON. No explanation. No extra text.""",
|
|
442
1202
|
}
|
|
443
1203
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
# def __init__(
|
|
448
|
-
# self,
|
|
449
|
-
# models: Optional[List[str]] = None,
|
|
450
|
-
# config: Optional[Dict[str, Dict[str, Any]]] = None, # one dict for providers and models
|
|
451
|
-
# base_headers: Optional[Dict[str, str]] = None,
|
|
452
|
-
# wandb_project: Optional[str] = None,
|
|
453
|
-
# custom_template: Optional[str] = None,
|
|
454
|
-
# use_fully_custom_prompt: bool = False,
|
|
455
|
-
# notes: Optional[str] = None,
|
|
456
|
-
# output_parser: Optional[str] = 'right/wrong',
|
|
457
|
-
# fallback_comparison: bool = True,
|
|
458
|
-
# default_temperature: float = 0.0,
|
|
459
|
-
# verbose: bool = False,
|
|
460
|
-
# num_retries: int = 2, # per-call retries before giving up on that model
|
|
461
|
-
# backoff_base: float = 0.5, # seconds
|
|
462
|
-
# backoff_max: float = 4.0, # seconds
|
|
463
|
-
# custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
|
|
464
|
-
# mode: str = "majority", # "single", "majority", "all"
|
|
465
|
-
# ):
|
|
466
|
-
# """
|
|
467
|
-
# config keys can be a provider name ("wandb", "openai", "anthropic")
|
|
468
|
-
# or a full model name ("openai/gpt-4o-mini", "wandb/deepseek-ai/DeepSeek-V3.1").
|
|
469
|
-
|
|
470
|
-
# Values can include:
|
|
471
|
-
# api_base: Optional[str]
|
|
472
|
-
# headers: Dict[str, str]
|
|
473
|
-
# temperature: float
|
|
474
|
-
|
|
475
|
-
# Precedence:
|
|
476
|
-
# base_headers < provider config < model config
|
|
477
|
-
|
|
478
|
-
# Args:
|
|
479
|
-
# models: List of litellm model strings (e.g., ["openai/gpt-4", "anthropic/claude-3"])
|
|
480
|
-
# custom_template: Template with placeholders for input/output/ground_truth
|
|
481
|
-
# use_fully_custom_prompt: If True, pass complete prompt to judge(prompt=...).
|
|
482
|
-
# When True, input/output/ground_truth must NOT be passed to judge()
|
|
483
|
-
# output_parser: Parser name ('right/wrong', 'yes/no', 'pass/fail', 'numeric')
|
|
484
|
-
# or custom function with signature (str) -> Any
|
|
485
|
-
# fallback_comparison: If True and parser returns None, falls back to string comparison
|
|
486
|
-
# custom_generation_fns: List of custom inference functions with signature fn(prompt: str) -> str
|
|
487
|
-
# These will be used in addition to litellm models for voting.
|
|
488
|
-
# mode: Voting mode - "majority" (default), "single" (first judge only), or "all" (unanimous)
|
|
489
|
-
# """
|
|
490
|
-
# self.models = models or []
|
|
491
|
-
# self.custom_generation_fns = custom_generation_fns or []
|
|
492
|
-
|
|
493
|
-
# # Validate that at least one judge is provided
|
|
494
|
-
# if not self.models and not self.custom_generation_fns:
|
|
495
|
-
# raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
|
|
496
|
-
|
|
497
|
-
# # Validate mode
|
|
498
|
-
# if mode not in ("majority", "single", "all"):
|
|
499
|
-
# raise ValueError("mode must be 'majority', 'single', or 'all'")
|
|
500
|
-
|
|
501
|
-
# self.config = config or {}
|
|
502
|
-
# self.base_headers = dict(base_headers or {})
|
|
503
|
-
# self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
|
|
504
|
-
# self.notes = notes or ""
|
|
505
|
-
# self.use_fully_custom_prompt = use_fully_custom_prompt
|
|
506
|
-
# self.mode = mode
|
|
507
|
-
|
|
508
|
-
# # Resolve output parser
|
|
509
|
-
# parser_name = None
|
|
510
|
-
# if isinstance(output_parser, str):
|
|
511
|
-
# parser_map = {
|
|
512
|
-
# 'right/wrong': OutputParsers.right_wrong,
|
|
513
|
-
# 'pass/fail': OutputParsers.pass_fail,
|
|
514
|
-
# 'yes/no': OutputParsers.yes_no,
|
|
515
|
-
# 'numeric': OutputParsers.numeric_score,
|
|
516
|
-
# }
|
|
517
|
-
# if output_parser not in parser_map:
|
|
518
|
-
# raise ValueError(f"Unknown parser '{output_parser}'. Available: {list(parser_map.keys())}")
|
|
519
|
-
# self.output_parser = parser_map[output_parser]
|
|
520
|
-
# parser_name = output_parser
|
|
521
|
-
# else:
|
|
522
|
-
# self.output_parser = output_parser
|
|
523
|
-
|
|
524
|
-
# # Set template based on mode
|
|
525
|
-
# if use_fully_custom_prompt:
|
|
526
|
-
# self.template = None # No template in fully custom mode
|
|
527
|
-
# elif custom_template:
|
|
528
|
-
# self.template = custom_template
|
|
529
|
-
# elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
|
|
530
|
-
# self.template = self.BASE_TEMPLATE.format(
|
|
531
|
-
# instruction=self.PARSER_INSTRUCTIONS[parser_name],
|
|
532
|
-
# notes_section="{notes_section}",
|
|
533
|
-
# input_block="{input_block}",
|
|
534
|
-
# model_output="{model_output}",
|
|
535
|
-
# ground_truth="{ground_truth}",
|
|
536
|
-
# )
|
|
537
|
-
# else:
|
|
538
|
-
# # Default to right/wrong for custom parsers
|
|
539
|
-
# self.template = self.BASE_TEMPLATE.format(
|
|
540
|
-
# instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
|
|
541
|
-
# notes_section="{notes_section}",
|
|
542
|
-
# input_block="{input_block}",
|
|
543
|
-
# model_output="{model_output}",
|
|
544
|
-
# ground_truth="{ground_truth}",
|
|
545
|
-
# )
|
|
546
|
-
|
|
547
|
-
# self.fallback_comparison = fallback_comparison
|
|
548
|
-
# self.default_temperature = float(default_temperature)
|
|
549
|
-
# self.verbose = verbose
|
|
550
|
-
# self.num_retries = int(num_retries)
|
|
551
|
-
# self.backoff_base = float(backoff_base)
|
|
552
|
-
# self.backoff_max = float(backoff_max)
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
1204
|
def __init__(
|
|
560
1205
|
self,
|
|
561
1206
|
models: Optional[List[str]] = None,
|
|
@@ -565,7 +1210,7 @@ Output only the number. No explanation. No extra text.""",
|
|
|
565
1210
|
custom_template: Optional[str] = None,
|
|
566
1211
|
use_fully_custom_prompt: bool = False,
|
|
567
1212
|
notes: Optional[str] = None,
|
|
568
|
-
output_parser: Optional[str] = 'right/wrong',
|
|
1213
|
+
output_parser: Optional[Union[str, Callable]] = 'right/wrong',
|
|
569
1214
|
fallback_comparison: bool = True,
|
|
570
1215
|
default_temperature: float = 0.0,
|
|
571
1216
|
verbose: bool = False,
|
|
@@ -573,10 +1218,9 @@ Output only the number. No explanation. No extra text.""",
|
|
|
573
1218
|
backoff_base: float = 0.5,
|
|
574
1219
|
backoff_max: float = 4.0,
|
|
575
1220
|
custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
|
|
576
|
-
mode: str =
|
|
577
|
-
use_cache: bool = False,
|
|
1221
|
+
mode: Optional[str] = None,
|
|
578
1222
|
litellm_cache_dir: Optional[str] = None,
|
|
579
|
-
|
|
1223
|
+
return_type: Optional[str] = None,
|
|
580
1224
|
):
|
|
581
1225
|
self.models = models or []
|
|
582
1226
|
self.custom_generation_fns = custom_generation_fns or []
|
|
@@ -584,15 +1228,11 @@ Output only the number. No explanation. No extra text.""",
|
|
|
584
1228
|
if not self.models and not self.custom_generation_fns:
|
|
585
1229
|
raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
|
|
586
1230
|
|
|
587
|
-
if mode not in ("majority", "single", "all"):
|
|
588
|
-
raise ValueError("mode must be 'majority', 'single', or 'all'")
|
|
589
|
-
|
|
590
1231
|
self.config = config or {}
|
|
591
1232
|
self.base_headers = dict(base_headers or {})
|
|
592
1233
|
self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
|
|
593
1234
|
self.notes = notes or ""
|
|
594
1235
|
self.use_fully_custom_prompt = use_fully_custom_prompt
|
|
595
|
-
self.mode = mode
|
|
596
1236
|
self.fallback_comparison = fallback_comparison
|
|
597
1237
|
self.default_temperature = float(default_temperature)
|
|
598
1238
|
self.verbose = verbose
|
|
@@ -600,6 +1240,7 @@ Output only the number. No explanation. No extra text.""",
|
|
|
600
1240
|
self.backoff_base = float(backoff_base)
|
|
601
1241
|
self.backoff_max = float(backoff_max)
|
|
602
1242
|
|
|
1243
|
+
# Resolve output parser
|
|
603
1244
|
parser_name = None
|
|
604
1245
|
if isinstance(output_parser, str):
|
|
605
1246
|
parser_map = {
|
|
@@ -607,6 +1248,7 @@ Output only the number. No explanation. No extra text.""",
|
|
|
607
1248
|
'pass/fail': OutputParsers.pass_fail,
|
|
608
1249
|
'yes/no': OutputParsers.yes_no,
|
|
609
1250
|
'numeric': OutputParsers.numeric_score,
|
|
1251
|
+
'multi_score': OutputParsers.json_map(),
|
|
610
1252
|
}
|
|
611
1253
|
if output_parser not in parser_map:
|
|
612
1254
|
raise ValueError(f"Unknown parser '{output_parser}'")
|
|
@@ -615,6 +1257,42 @@ Output only the number. No explanation. No extra text.""",
|
|
|
615
1257
|
else:
|
|
616
1258
|
self.output_parser = output_parser
|
|
617
1259
|
|
|
1260
|
+
# Determine expected return type
|
|
1261
|
+
self._explicit_return_type: Optional[ReturnType] = None
|
|
1262
|
+
if return_type is not None:
|
|
1263
|
+
self._explicit_return_type = ReturnType(return_type)
|
|
1264
|
+
elif parser_name:
|
|
1265
|
+
if parser_name in ('right/wrong', 'pass/fail', 'yes/no'):
|
|
1266
|
+
self._explicit_return_type = ReturnType.BOOLEAN
|
|
1267
|
+
elif parser_name == 'numeric':
|
|
1268
|
+
self._explicit_return_type = ReturnType.SCALAR
|
|
1269
|
+
elif parser_name == 'multi_score':
|
|
1270
|
+
self._explicit_return_type = ReturnType.MAP
|
|
1271
|
+
|
|
1272
|
+
# Resolve aggregation mode
|
|
1273
|
+
if mode is not None:
|
|
1274
|
+
if mode not in MODE_STR_MAP:
|
|
1275
|
+
raise ValueError(f"Unknown mode '{mode}'. Available: {list(MODE_STR_MAP.keys())}")
|
|
1276
|
+
self._mode = MODE_STR_MAP[mode]
|
|
1277
|
+
else:
|
|
1278
|
+
if self._explicit_return_type:
|
|
1279
|
+
self._mode = DEFAULT_MODES[self._explicit_return_type]
|
|
1280
|
+
else:
|
|
1281
|
+
self._mode = AggregationMode.MAJORITY
|
|
1282
|
+
|
|
1283
|
+
# Validate mode against return type if known
|
|
1284
|
+
if self._explicit_return_type:
|
|
1285
|
+
valid_modes = VALID_MODES[self._explicit_return_type]
|
|
1286
|
+
if self._mode not in valid_modes:
|
|
1287
|
+
raise ValueError(
|
|
1288
|
+
f"Mode '{self._mode.value}' not valid for return type '{self._explicit_return_type.value}'. "
|
|
1289
|
+
f"Valid: {[m.value for m in valid_modes]}"
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
# For backward compat, expose mode as string
|
|
1293
|
+
self.mode = self._mode.value
|
|
1294
|
+
|
|
1295
|
+
# Set template
|
|
618
1296
|
if use_fully_custom_prompt:
|
|
619
1297
|
self.template = None
|
|
620
1298
|
elif custom_template:
|
|
@@ -636,50 +1314,23 @@ Output only the number. No explanation. No extra text.""",
|
|
|
636
1314
|
ground_truth="{ground_truth}",
|
|
637
1315
|
)
|
|
638
1316
|
|
|
639
|
-
#
|
|
640
|
-
|
|
641
|
-
self.cache_enabled = use_cache or (litellm_cache_dir is not None)
|
|
1317
|
+
# Optional local cache setup
|
|
1318
|
+
self.cache_enabled = litellm_cache_dir is not None
|
|
642
1319
|
if self.cache_enabled:
|
|
643
|
-
|
|
644
|
-
if litellm.cache is None:
|
|
645
|
-
# Set default cache directory if not specified
|
|
646
|
-
if litellm_cache_dir is None:
|
|
647
|
-
litellm_cache_dir = ".litellm_cache"
|
|
648
|
-
|
|
649
|
-
# Convert GB to bytes if specified, otherwise unlimited
|
|
650
|
-
size_limit_bytes = None if cache_size_gb is None else int(cache_size_gb * 1024 * 1024 * 1024)
|
|
651
|
-
cache_backend = UnlimitedDiskCache(litellm_cache_dir, size_limit=size_limit_bytes)
|
|
652
|
-
litellm.cache = Cache(disk_cache_dir=litellm_cache_dir)
|
|
653
|
-
litellm.cache.cache = cache_backend
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
1320
|
+
litellm.cache = Cache(type="disk", disk_cache_dir=litellm_cache_dir)
|
|
662
1321
|
|
|
663
1322
|
def _build_prompt(self, input: Any, model_output: Any, ground_truth: Any) -> str:
|
|
664
1323
|
notes_section = f"notes:\n{self.notes}\n" if self.notes else ""
|
|
665
1324
|
input_text = str(input) if input not in (None, "") else "[omitted input for brevity]"
|
|
666
|
-
return self.template.format(
|
|
667
|
-
notes_section=notes_section,
|
|
668
|
-
input_block=input_text,
|
|
669
|
-
model_output=str(model_output),
|
|
670
|
-
ground_truth=str(ground_truth),
|
|
671
|
-
)
|
|
672
1325
|
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
return False
|
|
682
|
-
return None
|
|
1326
|
+
# Use string replacement instead of .format() to avoid issues with { } in template
|
|
1327
|
+
# (e.g., JSON examples in multi_score instructions)
|
|
1328
|
+
prompt = self.template
|
|
1329
|
+
prompt = prompt.replace("{notes_section}", notes_section)
|
|
1330
|
+
prompt = prompt.replace("{input_block}", input_text)
|
|
1331
|
+
prompt = prompt.replace("{model_output}", str(model_output))
|
|
1332
|
+
prompt = prompt.replace("{ground_truth}", str(ground_truth))
|
|
1333
|
+
return prompt
|
|
683
1334
|
|
|
684
1335
|
def _resolve_per_model(self, model: str) -> Tuple[Optional[str], Dict[str, str], float]:
|
|
685
1336
|
provider = model.split("/", 1)[0] if "/" in model else model
|
|
@@ -688,7 +1339,6 @@ Output only the number. No explanation. No extra text.""",
|
|
|
688
1339
|
headers: Dict[str, str] = dict(self.base_headers)
|
|
689
1340
|
temperature: float = self.default_temperature
|
|
690
1341
|
|
|
691
|
-
# provider-level defaults
|
|
692
1342
|
pc = self.config.get(provider, {})
|
|
693
1343
|
if pc.get("api_base") is not None:
|
|
694
1344
|
api_base = pc["api_base"]
|
|
@@ -696,7 +1346,6 @@ Output only the number. No explanation. No extra text.""",
|
|
|
696
1346
|
if "temperature" in pc:
|
|
697
1347
|
temperature = float(pc["temperature"])
|
|
698
1348
|
|
|
699
|
-
# model-level overrides
|
|
700
1349
|
mc = self.config.get(model, {})
|
|
701
1350
|
if mc.get("api_base") is not None:
|
|
702
1351
|
api_base = mc["api_base"]
|
|
@@ -704,7 +1353,6 @@ Output only the number. No explanation. No extra text.""",
|
|
|
704
1353
|
if "temperature" in mc:
|
|
705
1354
|
temperature = float(mc["temperature"])
|
|
706
1355
|
|
|
707
|
-
# wandb defaults
|
|
708
1356
|
if provider == "wandb":
|
|
709
1357
|
if api_base is None:
|
|
710
1358
|
api_base = "https://api.inference.wandb.ai/v1"
|
|
@@ -726,15 +1374,6 @@ Output only the number. No explanation. No extra text.""",
|
|
|
726
1374
|
last_err = None
|
|
727
1375
|
for i in range(attempts):
|
|
728
1376
|
try:
|
|
729
|
-
# resp = completion(
|
|
730
|
-
# model=model,
|
|
731
|
-
# api_base=api_base, # None uses provider default
|
|
732
|
-
# messages=[{"role": "user", "content": prompt}],
|
|
733
|
-
# temperature=temperature,
|
|
734
|
-
# max_tokens=max_tokens,
|
|
735
|
-
# extra_headers=headers,
|
|
736
|
-
# )
|
|
737
|
-
|
|
738
1377
|
resp = completion(
|
|
739
1378
|
model=model,
|
|
740
1379
|
api_base=api_base,
|
|
@@ -743,20 +1382,20 @@ Output only the number. No explanation. No extra text.""",
|
|
|
743
1382
|
max_tokens=max_tokens,
|
|
744
1383
|
extra_headers=headers,
|
|
745
1384
|
caching=self.cache_enabled
|
|
746
|
-
)
|
|
1385
|
+
)
|
|
747
1386
|
return (resp.choices[0].message.content or "").strip()
|
|
748
1387
|
except Exception as e:
|
|
749
1388
|
last_err = e
|
|
750
1389
|
if i == attempts - 1:
|
|
751
1390
|
break
|
|
752
1391
|
sleep_s = min(self.backoff_max, self.backoff_base * (2 ** i))
|
|
753
|
-
jitter = sleep_s * (0.1 * (2 * random.random() - 1.0))
|
|
1392
|
+
jitter = sleep_s * (0.1 * (2 * random.random() - 1.0))
|
|
754
1393
|
if self.verbose:
|
|
755
1394
|
print(f"[retry {i+1}/{attempts-1}] {model} error: {e} — sleeping {max(0.0, sleep_s + jitter):.2f}s", flush=True)
|
|
756
1395
|
time.sleep(max(0.0, sleep_s + jitter))
|
|
757
|
-
raise last_err
|
|
1396
|
+
raise last_err
|
|
758
1397
|
|
|
759
|
-
def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any):
|
|
1398
|
+
def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any) -> Any:
|
|
760
1399
|
api_base, headers, temperature = self._resolve_per_model(model)
|
|
761
1400
|
|
|
762
1401
|
content = self._attempt_completion(
|
|
@@ -768,15 +1407,80 @@ Output only the number. No explanation. No extra text.""",
|
|
|
768
1407
|
max_tokens=max_tokens,
|
|
769
1408
|
)
|
|
770
1409
|
|
|
771
|
-
# Use the instance parser
|
|
772
1410
|
parsed = self.output_parser(content)
|
|
773
1411
|
|
|
774
|
-
# If parser returns None and fallback is enabled, do string comparison
|
|
775
1412
|
if parsed is None and self.fallback_comparison:
|
|
776
1413
|
return str(model_output).strip() == str(ground_truth).strip()
|
|
777
1414
|
|
|
778
1415
|
return parsed
|
|
779
1416
|
|
|
1417
|
+
def _aggregate_boolean(self, votes: List[Dict[str, Any]]) -> bool:
|
|
1418
|
+
results = [v["result"] for v in votes if v["result"] is not None]
|
|
1419
|
+
if not results:
|
|
1420
|
+
raise ValueError("No valid votes to aggregate")
|
|
1421
|
+
|
|
1422
|
+
if self._mode == AggregationMode.SINGLE:
|
|
1423
|
+
return bool(results[0])
|
|
1424
|
+
elif self._mode == AggregationMode.MAJORITY:
|
|
1425
|
+
return sum(1 for r in results if r) >= len(results) / 2
|
|
1426
|
+
elif self._mode == AggregationMode.ALL:
|
|
1427
|
+
return all(results)
|
|
1428
|
+
else:
|
|
1429
|
+
raise ValueError(f"Invalid mode for boolean: {self._mode}")
|
|
1430
|
+
|
|
1431
|
+
def _aggregate_scalar(self, votes: List[Dict[str, Any]]) -> float:
|
|
1432
|
+
results = [float(v["result"]) for v in votes if v["result"] is not None]
|
|
1433
|
+
if not results:
|
|
1434
|
+
raise ValueError("No valid votes to aggregate")
|
|
1435
|
+
|
|
1436
|
+
if self._mode == AggregationMode.SINGLE:
|
|
1437
|
+
return results[0]
|
|
1438
|
+
elif self._mode == AggregationMode.AVERAGE:
|
|
1439
|
+
return sum(results) / len(results)
|
|
1440
|
+
elif self._mode == AggregationMode.MIN:
|
|
1441
|
+
return min(results)
|
|
1442
|
+
elif self._mode == AggregationMode.MAX:
|
|
1443
|
+
return max(results)
|
|
1444
|
+
elif self._mode == AggregationMode.MEDIAN:
|
|
1445
|
+
s = sorted(results)
|
|
1446
|
+
n = len(s)
|
|
1447
|
+
mid = n // 2
|
|
1448
|
+
return (s[mid - 1] + s[mid]) / 2 if n % 2 == 0 else s[mid]
|
|
1449
|
+
else:
|
|
1450
|
+
raise ValueError(f"Invalid mode for scalar: {self._mode}")
|
|
1451
|
+
|
|
1452
|
+
def _aggregate_map(self, votes: List[Dict[str, Any]]) -> Dict[str, float]:
|
|
1453
|
+
valid = [v["result"] for v in votes if v["result"] is not None and isinstance(v["result"], dict)]
|
|
1454
|
+
if not valid:
|
|
1455
|
+
raise ValueError("No valid map votes to aggregate")
|
|
1456
|
+
|
|
1457
|
+
keys = set()
|
|
1458
|
+
for v in valid:
|
|
1459
|
+
keys.update(v.keys())
|
|
1460
|
+
|
|
1461
|
+
if self._mode == AggregationMode.SINGLE:
|
|
1462
|
+
return valid[0]
|
|
1463
|
+
|
|
1464
|
+
result = {}
|
|
1465
|
+
for key in keys:
|
|
1466
|
+
values = [v[key] for v in valid if key in v]
|
|
1467
|
+
if not values:
|
|
1468
|
+
continue
|
|
1469
|
+
|
|
1470
|
+
if self._mode == AggregationMode.AVERAGE:
|
|
1471
|
+
result[key] = sum(values) / len(values)
|
|
1472
|
+
elif self._mode == AggregationMode.MIN:
|
|
1473
|
+
result[key] = min(values)
|
|
1474
|
+
elif self._mode == AggregationMode.MAX:
|
|
1475
|
+
result[key] = max(values)
|
|
1476
|
+
elif self._mode == AggregationMode.MEDIAN:
|
|
1477
|
+
s = sorted(values)
|
|
1478
|
+
n = len(s)
|
|
1479
|
+
mid = n // 2
|
|
1480
|
+
result[key] = (s[mid - 1] + s[mid]) / 2 if n % 2 == 0 else s[mid]
|
|
1481
|
+
|
|
1482
|
+
return result
|
|
1483
|
+
|
|
780
1484
|
def judge(
|
|
781
1485
|
self,
|
|
782
1486
|
input: Any = None,
|
|
@@ -784,23 +1488,25 @@ Output only the number. No explanation. No extra text.""",
|
|
|
784
1488
|
ground_truth: Any = None,
|
|
785
1489
|
prompt: Optional[str] = None,
|
|
786
1490
|
max_tokens: int = 10000,
|
|
787
|
-
):
|
|
788
|
-
|
|
1491
|
+
) -> Dict[str, Any]:
|
|
1492
|
+
"""
|
|
1493
|
+
Run the judge evaluation.
|
|
1494
|
+
|
|
1495
|
+
Returns dict with:
|
|
1496
|
+
- 'correct': bool or None (bool for boolean type, None for scalar/map)
|
|
1497
|
+
- 'scores': float or Dict[str, float] or None (for scalar/map, None for boolean)
|
|
1498
|
+
- 'mode': aggregation mode string
|
|
1499
|
+
- 'votes': list of individual judge votes
|
|
1500
|
+
"""
|
|
789
1501
|
if self.use_fully_custom_prompt:
|
|
790
1502
|
if prompt is None:
|
|
791
|
-
raise ValueError(
|
|
792
|
-
"When use_fully_custom_prompt=True, you must pass prompt to judge()."
|
|
793
|
-
)
|
|
1503
|
+
raise ValueError("When use_fully_custom_prompt=True, you must pass prompt to judge().")
|
|
794
1504
|
if input is not None or model_output is not None or ground_truth is not None:
|
|
795
1505
|
raise ValueError(
|
|
796
|
-
"When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth
|
|
797
|
-
"Only pass the complete prompt."
|
|
1506
|
+
"When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth."
|
|
798
1507
|
)
|
|
799
1508
|
elif prompt is not None:
|
|
800
|
-
raise ValueError(
|
|
801
|
-
"prompt parameter can only be used when use_fully_custom_prompt=True. "
|
|
802
|
-
"Otherwise, use input/model_output/ground_truth."
|
|
803
|
-
)
|
|
1509
|
+
raise ValueError("prompt parameter can only be used when use_fully_custom_prompt=True.")
|
|
804
1510
|
else:
|
|
805
1511
|
prompt = self._build_prompt(input, model_output, ground_truth)
|
|
806
1512
|
|
|
@@ -808,18 +1514,27 @@ Output only the number. No explanation. No extra text.""",
|
|
|
808
1514
|
|
|
809
1515
|
# Vote with litellm models
|
|
810
1516
|
for m in self.models:
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
1517
|
+
try:
|
|
1518
|
+
res = self._ask_model(m, prompt, max_tokens, model_output, ground_truth)
|
|
1519
|
+
if self.verbose:
|
|
1520
|
+
print(f"Model {m} voted: {res}", flush=True)
|
|
1521
|
+
votes.append({"model": m, "result": res})
|
|
1522
|
+
except Exception as e:
|
|
1523
|
+
if self.verbose:
|
|
1524
|
+
print(f"Model {m} failed: {e}", flush=True)
|
|
1525
|
+
if self.fallback_comparison:
|
|
1526
|
+
res = str(model_output).strip() == str(ground_truth).strip()
|
|
1527
|
+
votes.append({"model": m, "result": res, "error": str(e)})
|
|
1528
|
+
else:
|
|
1529
|
+
raise
|
|
815
1530
|
|
|
816
1531
|
# Vote with custom generation functions
|
|
817
1532
|
for idx, custom_fn in enumerate(self.custom_generation_fns):
|
|
1533
|
+
fn_name = f"custom_fn_{idx}"
|
|
818
1534
|
try:
|
|
819
1535
|
content = custom_fn(prompt)
|
|
820
1536
|
parsed = self.output_parser(content)
|
|
821
1537
|
|
|
822
|
-
# If parser returns None and fallback is enabled, do string comparison
|
|
823
1538
|
if parsed is None and self.fallback_comparison:
|
|
824
1539
|
res = str(model_output).strip() == str(ground_truth).strip()
|
|
825
1540
|
else:
|
|
@@ -827,27 +1542,70 @@ Output only the number. No explanation. No extra text.""",
|
|
|
827
1542
|
|
|
828
1543
|
if self.verbose:
|
|
829
1544
|
print(f"Custom function {idx} voted: {res}", flush=True)
|
|
830
|
-
votes.append({"model":
|
|
1545
|
+
votes.append({"model": fn_name, "result": res})
|
|
831
1546
|
except Exception as e:
|
|
832
1547
|
if self.verbose:
|
|
833
1548
|
print(f"Custom function {idx} failed: {e}", flush=True)
|
|
834
|
-
# If custom function fails and fallback is enabled, do string comparison
|
|
835
1549
|
if self.fallback_comparison:
|
|
836
1550
|
res = str(model_output).strip() == str(ground_truth).strip()
|
|
837
|
-
votes.append({"model":
|
|
1551
|
+
votes.append({"model": fn_name, "result": res, "error": str(e)})
|
|
838
1552
|
else:
|
|
839
1553
|
raise
|
|
840
1554
|
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
false_votes = len(votes) - true_votes
|
|
846
|
-
final = True if true_votes >= false_votes else False
|
|
847
|
-
elif self.mode == "all":
|
|
848
|
-
final = all(v["correct"] for v in votes)
|
|
1555
|
+
# Determine return type
|
|
1556
|
+
return_type = self._explicit_return_type
|
|
1557
|
+
if return_type is None:
|
|
1558
|
+
return_type = _validate_consistent_types(votes)
|
|
849
1559
|
else:
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
1560
|
+
actual = _validate_consistent_types(votes)
|
|
1561
|
+
if actual != return_type:
|
|
1562
|
+
raise ValueError(f"Expected return type '{return_type.value}' but got '{actual.value}'")
|
|
1563
|
+
|
|
1564
|
+
# Auto-correct mode if needed
|
|
1565
|
+
if self._mode not in VALID_MODES[return_type]:
|
|
1566
|
+
self._mode = DEFAULT_MODES[return_type]
|
|
1567
|
+
self.mode = self._mode.value
|
|
1568
|
+
|
|
1569
|
+
# Validate map keys
|
|
1570
|
+
if return_type == ReturnType.MAP:
|
|
1571
|
+
_validate_consistent_map_keys(votes)
|
|
1572
|
+
|
|
1573
|
+
# Aggregate
|
|
1574
|
+
if return_type == ReturnType.BOOLEAN:
|
|
1575
|
+
final = self._aggregate_boolean(votes)
|
|
1576
|
+
elif return_type == ReturnType.SCALAR:
|
|
1577
|
+
final = self._aggregate_scalar(votes)
|
|
1578
|
+
elif return_type == ReturnType.MAP:
|
|
1579
|
+
final = self._aggregate_map(votes)
|
|
1580
|
+
else:
|
|
1581
|
+
raise ValueError(f"Unknown return type: {return_type}")
|
|
1582
|
+
|
|
1583
|
+
# Build backward-compatible response
|
|
1584
|
+
# Boolean: correct=bool, scores=None
|
|
1585
|
+
# Scalar: correct=score, scores=score (both fields for convenience)
|
|
1586
|
+
# Map: correct=None, scores=map
|
|
1587
|
+
if return_type == ReturnType.BOOLEAN:
|
|
1588
|
+
# Also put "correct" in each vote for backward compat
|
|
1589
|
+
for v in votes:
|
|
1590
|
+
v["correct"] = v["result"]
|
|
1591
|
+
return {
|
|
1592
|
+
"correct": final,
|
|
1593
|
+
"scores": None,
|
|
1594
|
+
"mode": self.mode,
|
|
1595
|
+
"votes": votes,
|
|
1596
|
+
}
|
|
1597
|
+
elif return_type == ReturnType.SCALAR:
|
|
1598
|
+
# For scalar, put score in both correct and scores for convenience
|
|
1599
|
+
return {
|
|
1600
|
+
"correct": final,
|
|
1601
|
+
"scores": final,
|
|
1602
|
+
"mode": self.mode,
|
|
1603
|
+
"votes": votes,
|
|
1604
|
+
}
|
|
1605
|
+
else: # MAP
|
|
1606
|
+
return {
|
|
1607
|
+
"correct": None,
|
|
1608
|
+
"scores": final,
|
|
1609
|
+
"mode": self.mode,
|
|
1610
|
+
"votes": votes,
|
|
1611
|
+
}
|