judgeval 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -41,11 +41,13 @@ from judgeval.constants import JUDGMENT_TRACES_SAVE_API_URL, JUDGMENT_TRACES_FET
41
41
  from judgeval.judgment_client import JudgmentClient
42
42
  from judgeval.data import Example
43
43
  from judgeval.scorers import APIJudgmentScorer, JudgevalScorer, ScorerWrapper
44
+ from judgeval.rules import Rule
45
+ from judgeval.evaluation_run import EvaluationRun
46
+ from judgeval.judges import JudgevalJudge
44
47
 
45
48
  from rich import print as rprint
46
49
 
47
50
  from judgeval.data.result import ScoringResult
48
- from judgeval.evaluation_run import EvaluationRun
49
51
 
50
52
  # Define type aliases for better code readability and maintainability
51
53
  ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
@@ -285,17 +287,29 @@ class TraceManagerClient:
285
287
 
286
288
  class TraceClient:
287
289
  """Client for managing a single trace context"""
288
- def __init__(self, tracer, trace_id: str, name: str, project_name: str = "default_project", overwrite: bool = False):
289
- self.tracer = tracer
290
- self.trace_id = trace_id
290
+
291
+ def __init__(
292
+ self,
293
+ tracer: Optional["Tracer"],
294
+ trace_id: Optional[str] = None,
295
+ name: str = "default",
296
+ project_name: str = "default_project",
297
+ overwrite: bool = False,
298
+ rules: Optional[List[Rule]] = None,
299
+ ):
291
300
  self.name = name
301
+ self.trace_id = trace_id or str(uuid.uuid4())
292
302
  self.project_name = project_name
303
+ self.overwrite = overwrite
304
+ self.tracer = tracer
305
+ # Initialize rules with either provided rules or an empty list
306
+ self.rules = rules or []
307
+
293
308
  self.client: JudgmentClient = tracer.client
294
309
  self.entries: List[TraceEntry] = []
295
310
  self.start_time = time.time()
296
311
  self.span_type = None
297
312
  self._current_span: Optional[TraceEntry] = None
298
- self.overwrite = overwrite
299
313
  self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id) # Manages DB operations for trace data
300
314
 
301
315
  @contextmanager
@@ -348,7 +362,7 @@ class TraceClient:
348
362
  expected_tools: Optional[List[str]] = None,
349
363
  additional_metadata: Optional[Dict[str, Any]] = None,
350
364
  model: Optional[str] = None,
351
- log_results: Optional[bool] = True,
365
+ log_results: Optional[bool] = True
352
366
  ):
353
367
  start_time = time.time() # Record start time
354
368
  example = Example(
@@ -362,29 +376,68 @@ class TraceClient:
362
376
  additional_metadata=additional_metadata,
363
377
  trace_id=self.trace_id
364
378
  )
365
-
379
+ loaded_rules = None
380
+ if self.rules:
381
+ loaded_rules = []
382
+ for rule in self.rules:
383
+ processed_conditions = []
384
+ for condition in rule.conditions:
385
+ # Convert metric if it's a ScorerWrapper
386
+ try:
387
+ if isinstance(condition.metric, ScorerWrapper):
388
+ condition_copy = condition.model_copy()
389
+ condition_copy.metric = condition.metric.load_implementation(use_judgment=True)
390
+ processed_conditions.append(condition_copy)
391
+ else:
392
+ processed_conditions.append(condition)
393
+ except Exception as e:
394
+ warnings.warn(f"Failed to convert ScorerWrapper in rule '{rule.name}', condition metric '{condition.metric_name}': {str(e)}")
395
+ processed_conditions.append(condition) # Keep original condition as fallback
396
+
397
+ # Create new rule with processed conditions
398
+ new_rule = rule.model_copy()
399
+ new_rule.conditions = processed_conditions
400
+ loaded_rules.append(new_rule)
366
401
  try:
367
402
  # Load appropriate implementations for all scorers
