llmasajudge 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llmasajudge/__init__.py CHANGED
@@ -201,141 +201,726 @@
201
201
 
202
202
 
203
203
 
204
- import os
205
- import time
206
- import random
207
- import re
208
- from typing import Any, Callable, Dict, List, Optional, Tuple
209
- import litellm
210
- from litellm import completion
211
- from litellm.caching.caching import Cache
204
+ # import os
205
+ # import time
206
+ # import random
207
+ # import re
208
+ # from typing import Any, Callable, Dict, List, Optional, Tuple
209
+ # import litellm
210
+ # from litellm import completion
211
+ # from litellm.caching.caching import Cache
212
+
213
+
214
+ # __all__ = ["LLMAsAJudge", "OutputParsers"]
215
+
216
+
217
+ # class UnlimitedDiskCache:
218
+ # """
219
+ # Drop-in replacement backend with 'unlimited' size for LiteLLM cache.
220
+
221
+ # This wraps diskcache.Cache with a very large size limit (2^62 bytes ~ 4.6 exabytes)
222
+ # to effectively disable automatic cache eviction, allowing the cache to grow
223
+ # without size constraints.
224
+ # """
225
+
226
+ # def __init__(self, directory, size_limit=None):
227
+ # """
228
+ # Initialize unlimited disk cache.
229
+
230
+ # Args:
231
+ # directory: Path to cache directory
232
+ # size_limit: Optional size limit in bytes. If None, uses 2^62 bytes (~4.6 exabytes)
233
+ # """
234
+ # import diskcache as dc
235
+
236
+ # # Set to very large cap so culling never triggers (effectively unlimited)
237
+ # cap = size_limit if size_limit is not None else (1 << 62)
238
+ # self._dc = dc.Cache(directory, size_limit=cap)
239
+
240
+ # # Sync API used by LiteLLM
241
+ # def get_cache(self, key, **kwargs):
242
+ # """Get value from cache by key."""
243
+ # return self._dc.get(key)
244
+
245
+ # def set_cache(self, key, value, ttl=None, **kwargs):
246
+ # """Set value in cache with optional TTL."""
247
+ # expire = None if ttl is None else float(ttl)
248
+ # self._dc.set(key, value, expire=expire)
249
+
250
+ # # Async API used by LiteLLM
251
+ # async def async_get_cache(self, key, **kwargs):
252
+ # """Async get value from cache by key."""
253
+ # return self.get_cache(key, **kwargs)
254
+
255
+ # async def async_set_cache(self, key, value, ttl=None, **kwargs):
256
+ # """Async set value in cache with optional TTL."""
257
+ # return self.set_cache(key, value, ttl=ttl, **kwargs)
258
+
259
+ # async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
260
+ # """
261
+ # Async batch set multiple cache entries.
262
+
263
+ # Args:
264
+ # cache_list: List of (key, value) tuples
265
+ # ttl: Optional time-to-live in seconds
266
+ # """
267
+ # for k, v in cache_list:
268
+ # self.set_cache(k, v, ttl=ttl)
269
+
270
+ # async def batch_cache_write(self, key, value, ttl=None, **kwargs):
271
+ # """Async batch write (single entry)."""
272
+ # self.set_cache(key, value, ttl=ttl)
273
+
274
+ # async def ping(self):
275
+ # """Async ping check."""
276
+ # return True
277
+
278
+ # async def delete_cache_keys(self, keys):
279
+ # """
280
+ # Async delete multiple cache keys.
281
+
282
+ # Args:
283
+ # keys: List of keys to delete
284
+ # """
285
+ # for k in keys:
286
+ # try:
287
+ # del self._dc[k]
288
+ # except KeyError:
289
+ # pass
290
+ # return True
291
+
292
+ # async def disconnect(self):
293
+ # """Async disconnect and close cache."""
294
+ # self._dc.close()
295
+
296
+ # def get_stats(self):
297
+ # """
298
+ # Get cache statistics.
299
+
300
+ # Returns:
301
+ # dict with size_limit, current_size, item_count, and percent_full
302
+ # """
303
+ # size_limit = self._dc.size_limit
304
+ # volume = self._dc.volume() # Current size in bytes
305
+ # count = len(self._dc) # Number of items
306
+
307
+ # return {
308
+ # "size_limit": size_limit,
309
+ # "current_size": volume,
310
+ # "item_count": count,
311
+ # "percent_full": (volume / size_limit) * 100 if size_limit > 0 else 0.0,
312
+ # }
313
+
314
+ # def print_stats(self):
315
+ # """Print human-readable cache statistics."""
316
+ # stats = self.get_stats()
317
+
318
+ # def human_size(bytes_val):
319
+ # """Convert bytes to human readable format."""
320
+ # for unit in ["B", "KB", "MB", "GB", "TB", "PB", "EB"]:
321
+ # if bytes_val < 1024.0:
322
+ # return f"{bytes_val:.2f} {unit}"
323
+ # bytes_val /= 1024.0
324
+ # return f"{bytes_val:.2f} EB"
325
+
326
+ # print("=" * 60)
327
+ # print("CACHE STATISTICS")
328
+ # print("=" * 60)
329
+ # print(f" Size limit: {human_size(stats['size_limit'])} ({stats['size_limit']:,} bytes)")
330
+ # print(f" Current size: {human_size(stats['current_size'])} ({stats['current_size']:,} bytes)")
331
+ # print(f" Items cached: {stats['item_count']}")
332
+ # print(f" % full: {stats['percent_full']:.6f}%")
333
+ # print("=" * 60)
334
+
335
+
336
+ # class OutputParsers:
337
+ # """Stock output parsers for common judge output formats."""
212
338
 
339
+ # @staticmethod
340
+ # def right_wrong(s: str) -> Optional[bool]:
341
+ # """Parse 'right' or 'wrong' from the last 6 characters."""
342
+ # if not s:
343
+ # return None
344
+ # tail = s.strip()[-6:].lower()
345
+ # if "right" in tail:
346
+ # return True
347
+ # if "wrong" in tail:
348
+ # return False
349
+ # return None
213
350
 
214
- __all__ = ["LLMAsAJudge", "OutputParsers"]
351
+ # @staticmethod
352
+ # def pass_fail(s: str) -> Optional[bool]:
353
+ # """Parse 'pass' or 'fail' from the response."""
354
+ # if not s:
355
+ # return None
356
+ # text = s.strip().lower()
357
+ # if "pass" in text:
358
+ # return True
359
+ # if "fail" in text:
360
+ # return False
361
+ # return None
215
362
 
363
+ # @staticmethod
364
+ # def yes_no(s: str) -> Optional[bool]:
365
+ # """Parse 'yes' or 'no' from the response."""
366
+ # if not s:
367
+ # return None
368
+ # text = s.strip().lower()
369
+ # if "yes" in text:
370
+ # return True
371
+ # if "no" in text:
372
+ # return False
373
+ # return None
216
374
 
217
- class UnlimitedDiskCache:
218
- """
219
- Drop-in replacement backend with 'unlimited' size for LiteLLM cache.
375
+ # @staticmethod
376
+ # def numeric_score(s: str) -> Optional[float]:
377
+ # """Extract first numeric value from the response."""
378
+ # if not s:
379
+ # return None
380
+ # match = re.search(r'[-+]?\d*\.?\d+', s.strip())
381
+ # if match:
382
+ # return float(match.group())
383
+ # return None
220
384
 
221
- This wraps diskcache.Cache with a very large size limit (2^62 bytes ~ 4.6 exabytes)
222
- to effectively disable automatic cache eviction, allowing the cache to grow
223
- without size constraints.
224
- """
385
+ # @staticmethod
386
+ # def json_extract(key: str) -> Callable[[str], Any]:
387
+ # """Create a parser that extracts a specific key from JSON output."""
388
+ # import json
389
+ # def parser(s: str) -> Any:
390
+ # if not s:
391
+ # return None
392
+ # try:
393
+ # data = json.loads(s.strip())
394
+ # return data.get(key)
395
+ # except (json.JSONDecodeError, AttributeError):
396
+ # return None
397
+ # return parser
398
+
399
+
400
+ # class LLMAsAJudge:
401
+ # BASE_TEMPLATE = """\
402
+ # You are a judge. Read input, model_output, and ground_truth.
403
+ # {instruction}
404
+ # ##################
405
+ # {notes_section}### input:
406
+ # {input_block}
407
+ # ##################
408
+ # model's output:
409
+ # {model_output}
410
+ # ##################
411
+ # ground_truth answer:
412
+ # {ground_truth}
413
+ # ##################
414
+ # """
225
415
 
