llmasajudge 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llmasajudge/__init__.py CHANGED
@@ -201,141 +201,726 @@
201
201
 
202
202
 
203
203
 
204
- import os
205
- import time
206
- import random
207
- import re
208
- from typing import Any, Callable, Dict, List, Optional, Tuple
209
- import litellm
210
- from litellm import completion
211
- from litellm.caching.caching import Cache
204
+ # import os
205
+ # import time
206
+ # import random
207
+ # import re
208
+ # from typing import Any, Callable, Dict, List, Optional, Tuple
209
+ # import litellm
210
+ # from litellm import completion
211
+ # from litellm.caching.caching import Cache
212
+
213
+
214
+ # __all__ = ["LLMAsAJudge", "OutputParsers"]
215
+
216
+
217
+ # class UnlimitedDiskCache:
218
+ # """
219
+ # Drop-in replacement backend with 'unlimited' size for LiteLLM cache.
220
+
221
+ # This wraps diskcache.Cache with a very large size limit (2^62 bytes ~ 4.6 exabytes)
222
+ # to effectively disable automatic cache eviction, allowing the cache to grow
223
+ # without size constraints.
224
+ # """
225
+
226
+ # def __init__(self, directory, size_limit=None):
227
+ # """
228
+ # Initialize unlimited disk cache.
229
+
230
+ # Args:
231
+ # directory: Path to cache directory
232
+ # size_limit: Optional size limit in bytes. If None, uses 2^62 bytes (~4.6 exabytes)
233
+ # """
234
+ # import diskcache as dc
235
+
236
+ # # Set to very large cap so culling never triggers (effectively unlimited)
237
+ # cap = size_limit if size_limit is not None else (1 << 62)
238
+ # self._dc = dc.Cache(directory, size_limit=cap)
239
+
240
+ # # Sync API used by LiteLLM
241
+ # def get_cache(self, key, **kwargs):
242
+ # """Get value from cache by key."""
243
+ # return self._dc.get(key)
244
+
245
+ # def set_cache(self, key, value, ttl=None, **kwargs):
246
+ # """Set value in cache with optional TTL."""
247
+ # expire = None if ttl is None else float(ttl)
248
+ # self._dc.set(key, value, expire=expire)
249
+
250
+ # # Async API used by LiteLLM
251
+ # async def async_get_cache(self, key, **kwargs):
252
+ # """Async get value from cache by key."""
253
+ # return self.get_cache(key, **kwargs)
254
+
255
+ # async def async_set_cache(self, key, value, ttl=None, **kwargs):
256
+ # """Async set value in cache with optional TTL."""
257
+ # return self.set_cache(key, value, ttl=ttl, **kwargs)
258
+
259
+ # async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
260
+ # """
261
+ # Async batch set multiple cache entries.
262
+
263
+ # Args:
264
+ # cache_list: List of (key, value) tuples
265
+ # ttl: Optional time-to-live in seconds
266
+ # """
267
+ # for k, v in cache_list:
268
+ # self.set_cache(k, v, ttl=ttl)
269
+
270
+ # async def batch_cache_write(self, key, value, ttl=None, **kwargs):
271
+ # """Async batch write (single entry)."""
272
+ # self.set_cache(key, value, ttl=ttl)
273
+
274
+ # async def ping(self):
275
+ # """Async ping check."""
276
+ # return True
277
+
278
+ # async def delete_cache_keys(self, keys):
279
+ # """
280
+ # Async delete multiple cache keys.
281
+
282
+ # Args:
283
+ # keys: List of keys to delete
284
+ # """
285
+ # for k in keys:
286
+ # try:
287
+ # del self._dc[k]
288
+ # except KeyError:
289
+ # pass
290
+ # return True
291
+
292
+ # async def disconnect(self):
293
+ # """Async disconnect and close cache."""
294
+ # self._dc.close()
295
+
296
+ # def get_stats(self):
297
+ # """
298
+ # Get cache statistics.
299
+
300
+ # Returns:
301
+ # dict with size_limit, current_size, item_count, and percent_full
302
+ # """
303
+ # size_limit = self._dc.size_limit
304
+ # volume = self._dc.volume() # Current size in bytes
305
+ # count = len(self._dc) # Number of items
306
+
307
+ # return {
308
+ # "size_limit": size_limit,
309
+ # "current_size": volume,
310
+ # "item_count": count,
311
+ # "percent_full": (volume / size_limit) * 100 if size_limit > 0 else 0.0,
312
+ # }
313
+
314
+ # def print_stats(self):
315
+ # """Print human-readable cache statistics."""
316
+ # stats = self.get_stats()
317
+
318
+ # def human_size(bytes_val):
319
+ # """Convert bytes to human readable format."""
320
+ # for unit in ["B", "KB", "MB", "GB", "TB", "PB", "EB"]:
321
+ # if bytes_val < 1024.0:
322
+ # return f"{bytes_val:.2f} {unit}"
323
+ # bytes_val /= 1024.0
324
+ # return f"{bytes_val:.2f} EB"
325
+
326
+ # print("=" * 60)
327
+ # print("CACHE STATISTICS")
328
+ # print("=" * 60)
329
+ # print(f" Size limit: {human_size(stats['size_limit'])} ({stats['size_limit']:,} bytes)")
330
+ # print(f" Current size: {human_size(stats['current_size'])} ({stats['current_size']:,} bytes)")
331
+ # print(f" Items cached: {stats['item_count']}")
332
+ # print(f" % full: {stats['percent_full']:.6f}%")
333
+ # print("=" * 60)
334
+
335
+
336
+ # class OutputParsers:
337
+ # """Stock output parsers for common judge output formats."""
212
338
 
339
+ # @staticmethod
340
+ # def right_wrong(s: str) -> Optional[bool]:
341
+ # """Parse 'right' or 'wrong' from the last 6 characters."""
342
+ # if not s:
343
+ # return None
344
+ # tail = s.strip()[-6:].lower()
345
+ # if "right" in tail:
346
+ # return True
347
+ # if "wrong" in tail:
348
+ # return False
349
+ # return None
213
350
 
214
- __all__ = ["LLMAsAJudge", "OutputParsers"]
351
+ # @staticmethod
352
+ # def pass_fail(s: str) -> Optional[bool]:
353
+ # """Parse 'pass' or 'fail' from the response."""
354
+ # if not s:
355
+ # return None
356
+ # text = s.strip().lower()
357
+ # if "pass" in text:
358
+ # return True
359
+ # if "fail" in text:
360
+ # return False
361
+ # return None
215
362
 
363
+ # @staticmethod
364
+ # def yes_no(s: str) -> Optional[bool]:
365
+ # """Parse 'yes' or 'no' from the response."""
366
+ # if not s:
367
+ # return None
368
+ # text = s.strip().lower()
369
+ # if "yes" in text:
370
+ # return True
371
+ # if "no" in text:
372
+ # return False
373
+ # return None
216
374
 
217
- class UnlimitedDiskCache:
218
- """
219
- Drop-in replacement backend with 'unlimited' size for LiteLLM cache.
375
+ # @staticmethod
376
+ # def numeric_score(s: str) -> Optional[float]:
377
+ # """Extract first numeric value from the response."""
378
+ # if not s:
379
+ # return None
380
+ # match = re.search(r'[-+]?\d*\.?\d+', s.strip())
381
+ # if match:
382
+ # return float(match.group())
383
+ # return None
220
384
 
221
- This wraps diskcache.Cache with a very large size limit (2^62 bytes ~ 4.6 exabytes)
222
- to effectively disable automatic cache eviction, allowing the cache to grow
223
- without size constraints.
224
- """
385
+ # @staticmethod
386
+ # def json_extract(key: str) -> Callable[[str], Any]:
387
+ # """Create a parser that extracts a specific key from JSON output."""
388
+ # import json
389
+ # def parser(s: str) -> Any:
390
+ # if not s:
391
+ # return None
392
+ # try:
393
+ # data = json.loads(s.strip())
394
+ # return data.get(key)
395
+ # except (json.JSONDecodeError, AttributeError):
396
+ # return None
397
+ # return parser
398
+
399
+
400
+ # class LLMAsAJudge:
401
+ # BASE_TEMPLATE = """\
402
+ # You are a judge. Read input, model_output, and ground_truth.
403
+ # {instruction}
404
+ # ##################
405
+ # {notes_section}### input:
406
+ # {input_block}
407
+ # ##################
408
+ # model's output:
409
+ # {model_output}
410
+ # ##################
411
+ # ground_truth answer:
412
+ # {ground_truth}
413
+ # ##################
414
+ # """
225
415
 