368
- loaded_scorers: List[Union[JudgevalScorer, APIJudgmentScorer]] = [
369
- scorer.load_implementation(use_judgment=True) if isinstance(scorer, ScorerWrapper) else scorer
370
- for scorer in scorers
371
- ]
403
+ loaded_scorers: List[Union[JudgevalScorer, APIJudgmentScorer]] = []
404
+ for scorer in scorers:
405
+ try:
406
+ if isinstance(scorer, ScorerWrapper):
407
+ loaded_scorers.append(scorer.load_implementation(use_judgment=True))
408
+ else:
409
+ loaded_scorers.append(scorer)
410
+ except Exception as e:
411
+ warnings.warn(f"Failed to load implementation for scorer {scorer}: {str(e)}")
412
+ # Skip this scorer
413
+
414
+ if not loaded_scorers:
415
+ warnings.warn("No valid scorers available for evaluation")
416
+ return
417
+
418
+ # Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
419
+ if loaded_rules and any(isinstance(scorer, JudgevalScorer) for scorer in loaded_scorers):
420
+ raise ValueError("Cannot use Judgeval scorers (only API scorers) when using rules. Please either remove rules or use only APIJudgmentScorer types.")
421
+
372
422
  except Exception as e:
373
- raise ValueError(f"Failed to load scorers: {str(e)}")
423
+ warnings.warn(f"Failed to load scorers: {str(e)}")
424
+ return
374
425
 
426
+ # Combine the trace-level rules with any evaluation-specific rules)
375
427
  eval_run = EvaluationRun(
376
428
  organization_id=self.tracer.organization_id,
377
429
  log_results=log_results,
378
430
  project_name=self.project_name,
379
431
  eval_name=f"{self.name.capitalize()}-"
380
432
  f"{self._current_span}-"
381
- f"[{','.join(scorer.load_implementation().score_type.capitalize() for scorer in scorers)}]",
433
+ f"[{','.join(scorer.score_type.capitalize() for scorer in loaded_scorers)}]",
382
434
  examples=[example],
383
435
  scorers=loaded_scorers,
384
436
  model=model,
385
437
  metadata={},
386
438
  judgment_api_key=self.tracer.api_key,
387
- override=self.overwrite
439
+ override=self.overwrite,
440
+ rules=loaded_rules # Use the combined rules
388
441
  )
389
442
 
390
443
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
@@ -562,7 +615,6 @@ class TraceClient:
562
615
  "empty_save": empty_save,
563
616
  "overwrite": overwrite
564
617
  }
565
-
566
618
  # Execute asynchrous evaluation in the background
567
619
  if not empty_save: # Only send to RabbitMQ if the trace is not empty