226
- def __init__(self, directory, size_limit=None):
227
- """
228
- Initialize unlimited disk cache.
416
+ # PARSER_INSTRUCTIONS = {
417
+ # 'right/wrong': """\
418
+ # Return exactly one word: right or wrong.
419
+ # Rules:
420
+ # - Treat extra words or punctuation as irrelevant if the same final value is present.
421
+ # - Output must be exactly right or wrong. No JSON. No quotes. No extra text.""",
422
+ # 'yes/no': """\
423
+ # Return exactly one word: yes or no.
424
+ # Answer yes if the model output matches the ground truth, no otherwise.
425
+ # Rules:
426
+ # - Treat extra words or punctuation as irrelevant if the same final value is present.
427
+ # - Output must be exactly yes or no. No JSON. No quotes. No extra text.""",
428
+ # 'pass/fail': """\
429
+ # Return exactly one word: pass or fail.
430
+ # Answer pass if the model output matches the ground truth, fail otherwise.
431
+ # Rules:
432
+ # - Treat extra words or punctuation as irrelevant if the same final value is present.
433
+ # - Output must be exactly pass or fail. No JSON. No quotes. No extra text.""",
434
+ # 'numeric': """\
435
+ # Return a single numeric score from 0-10 indicating how well the model output matches the ground truth.
436
+ # - 10 = perfect match
437
+ # - 7-9 = close match with minor differences
438
+ # - 4-6 = partial match
439
+ # - 1-3 = poor match
440
+ # - 0 = completely wrong
441
+ # Output only the number. No explanation. No extra text.""",
442
+ # }
443
+
444
+
445
+
446
+
447
+ # # def __init__(
448
+ # # self,
449
+ # # models: Optional[List[str]] = None,
450
+ # # config: Optional[Dict[str, Dict[str, Any]]] = None, # one dict for providers and models
451
+ # # base_headers: Optional[Dict[str, str]] = None,
452
+ # # wandb_project: Optional[str] = None,
453
+ # # custom_template: Optional[str] = None,
454
+ # # use_fully_custom_prompt: bool = False,
455
+ # # notes: Optional[str] = None,
456
+ # # output_parser: Optional[str] = 'right/wrong',
457
+ # # fallback_comparison: bool = True,
458
+ # # default_temperature: float = 0.0,
459
+ # # verbose: bool = False,
460
+ # # num_retries: int = 2, # per-call retries before giving up on that model
461
+ # # backoff_base: float = 0.5, # seconds
462
+ # # backoff_max: float = 4.0, # seconds
463
+ # # custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
464
+ # # mode: str = "majority", # "single", "majority", "all"
465
+ # # ):
466
+ # # """
467
+ # # config keys can be a provider name ("wandb", "openai", "anthropic")
468
+ # # or a full model name ("openai/gpt-4o-mini", "wandb/deepseek-ai/DeepSeek-V3.1").
469
+
470
+ # # Values can include:
471
+ # # api_base: Optional[str]
472
+ # # headers: Dict[str, str]
473
+ # # temperature: float
474
+
475
+ # # Precedence:
476
+ # # base_headers < provider config < model config
477
+
478
+ # # Args:
479
+ # # models: List of litellm model strings (e.g., ["openai/gpt-4", "anthropic/claude-3"])
480
+ # # custom_template: Template with placeholders for input/output/ground_truth
481
+ # # use_fully_custom_prompt: If True, pass complete prompt to judge(prompt=...).
482
+ # # When True, input/output/ground_truth must NOT be passed to judge()
483
+ # # output_parser: Parser name ('right/wrong', 'yes/no', 'pass/fail', 'numeric')
484
+ # # or custom function with signature (str) -> Any
485
+ # # fallback_comparison: If True and parser returns None, falls back to string comparison
486
+ # # custom_generation_fns: List of custom inference functions with signature fn(prompt: str) -> str
487
+ # # These will be used in addition to litellm models for voting.
488
+ # # mode: Voting mode - "majority" (default), "single" (first judge only), or "all" (unanimous)
489
+ # # """
490
+ # # self.models = models or []
491
+ # # self.custom_generation_fns = custom_generation_fns or []
492
+
493
+ # # # Validate that at least one judge is provided
494
+ # # if not self.models and not self.custom_generation_fns:
495
+ # # raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
496
+
497
+ # # # Validate mode
498
+ # # if mode not in ("majority", "single", "all"):
499
+ # # raise ValueError("mode must be 'majority', 'single', or 'all'")
500
+
501
+ # # self.config = config or {}
502
+ # # self.base_headers = dict(base_headers or {})
503
+ # # self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
504
+ # # self.notes = notes or ""
505
+ # # self.use_fully_custom_prompt = use_fully_custom_prompt
506
+ # # self.mode = mode
507
+
508
+ # # # Resolve output parser
509
+ # # parser_name = None
510
+ # # if isinstance(output_parser, str):
511
+ # # parser_map = {
512
+ # # 'right/wrong': OutputParsers.right_wrong,
513
+ # # 'pass/fail': OutputParsers.pass_fail,
514
+ # # 'yes/no': OutputParsers.yes_no,
515
+ # # 'numeric': OutputParsers.numeric_score,
516
+ # # }
517
+ # # if output_parser not in parser_map:
518
+ # # raise ValueError(f"Unknown parser '{output_parser}'. Available: {list(parser_map.keys())}")
519
+ # # self.output_parser = parser_map[output_parser]
520
+ # # parser_name = output_parser
521
+ # # else:
522
+ # # self.output_parser = output_parser
523
+
524
+ # # # Set template based on mode
525
+ # # if use_fully_custom_prompt:
526
+ # # self.template = None # No template in fully custom mode
527
+ # # elif custom_template:
528
+ # # self.template = custom_template
529
+ # # elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
530
+ # # self.template = self.BASE_TEMPLATE.format(
531
+ # # instruction=self.PARSER_INSTRUCTIONS[parser_name],
532
+ # # notes_section="{notes_section}",
533
+ # # input_block="{input_block}",
534
+ # # model_output="{model_output}",
535
+ # # ground_truth="{ground_truth}",
536
+ # # )
537
+ # # else:
538
+ # # # Default to right/wrong for custom parsers
539
+ # # self.template = self.BASE_TEMPLATE.format(
540
+ # # instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
541
+ # # notes_section="{notes_section}",
542
+ # # input_block="{input_block}",
543
+ # # model_output="{model_output}",
544
+ # # ground_truth="{ground_truth}",
545
+ # # )
546
+
547
+ # # self.fallback_comparison = fallback_comparison
548
+ # # self.default_temperature = float(default_temperature)
549
+ # # self.verbose = verbose
550
+ # # self.num_retries = int(num_retries)
551
+ # # self.backoff_base = float(backoff_base)
552
+ # # self.backoff_max = float(backoff_max)
553
+
554
+
555
+
556
+
557
+
558
+
559
+ # def __init__(
560
+ # self,
561
+ # models: Optional[List[str]] = None,
562
+ # config: Optional[Dict[str, Dict[str, Any]]] = None,
563
+ # base_headers: Optional[Dict[str, str]] = None,
564
+ # wandb_project: Optional[str] = None,
565
+ # custom_template: Optional[str] = None,
566
+ # use_fully_custom_prompt: bool = False,
567
+ # notes: Optional[str] = None,
568
+ # output_parser: Optional[str] = 'right/wrong',
569
+ # fallback_comparison: bool = True,
570
+ # default_temperature: float = 0.0,
571
+ # verbose: bool = False,
572
+ # num_retries: int = 2,
573
+ # backoff_base: float = 0.5,
574
+ # backoff_max: float = 4.0,
575
+ # custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
576
+ # mode: str = "majority",
577
+ # use_cache: bool = False,
578
+ # litellm_cache_dir: Optional[str] = None,
579
+ # cache_size_gb: Optional[float] = None,
580
+ # ):
581
+ # self.models = models or []
582
+ # self.custom_generation_fns = custom_generation_fns or []
583
+
584
+ # if not self.models and not self.custom_generation_fns:
585
+ # raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
586
+
587
+ # if mode not in ("majority", "single", "all"):
588
+ # raise ValueError("mode must be 'majority', 'single', or 'all'")
589
+
590
+ # self.config = config or {}
591
+ # self.base_headers = dict(base_headers or {})
592
+ # self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
593
+ # self.notes = notes or ""
594
+ # self.use_fully_custom_prompt = use_fully_custom_prompt
595
+ # self.mode = mode
596
+ # self.fallback_comparison = fallback_comparison
597
+ # self.default_temperature = float(default_temperature)
598
+ # self.verbose = verbose
599
+ # self.num_retries = int(num_retries)
600
+ # self.backoff_base = float(backoff_base)
601
+ # self.backoff_max = float(backoff_max)
602
+
603
+ # parser_name = None
604
+ # if isinstance(output_parser, str):
605
+ # parser_map = {
606
+ # 'right/wrong': OutputParsers.right_wrong,
607
+ # 'pass/fail': OutputParsers.pass_fail,
608
+ # 'yes/no': OutputParsers.yes_no,
609
+ # 'numeric': OutputParsers.numeric_score,
610
+ # }
611
+ # if output_parser not in parser_map:
612
+ # raise ValueError(f"Unknown parser '{output_parser}'")
613
+ # self.output_parser = parser_map[output_parser]
614
+ # parser_name = output_parser
615
+ # else:
616
+ # self.output_parser = output_parser
617
+
618
+ # if use_fully_custom_prompt:
619
+ # self.template = None
620
+ # elif custom_template:
621
+ # self.template = custom_template
622
+ # elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
623
+ # self.template = self.BASE_TEMPLATE.format(
624
+ # instruction=self.PARSER_INSTRUCTIONS[parser_name],
625
+ # notes_section="{notes_section}",
626
+ # input_block="{input_block}",
627
+ # model_output="{model_output}",
628
+ # ground_truth="{ground_truth}",
629
+ # )
630
+ # else:
631
+ # self.template = self.BASE_TEMPLATE.format(
632
+ # instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
633
+ # notes_section="{notes_section}",
634
+ # input_block="{input_block}",
635
+ # model_output="{model_output}",
636
+ # ground_truth="{ground_truth}",
637
+ # )
229
638
 