226
- def __init__(self, directory, size_limit=None):
227
- """
228
- Initialize unlimited disk cache.
416
+ # PARSER_INSTRUCTIONS = {
417
+ # 'right/wrong': """\
418
+ # Return exactly one word: right or wrong.
419
+ # Rules:
420
+ # - Treat extra words or punctuation as irrelevant if the same final value is present.
421
+ # - Output must be exactly right or wrong. No JSON. No quotes. No extra text.""",
422
+ # 'yes/no': """\
423
+ # Return exactly one word: yes or no.
424
+ # Answer yes if the model output matches the ground truth, no otherwise.
425
+ # Rules:
426
+ # - Treat extra words or punctuation as irrelevant if the same final value is present.
427
+ # - Output must be exactly yes or no. No JSON. No quotes. No extra text.""",
428
+ # 'pass/fail': """\
429
+ # Return exactly one word: pass or fail.
430
+ # Answer pass if the model output matches the ground truth, fail otherwise.
431
+ # Rules:
432
+ # - Treat extra words or punctuation as irrelevant if the same final value is present.
433
+ # - Output must be exactly pass or fail. No JSON. No quotes. No extra text.""",
434
+ # 'numeric': """\
435
+ # Return a single numeric score from 0-10 indicating how well the model output matches the ground truth.
436
+ # - 10 = perfect match
437
+ # - 7-9 = close match with minor differences
438
+ # - 4-6 = partial match
439
+ # - 1-3 = poor match
440
+ # - 0 = completely wrong
441
+ # Output only the number. No explanation. No extra text.""",
442
+ # }
443
+
444
+
445
+
446
+
447
+ # # def __init__(
448
+ # # self,
449
+ # # models: Optional[List[str]] = None,
450
+ # # config: Optional[Dict[str, Dict[str, Any]]] = None, # one dict for providers and models
451
+ # # base_headers: Optional[Dict[str, str]] = None,
452
+ # # wandb_project: Optional[str] = None,
453
+ # # custom_template: Optional[str] = None,
454
+ # # use_fully_custom_prompt: bool = False,
455
+ # # notes: Optional[str] = None,
456
+ # # output_parser: Optional[str] = 'right/wrong',
457
+ # # fallback_comparison: bool = True,
458
+ # # default_temperature: float = 0.0,
459
+ # # verbose: bool = False,
460
+ # # num_retries: int = 2, # per-call retries before giving up on that model
461
+ # # backoff_base: float = 0.5, # seconds
462
+ # # backoff_max: float = 4.0, # seconds
463
+ # # custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
464
+ # # mode: str = "majority", # "single", "majority", "all"
465
+ # # ):
466
+ # # """
467
+ # # config keys can be a provider name ("wandb", "openai", "anthropic")
468
+ # # or a full model name ("openai/gpt-4o-mini", "wandb/deepseek-ai/DeepSeek-V3.1").
469
+
470
+ # # Values can include:
471
+ # # api_base: Optional[str]
472
+ # # headers: Dict[str, str]
473
+ # # temperature: float
474
+
475
+ # # Precedence:
476
+ # # base_headers < provider config < model config
477
+
478
+ # # Args:
479
+ # # models: List of litellm model strings (e.g., ["openai/gpt-4", "anthropic/claude-3"])
480
+ # # custom_template: Template with placeholders for input/output/ground_truth
481
+ # # use_fully_custom_prompt: If True, pass complete prompt to judge(prompt=...).
482
+ # # When True, input/output/ground_truth must NOT be passed to judge()
483
+ # # output_parser: Parser name ('right/wrong', 'yes/no', 'pass/fail', 'numeric')
484
+ # # or custom function with signature (str) -> Any
485
+ # # fallback_comparison: If True and parser returns None, falls back to string comparison
486
+ # # custom_generation_fns: List of custom inference functions with signature fn(prompt: str) -> str
487
+ # # These will be used in addition to litellm models for voting.
488
+ # # mode: Voting mode - "majority" (default), "single" (first judge only), or "all" (unanimous)
489
+ # # """
490
+ # # self.models = models or []
491
+ # # self.custom_generation_fns = custom_generation_fns or []
492
+
493
+ # # # Validate that at least one judge is provided
494
+ # # if not self.models and not self.custom_generation_fns:
495
+ # # raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
496
+
497
+ # # # Validate mode
498
+ # # if mode not in ("majority", "single", "all"):
499
+ # # raise ValueError("mode must be 'majority', 'single', or 'all'")
500
+
501
+ # # self.config = config or {}
502
+ # # self.base_headers = dict(base_headers or {})
503
+ # # self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
504
+ # # self.notes = notes or ""
505
+ # # self.use_fully_custom_prompt = use_fully_custom_prompt
506
+ # # self.mode = mode
507
+
508
+ # # # Resolve output parser
509
+ # # parser_name = None
510
+ # # if isinstance(output_parser, str):
511
+ # # parser_map = {
512
+ # # 'right/wrong': OutputParsers.right_wrong,
513
+ # # 'pass/fail': OutputParsers.pass_fail,
514
+ # # 'yes/no': OutputParsers.yes_no,
515
+ # # 'numeric': OutputParsers.numeric_score,
516
+ # # }
517
+ # # if output_parser not in parser_map:
518
+ # # raise ValueError(f"Unknown parser '{output_parser}'. Available: {list(parser_map.keys())}")
519
+ # # self.output_parser = parser_map[output_parser]
520
+ # # parser_name = output_parser
521
+ # # else:
522
+ # # self.output_parser = output_parser
523
+
524
+ # # # Set template based on mode
525
+ # # if use_fully_custom_prompt:
526
+ # # self.template = None # No template in fully custom mode
527
+ # # elif custom_template:
528
+ # # self.template = custom_template
529
+ # # elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
530
+ # # self.template = self.BASE_TEMPLATE.format(
531
+ # # instruction=self.PARSER_INSTRUCTIONS[parser_name],
532
+ # # notes_section="{notes_section}",
533
+ # # input_block="{input_block}",
534
+ # # model_output="{model_output}",
535
+ # # ground_truth="{ground_truth}",
536
+ # # )
537
+ # # else:
538
+ # # # Default to right/wrong for custom parsers
539
+ # # self.template = self.BASE_TEMPLATE.format(
540
+ # # instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
541
+ # # notes_section="{notes_section}",
542
+ # # input_block="{input_block}",
543
+ # # model_output="{model_output}",
544
+ # # ground_truth="{ground_truth}",
545
+ # # )
546
+
547
+ # # self.fallback_comparison = fallback_comparison
548
+ # # self.default_temperature = float(default_temperature)
549
+ # # self.verbose = verbose
550
+ # # self.num_retries = int(num_retries)
551
+ # # self.backoff_base = float(backoff_base)
552
+ # # self.backoff_max = float(backoff_max)
553
+
554
+
555
+
556
+
557
+
558
+
559
+ # def __init__(
560
+ # self,
561
+ # models: Optional[List[str]] = None,
562
+ # config: Optional[Dict[str, Dict[str, Any]]] = None,
563
+ # base_headers: Optional[Dict[str, str]] = None,
564
+ # wandb_project: Optional[str] = None,
565
+ # custom_template: Optional[str] = None,
566
+ # use_fully_custom_prompt: bool = False,
567
+ # notes: Optional[str] = None,
568
+ # output_parser: Optional[str] = 'right/wrong',
569
+ # fallback_comparison: bool = True,
570
+ # default_temperature: float = 0.0,
571
+ # verbose: bool = False,
572
+ # num_retries: int = 2,
573
+ # backoff_base: float = 0.5,
574
+ # backoff_max: float = 4.0,
575
+ # custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
576
+ # mode: str = "majority",
577
+ # use_cache: bool = False,
578
+ # litellm_cache_dir: Optional[str] = None,
579
+ # cache_size_gb: Optional[float] = None,
580
+ # ):
581
+ # self.models = models or []
582
+ # self.custom_generation_fns = custom_generation_fns or []
583
+
584
+ # if not self.models and not self.custom_generation_fns:
585
+ # raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
586
+
587
+ # if mode not in ("majority", "single", "all"):
588
+ # raise ValueError("mode must be 'majority', 'single', or 'all'")
589
+
590
+ # self.config = config or {}
591
+ # self.base_headers = dict(base_headers or {})
592
+ # self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
593
+ # self.notes = notes or ""
594
+ # self.use_fully_custom_prompt = use_fully_custom_prompt
595
+ # self.mode = mode
596
+ # self.fallback_comparison = fallback_comparison
597
+ # self.default_temperature = float(default_temperature)
598
+ # self.verbose = verbose
599
+ # self.num_retries = int(num_retries)
600
+ # self.backoff_base = float(backoff_base)
601
+ # self.backoff_max = float(backoff_max)
602
+
603
+ # parser_name = None
604
+ # if isinstance(output_parser, str):
605
+ # parser_map = {
606
+ # 'right/wrong': OutputParsers.right_wrong,
607
+ # 'pass/fail': OutputParsers.pass_fail,
608
+ # 'yes/no': OutputParsers.yes_no,
609
+ # 'numeric': OutputParsers.numeric_score,
610
+ # }
611
+ # if output_parser not in parser_map:
612
+ # raise ValueError(f"Unknown parser '{output_parser}'")
613
+ # self.output_parser = parser_map[output_parser]
614
+ # parser_name = output_parser
615
+ # else:
616
+ # self.output_parser = output_parser
617
+
618
+ # if use_fully_custom_prompt:
619
+ # self.template = None
620
+ # elif custom_template:
621
+ # self.template = custom_template
622
+ # elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
623
+ # self.template = self.BASE_TEMPLATE.format(
624
+ # instruction=self.PARSER_INSTRUCTIONS[parser_name],
625
+ # notes_section="{notes_section}",
626
+ # input_block="{input_block}",
627
+ # model_output="{model_output}",
628
+ # ground_truth="{ground_truth}",
629
+ # )
630
+ # else:
631
+ # self.template = self.BASE_TEMPLATE.format(
632
+ # instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
633
+ # notes_section="{notes_section}",
634
+ # input_block="{input_block}",
635
+ # model_output="{model_output}",
636
+ # ground_truth="{ground_truth}",
637
+ # )
229
638
 