568
620
  connection = pika.BlockingConnection(
@@ -572,13 +624,16 @@ class TraceClient:
572
624
  channel.queue_declare(queue=RABBITMQ_QUEUE, durable=True)
573
625
  trace_data["judgment_api_key"] = self.tracer.api_key
574
626
  trace_data["organization_id"] = self.tracer.organization_id
575
-
576
627
  channel.basic_publish(
577
628
  exchange='',
578
629
  routing_key=RABBITMQ_QUEUE,
579
630
  body=json.dumps(trace_data),
580
631
  properties=pika.BasicProperties(
581
- delivery_mode=pika.DeliveryMode.Transient # Changed from Persistent to Transient
632
+ delivery_mode=pika.DeliveryMode.Transient, # Changed from Persistent to Transient
633
+ headers={
634
+ 'api_key': self.tracer.api_key,
635
+ 'organization_id': self.tracer.organization_id
636
+ }
582
637
  ))
583
638
  connection.close()
584
639
 
@@ -597,7 +652,12 @@ class Tracer:
597
652
  cls._instance = super(Tracer, cls).__new__(cls)
598
653
  return cls._instance
599
654
 
600
- def __init__(self, api_key: str = os.getenv("JUDGMENT_API_KEY"), project_name: str = "default_project", organization_id: str = os.getenv("ORGANIZATION_ID")):
655
+ def __init__(
656
+ self,
657
+ api_key: str = os.getenv("JUDGMENT_API_KEY"),
658
+ project_name: str = "default_project",
659
+ rules: Optional[List[Rule]] = None, # Added rules parameter
660
+ organization_id: str = os.getenv("JUDGMENT_ORG_ID")):
601
661
  if not hasattr(self, 'initialized'):
602
662
  if not api_key:
603
663
  raise ValueError("Tracer must be configured with a Judgment API key")
@@ -611,6 +671,7 @@ class Tracer:
611
671
  self.organization_id: str = organization_id
612
672
  self.depth: int = 0
613
673
  self._current_trace: Optional[str] = None
674
+ self.rules: List[Rule] = rules or [] # Store rules at tracer level
614
675
  self.initialized: bool = True
615
676
  elif hasattr(self, 'project_name') and self.project_name != project_name:
616
677
  warnings.warn(
@@ -621,11 +682,25 @@ class Tracer:
621
682
  )
622
683
 
623
684
  @contextmanager
624
- def trace(self, name: str, project_name: str = None, overwrite: bool = False) -> Generator[TraceClient, None, None]:
685
+ def trace(
686
+ self,
687
+ name: str,
688
+ project_name: str = None,
689
+ overwrite: bool = False,
690
+ rules: Optional[List[Rule]] = None # Added rules parameter
691
+ ) -> Generator[TraceClient, None, None]:
625
692
  """Start a new trace context using a context manager"""
626
693
  trace_id = str(uuid.uuid4())
627
694
  project = project_name if project_name is not None else self.project_name
628
- trace = TraceClient(self, trace_id, name, project_name=project, overwrite=overwrite)
695
+
696
+ trace = TraceClient(
697
+ self,
698
+ trace_id,
699
+ name,
700
+ project_name=project,
701
+ overwrite=overwrite,
702
+ rules=self.rules # Pass combined rules to the trace client
703
+ )
629
704
  prev_trace = self._current_trace
630
705
  self._current_trace = trace
631
706
 
@@ -669,9 +744,9 @@ class Tracer:
669
744
  trace = self._current_trace
670
745
  else:
671
746
  trace_id = str(uuid.uuid4())
672
- trace_name = str(uuid.uuid4())
747
+ trace_name = func.__name__
673
748
  project = project_name if project_name is not None else self.project_name
674
- trace = TraceClient(self, trace_id, trace_name, project_name=project, overwrite=overwrite)
749
+ trace = TraceClient(self, trace_id, trace_name, project_name=project, overwrite=overwrite, rules=self.rules)
675
750
  self._current_trace = trace
676
751
  # Only save empty trace for the root call
677
752
  trace.save(empty_save=True, overwrite=overwrite)
@@ -706,9 +781,9 @@ class Tracer:
706
781
  trace = self._current_trace
707
782
  else:
708
783
  trace_id = str(uuid.uuid4())
709
- trace_name = str(uuid.uuid4())
784
+ trace_name = func.__name__
710
785
  project = project_name if project_name is not None else self.project_name
711
- trace = TraceClient(self, trace_id, trace_name, project_name=project, overwrite=overwrite)
786
+ trace = TraceClient(self, trace_id, trace_name, project_name=project, overwrite=overwrite, rules=self.rules)
712
787
  self._current_trace = trace
713
788
  # Only save empty trace for the root call
714
789
  trace.save(empty_save=True, overwrite=overwrite)
@@ -741,14 +816,15 @@ def wrap(client: Any) -> Any:
741
816
  Wraps an API client to add tracing capabilities.
742
817
  Supports OpenAI, Together, and Anthropic clients.
743
818
  """
744
- tracer = Tracer._instance # Get the global tracer instance
745
-
746
819
  # Get the appropriate configuration for this client type
747
820
  span_name, original_create = _get_client_config(client)
748
821
 
749
822
  def traced_create(*args, **kwargs):
750
- # Skip tracing if no active trace
751
- if not (tracer and tracer._current_trace):
823
+ # Get the current tracer instance (might be created after client was wrapped)
824
+ tracer = Tracer._instance
825
+
826
+ # Skip tracing if no tracer exists or no active trace
827
+ if not tracer or not tracer._current_trace:
752
828
  return original_create(*args, **kwargs)
753
829
 
754
830
  with tracer._current_trace.span(span_name, span_type="llm") as span:
judgeval/common/utils.py CHANGED
@@ -21,7 +21,6 @@ from judgeval.clients import async_together_client, together_client
21
21
  from judgeval.constants import *
22
22
  from judgeval.common.logger import debug, error
23
23
 
24
- LITELLM_SUPPORTED_MODELS = set(litellm.model_list)
25
24
 
26
25
  class CustomModelParameters(pydantic.BaseModel):
27
26
  model_name: str
@@ -72,7 +71,7 @@ class ChatCompletionRequest(pydantic.BaseModel):
72
71
  def validate_model(cls, model):
73
72
  if not model:
74
73
  raise ValueError("Model cannot be empty")
75
- if model not in TOGETHER_SUPPORTED_MODELS and model not in LITELLM_SUPPORTED_MODELS:
74
+ if model not in ACCEPTABLE_MODELS:
76
75
  raise ValueError(f"Model {model} is not in the list of supported models.")
77
76
  return model
78
77
 
@@ -114,13 +113,13 @@ def fetch_together_api_response(model: str, messages: List[Mapping], response_fo
114
113
  if request.response_format is not None:
115
114
  debug(f"Using response format: {request.response_format}")
116
115
  response = together_client.chat.completions.create(
117
- model=TOGETHER_SUPPORTED_MODELS.get(request.model),
116
+ model=request.model,
118
117
  messages=request.messages,
119
118
  response_format=request.response_format
120
119
  )
121
120
  else:
122
121
  response = together_client.chat.completions.create(
123
- model=TOGETHER_SUPPORTED_MODELS.get(request.model),
122
+ model=request.model,
124
123
  messages=request.messages,
125
124
  )
126
125
 
@@ -144,13 +143,13 @@ async def afetch_together_api_response(model: str, messages: List[Mapping], resp
144
143
  if request.response_format is not None:
145
144
  debug(f"Using response format: {request.response_format}")
146
145
  response = await async_together_client.chat.completions.create(
147
- model=TOGETHER_SUPPORTED_MODELS.get(request.model),
146
+ model=request.model,
148
147
  messages=request.messages,
149
148
  response_format=request.response_format
150
149
  )
151
150
  else:
152
151
  response = await async_together_client.chat.completions.create(
153
- model=TOGETHER_SUPPORTED_MODELS.get(request.model),
152
+ model=request.model,
154
153
  messages=request.messages,
155
154
  )
156
155
  return response.choices[0].message.content
@@ -174,8 +173,8 @@ def query_together_api_multiple_calls(models: List[str], messages: List[List[Map
174
173
 
175
174
  # Validate all models are supported
176
175
  for model in models:
177
- if model not in TOGETHER_SUPPORTED_MODELS:
178
- raise ValueError(f"Model {model} is not in the list of supported TogetherAI models: {TOGETHER_SUPPORTED_MODELS}.")
176
+ if model not in ACCEPTABLE_MODELS:
177
+ raise ValueError(f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}.")
179
178
 
180
179
  # Validate input lengths match
181
180
  if response_formats is None:
@@ -223,8 +222,8 @@ async def aquery_together_api_multiple_calls(models: List[str], messages: List[L
223
222
 
224
223
  # Validate all models are supported
225
224
  for model in models:
226
- if model not in TOGETHER_SUPPORTED_MODELS:
227
- raise ValueError(f"Model {model} is not in the list of supported TogetherAI models: {TOGETHER_SUPPORTED_MODELS}.")
225
+ if model not in ACCEPTABLE_MODELS:
226
+ raise ValueError(f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}.")
228
227
 
229
228
  # Validate input lengths match
230
229
  if response_formats is None:
@@ -322,8 +321,8 @@ async def afetch_litellm_api_response(model: str, messages: List[Mapping], respo
322
321
  # Add validation
323
322
  validate_chat_messages(messages)
324
323
 
325
- if model not in LITELLM_SUPPORTED_MODELS:
326
- raise ValueError(f"Model {model} is not in the list of supported Litellm models: {LITELLM_SUPPORTED_MODELS}.")
324
+ if model not in ACCEPTABLE_MODELS:
325
+ raise ValueError(f"Model {model} is not in the list of supported models: {ACCEPTABLE_MODELS}.")
327
326
 
328
327
  if response_format is not None:
329
328
  response = await litellm.acompletion(
@@ -409,7 +408,7 @@ async def aquery_litellm_api_multiple_calls(models: List[str], messages: List[Ma
409
408
  models (List[str]): List of models to query
410
409
  messages (List[Mapping]): List of messages to query
411
410
  response_formats (List[pydantic.BaseModel], optional): A list of the format of the response if JSON forcing. Defaults to None.
412
-
411
+
413
412
  Returns:
414
413
  List[str]: Litellm responses for each model and message pair in order. Any exceptions in the thread call result in a None.
415
414
  """
judgeval/constants.py CHANGED
@@ -51,20 +51,71 @@ JUDGMENT_TRACES_DELETE_API_URL = f"{ROOT_API}/traces/delete/"
51
51
  RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq-networklb-faa155df16ec9085.elb.us-west-1.amazonaws.com")
52
52
  RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", 5672)
53
53
  RABBITMQ_QUEUE = os.getenv("RABBITMQ_QUEUE", "task_queue")
54
-
55
54
  # Models
56
- TOGETHER_SUPPORTED_MODELS = {
57
- "QWEN": "Qwen/Qwen2-72B-Instruct",
58
- "LLAMA3_70B_INSTRUCT_TURBO": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
59
- "LLAMA3_405B_INSTRUCT_TURBO": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
60
- "LLAMA3_8B_INSTRUCT_TURBO": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
61
- "MISTRAL_8x22B_INSTRUCT": "mistralai/Mixtral-8x22B-Instruct-v0.1",
62
- "MISTRAL_8x7B_INSTRUCT": "mistralai/Mixtral-8x7B-Instruct-v0.1",
63
- }
55
+ LITELLM_SUPPORTED_MODELS = set(litellm.model_list)
56
+
57
+ TOGETHER_SUPPORTED_MODELS = [
58
+ "meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
59
+ "Qwen/Qwen2-VL-72B-Instruct",
60
+ "meta-llama/Llama-Vision-Free",
61
+ "Gryphe/MythoMax-L2-13b",
62
+ "Qwen/Qwen2.5-72B-Instruct-Turbo",
63
+ "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
64
+ "deepseek-ai/DeepSeek-R1",
65
+ "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
66
+ "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
67
+ "google/gemma-2-27b-it",
68
+ "mistralai/Mistral-Small-24B-Instruct-2501",
69
+ "mistralai/Mixtral-8x22B-Instruct-v0.1",
70
+ "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
71
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
72
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-classifier",
73
+ "deepseek-ai/DeepSeek-V3",
74
+ "Qwen/Qwen2-72B-Instruct",
75
+ "meta-llama/Meta-Llama-3-8B-Instruct-Lite",
76
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
77
+ "upstage/SOLAR-10.7B-Instruct-v1.0",
78
+ "togethercomputer/MoA-1",
79
+ "Qwen/QwQ-32B-Preview",
80
+ "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
81
+ "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
82
+ "mistralai/Mistral-7B-Instruct-v0.2",
83
+ "databricks/dbrx-instruct",
84
+ "meta-llama/Llama-3-8b-chat-hf",
85
+ "google/gemma-2b-it",
86
+ "meta-llama/Meta-Llama-3-70B-Instruct-Lite",
87
+ "google/gemma-2-9b-it",
88
+ "meta-llama/Llama-3.3-70B-Instruct-Turbo",
89
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-p",
90
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
91
+ "Gryphe/MythoMax-L2-13b-Lite",
92
+ "meta-llama/Llama-2-7b-chat-hf",
93
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
94
+ "meta-llama/Llama-2-13b-chat-hf",
95
+ "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct",
96
+ "scb10x/scb10x-llama3-typhoon-v1-5x-4f316",
97
+ "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
98
+ "Qwen/Qwen2.5-Coder-32B-Instruct",
99
+ "microsoft/WizardLM-2-8x22B",
100
+ "mistralai/Mistral-7B-Instruct-v0.3",
101
+ "scb10x/scb10x-llama3-1-typhoon2-60256",
102
+ "Qwen/Qwen2.5-7B-Instruct-Turbo",
103
+ "scb10x/scb10x-llama3-1-typhoon-18370",
104
+ "meta-llama/Llama-3.2-3B-Instruct-Turbo",
105
+ "meta-llama/Llama-3-70b-chat-hf",
106
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
107
+ "togethercomputer/MoA-1-Turbo",
108
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free",
109
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
110
+ "mistralai/Mistral-7B-Instruct-v0.1"
111
+ ]
64
112
 
65
113
  JUDGMENT_SUPPORTED_MODELS = {"osiris-large", "osiris-mini"}
66
114
 
67
- ACCEPTABLE_MODELS = set(litellm.model_list) | set(TOGETHER_SUPPORTED_MODELS.keys()) | JUDGMENT_SUPPORTED_MODELS
115
+ ACCEPTABLE_MODELS = set(litellm.model_list) | set(TOGETHER_SUPPORTED_MODELS) | JUDGMENT_SUPPORTED_MODELS
68
116
 
69
117
  ## System settings
70
118
  MAX_WORKER_THREADS = 10
119
+
120
+ # Maximum number of concurrent operations for evaluation runs
121
+ MAX_CONCURRENT_EVALUATIONS = 50 # Adjust based on system capabilities
@@ -20,7 +20,7 @@ class EvalDataset:
20
20
  organization_id: str = field(default="")
21
21
  def __init__(self,
22
22
  judgment_api_key: str = os.getenv("JUDGMENT_API_KEY"),
23
- organization_id: str = os.getenv("ORGANIZATION_ID"),
23
+ organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
24
24
  ground_truths: List[GroundTruthExample] = [],
25
25
  examples: List[Example] = [],
26
26
  ):
@@ -233,7 +233,6 @@ class EvalDatasetClient:
233
233
  "alias": alias,
234
234
  "examples": [e.to_dict() for e in examples],
235
235
  "ground_truths": [g.to_dict() for g in ground_truths],
236
- "judgment_api_key": self.judgment_api_key
237
236
  }