230
- Args:
231
- directory: Path to cache directory
232
- size_limit: Optional size limit in bytes. If None, uses 2^62 bytes (~4.6 exabytes)
233
- """
234
- import diskcache as dc
639
+ # # optional local cache setup
640
+ # # Enable cache if use_cache=True OR if litellm_cache_dir is explicitly provided (backward compatible)
641
+ # self.cache_enabled = use_cache or (litellm_cache_dir is not None)
642
+ # if self.cache_enabled:
643
+ # # Only set up cache if it hasn't been set up already
644
+ # if litellm.cache is None:
645
+ # # Set default cache directory if not specified
646
+ # if litellm_cache_dir is None:
647
+ # litellm_cache_dir = ".litellm_cache"
235
648
 
236
- # Set to very large cap so culling never triggers (effectively unlimited)
237
- cap = size_limit if size_limit is not None else (1 << 62)
238
- self._dc = dc.Cache(directory, size_limit=cap)
649
+ # # Convert GB to bytes if specified, otherwise unlimited
650
+ # size_limit_bytes = None if cache_size_gb is None else int(cache_size_gb * 1024 * 1024 * 1024)
651
+ # cache_backend = UnlimitedDiskCache(litellm_cache_dir, size_limit=size_limit_bytes)
652
+ # litellm.cache = Cache(disk_cache_dir=litellm_cache_dir)
653
+ # litellm.cache.cache = cache_backend
239
654
 
240
- # Sync API used by LiteLLM
241
- def get_cache(self, key, **kwargs):
242
- """Get value from cache by key."""
243
- return self._dc.get(key)
244
655
 
245
- def set_cache(self, key, value, ttl=None, **kwargs):
246
- """Set value in cache with optional TTL."""
247
- expire = None if ttl is None else float(ttl)
248
- self._dc.set(key, value, expire=expire)
249
656
 
250
- # Async API used by LiteLLM
251
- async def async_get_cache(self, key, **kwargs):
252
- """Async get value from cache by key."""
253
- return self.get_cache(key, **kwargs)
254
657
 
255
- async def async_set_cache(self, key, value, ttl=None, **kwargs):
256
- """Async set value in cache with optional TTL."""
257
- return self.set_cache(key, value, ttl=ttl, **kwargs)
258
658
 
259
- async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
260
- """
261
- Async batch set multiple cache entries.
262
659
 
263
- Args:
264
- cache_list: List of (key, value) tuples
265
- ttl: Optional time-to-live in seconds
266
- """
267
- for k, v in cache_list:
268
- self.set_cache(k, v, ttl=ttl)
269
660
 
270
- async def batch_cache_write(self, key, value, ttl=None, **kwargs):
271
- """Async batch write (single entry)."""
272
- self.set_cache(key, value, ttl=ttl)
273
661
 
274
- async def ping(self):
275
- """Async ping check."""
276
- return True
277
662
 
278
- async def delete_cache_keys(self, keys):
279
- """
280
- Async delete multiple cache keys.
663
+ # def _build_prompt(self, input: Any, model_output: Any, ground_truth: Any) -> str:
664
+ # notes_section = f"notes:\n{self.notes}\n" if self.notes else ""
665
+ # input_text = str(input) if input not in (None, "") else "[omitted input for brevity]"
666
+ # return self.template.format(
667
+ # notes_section=notes_section,
668
+ # input_block=input_text,
669
+ # model_output=str(model_output),
670
+ # ground_truth=str(ground_truth),
671
+ # )
281
672
 