230
- Args:
231
- directory: Path to cache directory
232
- size_limit: Optional size limit in bytes. If None, uses 2^62 bytes (~4.6 exabytes)
233
- """
234
- import diskcache as dc
639
+ # # optional local cache setup
640
+ # # Enable cache if use_cache=True OR if litellm_cache_dir is explicitly provided (backward compatible)
641
+ # self.cache_enabled = use_cache or (litellm_cache_dir is not None)
642
+ # if self.cache_enabled:
643
+ # # Only set up cache if it hasn't been set up already
644
+ # if litellm.cache is None:
645
+ # # Set default cache directory if not specified
646
+ # if litellm_cache_dir is None:
647
+ # litellm_cache_dir = ".litellm_cache"
235
648
 
236
- # Set to very large cap so culling never triggers (effectively unlimited)
237
- cap = size_limit if size_limit is not None else (1 << 62)
238
- self._dc = dc.Cache(directory, size_limit=cap)
649
+ # # Convert GB to bytes if specified, otherwise unlimited
650
+ # size_limit_bytes = None if cache_size_gb is None else int(cache_size_gb * 1024 * 1024 * 1024)
651
+ # cache_backend = UnlimitedDiskCache(litellm_cache_dir, size_limit=size_limit_bytes)
652
+ # litellm.cache = Cache(disk_cache_dir=litellm_cache_dir)
653
+ # litellm.cache.cache = cache_backend
239
654
 
240
- # Sync API used by LiteLLM
241
- def get_cache(self, key, **kwargs):
242
- """Get value from cache by key."""
243
- return self._dc.get(key)
244
655
 
245
- def set_cache(self, key, value, ttl=None, **kwargs):
246
- """Set value in cache with optional TTL."""
247
- expire = None if ttl is None else float(ttl)
248
- self._dc.set(key, value, expire=expire)
249
656
 
250
- # Async API used by LiteLLM
251
- async def async_get_cache(self, key, **kwargs):
252
- """Async get value from cache by key."""
253
- return self.get_cache(key, **kwargs)
254
657
 
255
- async def async_set_cache(self, key, value, ttl=None, **kwargs):
256
- """Async set value in cache with optional TTL."""
257
- return self.set_cache(key, value, ttl=ttl, **kwargs)
258
658
 
259
- async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
260
- """
261
- Async batch set multiple cache entries.
262
659
 
263
- Args:
264
- cache_list: List of (key, value) tuples
265
- ttl: Optional time-to-live in seconds
266
- """
267
- for k, v in cache_list:
268
- self.set_cache(k, v, ttl=ttl)
269
660
 
270
- async def batch_cache_write(self, key, value, ttl=None, **kwargs):
271
- """Async batch write (single entry)."""
272
- self.set_cache(key, value, ttl=ttl)
273
661
 
274
- async def ping(self):
275
- """Async ping check."""
276
- return True
277
662
 
278
- async def delete_cache_keys(self, keys):
279
- """
280
- Async delete multiple cache keys.
663
+ # def _build_prompt(self, input: Any, model_output: Any, ground_truth: Any) -> str:
664
+ # notes_section = f"notes:\n{self.notes}\n" if self.notes else ""
665
+ # input_text = str(input) if input not in (None, "") else "[omitted input for brevity]"
666
+ # return self.template.format(
667
+ # notes_section=notes_section,
668
+ # input_block=input_text,
669
+ # model_output=str(model_output),
670
+ # ground_truth=str(ground_truth),
671
+ # )
281
672
 