238
237
 
239
238
  try:
@@ -6,6 +6,7 @@ from judgeval.scorers import JudgevalScorer, APIJudgmentScorer
6
6
  from judgeval.constants import ACCEPTABLE_MODELS
7
7
  from judgeval.common.logger import debug, error
8
8
  from judgeval.judges import JudgevalJudge
9
+ from judgeval.rules import Rule
9
10
 
10
11
  class EvaluationRun(BaseModel):
11
12
  """
@@ -20,6 +21,7 @@ class EvaluationRun(BaseModel):
20
21
  aggregator (Optional[str]): The aggregator to use for evaluation if using Mixture of Judges
21
22
  metadata (Optional[Dict[str, Any]]): Additional metadata to include for this evaluation run, e.g. comments, dataset name, purpose, etc.
22
23
  judgment_api_key (Optional[str]): The API key for running evaluations on the Judgment API
24
+ rules (Optional[List[Rule]]): Rules to evaluate against scoring results
23
25
  """
24
26
 
25
27
  # The user will specify whether they want log_results when they call run_eval
@@ -35,6 +37,7 @@ class EvaluationRun(BaseModel):
35
37
  # API Key will be "" until user calls client.run_eval(), then API Key will be set
36
38
  judgment_api_key: Optional[str] = ""