282
- Args:
283
- keys: List of keys to delete
284
- """
285
- for k in keys:
286
- try:
287
- del self._dc[k]
288
- except KeyError:
289
- pass
290
- return True
673
+ # @staticmethod
674
+ # def _last6_right_wrong(s: str):
675
+ # if not s:
676
+ # return None
677
+ # tail = s.strip()[-6:].lower()
678
+ # if "right" in tail:
679
+ # return True
680
+ # if "wrong" in tail:
681
+ # return False
682
+ # return None
291
683
 
292
- async def disconnect(self):
293
- """Async disconnect and close cache."""
294
- self._dc.close()
684
+ # def _resolve_per_model(self, model: str) -> Tuple[Optional[str], Dict[str, str], float]:
685
+ # provider = model.split("/", 1)[0] if "/" in model else model
686
+
687
+ # api_base: Optional[str] = None
688
+ # headers: Dict[str, str] = dict(self.base_headers)
689
+ # temperature: float = self.default_temperature
690
+
691
+ # # provider-level defaults
692
+ # pc = self.config.get(provider, {})
693
+ # if pc.get("api_base") is not None:
694
+ # api_base = pc["api_base"]
695
+ # headers.update(pc.get("headers", {}))
696
+ # if "temperature" in pc:
697
+ # temperature = float(pc["temperature"])
698
+
699
+ # # model-level overrides
700
+ # mc = self.config.get(model, {})
701
+ # if mc.get("api_base") is not None:
702
+ # api_base = mc["api_base"]
703
+ # headers.update(mc.get("headers", {}))
704
+ # if "temperature" in mc:
705
+ # temperature = float(mc["temperature"])
706
+
707
+ # # wandb defaults
708
+ # if provider == "wandb":
709
+ # if api_base is None:
710
+ # api_base = "https://api.inference.wandb.ai/v1"
711
+ # if "OpenAI-Project" not in headers:
712
+ # headers["OpenAI-Project"] = self.wandb_project or "wandb_fc/quickstart_playground"
713
+
714
+ # return api_base, headers, temperature
715
+
716
+ # def _attempt_completion(
717
+ # self,
718
+ # model: str,
719
+ # api_base: Optional[str],
720
+ # headers: Dict[str, str],
721
+ # prompt: str,
722
+ # temperature: float,
723
+ # max_tokens: int,
724
+ # ) -> str:
725
+ # attempts = self.num_retries + 1
726
+ # last_err = None
727
+ # for i in range(attempts):
728
+ # try:
729
+ # # resp = completion(
730
+ # # model=model,
731
+ # # api_base=api_base, # None uses provider default
732
+ # # messages=[{"role": "user", "content": prompt}],
733
+ # # temperature=temperature,
734
+ # # max_tokens=max_tokens,
735
+ # # extra_headers=headers,
736
+ # # )
737
+
738
+ # resp = completion(
739
+ # model=model,
740
+ # api_base=api_base,
741
+ # messages=[{"role": "user", "content": prompt}],
742
+ # temperature=temperature,
743
+ # max_tokens=max_tokens,
744
+ # extra_headers=headers,
745
+ # caching=self.cache_enabled
746
+ # )
747
+ # return (resp.choices[0].message.content or "").strip()
748
+ # except Exception as e:
749
+ # last_err = e
750
+ # if i == attempts - 1:
751
+ # break
752
+ # sleep_s = min(self.backoff_max, self.backoff_base * (2 ** i))
753
+ # jitter = sleep_s * (0.1 * (2 * random.random() - 1.0)) # ±10%
754
+ # if self.verbose:
755
+ # print(f"[retry {i+1}/{attempts-1}] {model} error: {e} — sleeping {max(0.0, sleep_s + jitter):.2f}s", flush=True)
756
+ # time.sleep(max(0.0, sleep_s + jitter))
757
+ # raise last_err # fail after retries
758
+
759
+ # def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any):
760
+ # api_base, headers, temperature = self._resolve_per_model(model)
761
+
762
+ # content = self._attempt_completion(
763
+ # model=model,
764
+ # api_base=api_base,
765
+ # headers=headers,
766
+ # prompt=prompt,
767
+ # temperature=temperature,
768
+ # max_tokens=max_tokens,
769
+ # )
295
770
 
296
- def get_stats(self):
297
- """
298
- Get cache statistics.
771
+ # # Use the instance parser
772
+ # parsed = self.output_parser(content)
773
+
774
+ # # If parser returns None and fallback is enabled, do string comparison
775
+ # if parsed is None and self.fallback_comparison:
776
+ # return str(model_output).strip() == str(ground_truth).strip()
777
+
778
+ # return parsed
779
+
780
+ # def judge(
781
+ # self,
782
+ # input: Any = None,
783
+ # model_output: Any = None,
784
+ # ground_truth: Any = None,
785
+ # prompt: Optional[str] = None,
786
+ # max_tokens: int = 10000,
787
+ # ):
788
+ # # Validation for fully_custom_prompt mode
789
+ # if self.use_fully_custom_prompt:
790
+ # if prompt is None:
791
+ # raise ValueError(
792
+ # "When use_fully_custom_prompt=True, you must pass prompt to judge()."
793
+ # )
794
+ # if input is not None or model_output is not None or ground_truth is not None:
795
+ # raise ValueError(
796
+ # "When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth to judge(). "
797
+ # "Only pass the complete prompt."
798
+ # )
799
+ # elif prompt is not None:
800
+ # raise ValueError(
801
+ # "prompt parameter can only be used when use_fully_custom_prompt=True. "
802
+ # "Otherwise, use input/model_output/ground_truth."
803
+ # )
804
+ # else:
805
+ # prompt = self._build_prompt(input, model_output, ground_truth)
806
+
807
+ # votes = []
808
+
809
+ # # Vote with litellm models
810
+ # for m in self.models:
811
+ # res = self._ask_model(m, prompt, max_tokens, model_output, ground_truth)
812
+ # if self.verbose:
813
+ # print(f"Model {m} voted: {res}", flush=True)
814
+ # votes.append({"model": m, "correct": res})
815
+
816
+ # # Vote with custom generation functions
817
+ # for idx, custom_fn in enumerate(self.custom_generation_fns):
818
+ # try:
819
+ # content = custom_fn(prompt)
820
+ # parsed = self.output_parser(content)
821
+
822
+ # # If parser returns None and fallback is enabled, do string comparison
823
+ # if parsed is None and self.fallback_comparison:
824
+ # res = str(model_output).strip() == str(ground_truth).strip()
825
+ # else:
826
+ # res = parsed
827
+
828
+ # if self.verbose:
829
+ # print(f"Custom function {idx} voted: {res}", flush=True)
830
+ # votes.append({"model": f"custom_fn_{idx}", "correct": res})
831
+ # except Exception as e:
832
+ # if self.verbose:
833
+ # print(f"Custom function {idx} failed: {e}", flush=True)
834
+ # # If custom function fails and fallback is enabled, do string comparison
835
+ # if self.fallback_comparison:
836
+ # res = str(model_output).strip() == str(ground_truth).strip()
837
+ # votes.append({"model": f"custom_fn_{idx}", "correct": res})
838
+ # else:
839
+ # raise
840
+
841
+ # if self.mode == "single":
842
+ # final = votes[0]["correct"]
843
+ # elif self.mode == "majority":
844
+ # true_votes = sum(v["correct"] for v in votes)
845
+ # false_votes = len(votes) - true_votes
846
+ # final = True if true_votes >= false_votes else False
847
+ # elif self.mode == "all":
848
+ # final = all(v["correct"] for v in votes)
849
+ # else:
850
+ # raise ValueError("mode must be 'majority', 'single', or 'all'")
851
+
852
+ # return {"correct": final, "mode": self.mode, "votes": votes}
853
+
854
+
855
+
856
+
857
+
858
+ import os
859
+ import time
860
+ import random
861
+ import re
862
+ import json
863
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
864
+ from enum import Enum
865
+ import litellm
866
+ from litellm import completion
867
+ from litellm.caching.caching import Cache
299
868
 