282
- Args:
283
- keys: List of keys to delete
284
- """
285
- for k in keys:
286
- try:
287
- del self._dc[k]
288
- except KeyError:
289
- pass
290
- return True
673
+ # @staticmethod
674
+ # def _last6_right_wrong(s: str):
675
+ # if not s:
676
+ # return None
677
+ # tail = s.strip()[-6:].lower()
678
+ # if "right" in tail:
679
+ # return True
680
+ # if "wrong" in tail:
681
+ # return False
682
+ # return None
291
683
 
292
- async def disconnect(self):
293
- """Async disconnect and close cache."""
294
- self._dc.close()
684
+ # def _resolve_per_model(self, model: str) -> Tuple[Optional[str], Dict[str, str], float]:
685
+ # provider = model.split("/", 1)[0] if "/" in model else model
686
+
687
+ # api_base: Optional[str] = None
688
+ # headers: Dict[str, str] = dict(self.base_headers)
689
+ # temperature: float = self.default_temperature
690
+
691
+ # # provider-level defaults
692
+ # pc = self.config.get(provider, {})
693
+ # if pc.get("api_base") is not None:
694
+ # api_base = pc["api_base"]
695
+ # headers.update(pc.get("headers", {}))
696
+ # if "temperature" in pc:
697
+ # temperature = float(pc["temperature"])
698
+
699
+ # # model-level overrides
700
+ # mc = self.config.get(model, {})
701
+ # if mc.get("api_base") is not None:
702
+ # api_base = mc["api_base"]
703
+ # headers.update(mc.get("headers", {}))
704
+ # if "temperature" in mc:
705
+ # temperature = float(mc["temperature"])
706
+
707
+ # # wandb defaults
708
+ # if provider == "wandb":
709
+ # if api_base is None:
710
+ # api_base = "https://api.inference.wandb.ai/v1"
711
+ # if "OpenAI-Project" not in headers:
712
+ # headers["OpenAI-Project"] = self.wandb_project or "wandb_fc/quickstart_playground"
713
+
714
+ # return api_base, headers, temperature
715
+
716
+ # def _attempt_completion(
717
+ # self,
718
+ # model: str,
719
+ # api_base: Optional[str],
720
+ # headers: Dict[str, str],
721
+ # prompt: str,
722
+ # temperature: float,
723
+ # max_tokens: int,
724
+ # ) -> str:
725
+ # attempts = self.num_retries + 1
726
+ # last_err = None
727
+ # for i in range(attempts):
728
+ # try:
729
+ # # resp = completion(
730
+ # # model=model,
731
+ # # api_base=api_base, # None uses provider default
732
+ # # messages=[{"role": "user", "content": prompt}],
733
+ # # temperature=temperature,
734
+ # # max_tokens=max_tokens,
735
+ # # extra_headers=headers,
736
+ # # )
737
+
738
+ # resp = completion(
739
+ # model=model,
740
+ # api_base=api_base,
741
+ # messages=[{"role": "user", "content": prompt}],
742
+ # temperature=temperature,
743
+ # max_tokens=max_tokens,
744
+ # extra_headers=headers,
745
+ # caching=self.cache_enabled
746
+ # )
747
+ # return (resp.choices[0].message.content or "").strip()
748
+ # except Exception as e:
749
+ # last_err = e
750
+ # if i == attempts - 1:
751
+ # break
752
+ # sleep_s = min(self.backoff_max, self.backoff_base * (2 ** i))
753
+ # jitter = sleep_s * (0.1 * (2 * random.random() - 1.0)) # ±10%
754
+ # if self.verbose:
755
+ # print(f"[retry {i+1}/{attempts-1}] {model} error: {e} — sleeping {max(0.0, sleep_s + jitter):.2f}s", flush=True)
756
+ # time.sleep(max(0.0, sleep_s + jitter))
757
+ # raise last_err # fail after retries
758
+
759
+ # def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any):
760
+ # api_base, headers, temperature = self._resolve_per_model(model)
761
+
762
+ # content = self._attempt_completion(
763
+ # model=model,
764
+ # api_base=api_base,
765
+ # headers=headers,
766
+ # prompt=prompt,
767
+ # temperature=temperature,
768
+ # max_tokens=max_tokens,
769
+ # )
295
770
 
296
- def get_stats(self):
297
- """
298
- Get cache statistics.
771
+ # # Use the instance parser
772
+ # parsed = self.output_parser(content)
773
+
774
+ # # If parser returns None and fallback is enabled, do string comparison
775
+ # if parsed is None and self.fallback_comparison:
776
+ # return str(model_output).strip() == str(ground_truth).strip()
777
+
778
+ # return parsed
779
+
780
+ # def judge(
781
+ # self,
782
+ # input: Any = None,
783
+ # model_output: Any = None,
784
+ # ground_truth: Any = None,
785
+ # prompt: Optional[str] = None,
786
+ # max_tokens: int = 10000,
787
+ # ):
788
+ # # Validation for fully_custom_prompt mode
789
+ # if self.use_fully_custom_prompt:
790
+ # if prompt is None:
791
+ # raise ValueError(
792
+ # "When use_fully_custom_prompt=True, you must pass prompt to judge()."
793
+ # )
794
+ # if input is not None or model_output is not None or ground_truth is not None:
795
+ # raise ValueError(
796
+ # "When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth to judge(). "
797
+ # "Only pass the complete prompt."
798
+ # )
799
+ # elif prompt is not None:
800
+ # raise ValueError(
801
+ # "prompt parameter can only be used when use_fully_custom_prompt=True. "
802
+ # "Otherwise, use input/model_output/ground_truth."
803
+ # )
804
+ # else:
805
+ # prompt = self._build_prompt(input, model_output, ground_truth)
806
+
807
+ # votes = []
808
+
809
+ # # Vote with litellm models
810
+ # for m in self.models:
811
+ # res = self._ask_model(m, prompt, max_tokens, model_output, ground_truth)
812
+ # if self.verbose:
813
+ # print(f"Model {m} voted: {res}", flush=True)
814
+ # votes.append({"model": m, "correct": res})
815
+
816
+ # # Vote with custom generation functions
817
+ # for idx, custom_fn in enumerate(self.custom_generation_fns):
818
+ # try:
819
+ # content = custom_fn(prompt)
820
+ # parsed = self.output_parser(content)
821
+
822
+ # # If parser returns None and fallback is enabled, do string comparison
823
+ # if parsed is None and self.fallback_comparison:
824
+ # res = str(model_output).strip() == str(ground_truth).strip()
825
+ # else:
826
+ # res = parsed
827
+
828
+ # if self.verbose:
829
+ # print(f"Custom function {idx} voted: {res}", flush=True)
830
+ # votes.append({"model": f"custom_fn_{idx}", "correct": res})
831
+ # except Exception as e:
832
+ # if self.verbose:
833
+ # print(f"Custom function {idx} failed: {e}", flush=True)
834
+ # # If custom function fails and fallback is enabled, do string comparison
835
+ # if self.fallback_comparison:
836
+ # res = str(model_output).strip() == str(ground_truth).strip()
837
+ # votes.append({"model": f"custom_fn_{idx}", "correct": res})
838
+ # else:
839
+ # raise
840
+
841
+ # if self.mode == "single":
842
+ # final = votes[0]["correct"]
843
+ # elif self.mode == "majority":
844
+ # true_votes = sum(v["correct"] for v in votes)
845
+ # false_votes = len(votes) - true_votes
846
+ # final = True if true_votes >= false_votes else False
847
+ # elif self.mode == "all":
848
+ # final = all(v["correct"] for v in votes)
849
+ # else:
850
+ # raise ValueError("mode must be 'majority', 'single', or 'all'")
851
+
852
+ # return {"correct": final, "mode": self.mode, "votes": votes}
299
853
 