37
39
  override: Optional[bool] = False
40
+ rules: Optional[List[Rule]] = None
38
41
 
39
42
  def model_dump(self, **kwargs):
40
43
  data = super().model_dump(**kwargs)
@@ -45,6 +48,11 @@ class EvaluationRun(BaseModel):
45
48
  else {"score_type": scorer.score_type, "threshold": scorer.threshold}
46
49
  for scorer in self.scorers
47
50
  ]
51
+
52
+ if self.rules:
53
+ # Process rules to ensure proper serialization
54
+ data["rules"] = [rule.model_dump() for rule in self.rules]
55
+
48
56
  return data
49
57
 
50
58
  @field_validator('log_results', mode='before')
@@ -14,7 +14,7 @@ BASE_CONVERSATION = [
14
14
  ]
15
15
 
16
16
  class TogetherJudge(JudgevalJudge):
17
- def __init__(self, model: str = "QWEN", **kwargs):
17
+ def __init__(self, model: str = "Qwen/Qwen2.5-72B-Instruct-Turbo", **kwargs):
18
18
  debug(f"Initializing TogetherJudge with model={model}")
19
19
  self.model = model
20
20
  self.kwargs = kwargs
judgeval/judges/utils.py CHANGED
@@ -39,7 +39,7 @@ def create_judge(
39
39
  Please either set the `use_judgment` flag to True or use
40
40
  non-Judgment models."""
41
41
  )
42
- if m not in LITELLM_SUPPORTED_MODELS and m not in TOGETHER_SUPPORTED_MODELS:
42
+ if m not in ACCEPTABLE_MODELS:
43
43
  raise InvalidJudgeModelError(f"Invalid judge model chosen: {m}")
44
44
  return MixtureOfJudges(models=model), True
45
45
  # If model is a string, check that it corresponds to a valid model