300
- Returns:
301
- dict with size_limit, current_size, item_count, and percent_full
302
- """
303
- size_limit = self._dc.size_limit
304
- volume = self._dc.volume() # Current size in bytes
305
- count = len(self._dc) # Number of items
306
-
307
- return {
308
- "size_limit": size_limit,
309
- "current_size": volume,
310
- "item_count": count,
311
- "percent_full": (volume / size_limit) * 100 if size_limit > 0 else 0.0,
312
- }
313
-
314
- def print_stats(self):
315
- """Print human-readable cache statistics."""
316
- stats = self.get_stats()
317
-
318
- def human_size(bytes_val):
319
- """Convert bytes to human readable format."""
320
- for unit in ["B", "KB", "MB", "GB", "TB", "PB", "EB"]:
321
- if bytes_val < 1024.0:
322
- return f"{bytes_val:.2f} {unit}"
323
- bytes_val /= 1024.0
324
- return f"{bytes_val:.2f} EB"
325
-
326
- print("=" * 60)
327
- print("CACHE STATISTICS")
328
- print("=" * 60)
329
- print(f" Size limit: {human_size(stats['size_limit'])} ({stats['size_limit']:,} bytes)")
330
- print(f" Current size: {human_size(stats['current_size'])} ({stats['current_size']:,} bytes)")
331
- print(f" Items cached: {stats['item_count']}")
332
- print(f" % full: {stats['percent_full']:.6f}%")
333
- print("=" * 60)
869
+
870
+ __all__ = ["LLMAsAJudge", "OutputParsers"]
871
+
872
+
873
+ class ReturnType(Enum):
874
+ """Enum for categorizing parser return types."""
875
+ BOOLEAN = "boolean"
876
+ SCALAR = "scalar"
877
+ MAP = "map"
878
+
879
+
880
+ class AggregationMode(Enum):
881
+ """Enum for aggregation modes across multiple voters."""
882
+ # Boolean modes
883
+ MAJORITY = "majority"
884
+ SINGLE = "single"
885
+ ALL = "all"
886
+ # Scalar modes
887
+ AVERAGE = "average"
888
+ MIN = "min"
889
+ MAX = "max"
890
+ MEDIAN = "median"
891
+
892
+
893
+ # Valid aggregation modes per return type
894
+ VALID_MODES = {
895
+ ReturnType.BOOLEAN: {AggregationMode.MAJORITY, AggregationMode.SINGLE, AggregationMode.ALL},
896
+ ReturnType.SCALAR: {AggregationMode.AVERAGE, AggregationMode.MIN, AggregationMode.MAX, AggregationMode.MEDIAN, AggregationMode.SINGLE},
897
+ ReturnType.MAP: {AggregationMode.AVERAGE, AggregationMode.MIN, AggregationMode.MAX, AggregationMode.MEDIAN, AggregationMode.SINGLE},
898
+ }
899
+
900
+ # Default aggregation modes per return type
901
+ DEFAULT_MODES = {
902
+ ReturnType.BOOLEAN: AggregationMode.MAJORITY,
903
+ ReturnType.SCALAR: AggregationMode.AVERAGE,
904
+ ReturnType.MAP: AggregationMode.AVERAGE,
905
+ }
906
+
907
+ # String to enum mapping (for backward compat)
908
+ MODE_STR_MAP = {
909
+ 'majority': AggregationMode.MAJORITY,
910
+ 'single': AggregationMode.SINGLE,
911
+ 'all': AggregationMode.ALL,
912
+ 'average': AggregationMode.AVERAGE,
913
+ 'min': AggregationMode.MIN,
914
+ 'max': AggregationMode.MAX,
915
+ 'median': AggregationMode.MEDIAN,
916
+ }
334
917
 
335
918
 
336
919
  class OutputParsers:
337
920
  """Stock output parsers for common judge output formats."""
338
921
 
922
+ # ========== BOOLEAN PARSERS ==========
923
+
339
924
  @staticmethod
340
925
  def right_wrong(s: str) -> Optional[bool]:
341
926
  """Parse 'right' or 'wrong' from the last 6 characters."""
@@ -372,6 +957,8 @@ class OutputParsers:
372
957
  return False
373
958
  return None
374
959
 
960
+ # ========== SCALAR PARSERS ==========
961
+
375
962
  @staticmethod
376
963
  def numeric_score(s: str) -> Optional[float]:
377
964
  """Extract first numeric value from the response."""
@@ -382,20 +969,188 @@ class OutputParsers:
382
969
  return float(match.group())
383
970
  return None
384
971
 
972
+ @staticmethod
973
+ def score_0_to_10(s: str) -> Optional[float]:
974
+ """Extract a score from 0-10, clamping to valid range."""
975
+ if not s:
976
+ return None
977
+ match = re.search(r'[-+]?\d*\.?\d+', s.strip())
978
+ if match:
979
+ return max(0.0, min(10.0, float(match.group())))
980
+ return None
981
+
982
+ @staticmethod
983
+ def score_0_to_100(s: str) -> Optional[float]:
984
+ """Extract a score from 0-100, clamping to valid range."""
985
+ if not s:
986
+ return None
987
+ match = re.search(r'[-+]?\d*\.?\d+', s.strip())
988
+ if match:
989
+ return max(0.0, min(100.0, float(match.group())))
990
+ return None
991
+
992
+ @staticmethod
993
+ def score_1_to_5(s: str) -> Optional[float]:
994
+ """Extract a score from 1-5, clamping to valid range."""
995
+ if not s:
996
+ return None
997
+ match = re.search(r'[-+]?\d*\.?\d+', s.strip())
998
+ if match:
999
+ return max(1.0, min(5.0, float(match.group())))
1000
+ return None
1001
+
1002
+ # ========== MAP PARSERS ==========
1003
+
385
1004
  @staticmethod
386
1005
  def json_extract(key: str) -> Callable[[str], Any]:
387
1006
  """Create a parser that extracts a specific key from JSON output."""
388
- import json
389
1007
  def parser(s: str) -> Any:
390
1008
  if not s:
391
1009
  return None
392
1010
  try:
393
- data = json.loads(s.strip())
1011
+ s = s.strip()
1012
+ if "```json" in s:
1013
+ start = s.find("```json") + 7
1014
+ end = s.find("```", start)
1015
+ s = s[start:end].strip()
1016
+ elif "```" in s:
1017
+ start = s.find("```") + 3
1018
+ end = s.find("```", start)
1019
+ s = s[start:end].strip()
1020
+ data = json.loads(s)
394
1021
  return data.get(key)
395
1022
  except (json.JSONDecodeError, AttributeError):
396
1023
  return None
397
1024
  return parser
398
1025
 
1026
+ @staticmethod
1027
+ def json_map(keys: Optional[List[str]] = None) -> Callable[[str], Optional[Dict[str, float]]]:
1028
+ """
1029
+ Create a parser that extracts multiple keys from JSON output as a map.
1030
+ Returns Dict[str, float] or None.
1031
+
1032
+ More robust version that handles:
1033
+ - Code blocks with ```json or ```
1034
+ - Extra text before/after JSON
1035
+ - Various JSON formatting issues
1036
+ """
1037
+ def parser(s: str) -> Optional[Dict[str, float]]:
1038
+ if not s:
1039
+ return None
1040
+ try:
1041
+ s = s.strip()
1042
+
1043
+ # Handle markdown code blocks
1044
+ if "```json" in s.lower():
1045
+ start = s.lower().find("```json") + 7
1046
+ end = s.find("```", start)
1047
+ if end > start:
1048
+ s = s[start:end].strip()
1049
+ elif "```" in s:
1050
+ start = s.find("```") + 3
1051
+ end = s.find("```", start)
1052
+ if end > start:
1053
+ s = s[start:end].strip()
1054
+
1055
+ # Try to find JSON object if there's extra text
1056
+ # Look for first { and last }
1057
+ if '{' in s and '}' in s:
1058
+ start_brace = s.find('{')
1059
+ end_brace = s.rfind('}')
1060
+ if start_brace < end_brace:
1061
+ s = s[start_brace:end_brace + 1]
1062
+
1063
+ data = json.loads(s)
1064
+ if not isinstance(data, dict):
1065
+ return None
1066
+
1067
+ result = {}
1068
+ target_keys = keys if keys else list(data.keys())
1069
+
1070
+ for key in target_keys:
1071
+ if key in data:
1072
+ val = data[key]
1073
+ if isinstance(val, (int, float)):
1074
+ result[key] = float(val)
1075
+ elif isinstance(val, str):
1076
+ try:
1077
+ result[key] = float(val)
1078
+ except ValueError:
1079
+ pass
1080
+
1081
+ return result if result else None
1082
+ except (json.JSONDecodeError, AttributeError, ValueError):
1083
+ return None
1084
+ return parser
1085
+
1086
+ @staticmethod
1087
+ def multi_score_pattern(pattern: str = r'(\w+):\s*([\d.]+)') -> Callable[[str], Optional[Dict[str, float]]]:
1088
+ """
1089
+ Create a parser that extracts key-value pairs using regex pattern.
1090
+ Default pattern matches "key: value" format.
1091
+ """
1092
+ def parser(s: str) -> Optional[Dict[str, float]]:
1093
+ if not s:
1094
+ return None
1095
+ matches = re.findall(pattern, s)
1096
+ if not matches:
1097
+ return None
1098
+ result = {}
1099
+ for key, val in matches:
1100
+ try:
1101
+ result[key.lower()] = float(val)
1102
+ except ValueError:
1103
+ pass
1104
+ return result if result else None
1105
+ return parser
1106
+
1107
+
1108
+ def _infer_return_type(value: Any) -> Optional[ReturnType]:
1109
+ """Infer the return type from a parsed value."""
1110
+ if value is None:
1111
+ return None
1112
+ if isinstance(value, bool):
1113
+ return ReturnType.BOOLEAN
1114
+ if isinstance(value, (int, float)):
1115
+ return ReturnType.SCALAR
1116
+ if isinstance(value, dict) and all(isinstance(v, (int, float)) for v in value.values()):
1117
+ return ReturnType.MAP
1118
+ return None
1119
+
1120
+
1121
+ def _validate_consistent_types(votes: List[Dict[str, Any]]) -> ReturnType:
1122
+ """Validate that all votes have consistent return types."""
1123
+ return_types = set()
1124
+ for vote in votes:
1125
+ result = vote.get("result")
1126
+ if result is not None:
1127
+ rt = _infer_return_type(result)
1128
+ if rt is not None:
1129
+ return_types.add(rt)
1130
+
1131
+ if len(return_types) == 0:
1132
+ raise ValueError("All parsers returned None - cannot determine return type")
1133
+ if len(return_types) > 1:
1134
+ raise ValueError(
1135
+ f"Mixed return types detected: {[rt.value for rt in return_types]}. "
1136
+ "All judges must return the same type (boolean, scalar, or map)."
1137
+ )
1138
+ return return_types.pop()
1139
+
1140
+
1141
+ def _validate_consistent_map_keys(votes: List[Dict[str, Any]]) -> set:
1142
+ """Validate that all map votes have the same keys."""
1143
+ all_keys = None
1144
+ for vote in votes:
1145
+ result = vote.get("result")
1146
+ if isinstance(result, dict):
1147
+ keys = set(result.keys())
1148
+ if all_keys is None:
1149
+ all_keys = keys
1150
+ elif keys != all_keys:
1151
+ raise ValueError(f"Inconsistent map keys across voters. Expected {all_keys}, got {keys}")
1152
+ return all_keys or set()
1153
+
399
1154
 
400
1155
  class LLMAsAJudge:
401
1156
  BASE_TEMPLATE = """\
@@ -439,123 +1194,13 @@ Return a single numeric score from 0-10 indicating how well the model output mat
439
1194
  - 1-3 = poor match
440
1195
  - 0 = completely wrong