300
- Returns:
301
- dict with size_limit, current_size, item_count, and percent_full
302
- """
303
- size_limit = self._dc.size_limit
304
- volume = self._dc.volume() # Current size in bytes
305
- count = len(self._dc) # Number of items
306
-
307
- return {
308
- "size_limit": size_limit,
309
- "current_size": volume,
310
- "item_count": count,
311
- "percent_full": (volume / size_limit) * 100 if size_limit > 0 else 0.0,
312
- }
313
-
314
- def print_stats(self):
315
- """Print human-readable cache statistics."""
316
- stats = self.get_stats()
317
-
318
- def human_size(bytes_val):
319
- """Convert bytes to human readable format."""
320
- for unit in ["B", "KB", "MB", "GB", "TB", "PB", "EB"]:
321
- if bytes_val < 1024.0:
322
- return f"{bytes_val:.2f} {unit}"
323
- bytes_val /= 1024.0
324
- return f"{bytes_val:.2f} EB"
325
-
326
- print("=" * 60)
327
- print("CACHE STATISTICS")
328
- print("=" * 60)
329
- print(f" Size limit: {human_size(stats['size_limit'])} ({stats['size_limit']:,} bytes)")
330
- print(f" Current size: {human_size(stats['current_size'])} ({stats['current_size']:,} bytes)")
331
- print(f" Items cached: {stats['item_count']}")
332
- print(f" % full: {stats['percent_full']:.6f}%")
333
- print("=" * 60)
854
+
855
+
856
+
857
+
858
+ import os
859
+ import time
860
+ import random
861
+ import re
862
+ import json
863
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
864
+ from enum import Enum
865
+ import litellm
866
+ from litellm import completion
867
+ from litellm.caching.caching import Cache
868
+
869
+
870
+ __all__ = ["LLMAsAJudge", "OutputParsers"]
871
+
872
+
873
+ class ReturnType(Enum):
874
+ """Enum for categorizing parser return types."""
875
+ BOOLEAN = "boolean"
876
+ SCALAR = "scalar"
877
+ MAP = "map"
878
+
879
+
880
+ class AggregationMode(Enum):
881
+ """Enum for aggregation modes across multiple voters."""
882
+ # Boolean modes
883
+ MAJORITY = "majority"
884
+ SINGLE = "single"
885
+ ALL = "all"
886
+ # Scalar modes
887
+ AVERAGE = "average"
888
+ MIN = "min"
889
+ MAX = "max"
890
+ MEDIAN = "median"
891
+
892
+
893
+ # Valid aggregation modes per return type
894
+ VALID_MODES = {
895
+ ReturnType.BOOLEAN: {AggregationMode.MAJORITY, AggregationMode.SINGLE, AggregationMode.ALL},
896
+ ReturnType.SCALAR: {AggregationMode.AVERAGE, AggregationMode.MIN, AggregationMode.MAX, AggregationMode.MEDIAN, AggregationMode.SINGLE},
897
+ ReturnType.MAP: {AggregationMode.AVERAGE, AggregationMode.MIN, AggregationMode.MAX, AggregationMode.MEDIAN, AggregationMode.SINGLE},
898
+ }
899
+
900
+ # Default aggregation modes per return type
901
+ DEFAULT_MODES = {
902
+ ReturnType.BOOLEAN: AggregationMode.MAJORITY,
903
+ ReturnType.SCALAR: AggregationMode.AVERAGE,
904
+ ReturnType.MAP: AggregationMode.AVERAGE,
905
+ }
906
+
907
+ # String to enum mapping (for backward compat)
908
+ MODE_STR_MAP = {
909
+ 'majority': AggregationMode.MAJORITY,
910
+ 'single': AggregationMode.SINGLE,
911
+ 'all': AggregationMode.ALL,
912
+ 'average': AggregationMode.AVERAGE,
913
+ 'min': AggregationMode.MIN,
914
+ 'max': AggregationMode.MAX,
915
+ 'median': AggregationMode.MEDIAN,
916
+ }
334
917
 
335
918
 
336
919
  class OutputParsers:
337
920
  """Stock output parsers for common judge output formats."""
338
921
 
922
+ # ========== BOOLEAN PARSERS ==========
923
+
339
924
  @staticmethod
340
925
  def right_wrong(s: str) -> Optional[bool]:
341
926
  """Parse 'right' or 'wrong' from the last 6 characters."""
@@ -372,6 +957,8 @@ class OutputParsers:
372
957
  return False
373
958
  return None
374
959
 
960
+ # ========== SCALAR PARSERS ==========
961
+
375
962
  @staticmethod
376
963
  def numeric_score(s: str) -> Optional[float]:
377
964
  """Extract first numeric value from the response."""
@@ -382,20 +969,188 @@ class OutputParsers:
382
969
  return float(match.group())
383
970
  return None
384
971
 
972
+ @staticmethod
973
+ def score_0_to_10(s: str) -> Optional[float]:
974
+ """Extract a score from 0-10, clamping to valid range."""
975
+ if not s:
976
+ return None
977
+ match = re.search(r'[-+]?\d*\.?\d+', s.strip())
978
+ if match:
979
+ return max(0.0, min(10.0, float(match.group())))
980
+ return None
981
+
982
+ @staticmethod
983
+ def score_0_to_100(s: str) -> Optional[float]:
984
+ """Extract a score from 0-100, clamping to valid range."""
985
+ if not s:
986
+ return None
987
+ match = re.search(r'[-+]?\d*\.?\d+', s.strip())
988
+ if match:
989
+ return max(0.0, min(100.0, float(match.group())))
990
+ return None
991
+
992
+ @staticmethod
993
+ def score_1_to_5(s: str) -> Optional[float]:
994
+ """Extract a score from 1-5, clamping to valid range."""
995
+ if not s:
996
+ return None
997
+ match = re.search(r'[-+]?\d*\.?\d+', s.strip())
998
+ if match:
999
+ return max(1.0, min(5.0, float(match.group())))
1000
+ return None
1001
+
1002
+ # ========== MAP PARSERS ==========
1003
+
385
1004
  @staticmethod
386
1005
  def json_extract(key: str) -> Callable[[str], Any]:
387
1006
  """Create a parser that extracts a specific key from JSON output."""
388
- import json
389
1007
  def parser(s: str) -> Any:
390
1008
  if not s:
391
1009
  return None
392
1010
  try:
393
- data = json.loads(s.strip())
1011
+ s = s.strip()
1012
+ if "```json" in s:
1013
+ start = s.find("```json") + 7
1014
+ end = s.find("```", start)
1015
+ s = s[start:end].strip()
1016
+ elif "```" in s:
1017
+ start = s.find("```") + 3
1018
+ end = s.find("```", start)
1019
+ s = s[start:end].strip()
1020
+ data = json.loads(s)
394
1021
  return data.get(key)
395
1022
  except (json.JSONDecodeError, AttributeError):
396
1023
  return None
397
1024
  return parser
398
1025
 
1026
+ @staticmethod
1027
+ def json_map(keys: Optional[List[str]] = None) -> Callable[[str], Optional[Dict[str, float]]]:
1028
+ """
1029
+ Create a parser that extracts multiple keys from JSON output as a map.
1030
+ Returns Dict[str, float] or None.
1031
+
1032
+ More robust version that handles:
1033
+ - Code blocks with ```json or ```
1034
+ - Extra text before/after JSON
1035
+ - Various JSON formatting issues
1036
+ """
1037
+ def parser(s: str) -> Optional[Dict[str, float]]:
1038
+ if not s:
1039
+ return None
1040
+ try:
1041
+ s = s.strip()
1042
+
1043
+ # Handle markdown code blocks
1044
+ if "```json" in s.lower():
1045
+ start = s.lower().find("```json") + 7
1046
+ end = s.find("```", start)
1047
+ if end > start:
1048
+ s = s[start:end].strip()
1049
+ elif "```" in s:
1050
+ start = s.find("```") + 3
1051
+ end = s.find("```", start)
1052
+ if end > start:
1053
+ s = s[start:end].strip()
1054
+
1055
+ # Try to find JSON object if there's extra text
1056
+ # Look for first { and last }
1057
+ if '{' in s and '}' in s:
1058
+ start_brace = s.find('{')
1059
+ end_brace = s.rfind('}')
1060
+ if start_brace < end_brace:
1061
+ s = s[start_brace:end_brace + 1]
1062
+
1063
+ data = json.loads(s)
1064
+ if not isinstance(data, dict):
1065
+ return None
1066
+
1067
+ result = {}
1068
+ target_keys = keys if keys else list(data.keys())
1069
+
1070
+ for key in target_keys:
1071
+ if key in data:
1072
+ val = data[key]
1073
+ if isinstance(val, (int, float)):
1074
+ result[key] = float(val)
1075
+ elif isinstance(val, str):
1076
+ try:
1077
+ result[key] = float(val)
1078
+ except ValueError:
1079
+ pass
1080
+
1081
+ return result if result else None
1082
+ except (json.JSONDecodeError, AttributeError, ValueError):
1083
+ return None
1084
+ return parser
1085
+
1086
+ @staticmethod
1087
+ def multi_score_pattern(pattern: str = r'(\w+):\s*([\d.]+)') -> Callable[[str], Optional[Dict[str, float]]]:
1088
+ """
1089
+ Create a parser that extracts key-value pairs using regex pattern.
1090
+ Default pattern matches "key: value" format.
1091
+ """
1092
+ def parser(s: str) -> Optional[Dict[str, float]]:
1093
+ if not s:
1094
+ return None
1095
+ matches = re.findall(pattern, s)
1096
+ if not matches:
1097
+ return None
1098
+ result = {}
1099
+ for key, val in matches:
1100
+ try:
1101
+ result[key.lower()] = float(val)
1102
+ except ValueError:
1103
+ pass
1104
+ return result if result else None
1105
+ return parser
1106
+
1107
+
1108
+ def _infer_return_type(value: Any) -> Optional[ReturnType]:
1109
+ """Infer the return type from a parsed value."""
1110
+ if value is None:
1111
+ return None
1112
+ if isinstance(value, bool):
1113
+ return ReturnType.BOOLEAN
1114
+ if isinstance(value, (int, float)):
1115
+ return ReturnType.SCALAR
1116
+ if isinstance(value, dict) and all(isinstance(v, (int, float)) for v in value.values()):
1117
+ return ReturnType.MAP
1118
+ return None
1119
+
1120
+
1121
+ def _validate_consistent_types(votes: List[Dict[str, Any]]) -> ReturnType:
1122
+ """Validate that all votes have consistent return types."""
1123
+ return_types = set()
1124
+ for vote in votes:
1125
+ result = vote.get("result")
1126
+ if result is not None:
1127
+ rt = _infer_return_type(result)
1128
+ if rt is not None:
1129
+ return_types.add(rt)
1130
+
1131
+ if len(return_types) == 0:
1132
+ raise ValueError("All parsers returned None - cannot determine return type")
1133
+ if len(return_types) > 1:
1134
+ raise ValueError(
1135
+ f"Mixed return types detected: {[rt.value for rt in return_types]}. "
1136
+ "All judges must return the same type (boolean, scalar, or map)."
1137
+ )
1138
+ return return_types.pop()
1139
+
1140
+
1141
+ def _validate_consistent_map_keys(votes: List[Dict[str, Any]]) -> set:
1142
+ """Validate that all map votes have the same keys."""
1143
+ all_keys = None
1144
+ for vote in votes:
1145
+ result = vote.get("result")
1146
+ if isinstance(result, dict):
1147
+ keys = set(result.keys())
1148
+ if all_keys is None:
1149
+ all_keys = keys
1150
+ elif keys != all_keys:
1151
+ raise ValueError(f"Inconsistent map keys across voters. Expected {all_keys}, got {keys}")
1152
+ return all_keys or set()
1153
+
399
1154
 
400
1155
  class LLMAsAJudge:
401
1156
  BASE_TEMPLATE = """\
@@ -439,123 +1194,13 @@ Return a single numeric score from 0-10 indicating how well the model output mat
439
1194
  - 1-3 = poor match
440
1195
  - 0 = completely wrong