441
1196
  Output only the number. No explanation. No extra text.""",
1197
+ 'multi_score': """\
1198
+ Evaluate the model output against the ground truth on multiple dimensions.
1199
+ Return your scores in JSON format with numeric values (0-10 scale).
1200
+ Example: {"accuracy": 8, "helpfulness": 7, "relevance": 9}
1201
+ Output only valid JSON. No explanation. No extra text.""",
442
1202
  }
443
1203
 
444
-
445
-
446
-
447
- # def __init__(
448
- # self,
449
- # models: Optional[List[str]] = None,
450
- # config: Optional[Dict[str, Dict[str, Any]]] = None, # one dict for providers and models
451
- # base_headers: Optional[Dict[str, str]] = None,
452
- # wandb_project: Optional[str] = None,
453
- # custom_template: Optional[str] = None,
454
- # use_fully_custom_prompt: bool = False,
455
- # notes: Optional[str] = None,
456
- # output_parser: Optional[str] = 'right/wrong',
457
- # fallback_comparison: bool = True,
458
- # default_temperature: float = 0.0,
459
- # verbose: bool = False,
460
- # num_retries: int = 2, # per-call retries before giving up on that model
461
- # backoff_base: float = 0.5, # seconds
462
- # backoff_max: float = 4.0, # seconds
463
- # custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
464
- # mode: str = "majority", # "single", "majority", "all"
465
- # ):
466
- # """
467
- # config keys can be a provider name ("wandb", "openai", "anthropic")
468
- # or a full model name ("openai/gpt-4o-mini", "wandb/deepseek-ai/DeepSeek-V3.1").
469
-
470
- # Values can include:
471
- # api_base: Optional[str]
472
- # headers: Dict[str, str]
473
- # temperature: float
474
-
475
- # Precedence:
476
- # base_headers < provider config < model config
477
-
478
- # Args:
479
- # models: List of litellm model strings (e.g., ["openai/gpt-4", "anthropic/claude-3"])
480
- # custom_template: Template with placeholders for input/output/ground_truth
481
- # use_fully_custom_prompt: If True, pass complete prompt to judge(prompt=...).
482
- # When True, input/output/ground_truth must NOT be passed to judge()
483
- # output_parser: Parser name ('right/wrong', 'yes/no', 'pass/fail', 'numeric')
484
- # or custom function with signature (str) -> Any
485
- # fallback_comparison: If True and parser returns None, falls back to string comparison
486
- # custom_generation_fns: List of custom inference functions with signature fn(prompt: str) -> str
487
- # These will be used in addition to litellm models for voting.
488
- # mode: Voting mode - "majority" (default), "single" (first judge only), or "all" (unanimous)
489
- # """
490
- # self.models = models or []
491
- # self.custom_generation_fns = custom_generation_fns or []
492
-
493
- # # Validate that at least one judge is provided
494
- # if not self.models and not self.custom_generation_fns:
495
- # raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
496
-
497
- # # Validate mode
498
- # if mode not in ("majority", "single", "all"):
499
- # raise ValueError("mode must be 'majority', 'single', or 'all'")
500
-
501
- # self.config = config or {}
502
- # self.base_headers = dict(base_headers or {})
503
- # self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
504
- # self.notes = notes or ""
505
- # self.use_fully_custom_prompt = use_fully_custom_prompt
506
- # self.mode = mode
507
-
508
- # # Resolve output parser
509
- # parser_name = None
510
- # if isinstance(output_parser, str):
511
- # parser_map = {
512
- # 'right/wrong': OutputParsers.right_wrong,
513
- # 'pass/fail': OutputParsers.pass_fail,
514
- # 'yes/no': OutputParsers.yes_no,
515
- # 'numeric': OutputParsers.numeric_score,
516
- # }
517
- # if output_parser not in parser_map:
518
- # raise ValueError(f"Unknown parser '{output_parser}'. Available: {list(parser_map.keys())}")
519
- # self.output_parser = parser_map[output_parser]
520
- # parser_name = output_parser
521
- # else:
522
- # self.output_parser = output_parser
523
-
524
- # # Set template based on mode
525
- # if use_fully_custom_prompt:
526
- # self.template = None # No template in fully custom mode
527
- # elif custom_template:
528
- # self.template = custom_template
529
- # elif parser_name and parser_name in self.PARSER_INSTRUCTIONS:
530
- # self.template = self.BASE_TEMPLATE.format(
531
- # instruction=self.PARSER_INSTRUCTIONS[parser_name],
532
- # notes_section="{notes_section}",
533
- # input_block="{input_block}",
534
- # model_output="{model_output}",
535
- # ground_truth="{ground_truth}",
536
- # )
537
- # else:
538
- # # Default to right/wrong for custom parsers
539
- # self.template = self.BASE_TEMPLATE.format(
540
- # instruction=self.PARSER_INSTRUCTIONS['right/wrong'],
541
- # notes_section="{notes_section}",
542
- # input_block="{input_block}",
543
- # model_output="{model_output}",
544
- # ground_truth="{ground_truth}",
545
- # )
546
-
547
- # self.fallback_comparison = fallback_comparison
548
- # self.default_temperature = float(default_temperature)
549
- # self.verbose = verbose
550
- # self.num_retries = int(num_retries)
551
- # self.backoff_base = float(backoff_base)
552
- # self.backoff_max = float(backoff_max)
553
-
554
-
555
-
556
-
557
-
558
-
559
1204
  def __init__(
560
1205
  self,
561
1206
  models: Optional[List[str]] = None,
@@ -565,7 +1210,7 @@ Output only the number. No explanation. No extra text.""",
565
1210
  custom_template: Optional[str] = None,
566
1211
  use_fully_custom_prompt: bool = False,
567
1212
  notes: Optional[str] = None,
568
- output_parser: Optional[str] = 'right/wrong',
1213
+ output_parser: Optional[Union[str, Callable]] = 'right/wrong',
569
1214
  fallback_comparison: bool = True,
570
1215
  default_temperature: float = 0.0,
571
1216
  verbose: bool = False,
@@ -573,9 +1218,9 @@ Output only the number. No explanation. No extra text.""",
573
1218
  backoff_base: float = 0.5,
574
1219
  backoff_max: float = 4.0,
575
1220
  custom_generation_fns: Optional[List[Callable[[str], str]]] = None,
576
- mode: str = "majority",
1221
+ mode: Optional[str] = None,
577
1222
  litellm_cache_dir: Optional[str] = None,
578
- cache_size_gb: Optional[float] = None,
1223
+ return_type: Optional[str] = None,
579
1224
  ):
580
1225
  self.models = models or []
581
1226
  self.custom_generation_fns = custom_generation_fns or []
@@ -583,15 +1228,11 @@ Output only the number. No explanation. No extra text.""",
583
1228
  if not self.models and not self.custom_generation_fns:
584
1229
  raise ValueError("Must provide at least one of: models (litellm) or custom_generation_fns")
585
1230
 
586
- if mode not in ("majority", "single", "all"):
587
- raise ValueError("mode must be 'majority', 'single', or 'all'")
588
-
589
1231
  self.config = config or {}
590
1232
  self.base_headers = dict(base_headers or {})
591
1233
  self.wandb_project = wandb_project or os.getenv("WANDB_PROJECT")
592
1234
  self.notes = notes or ""
593
1235
  self.use_fully_custom_prompt = use_fully_custom_prompt
594
- self.mode = mode
595
1236
  self.fallback_comparison = fallback_comparison
596
1237
  self.default_temperature = float(default_temperature)
597
1238
  self.verbose = verbose
@@ -599,6 +1240,7 @@ Output only the number. No explanation. No extra text.""",
599
1240
  self.backoff_base = float(backoff_base)
600
1241
  self.backoff_max = float(backoff_max)
601
1242
 
1243
+ # Resolve output parser
602
1244
  parser_name = None
603
1245
  if isinstance(output_parser, str):
604
1246
  parser_map = {
@@ -606,6 +1248,7 @@ Output only the number. No explanation. No extra text.""",
606
1248
  'pass/fail': OutputParsers.pass_fail,
607
1249
  'yes/no': OutputParsers.yes_no,
608
1250
  'numeric': OutputParsers.numeric_score,
1251
+ 'multi_score': OutputParsers.json_map(),
609
1252
  }
610
1253
  if output_parser not in parser_map:
611
1254
  raise ValueError(f"Unknown parser '{output_parser}'")
@@ -614,6 +1257,42 @@ Output only the number. No explanation. No extra text.""",
614
1257
  else:
615
1258
  self.output_parser = output_parser
616
1259
 
1260
+ # Determine expected return type
1261
+ self._explicit_return_type: Optional[ReturnType] = None
1262
+ if return_type is not None:
1263
+ self._explicit_return_type = ReturnType(return_type)
1264
+ elif parser_name:
1265
+ if parser_name in ('right/wrong', 'pass/fail', 'yes/no'):
1266
+ self._explicit_return_type = ReturnType.BOOLEAN
1267
+ elif parser_name == 'numeric':
1268
+ self._explicit_return_type = ReturnType.SCALAR
1269
+ elif parser_name == 'multi_score':
1270
+ self._explicit_return_type = ReturnType.MAP
1271
+
1272
+ # Resolve aggregation mode
1273
+ if mode is not None:
1274
+ if mode not in MODE_STR_MAP:
1275
+ raise ValueError(f"Unknown mode '{mode}'. Available: {list(MODE_STR_MAP.keys())}")
1276
+ self._mode = MODE_STR_MAP[mode]
1277
+ else:
1278
+ if self._explicit_return_type:
1279
+ self._mode = DEFAULT_MODES[self._explicit_return_type]
1280
+ else:
1281
+ self._mode = AggregationMode.MAJORITY
1282
+
1283
+ # Validate mode against return type if known
1284
+ if self._explicit_return_type:
1285
+ valid_modes = VALID_MODES[self._explicit_return_type]
1286
+ if self._mode not in valid_modes:
1287
+ raise ValueError(
1288
+ f"Mode '{self._mode.value}' not valid for return type '{self._explicit_return_type.value}'. "
1289
+ f"Valid: {[m.value for m in valid_modes]}"
1290
+ )
1291
+
1292
+ # For backward compat, expose mode as string
1293
+ self.mode = self._mode.value
1294
+
1295
+ # Set template
617
1296
  if use_fully_custom_prompt:
618
1297
  self.template = None
619
1298
  elif custom_template:
@@ -635,43 +1314,23 @@ Output only the number. No explanation. No extra text.""",
635
1314
  ground_truth="{ground_truth}",
636
1315
  )
637
1316
 
638
- # optional local cache setup
1317
+ # Optional local cache setup
639
1318
  self.cache_enabled = litellm_cache_dir is not None
640
1319
  if self.cache_enabled:
641
- # Convert GB to bytes if specified, otherwise unlimited
642
- size_limit_bytes = None if cache_size_gb is None else int(cache_size_gb * 1024 * 1024 * 1024)
643
- cache_backend = UnlimitedDiskCache(litellm_cache_dir, size_limit=size_limit_bytes)
644
- litellm.cache = Cache(disk_cache_dir=litellm_cache_dir)
645
- litellm.cache.cache = cache_backend
646
-
647
-
648
-
649
-
650
-
651
-
652
-
653
-
1320
+ litellm.cache = Cache(type="disk", disk_cache_dir=litellm_cache_dir)
654
1321
 
655
1322
  def _build_prompt(self, input: Any, model_output: Any, ground_truth: Any) -> str:
656
1323
  notes_section = f"notes:\n{self.notes}\n" if self.notes else ""
657
1324
  input_text = str(input) if input not in (None, "") else "[omitted input for brevity]"
658
- return self.template.format(
659
- notes_section=notes_section,
660
- input_block=input_text,
661
- model_output=str(model_output),
662
- ground_truth=str(ground_truth),
663
- )
664
1325
 
665
- @staticmethod
666
- def _last6_right_wrong(s: str):
667
- if not s:
668
- return None
669
- tail = s.strip()[-6:].lower()
670
- if "right" in tail:
671
- return True
672
- if "wrong" in tail:
673
- return False
674
- return None
1326
+ # Use string replacement instead of .format() to avoid issues with { } in template
1327
+ # (e.g., JSON examples in multi_score instructions)
1328
+ prompt = self.template
1329
+ prompt = prompt.replace("{notes_section}", notes_section)
1330
+ prompt = prompt.replace("{input_block}", input_text)
1331
+ prompt = prompt.replace("{model_output}", str(model_output))
1332
+ prompt = prompt.replace("{ground_truth}", str(ground_truth))
1333
+ return prompt
675
1334
 
676
1335
  def _resolve_per_model(self, model: str) -> Tuple[Optional[str], Dict[str, str], float]:
677
1336
  provider = model.split("/", 1)[0] if "/" in model else model
@@ -680,7 +1339,6 @@ Output only the number. No explanation. No extra text.""",
680
1339
  headers: Dict[str, str] = dict(self.base_headers)
681
1340
  temperature: float = self.default_temperature
682
1341
 
683
- # provider-level defaults
684
1342
  pc = self.config.get(provider, {})
685
1343
  if pc.get("api_base") is not None:
686
1344
  api_base = pc["api_base"]
@@ -688,7 +1346,6 @@ Output only the number. No explanation. No extra text.""",
688
1346
  if "temperature" in pc:
689
1347
  temperature = float(pc["temperature"])
690
1348
 
691
- # model-level overrides
692
1349
  mc = self.config.get(model, {})
693
1350
  if mc.get("api_base") is not None:
694
1351
  api_base = mc["api_base"]
@@ -696,7 +1353,6 @@ Output only the number. No explanation. No extra text.""",
696
1353
  if "temperature" in mc:
697
1354
  temperature = float(mc["temperature"])
698
1355
 
699
- # wandb defaults
700
1356
  if provider == "wandb":
701
1357
  if api_base is None:
702
1358
  api_base = "https://api.inference.wandb.ai/v1"
@@ -718,15 +1374,6 @@ Output only the number. No explanation. No extra text.""",
718
1374
  last_err = None
719
1375
  for i in range(attempts):
720
1376
  try:
721
- # resp = completion(
722
- # model=model,
723
- # api_base=api_base, # None uses provider default
724
- # messages=[{"role": "user", "content": prompt}],
725
- # temperature=temperature,
726
- # max_tokens=max_tokens,
727
- # extra_headers=headers,
728
- # )
729
-
730
1377
  resp = completion(
731
1378
  model=model,
732
1379
  api_base=api_base,
@@ -735,20 +1382,20 @@ Output only the number. No explanation. No extra text.""",
735
1382
  max_tokens=max_tokens,
736
1383
  extra_headers=headers,
737
1384
  caching=self.cache_enabled
738
- )
1385
+ )
739
1386
  return (resp.choices[0].message.content or "").strip()
740
1387
  except Exception as e:
741
1388
  last_err = e
742
1389
  if i == attempts - 1:
743
1390
  break
744
1391
  sleep_s = min(self.backoff_max, self.backoff_base * (2 ** i))
745
- jitter = sleep_s * (0.1 * (2 * random.random() - 1.0)) # ±10%
1392
+ jitter = sleep_s * (0.1 * (2 * random.random() - 1.0))
746
1393
  if self.verbose:
747
1394
  print(f"[retry {i+1}/{attempts-1}] {model} error: {e} — sleeping {max(0.0, sleep_s + jitter):.2f}s", flush=True)
748
1395
  time.sleep(max(0.0, sleep_s + jitter))
749
- raise last_err # fail after retries
1396
+ raise last_err
750
1397
 
751
- def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any):
1398
+ def _ask_model(self, model: str, prompt: str, max_tokens: int, model_output: Any, ground_truth: Any) -> Any:
752
1399
  api_base, headers, temperature = self._resolve_per_model(model)
753
1400
 