441
1196
  Output only the number. No explanation. No extra text.""",
1197
+ 'multi_score': """\
1198
+ Evaluate the model output against the ground truth on multiple dimensions.
1199
+ Return your scores in JSON format with numeric values (0-10 scale).
1200
+ Example: {"accuracy": 8, "helpfulness": 7, "relevance": 9}
1201
+ Output only valid JSON. No explanation. No extra text.""",
442
1202
  }
443
1203
 
444
-
445
-
446
-
447
- # def __init__(
448
- # self,
449
- # models: Optional[List[str]] = None,
450
- # config: Optional[Dict[str, Dict[str, Any]]] = None, # one dict for providers and models
451
- # base_headers: Optional[Dict[str, str]] = None,
452
- # wandb_project: Optional[str] = None,
453
- # custom_template: Optional[str] = None,
454
- # use_fully_custom_prompt: bool = False,
455
- # notes: Optional[str] = None,
456
- # output_parser: Optional[str] = 'right/wrong',
457
- # fallback_comparison: bool = True,
458
- # default_temperature: float = 0.0,
459
- # verbose: bool = False,
460
- # num_retries: int = 2, # per-call retries before giving up on that model
461
- # backoff_base: float = 0.5, # seconds
462
- # backoff_max: float = 4.0, # seconds
463
- # custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
464
- # mode: str = "majority", # "single", "majority", "all"
465
- # ):
466
- # """
467
- # config keys can be a provider name ("wandb", "openai", "anthropic")
468
- # or a full model name ("openai/gpt-4o-mini", "wandb/deepseek-ai/DeepSeek-V3.1").
469
-
470
- # Values can include:
471
- # api_base: Optional[str]
472
- # headers: Dict[str, str]
473
- # temperature: float
474
-
475
- # Precedence:
476
- # base_headers < provider config < model config
477
-
478
- # Args:
479
- # models: List of litellm model strings (e.g., ["openai/gpt-4", "anthropic/claude-3"])
480
- # custom_template: Template with placeholders for input/output/ground_truth
481
- # use_fully_custom_prompt: If True, pass complete prompt to judge(prompt=...).
482
- # When True, input/output/ground_truth must NOT be passed to judge()
483
- # output_parser: Parser name ('right/wrong', 'yes/no', 'pass/fail', 'numeric')
484
- # or custom function with signature (str) -> Any
485
- # fallback_comparison: If True and parser returns None, falls back to string comparison
486
- # custom_generation_fns: List of custom inference functions with signature fn(prompt: str) -> str
487
- # These will be used in addition to litellm models for voting.
488
- # mode: Voting mode - "majority" (default), "single" (first judge only), or "all" (unanimous)
489
- # """
490
- # self.models = models or []
491
- # self.custom_generation_fns = custom_generation_fns or []
492
-
493
- # # Validate that at least one judge is provided
494
- # if not self.models and not self.custom_generation_fns:
495
- # raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
496
-
497
- # # Validate mode
498
- # if mode not in ("majority", "single", "all"):
499
- # raise ValueError("mode must be 'majority', 'single', or 'all'")
500
-
501
- # self.config = config or {}
502
- # self.base_headers = dict(base_headers or {})
503
- # self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
504
- # self.notes = notes or ""
505
- # self.use_fully_custom_prompt = use_fully_custom_prompt
506
- # self.mode = mode
507
-
508
- # # Resolve output parser
509
- # parser_name = None
510
- # if isinstance(output_parser, str):
511
- # parser_map = {
512
- # 'right/wrong': OutputParsers.right_wrong,
513
- # 'pass/fail': OutputParsers.pass_fail,
514
- # 'yes/no': OutputParsers.yes_no,
515
- # 'numeric': OutputParsers.numeric_score,
516
- # }
517
- # if output_parser not in parser_map:
518
- # raise ValueError(f"Unknown parser '{output_parser}'. Available: {list(parser_map.keys())}")
519
- # self.output_parser = parser_map[output_parser]
520
- # parser_name = output_parser
521
- # else:
522
- # self.output_parser = output_parser
523
-
524
- # # Set template based on mode
525
- # if use_fully_custom_prompt:
526
- # self.template = None # No template in fully custom mode
527
- # elif custom_template:
528
- # self.template = custom_template
529
- # elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
530
- # self.template = self.BASE_TEMPLATE.format(
531
- # instruction=self.PARSER_INSTRUCTIONS[parser_name],
532
- # notes_section="{notes_section}",
533
- # input_block="{input_block}",
534
- # model_output="{model_output}",
535
- # ground_truth="{ground_truth}",
536
- # )
537
- # else:
538
- # # Default to right/wrong for custom parsers
539
- # self.template = self.BASE_TEMPLATE.format(
540
- # instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
541
- # notes_section="{notes_section}",
542
- # input_block="{input_block}",
543
- # model_output="{model_output}",
544
- # ground_truth="{ground_truth}",
545
- # )
546
-
547
- # self.fallback_comparison = fallback_comparison
548
- # self.default_temperature = float(default_temperature)
549
- # self.verbose = verbose
550
- # self.num_retries = int(num_retries)
551
- # self.backoff_base = float(backoff_base)
552
- # self.backoff_max = float(backoff_max)
553
-
554
-
555
-
556
-
557
-
558
-
559
1204
  def __init__(
560
1205
  self,
561
1206
  models: Optional[List[str]] = None,
@@ -565,7 +1210,7 @@ Output only the number. No explanation. No extra text.""",
565
1210
  custom_template: Optional[str] = None,
566
1211
  use_fully_custom_prompt: bool = False,
567
1212
  notes: Optional[str] = None,
568
- output_parser: Optional[str] = 'right/wrong',
1213
+ output_parser: Optional[Union[str, Callable]] = 'right/wrong',
569
1214
  fallback_comparison: bool = True,
570
1215
  default_temperature: float = 0.0,
571
1216
  verbose: bool = False,
@@ -573,10 +1218,9 @@ Output only the number. No explanation. No extra text.""",
573
1218
  backoff_base: float = 0.5,
574
1219
  backoff_max: float = 4.0,
575
1220
  custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
576
- mode: str = "majority",
577
- use_cache: bool = False,
1221
+ mode: Optional[str] = None,
578
1222
  litellm_cache_dir: Optional[str] = None,
579
- cache_size_gb: Optional[float] = None,
1223
+ return_type: Optional[str] = None,
580
1224
  ):
581
1225
  self.models = models or []
582
1226
  self.custom_generation_fns = custom_generation_fns or []
@@ -584,15 +1228,11 @@ Output only the number. No explanation. No extra text.""",
584
1228
  if not self.models and not self.custom_generation_fns:
585
1229
  raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
586
1230
 
587
- if mode not in ("majority", "single", "all"):
588
- raise ValueError("mode must be 'majority', 'single', or 'all'")
589
-
590
1231
  self.config = config or {}
591
1232
  self.base_headers = dict(base_headers or {})
592
1233
  self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
593
1234
  self.notes = notes or ""
594
1235
  self.use_fully_custom_prompt = use_fully_custom_prompt
595
- self.mode = mode
596
1236
  self.fallback_comparison = fallback_comparison
597
1237
  self.default_temperature = float(default_temperature)
598
1238
  self.verbose = verbose
@@ -600,6 +1240,7 @@ Output only the number. No explanation. No extra text.""",
600
1240
  self.backoff_base = float(backoff_base)
601
1241
  self.backoff_max = float(backoff_max)
602
1242
 
1243
+ # Resolve output parser
603
1244
  parser_name = None
604
1245
  if isinstance(output_parser, str):