754
1401
  content = self._attempt_completion(
@@ -760,15 +1407,80 @@ Output only the number. No explanation. No extra text.""",
760
1407
  max_tokens=max_tokens,
761
1408
  )
762
1409
 
763
- # Use the instance parser
764
1410
  parsed = self.output_parser(content)
765
1411
 
766
- # If parser returns None and fallback is enabled, do string comparison
767
1412
  if parsed is None and self.fallback_comparison:
768
1413
  return str(model_output).strip() == str(ground_truth).strip()
769
1414
 
770
1415
  return parsed
771
1416
 
1417
+ def _aggregate_boolean(self, votes: List[Dict[str, Any]]) -> bool:
1418
+ results = [v["result"] for v in votes if v["result"] is not None]
1419
+ if not results:
1420
+ raise ValueError("No valid votes to aggregate")
1421
+
1422
+ if self._mode == AggregationMode.SINGLE:
1423
+ return bool(results[0])
1424
+ elif self._mode == AggregationMode.MAJORITY:
1425
+ return sum(1 for r in results if r) >= len(results) / 2
1426
+ elif self._mode == AggregationMode.ALL:
1427
+ return all(results)
1428
+ else:
1429
+ raise ValueError(f"Invalid mode for boolean: {self._mode}")
1430
+
1431
+ def _aggregate_scalar(self, votes: List[Dict[str, Any]]) -> float:
1432
+ results = [float(v["result"]) for v in votes if v["result"] is not None]
1433
+ if not results:
1434
+ raise ValueError("No valid votes to aggregate")
1435
+
1436
+ if self._mode == AggregationMode.SINGLE:
1437
+ return results[0]
1438
+ elif self._mode == AggregationMode.AVERAGE:
1439
+ return sum(results) / len(results)
1440
+ elif self._mode == AggregationMode.MIN:
1441
+ return min(results)
1442
+ elif self._mode == AggregationMode.MAX:
1443
+ return max(results)
1444
+ elif self._mode == AggregationMode.MEDIAN:
1445
+ s = sorted(results)
1446
+ n = len(s)
1447
+ mid = n // 2
1448
+ return (s[mid - 1] + s[mid]) / 2 if n % 2 == 0 else s[mid]
1449
+ else:
1450
+ raise ValueError(f"Invalid mode for scalar: {self._mode}")
1451
+
1452
+ def _aggregate_map(self, votes: List[Dict[str, Any]]) -> Dict[str, float]:
1453
+ valid = [v["result"] for v in votes if v["result"] is not None and isinstance(v["result"], dict)]
1454
+ if not valid:
1455
+ raise ValueError("No valid map votes to aggregate")
1456
+
1457
+ keys = set()
1458
+ for v in valid:
1459
+ keys.update(v.keys())
1460
+
1461
+ if self._mode == AggregationMode.SINGLE:
1462
+ return valid[0]
1463
+
1464
+ result = {}
1465
+ for key in keys:
1466
+ values = [v[key] for v in valid if key in v]
1467
+ if not values:
1468
+ continue
1469
+
1470
+ if self._mode == AggregationMode.AVERAGE:
1471
+ result[key] = sum(values) / len(values)
1472
+ elif self._mode == AggregationMode.MIN:
1473
+ result[key] = min(values)
1474
+ elif self._mode == AggregationMode.MAX:
1475
+ result[key] = max(values)
1476
+ elif self._mode == AggregationMode.MEDIAN:
1477
+ s = sorted(values)
1478
+ n = len(s)
1479
+ mid = n // 2
1480
+ result[key] = (s[mid - 1] + s[mid]) / 2 if n % 2 == 0 else s[mid]
1481
+
1482
+ return result
1483
+
772
1484
  def judge(
773
1485
  self,
774
1486
  input: Any = None,
@@ -776,23 +1488,25 @@ Output only the number. No explanation. No extra text.""",
776
1488
  ground_truth: Any = None,
777
1489
  prompt: Optional[str] = None,
778
1490
  max_tokens: int = 10000,
779
- ):
780
- # Validation for fully_custom_prompt mode
1491
+ ) -> Dict[str, Any]:
1492
+ """
1493
+ Run the judge evaluation.
1494
+
1495
+ Returns dict with:
1496
+ - 'correct': bool or None (bool for boolean type, None for scalar/map)
1497
+ - 'scores': float or Dict[str, float] or None (for scalar/map, None for boolean)
1498
+ - 'mode': aggregation mode string
1499
+ - 'votes': list of individual judge votes
1500
+ """
781
1501
  if self.use_fully_custom_prompt:
782
1502
  if prompt is None:
783
- raise ValueError(
784
- "When use_fully_custom_prompt=True, you must pass prompt to judge()."
785
- )
1503
+ raise ValueError("When use_fully_custom_prompt=True, you must pass prompt to judge().")
786
1504
  if input is not None or model_output is not None or ground_truth is not None:
787
1505
  raise ValueError(
788
- "When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth to judge(). "
789
- "Only pass the complete prompt."
1506
+ "When use_fully_custom_prompt=True, you cannot pass input, model_output, or ground_truth."
790
1507
  )
791
1508
  elif prompt is not None:
792
- raise ValueError(
793
- "prompt parameter can only be used when use_fully_custom_prompt=True. "
794
- "Otherwise, use input/model_output/ground_truth."
795
- )
1509
+ raise ValueError("prompt parameter can only be used when use_fully_custom_prompt=True.")
796
1510
  else:
797
1511
  prompt = self._build_prompt(input, model_output, ground_truth)
798
1512
 
@@ -800,18 +1514,27 @@ Output only the number. No explanation. No extra text.""",
800
1514
 
801
1515
  # Vote with litellm models
802
1516
  for m in self.models:
803
- res = self._ask_model(m, prompt, max_tokens, model_output, ground_truth)
804
- if self.verbose:
805
- print(f"Model {m} voted: {res}", flush=True)
806
- votes.append({"model": m, "correct": res})
1517
+ try:
1518
+ res = self._ask_model(m, prompt, max_tokens, model_output, ground_truth)
1519
+ if self.verbose:
1520
+ print(f"Model {m} voted: {res}", flush=True)
1521
+ votes.append({"model": m, "result": res})
1522
+ except Exception as e:
1523
+ if self.verbose:
1524
+ print(f"Model {m} failed: {e}", flush=True)
1525
+ if self.fallback_comparison:
1526
+ res = str(model_output).strip() == str(ground_truth).strip()
1527
+ votes.append({"model": m, "result": res, "error": str(e)})
1528
+ else:
1529
+ raise
807
1530
 
808
1531
  # Vote with custom generation functions
809
1532
  for idx, custom_fn in enumerate(self.custom_generation_fns):
1533
+ fn_name = f"custom_fn_{idx}"
810
1534
  try:
811
1535
  content = custom_fn(prompt)
812
1536
  parsed = self.output_parser(content)
813
1537
 
814
- # If parser returns None and fallback is enabled, do string comparison
815
1538
  if parsed is None and self.fallback_comparison:
816
1539
  res = str(model_output).strip() == str(ground_truth).strip()
817
1540
  else:
@@ -819,27 +1542,70 @@ Output only the number. No explanation. No extra text.""",
819
1542
 
820
1543
  if self.verbose:
821
1544
  print(f"Custom function {idx} voted: {res}", flush=True)
822
- votes.append({"model": f"custom_fn_{idx}", "correct": res})
1545
+ votes.append({"model": fn_name, "result": res})
823
1546
  except Exception as e:
824
1547
  if self.verbose:
825
1548
  print(f"Custom function {idx} failed: {e}", flush=True)
826
- # If custom function fails and fallback is enabled, do string comparison
827
1549
  if self.fallback_comparison:
828
1550
  res = str(model_output).strip() == str(ground_truth).strip()
829
- votes.append({"model": f"custom_fn_{idx}", "correct": res})
1551
+ votes.append({"model": fn_name, "result": res, "error": str(e)})
830
1552
  else:
831
1553
  raise
832
1554
 
833
- if self.mode == "single":
834
- final = votes[0]["correct"]
835
- elif self.mode == "majority":
836
- true_votes = sum(v["correct"] for v in votes)
837
- false_votes = len(votes) - true_votes
838
- final = True if true_votes >= false_votes else False
839
- elif self.mode == "all":
840
- final = all(v["correct"] for v in votes)
1555
+ # Determine return type
1556
+ return_type = self._explicit_return_type
1557
+ if return_type is None:
1558
+ return_type = _validate_consistent_types(votes)
841
1559
  else:
842
- raise ValueError("mode must be 'majority', 'single', or 'all'")
843
-
844
- return {"correct": final, "mode": self.mode, "votes": votes}
845
-
1560
+ actual = _validate_consistent_types(votes)
1561
+ if actual != return_type:
1562
+ raise ValueError(f"Expected return type '{return_type.value}' but got '{actual.value}'")
1563
+
1564
+ # Auto-correct mode if needed
1565
+ if self._mode not in VALID_MODES[return_type]:
1566
+ self._mode = DEFAULT_MODES[return_type]
1567
+ self.mode = self._mode.value
1568
+
1569
+ # Validate map keys
1570
+ if return_type == ReturnType.MAP:
1571
+ _validate_consistent_map_keys(votes)
1572
+
1573
+ # Aggregate
1574
+ if return_type == ReturnType.BOOLEAN:
1575
+ final = self._aggregate_boolean(votes)
1576
+ elif return_type == ReturnType.SCALAR:
1577
+ final = self._aggregate_scalar(votes)
1578
+ elif return_type == ReturnType.MAP:
1579
+ final = self._aggregate_map(votes)
1580
+ else:
1581
+ raise ValueError(f"Unknown return type: {return_type}")
1582
+
1583
+ # Build backward-compatible response
1584
+ # Boolean: correct=bool, scores=None
1585
+ # Scalar: correct=score, scores=score (both fields for convenience)
1586
+ # Map: correct=None, scores=map
1587
+ if return_type == ReturnType.BOOLEAN:
1588
+ # Also put "correct" in each vote for backward compat
1589
+ for v in votes:
1590
+ v["correct"] = v["result"]
1591
+ return {
1592
+ "correct": final,
1593
+ "scores": None,
1594
+ "mode": self.mode,
1595
+ "votes": votes,
1596
+ }
1597
+ elif return_type == ReturnType.SCALAR:
1598
+ # For scalar, put score in both correct and scores for convenience
1599
+ return {
1600
+ "correct": final,
1601
+ "scores": final,
1602
+ "mode": self.mode,
1603
+ "votes": votes,
1604
+ }
1605
+ else: # MAP
1606
+ return {
1607
+ "correct": None,
1608
+ "scores": final,
1609
+ "mode": self.mode,
1610
+ "votes": votes,
1611
+ }