605
1246
  parser_map = {
@@ -607,6 +1248,7 @@ Output only the number. No explanation. No extra text.""",
607
1248
  'pass/fail': OutputParsers.pass_fail,
608
1249
  'yes/no': OutputParsers.yes_no,
609
1250
  'numeric': OutputParsers.numeric_score,
1251
+ 'multi_score': OutputParsers.json_map(),
610
1252
  }
611
1253
  if output_parser not in parser_map:
612
1254
  raise ValueError(f"Unknown parser '{output_parser}'")
@@ -615,6 +1257,42 @@ Output only the number. No explanation. No extra text.""",
615
1257
  else:
616
1258
  self.output_parser = output_parser
617
1259
 
1260
+ # Determine expected return type
1261
+ self._explicit_return_type: Optional[ReturnType] = None
1262
+ if return_type is not None:
1263
+ self._explicit_return_type = ReturnType(return_type)
1264
+ elif parser_name:
1265
+ if parser_name in ('right/wrong', 'pass/fail', 'yes/no'):
1266
+ self._explicit_return_type = ReturnType.BOOLEAN
1267
+ elif parser_name == 'numeric':
1268
+ self._explicit_return_type = ReturnType.SCALAR
1269
+ elif parser_name == 'multi_score':
1270
+ self._explicit_return_type = ReturnType.MAP
1271
+
1272
+ # Resolve aggregation mode
1273
+ if mode is not None:
1274
+ if mode not in MODE_STR_MAP:
1275
+ raise ValueError(f"Unknown mode '{mode}'. Available: {list(MODE_STR_MAP.keys())}")
1276
+ self._mode = MODE_STR_MAP[mode]
1277
+ else:
1278
+ if self._explicit_return_type:
1279
+ self._mode = DEFAULT_MODES[self._explicit_return_type]
1280
+ else:
1281
+ self._mode = AggregationMode.MAJORITY
1282
+
1283
+ # Validate mode against return type if known
1284
+ if self._explicit_return_type:
1285
+ valid_modes = VALID_MODES[self._explicit_return_type]
1286
+ if self._mode not in valid_modes:
1287
+ raise ValueError(
1288
+ f"Mode '{self._mode.value}' not valid for return type '{self._explicit_return_type.value}'. "
1289
+ f"Valid: {[m.value for m in valid_modes]}"
1290
+ )
1291
+
1292
+ # For backward compat, expose mode as string
1293
+ self.mode = self._mode.value
1294
+
1295
+ # Set template
618
1296
  if use_fully_custom_prompt:
619
1297
  self.template = None
620
1298
  elif custom_template:
@@ -636,50 +1314,23 @@ Output only the number. No explanation. No extra text.""",
636
1314
  ground_truth="{ground_truth}",
637
1315
  )
638
1316
 
639
- # optional local cache setup
640
- # Enable cache if use_cache=True OR if litellm_cache_dir is explicitly provided (backward compatible)
641
- self.cache_enabled = use_cache or (litellm_cache_dir is not None)
1317
+ # Optional local cache setup
1318
+ self.cache_enabled = litellm_cache_dir is not None
642
1319
  if self.cache_enabled:
643
- # Only set up cache if it hasn't been set up already
644
- if litellm.cache is None:
645
- # Set default cache directory if not specified
646
- if litellm_cache_dir is None:
647
- litellm_cache_dir = ".litellm_cache"
648
-
649
- # Convert GB to bytes if specified, otherwise unlimited
650
- size_limit_bytes = None if cache_size_gb is None else int(cache_size_gb * 1024 * 1024 * 1024)
651
- cache_backend = UnlimitedDiskCache(litellm_cache_dir, size_limit=size_limit_bytes)
652
- litellm.cache = Cache(disk_cache_dir=litellm_cache_dir)
653
- litellm.cache.cache = cache_backend
654
-
655
-
656
-
657
-
658
-
659
-
660
-
661
-
1320
+ litellm.cache = Cache(type="disk", disk_cache_dir=litellm_cache_dir)
662
1321
 
663
1322
  def _build_prompt(self, input: Any, model_output: Any, ground_truth: Any) -> str:
664
1323
  notes_section = f"notes:\n{self.notes}\n" if self.notes else ""
665
1324
  input_text = str(input) if input not in (None, "") else "[omitted input for brevity]"
666
- return self.template.format(
667
- notes_section=notes_section,
668
- input_block=input_text,
669
- model_output=str(model_output),
670
- ground_truth=str(ground_truth),
671
- )
672
1325
 
673
- @staticmethod
674
- def _last6_right_wrong(s: str):
675
- if not s:
676
- return None
677
- tail = s.strip()[-6:].lower()
678
- if "right" in tail:
679
- return True
680
- if "wrong" in tail:
681
- return False
682
- return None
1326
+ # Use string replacement instead of .format() to avoid issues with { } in template
1327
+ # (e.g., JSON examples in multi_score instructions)
1328
+ prompt = self.template
1329
+ prompt = prompt.replace("{notes_section}", notes_section)
1330
+ prompt = prompt.replace("{input_block}", input_text)
1331
+ prompt = prompt.replace("{model_output}", str(model_output))
1332
+ prompt = prompt.replace("{ground_truth}", str(ground_truth))
1333
+ return prompt
683
1334
 
684
1335
  def _resolve_per_model(self, model: str) -> Tuple[Optional[str], Dict[str, str], float]:
685
1336
  provider = model.split("/", 1)[0] if "/" in model else model
@@ -688,7 +1339,6 @@ Output only the number. No explanation. No extra text.""",
688
1339
  headers: Dict[str, str] = dict(self.base_headers)
689
1340
  temperature: float = self.default_temperature
690
1341
 
691
- # provider-level defaults
692
1342
  pc = self.config.get(provider, {})
693
1343
  if pc.get("api_base") is not None:
694
1344
  api_base = pc["api_base"]
@@ -696,7 +1346,6 @@ Output only the number. No explanation. No extra text.""",
696
1346
  if "temperature" in pc:
697
1347
  temperature = float(pc["temperature"])
698
1348
 
699
- # model-level overrides
700
1349
  mc = self.config.get(model, {})
701
1350
  if mc.get("api_base") is not None:
702
1351
  api_base = mc["api_base"]
@@ -704,7 +1353,6 @@ Output only the number. No explanation. No extra text.""",
704
1353
  if "temperature" in mc:
705
1354
  temperature = float(mc["temperature"])
706
1355
 
707
- # wandb defaults
708
1356
  if provider == "wandb":
709
1357
  if api_base is None:
710
1358
  api_base = "https://api.inference.wandb.ai/v1"
@@ -726,15 +1374,6 @@ Output only the number. No explanation. No extra text.""",
726
1374
  last_err = None
727
1375
  for i in range(attempts):
728
1376
  try:
729
- # resp = completion(
730
- # model=model,
731
- # api_base=api_base, # None uses provider default
732
- # messages=[{"role": "user", "content": prompt}],
733
- # temperature=temperature,
734
- # max_tokens=max_tokens,
735
- # extra_headers=headers,
736
- # )
737
-
738
1377
  resp = completion(
739
1378
  model=model,
740
1379
  api_base=api_base,
@@ -743,20 +1382,20 @@ Output only the number. No explanation. No extra text.""",
743
1382
  max_tokens=max_tokens,
744
1383
  extra_headers=headers,
745
1384
  caching=self.cache_enabled
746
- )
1385
+ )
747
1386
  return (resp.choices[0].message.content or "").strip()
748
1387
  except Exception as e:
749
1388
  last_err = e
750
1389
  if i == attempts - 1:
751
1390
  break
752
1391
  sleep_s = min(self.backoff_max, self.backoff_base * (2 ** i))
753
- jitter = sleep_s * (0.1 * (2 * random.random() - 1.0)) # ±10%
1392
+ jitter = sleep_s * (0.1 * (2 * random.random() - 1.0))
754
1393
  if self.verbose:
755
1394
  print(f"[retry {i+1}/{attempts-1}] {model} error: {e} — sleeping {max(0.0, sleep_s + jitter):.2f}s", flush=True)
756
1395
  time.sleep(max(0.0, sleep_s + jitter))
757
- raise last_err # fail after retries
1396
+ raise last_err
758
1397
 
759
- def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any):
1398
+ def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any) -> Any:
760
1399
  api_base, headers, temperature = self._resolve_per_model(model)
761
1400
 
762
1401
  content = self._attempt_completion(
@@ -768,15 +1407,80 @@ Output only the number. No explanation. No extra text.""",
768
1407
  max_tokens=max_tokens,
769
1408
  )
770
1409
 
771
- # Use the instance parser
772
1410
  parsed = self.output_parser(content)
773
1411
 
774
- # If parser returns None and fallback is enabled, do string comparison
775
1412
  if parsed is None and self.fallback_comparison:
776
1413
  return str(model_output).strip() == str(ground_truth).strip()
777
1414
 
778
1415
  return parsed
779
1416
 
1417
+ def _aggregate_boolean(self, votes: List[Dict[str, Any]]) -> bool:
1418
+ results = [v["result"] for v in votes if v["result"] is not None]
1419
+ if not results:
1420
+ raise ValueError("No valid votes to aggregate")
1421
+
1422
+ if self._mode == AggregationMode.SINGLE:
1423
+ return bool(results[0])
1424
+ elif self._mode == AggregationMode.MAJORITY:
1425
+ return sum(1 for r in results if r) >= len(results) / 2
1426
+ elif self._mode == AggregationMode.ALL:
1427
+ return all(results)
1428
+ else:
1429
+ raise ValueError(f"Invalid mode for boolean: {self._mode}")
1430
+
1431
+ def _aggregate_scalar(self, votes: List[Dict[str, Any]]) -> float:
1432
+ results = [float(v["result"]) for v in votes if v["result"] is not None]
1433
+ if not results:
1434
+ raise ValueError("No valid votes to aggregate")
1435
+
1436
+ if self._mode == AggregationMode.SINGLE:
1437
+ return results[0]
1438
+ elif self._mode == AggregationMode.AVERAGE:
1439
+ return sum(results) / len(results)
1440
+ elif self._mode == AggregationMode.MIN:
1441
+ return min(results)
1442
+ elif self._mode == AggregationMode.MAX:
1443
+ return max(results)
1444
+ elif self._mode == AggregationMode.MEDIAN:
1445
+ s = sorted(results)
1446
+ n = len(s)
1447
+ mid = n // 2
1448
+ return (s[mid - 1] + s[mid]) / 2 if n % 2 == 0 else s[mid]
1449
+ else:
1450
+ raise ValueError(f"Invalid mode for scalar: {self._mode}")
1451
+
1452
+ def _aggregate_map(self, votes: List[Dict[str, Any]]) -> Dict[str, float]:
1453
+ valid = [v["result"] for v in votes if v["result"] is not None and isinstance(v["result"], dict)]
1454
+ if not valid:
1455
+ raise ValueError("No valid map votes to aggregate")
1456
+
1457
+ keys = set()
1458
+ for v in valid:
1459
+ keys.update(v.keys())
1460
+
1461
+ if self._mode == AggregationMode.SINGLE:
1462
+ return valid[0]
1463
+
1464
+ result = {}
1465
+ for key in keys:
1466
+ values = [v[key] for v in valid if key in v]
1467
+ if not values:
1468
+ continue
1469
+
1470
+ if self._mode == AggregationMode.AVERAGE:
1471
+ result[key] = sum(values) / len(values)
1472
+ elif self._mode == AggregationMode.MIN:
1473
+ result[key] = min(values)
1474
+ elif self._mode == AggregationMode.MAX:
1475
+ result[key] = max(values)
1476
+ elif self._mode == AggregationMode.MEDIAN:
1477
+ s = sorted(values)
1478
+ n = len(s)
1479
+ mid = n // 2
1480
+ result[key] = (s[mid - 1] + s[mid]) / 2 if n % 2 == 0 else s[mid]
1481
+
1482
+ return result
1483
+
780
1484
  def judge(
781
1485
  self,
782
1486
  input: Any = None,
@@ -784,23 +1488,25 @@ Output only the number. No explanation. No extra text.""",
784
1488
  ground_truth: Any = None,
785
1489
  prompt: Optional[str] = None,
786
1490
  max_tokens: int = 10000,
787
- ):
788
- # Validation for fully_custom_prompt mode
1491
+ ) -> Dict[str, Any]:
1492
+ """
1493
+ Run the judge evaluation.
1494
+
1495
+ Returns dict with:
1496
+ - 'correct': bool or None (bool for boolean type, None for scalar/map)
1497
+ - 'scores': float or Dict[str, float] or None (for scalar/map, None for boolean)
1498
+ - 'mode': aggregation mode string
1499
+ - 'votes': list of individual judge votes
1500
+ """
789
1501
  if self.use_fully_custom_prompt:
790
1502
  if prompt is None:
791
- raise ValueError(
792
- "When use_fully_custom_prompt=True, you must pass prompt to judge()."
793
- )
1503
+ raise ValueError("When use_fully_custom_prompt=True, you must pass prompt to judge().")
794
1504
  if input is not None or model_output is not None or ground_truth is not None:
795
1505
  raise ValueError(
796
- "When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth to judge(). "
797
- "Only pass the complete prompt."
1506
+ "When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth."
798
1507
  )
799
1508
  elif prompt is not None:
800
- raise ValueError(
801
- "prompt parameter can only be used when use_fully_custom_prompt=True. "
802
- "Otherwise, use input/model_output/ground_truth."
803
- )
1509
+ raise ValueError("prompt parameter can only be used when use_fully_custom_prompt=True.")
804
1510
  else:
805
1511
  prompt = self._build_prompt(input, model_output, ground_truth)
806
1512
 
@@ -808,18 +1514,27 @@ Output only the number. No explanation. No extra text.""",
808
1514
 
809
1515
  # Vote with litellm models
810
1516
  for m in self.models:
811
- res = self._ask_model(m, prompt, max_tokens, model_output, ground_truth)
812
- if self.verbose:
813
- print(f"Model {m} voted: {res}", flush=True)
814
- votes.append({"model": m, "correct": res})
1517
+ try:
1518
+ res = self._ask_model(m, prompt, max_tokens, model_output, ground_truth)
1519
+ if self.verbose:
1520
+ print(f"Model {m} voted: {res}", flush=True)
1521
+ votes.append({"model": m, "result": res})
1522
+ except Exception as e:
1523
+ if self.verbose:
1524
+ print(f"Model {m} failed: {e}", flush=True)
1525
+ if self.fallback_comparison:
1526
+ res = str(model_output).strip() == str(ground_truth).strip()
1527
+ votes.append({"model": m, "result": res, "error": str(e)})
1528
+ else:
1529
+ raise
815
1530
 
816
1531
  # Vote with custom generation functions
817
1532
  for idx, custom_fn in enumerate(self.custom_generation_fns):
1533
+ fn_name = f"custom_fn_{idx}"
818
1534
  try:
819
1535
  content = custom_fn(prompt)
820
1536
  parsed = self.output_parser(content)
821
1537
 
822
- # If parser returns None and fallback is enabled, do string comparison
823
1538
  if parsed is None and self.fallback_comparison:
824
1539
  res = str(model_output).strip() == str(ground_truth).strip()
825
1540
  else:
@@ -827,27 +1542,70 @@ Output only the number. No explanation. No extra text.""",
827
1542
 
828
1543
  if self.verbose:
829
1544
  print(f"Custom function {idx} voted: {res}", flush=True)
830
- votes.append({"model": f"custom_fn_{idx}", "correct": res})
1545
+ votes.append({"model": fn_name, "result": res})
831
1546
  except Exception as e:
832
1547
  if self.verbose:
833
1548
  print(f"Custom function {idx} failed: {e}", flush=True)
834
- # If custom function fails and fallback is enabled, do string comparison
835
1549
  if self.fallback_comparison:
836
1550
  res = str(model_output).strip() == str(ground_truth).strip()
837
- votes.append({"model": f"custom_fn_{idx}", "correct": res})
1551
+ votes.append({"model": fn_name, "result": res, "error": str(e)})
838
1552
  else:
839
1553
  raise
840
1554
 
841
- if self.mode == "single":
842
- final = votes[0]["correct"]
843
- elif self.mode == "majority":
844
- true_votes = sum(v["correct"] for v in votes)
845
- false_votes = len(votes) - true_votes
846
- final = True if true_votes >= false_votes else False
847
- elif self.mode == "all":
848
- final = all(v["correct"] for v in votes)
1555
+ # Determine return type
1556
+ return_type = self._explicit_return_type
1557
+ if return_type is None:
1558
+ return_type = _validate_consistent_types(votes)
849
1559
  else:
850
- raise ValueError("mode must be 'majority', 'single', or 'all'")
851
-
852
- return {"correct": final, "mode": self.mode, "votes": votes}
853
-
1560
+ actual = _validate_consistent_types(votes)
1561
+ if actual != return_type:
1562
+ raise ValueError(f"Expected return type '{return_type.value}' but got '{actual.value}'")
1563
+
1564
+ # Auto-correct mode if needed
1565
+ if self._mode not in VALID_MODES[return_type]:
1566
+ self._mode = DEFAULT_MODES[return_type]
1567
+ self.mode = self._mode.value
1568
+
1569
+ # Validate map keys
1570
+ if return_type == ReturnType.MAP:
1571
+ _validate_consistent_map_keys(votes)
1572
+
1573
+ # Aggregate
1574
+ if return_type == ReturnType.BOOLEAN:
1575
+ final = self._aggregate_boolean(votes)
1576
+ elif return_type == ReturnType.SCALAR:
1577
+ final = self._aggregate_scalar(votes)
1578
+ elif return_type == ReturnType.MAP:
1579
+ final = self._aggregate_map(votes)
1580
+ else:
1581
+ raise ValueError(f"Unknown return type: {return_type}")
1582
+
1583
+ # Build backward-compatible response
1584
+ # Boolean: correct=bool, scores=None
1585
+ # Scalar: correct=score, scores=score (both fields for convenience)
1586
+ # Map: correct=None, scores=map
1587
+ if return_type == ReturnType.BOOLEAN:
1588
+ # Also put "correct" in each vote for backward compat
1589
+ for v in votes:
1590
+ v["correct"] = v["result"]
1591
+ return {
1592
+ "correct": final,
1593
+ "scores": None,
1594
+ "mode": self.mode,
1595
+ "votes": votes,
1596
+ }
1597
+ elif return_type == ReturnType.SCALAR:
1598
+ # For scalar, put score in both correct and scores for convenience
1599
+ return {
1600
+ "correct": final,
1601
+ "scores": final,
1602
+ "mode": self.mode,
1603
+ "votes": votes,
1604
+ }
1605
+ else: # MAP
1606
+ return {
1607
+ "correct": None,
1608
+ "scores": final,
1609
+ "mode": self.mode,
1610
+ "votes": votes,
1611
+ }