judgeval 0.0.39__py3-none-any.whl → 0.0.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/clients.py +6 -4
- judgeval/common/tracer.py +504 -257
- judgeval/common/utils.py +5 -1
- judgeval/constants.py +2 -0
- judgeval/data/__init__.py +2 -1
- judgeval/data/datasets/dataset.py +12 -6
- judgeval/data/datasets/eval_dataset_client.py +3 -1
- judgeval/data/example.py +7 -7
- judgeval/data/tool.py +29 -1
- judgeval/data/trace.py +31 -39
- judgeval/data/trace_run.py +2 -1
- judgeval/evaluation_run.py +4 -7
- judgeval/judgment_client.py +34 -7
- judgeval/run_evaluation.py +67 -19
- judgeval/scorers/__init__.py +4 -1
- judgeval/scorers/judgeval_scorer.py +12 -1
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +4 -0
- judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +124 -0
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +20 -0
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +1 -1
- judgeval/scorers/prompt_scorer.py +8 -164
- judgeval/scorers/score.py +15 -15
- judgeval-0.0.41.dist-info/METADATA +1450 -0
- {judgeval-0.0.39.dist-info → judgeval-0.0.41.dist-info}/RECORD +26 -24
- judgeval-0.0.39.dist-info/METADATA +0 -247
- {judgeval-0.0.39.dist-info → judgeval-0.0.41.dist-info}/WHEEL +0 -0
- {judgeval-0.0.39.dist-info → judgeval-0.0.41.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/utils.py
CHANGED
@@ -12,9 +12,10 @@ NOTE: any function beginning with 'a', e.g. 'afetch_together_api_response', is a
|
|
12
12
|
import asyncio
|
13
13
|
import concurrent.futures
|
14
14
|
import os
|
15
|
+
from types import TracebackType
|
15
16
|
import requests
|
16
17
|
import pprint
|
17
|
-
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
|
18
|
+
from typing import Any, Dict, List, Literal, Mapping, Optional, TypeAlias, Union
|
18
19
|
|
19
20
|
# Third-party imports
|
20
21
|
import litellm
|
@@ -782,3 +783,6 @@ if __name__ == "__main__":
|
|
782
783
|
]
|
783
784
|
]
|
784
785
|
))
|
786
|
+
|
787
|
+
ExcInfo: TypeAlias = tuple[type[BaseException], BaseException, TracebackType]
|
788
|
+
OptExcInfo: TypeAlias = ExcInfo | tuple[None, None, None]
|
judgeval/constants.py
CHANGED
@@ -28,6 +28,8 @@ class APIScorer(str, Enum):
|
|
28
28
|
GROUNDEDNESS = "groundedness"
|
29
29
|
DERAILMENT = "derailment"
|
30
30
|
TOOL_ORDER = "tool_order"
|
31
|
+
CLASSIFIER = "classifier"
|
32
|
+
TOOL_DEPENDENCY = "tool_dependency"
|
31
33
|
@classmethod
|
32
34
|
def _missing_(cls, value):
|
33
35
|
# Handle case-insensitive lookup
|
judgeval/data/__init__.py
CHANGED
@@ -2,7 +2,7 @@ from judgeval.data.example import Example, ExampleParams
|
|
2
2
|
from judgeval.data.custom_example import CustomExample
|
3
3
|
from judgeval.data.scorer_data import ScorerData, create_scorer_data
|
4
4
|
from judgeval.data.result import ScoringResult, generate_scoring_result
|
5
|
-
from judgeval.data.trace import Trace, TraceSpan
|
5
|
+
from judgeval.data.trace import Trace, TraceSpan, TraceUsage
|
6
6
|
|
7
7
|
|
8
8
|
__all__ = [
|
@@ -15,4 +15,5 @@ __all__ = [
|
|
15
15
|
"generate_scoring_result",
|
16
16
|
"Trace",
|
17
17
|
"TraceSpan",
|
18
|
+
"TraceUsage"
|
18
19
|
]
|
@@ -5,14 +5,15 @@ import json
|
|
5
5
|
import os
|
6
6
|
import yaml
|
7
7
|
from dataclasses import dataclass, field
|
8
|
-
from typing import List, Union, Literal
|
8
|
+
from typing import List, Union, Literal, Optional
|
9
9
|
|
10
|
-
from judgeval.data import Example
|
10
|
+
from judgeval.data import Example, Trace
|
11
11
|
from judgeval.common.logger import debug, error, warning, info
|
12
12
|
|
13
13
|
@dataclass
|
14
14
|
class EvalDataset:
|
15
15
|
examples: List[Example]
|
16
|
+
traces: List[Trace]
|
16
17
|
_alias: Union[str, None] = field(default=None)
|
17
18
|
_id: Union[str, None] = field(default=None)
|
18
19
|
judgment_api_key: str = field(default="")
|
@@ -20,12 +21,13 @@ class EvalDataset:
|
|
20
21
|
def __init__(self,
|
21
22
|
judgment_api_key: str = os.getenv("JUDGMENT_API_KEY"),
|
22
23
|
organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
|
23
|
-
examples: List[Example] =
|
24
|
+
examples: Optional[List[Example]] = None,
|
25
|
+
traces: Optional[List[Trace]] = None
|
24
26
|
):
|
25
|
-
debug(f"Initializing EvalDataset with {len(examples)} examples")
|
26
27
|
if not judgment_api_key:
|
27
28
|
warning("No judgment_api_key provided")
|
28
|
-
self.examples = examples
|
29
|
+
self.examples = examples or []
|
30
|
+
self.traces = traces or []
|
29
31
|
self._alias = None
|
30
32
|
self._id = None
|
31
33
|
self.judgment_api_key = judgment_api_key
|
@@ -218,8 +220,11 @@ class EvalDataset:
|
|
218
220
|
self.add_example(e)
|
219
221
|
|
220
222
|
def add_example(self, e: Example) -> None:
|
221
|
-
self.examples
|
223
|
+
self.examples.append(e)
|
222
224
|
# TODO if we need to add rank, then we need to do it here
|
225
|
+
|
226
|
+
def add_trace(self, t: Trace) -> None:
|
227
|
+
self.traces.append(t)
|
223
228
|
|
224
229
|
def save_as(self, file_type: Literal["json", "csv", "yaml"], dir_path: str, save_name: str = None) -> None:
|
225
230
|
"""
|
@@ -307,6 +312,7 @@ class EvalDataset:
|
|
307
312
|
return (
|
308
313
|
f"{self.__class__.__name__}("
|
309
314
|
f"examples={self.examples}, "
|
315
|
+
f"traces={self.traces}, "
|
310
316
|
f"_alias={self._alias}, "
|
311
317
|
f"_id={self._id}"
|
312
318
|
f")"
|
@@ -13,7 +13,7 @@ from judgeval.constants import (
|
|
13
13
|
JUDGMENT_DATASETS_INSERT_API_URL,
|
14
14
|
JUDGMENT_DATASETS_EXPORT_JSONL_API_URL
|
15
15
|
)
|
16
|
-
from judgeval.data import Example
|
16
|
+
from judgeval.data import Example, Trace
|
17
17
|
from judgeval.data.datasets import EvalDataset
|
18
18
|
|
19
19
|
|
@@ -58,6 +58,7 @@ class EvalDatasetClient:
|
|
58
58
|
"dataset_alias": alias,
|
59
59
|
"project_name": project_name,
|
60
60
|
"examples": [e.to_dict() for e in dataset.examples],
|
61
|
+
"traces": [t.model_dump() for t in dataset.traces],
|
61
62
|
"overwrite": overwrite,
|
62
63
|
}
|
63
64
|
try:
|
@@ -202,6 +203,7 @@ class EvalDatasetClient:
|
|
202
203
|
info(f"Successfully pulled dataset with alias '{alias}'")
|
203
204
|
payload = response.json()
|
204
205
|
dataset.examples = [Example(**e) for e in payload.get("examples", [])]
|
206
|
+
dataset.traces = [Trace(**t) for t in payload.get("traces", [])]
|
205
207
|
dataset._alias = payload.get("alias")
|
206
208
|
dataset._id = payload.get("id")
|
207
209
|
progress.update(
|
judgeval/data/example.py
CHANGED
@@ -36,15 +36,15 @@ class Example(BaseModel):
|
|
36
36
|
name: Optional[str] = None
|
37
37
|
example_id: str = Field(default_factory=lambda: str(uuid4()))
|
38
38
|
example_index: Optional[int] = None
|
39
|
-
|
39
|
+
created_at: Optional[str] = None
|
40
40
|
trace_id: Optional[str] = None
|
41
41
|
|
42
42
|
def __init__(self, **data):
|
43
43
|
if 'example_id' not in data:
|
44
44
|
data['example_id'] = str(uuid4())
|
45
45
|
# Set timestamp if not provided
|
46
|
-
if '
|
47
|
-
data['
|
46
|
+
if 'created_at' not in data:
|
47
|
+
data['created_at'] = datetime.now().isoformat()
|
48
48
|
super().__init__(**data)
|
49
49
|
|
50
50
|
@field_validator('input', mode='before')
|
@@ -123,9 +123,9 @@ class Example(BaseModel):
|
|
123
123
|
raise ValueError(f"Example index must be an integer or None but got {v} of type {type(v)}")
|
124
124
|
return v
|
125
125
|
|
126
|
-
@field_validator('
|
126
|
+
@field_validator('created_at', mode='before')
|
127
127
|
@classmethod
|
128
|
-
def
|
128
|
+
def validate_created_at(cls, v):
|
129
129
|
if v is not None and not isinstance(v, str):
|
130
130
|
raise ValueError(f"Timestamp must be a string or None but got {v} of type {type(v)}")
|
131
131
|
return v
|
@@ -150,7 +150,7 @@ class Example(BaseModel):
|
|
150
150
|
"name": self.name,
|
151
151
|
"example_id": self.example_id,
|
152
152
|
"example_index": self.example_index,
|
153
|
-
"
|
153
|
+
"created_at": self.created_at,
|
154
154
|
}
|
155
155
|
|
156
156
|
def __str__(self):
|
@@ -166,5 +166,5 @@ class Example(BaseModel):
|
|
166
166
|
f"name={self.name}, "
|
167
167
|
f"example_id={self.example_id}, "
|
168
168
|
f"example_index={self.example_index}, "
|
169
|
-
f"
|
169
|
+
f"created_at={self.created_at}, "
|
170
170
|
)
|
judgeval/data/tool.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1
1
|
from pydantic import BaseModel, field_validator
|
2
|
-
from typing import Dict, Any, Optional
|
2
|
+
from typing import Dict, Any, Optional, List
|
3
3
|
import warnings
|
4
4
|
|
5
5
|
class Tool(BaseModel):
|
6
6
|
tool_name: str
|
7
7
|
parameters: Optional[Dict[str, Any]] = None
|
8
|
+
agent_name: Optional[str] = None
|
9
|
+
result_dependencies: Optional[List[Dict[str, Any]]] = None
|
10
|
+
action_dependencies: Optional[List[Dict[str, Any]]] = None
|
11
|
+
require_all: Optional[bool] = None
|
8
12
|
|
9
13
|
@field_validator('tool_name')
|
10
14
|
def validate_tool_name(cls, v):
|
@@ -16,4 +20,28 @@ class Tool(BaseModel):
|
|
16
20
|
def validate_parameters(cls, v):
|
17
21
|
if v is not None and not isinstance(v, dict):
|
18
22
|
warnings.warn(f"Parameters should be a dictionary, got {type(v)}", UserWarning)
|
23
|
+
return v
|
24
|
+
|
25
|
+
@field_validator('agent_name')
|
26
|
+
def validate_agent_name(cls, v):
|
27
|
+
if v is not None and not isinstance(v, str):
|
28
|
+
warnings.warn(f"Agent name should be a string, got {type(v)}", UserWarning)
|
29
|
+
return v
|
30
|
+
|
31
|
+
@field_validator('result_dependencies')
|
32
|
+
def validate_result_dependencies(cls, v):
|
33
|
+
if v is not None and not isinstance(v, list):
|
34
|
+
warnings.warn(f"Result dependencies should be a list, got {type(v)}", UserWarning)
|
35
|
+
return v
|
36
|
+
|
37
|
+
@field_validator('action_dependencies')
|
38
|
+
def validate_action_dependencies(cls, v):
|
39
|
+
if v is not None and not isinstance(v, list):
|
40
|
+
warnings.warn(f"Action dependencies should be a list, got {type(v)}", UserWarning)
|
41
|
+
return v
|
42
|
+
|
43
|
+
@field_validator('require_all')
|
44
|
+
def validate_require_all(cls, v):
|
45
|
+
if v is not None and not isinstance(v, bool):
|
46
|
+
warnings.warn(f"Require all should be a boolean, got {type(v)}", UserWarning)
|
19
47
|
return v
|
judgeval/data/trace.py
CHANGED
@@ -5,36 +5,56 @@ from judgeval.data.tool import Tool
|
|
5
5
|
import json
|
6
6
|
from datetime import datetime, timezone
|
7
7
|
|
8
|
+
class TraceUsage(BaseModel):
|
9
|
+
prompt_tokens: Optional[int] = None
|
10
|
+
completion_tokens: Optional[int] = None
|
11
|
+
total_tokens: Optional[int] = None
|
12
|
+
prompt_tokens_cost_usd: Optional[float] = None
|
13
|
+
completion_tokens_cost_usd: Optional[float] = None
|
14
|
+
total_cost_usd: Optional[float] = None
|
15
|
+
model_name: Optional[str] = None
|
16
|
+
|
8
17
|
class TraceSpan(BaseModel):
|
9
18
|
span_id: str
|
10
19
|
trace_id: str
|
11
|
-
function:
|
20
|
+
function: str
|
12
21
|
depth: int
|
13
22
|
created_at: Optional[Any] = None
|
14
23
|
parent_span_id: Optional[str] = None
|
15
24
|
span_type: Optional[str] = "span"
|
16
25
|
inputs: Optional[Dict[str, Any]] = None
|
26
|
+
error: Optional[Dict[str, Any]] = None
|
17
27
|
output: Optional[Any] = None
|
28
|
+
usage: Optional[TraceUsage] = None
|
18
29
|
duration: Optional[float] = None
|
19
30
|
annotation: Optional[List[Dict[str, Any]]] = None
|
20
31
|
evaluation_runs: Optional[List[EvaluationRun]] = []
|
21
32
|
expected_tools: Optional[List[Tool]] = None
|
22
33
|
additional_metadata: Optional[Dict[str, Any]] = None
|
34
|
+
has_evaluation: Optional[bool] = False
|
35
|
+
agent_name: Optional[str] = None
|
36
|
+
state_before: Optional[Dict[str, Any]] = None
|
37
|
+
state_after: Optional[Dict[str, Any]] = None
|
23
38
|
|
24
39
|
def model_dump(self, **kwargs):
|
25
40
|
return {
|
26
41
|
"span_id": self.span_id,
|
27
42
|
"trace_id": self.trace_id,
|
28
43
|
"depth": self.depth,
|
29
|
-
# "created_at": datetime.fromtimestamp(self.created_at).isoformat(),
|
30
44
|
"created_at": datetime.fromtimestamp(self.created_at, tz=timezone.utc).isoformat(),
|
31
|
-
"inputs": self.
|
32
|
-
"output": self.
|
45
|
+
"inputs": self._serialize_value(self.inputs),
|
46
|
+
"output": self._serialize_value(self.output),
|
47
|
+
"error": self._serialize_value(self.error),
|
33
48
|
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs] if self.evaluation_runs else [],
|
34
49
|
"parent_span_id": self.parent_span_id,
|
35
50
|
"function": self.function,
|
36
51
|
"duration": self.duration,
|
37
|
-
"span_type": self.span_type
|
52
|
+
"span_type": self.span_type,
|
53
|
+
"usage": self.usage.model_dump() if self.usage else None,
|
54
|
+
"has_evaluation": self.has_evaluation,
|
55
|
+
"agent_name": self.agent_name,
|
56
|
+
"state_before": self.state_before,
|
57
|
+
"state_after": self.state_after
|
38
58
|
}
|
39
59
|
|
40
60
|
def print_span(self):
|
@@ -42,30 +62,6 @@ class TraceSpan(BaseModel):
|
|
42
62
|
indent = " " * self.depth
|
43
63
|
parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
|
44
64
|
print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info}")
|
45
|
-
|
46
|
-
def _serialize_inputs(self) -> dict:
|
47
|
-
"""Helper method to serialize input data safely."""
|
48
|
-
if self.inputs is None:
|
49
|
-
return {}
|
50
|
-
|
51
|
-
serialized_inputs = {}
|
52
|
-
for key, value in self.inputs.items():
|
53
|
-
if isinstance(value, BaseModel):
|
54
|
-
serialized_inputs[key] = value.model_dump()
|
55
|
-
elif isinstance(value, (list, tuple)):
|
56
|
-
# Handle lists/tuples of arguments
|
57
|
-
serialized_inputs[key] = [
|
58
|
-
item.model_dump() if isinstance(item, BaseModel)
|
59
|
-
else None if not self._is_json_serializable(item)
|
60
|
-
else item
|
61
|
-
for item in value
|
62
|
-
]
|
63
|
-
else:
|
64
|
-
if self._is_json_serializable(value):
|
65
|
-
serialized_inputs[key] = value
|
66
|
-
else:
|
67
|
-
serialized_inputs[key] = self.safe_stringify(value, self.function)
|
68
|
-
return serialized_inputs
|
69
65
|
|
70
66
|
def _is_json_serializable(self, obj: Any) -> bool:
|
71
67
|
"""Helper method to check if an object is JSON serializable."""
|
@@ -88,15 +84,11 @@ class TraceSpan(BaseModel):
|
|
88
84
|
return repr(output)
|
89
85
|
except (TypeError, OverflowError, ValueError):
|
90
86
|
pass
|
91
|
-
|
92
|
-
warnings.warn(
|
93
|
-
f"Output for function {function_name} is not JSON serializable and could not be converted to string. Setting to None."
|
94
|
-
)
|
95
87
|
return None
|
96
88
|
|
97
|
-
def
|
98
|
-
"""Helper method to serialize
|
99
|
-
if
|
89
|
+
def _serialize_value(self, value: Any) -> Any:
|
90
|
+
"""Helper method to deep serialize a value safely supporting Pydantic Models / regular PyObjects."""
|
91
|
+
if value is None:
|
100
92
|
return None
|
101
93
|
|
102
94
|
def serialize_value(value):
|
@@ -117,15 +109,15 @@ class TraceSpan(BaseModel):
|
|
117
109
|
# Fallback to safe stringification
|
118
110
|
return self.safe_stringify(value, self.function)
|
119
111
|
|
120
|
-
# Start serialization with the top-level
|
121
|
-
return serialize_value(
|
112
|
+
# Start serialization with the top-level value
|
113
|
+
return serialize_value(value)
|
122
114
|
|
123
115
|
class Trace(BaseModel):
|
124
116
|
trace_id: str
|
125
117
|
name: str
|
126
118
|
created_at: str
|
127
119
|
duration: float
|
128
|
-
|
120
|
+
trace_spans: List[TraceSpan]
|
129
121
|
overwrite: bool = False
|
130
122
|
offline_mode: bool = False
|
131
123
|
rules: Optional[Dict[str, Any]] = None
|
judgeval/data/trace_run.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
|
2
1
|
from pydantic import BaseModel
|
3
2
|
from typing import List, Optional, Dict, Any, Union, Callable
|
4
3
|
from judgeval.data import Trace
|
@@ -22,6 +21,7 @@ class TraceRun(BaseModel):
|
|
22
21
|
judgment_api_key (Optional[str]): The API key for running evaluations on the Judgment API
|
23
22
|
rules (Optional[List[Rule]]): Rules to evaluate against scoring results
|
24
23
|
append (Optional[bool]): Whether to append to existing evaluation results
|
24
|
+
tools (Optional[List[Dict[str, Any]]]): List of tools to use for evaluation
|
25
25
|
"""
|
26
26
|
|
27
27
|
# The user will specify whether they want log_results when they call run_eval
|
@@ -40,6 +40,7 @@ class TraceRun(BaseModel):
|
|
40
40
|
judgment_api_key: Optional[str] = ""
|
41
41
|
override: Optional[bool] = False
|
42
42
|
rules: Optional[List[Rule]] = None
|
43
|
+
tools: Optional[List[Dict[str, Any]]] = None
|
43
44
|
|
44
45
|
class Config:
|
45
46
|
arbitrary_types_allowed = True
|
judgeval/evaluation_run.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
from typing import List, Optional, Dict, Any, Union
|
2
|
-
from pydantic import BaseModel, field_validator
|
2
|
+
from pydantic import BaseModel, field_validator, Field
|
3
3
|
|
4
4
|
from judgeval.data import Example, CustomExample
|
5
5
|
from judgeval.scorers import JudgevalScorer, APIJudgmentScorer
|
@@ -27,12 +27,12 @@ class EvaluationRun(BaseModel):
|
|
27
27
|
# The user will specify whether they want log_results when they call run_eval
|
28
28
|
log_results: bool = False # NOTE: log_results has to be set first because it is used to validate project_name and eval_name
|
29
29
|
organization_id: Optional[str] = None
|
30
|
-
project_name: Optional[str] = None
|
31
|
-
eval_name: Optional[str] = None
|
30
|
+
project_name: Optional[str] = Field(default=None, validate_default=True)
|
31
|
+
eval_name: Optional[str] = Field(default=None, validate_default=True)
|
32
32
|
examples: Union[List[Example], List[CustomExample]]
|
33
33
|
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]]
|
34
34
|
model: Optional[Union[str, List[str], JudgevalJudge]] = "gpt-4.1"
|
35
|
-
aggregator: Optional[str] = None
|
35
|
+
aggregator: Optional[str] = Field(default=None, validate_default=True)
|
36
36
|
metadata: Optional[Dict[str, Any]] = None
|
37
37
|
trace_span_id: Optional[str] = None
|
38
38
|
# API Key will be "" until user calls client.run_eval(), then API Key will be set
|
@@ -96,9 +96,6 @@ class EvaluationRun(BaseModel):
|
|
96
96
|
def validate_scorers(cls, v):
|
97
97
|
if not v:
|
98
98
|
raise ValueError("Scorers cannot be empty.")
|
99
|
-
for s in v:
|
100
|
-
if not isinstance(s, APIJudgmentScorer) and not isinstance(s, JudgevalScorer):
|
101
|
-
raise ValueError(f"Invalid type for Scorer: {type(s)}")
|
102
99
|
return v
|
103
100
|
|
104
101
|
@field_validator('model')
|
judgeval/judgment_client.py
CHANGED
@@ -5,6 +5,7 @@ import os
|
|
5
5
|
from uuid import uuid4
|
6
6
|
from typing import Optional, List, Dict, Any, Union, Callable
|
7
7
|
import requests
|
8
|
+
import asyncio
|
8
9
|
|
9
10
|
from judgeval.constants import ROOT_API
|
10
11
|
from judgeval.data.datasets import EvalDataset, EvalDatasetClient
|
@@ -62,7 +63,15 @@ class SingletonMeta(type):
|
|
62
63
|
return cls._instances[cls]
|
63
64
|
|
64
65
|
class JudgmentClient(metaclass=SingletonMeta):
|
65
|
-
def __init__(self, judgment_api_key: str = os.getenv("JUDGMENT_API_KEY"), organization_id: str = os.getenv("JUDGMENT_ORG_ID")):
|
66
|
+
def __init__(self, judgment_api_key: Optional[str] = os.getenv("JUDGMENT_API_KEY"), organization_id: Optional[str] = os.getenv("JUDGMENT_ORG_ID")):
|
67
|
+
# Check if API key is None
|
68
|
+
if judgment_api_key is None:
|
69
|
+
raise ValueError("JUDGMENT_API_KEY cannot be None. Please provide a valid API key or set the JUDGMENT_API_KEY environment variable.")
|
70
|
+
|
71
|
+
# Check if organization ID is None
|
72
|
+
if organization_id is None:
|
73
|
+
raise ValueError("JUDGMENT_ORG_ID cannot be None. Please provide a valid organization ID or set the JUDGMENT_ORG_ID environment variable.")
|
74
|
+
|
66
75
|
self.judgment_api_key = judgment_api_key
|
67
76
|
self.organization_id = organization_id
|
68
77
|
self.eval_dataset_client = EvalDatasetClient(judgment_api_key, organization_id)
|
@@ -121,7 +130,8 @@ class JudgmentClient(metaclass=SingletonMeta):
|
|
121
130
|
ignore_errors: bool = True,
|
122
131
|
rules: Optional[List[Rule]] = None,
|
123
132
|
function: Optional[Callable] = None,
|
124
|
-
tracer: Optional[Union[Tracer, BaseCallbackHandler]] = None
|
133
|
+
tracer: Optional[Union[Tracer, BaseCallbackHandler]] = None,
|
134
|
+
tools: Optional[List[Dict[str, Any]]] = None
|
125
135
|
) -> List[ScoringResult]:
|
126
136
|
try:
|
127
137
|
|
@@ -151,6 +161,7 @@ class JudgmentClient(metaclass=SingletonMeta):
|
|
151
161
|
append=append,
|
152
162
|
judgment_api_key=self.judgment_api_key,
|
153
163
|
organization_id=self.organization_id,
|
164
|
+
tools=tools
|
154
165
|
)
|
155
166
|
return run_trace_eval(trace_run, override, ignore_errors, function, tracer, examples)
|
156
167
|
except ValueError as e:
|
@@ -173,7 +184,7 @@ class JudgmentClient(metaclass=SingletonMeta):
|
|
173
184
|
ignore_errors: bool = True,
|
174
185
|
async_execution: bool = False,
|
175
186
|
rules: Optional[List[Rule]] = None
|
176
|
-
) -> List[ScoringResult]:
|
187
|
+
) -> Union[List[ScoringResult], asyncio.Task]:
|
177
188
|
"""
|
178
189
|
Executes an evaluation of `Example`s using one or more `Scorer`s
|
179
190
|
|
@@ -480,7 +491,7 @@ class JudgmentClient(metaclass=SingletonMeta):
|
|
480
491
|
|
481
492
|
return response.json()["slug"]
|
482
493
|
|
483
|
-
|
494
|
+
def assert_test(
|
484
495
|
self,
|
485
496
|
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
486
497
|
examples: Optional[List[Example]] = None,
|
@@ -495,6 +506,7 @@ class JudgmentClient(metaclass=SingletonMeta):
|
|
495
506
|
rules: Optional[List[Rule]] = None,
|
496
507
|
function: Optional[Callable] = None,
|
497
508
|
tracer: Optional[Union[Tracer, BaseCallbackHandler]] = None,
|
509
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
498
510
|
async_execution: bool = False
|
499
511
|
) -> None:
|
500
512
|
"""
|
@@ -513,6 +525,14 @@ class JudgmentClient(metaclass=SingletonMeta):
|
|
513
525
|
override (bool): Whether to override an existing evaluation run with the same name
|
514
526
|
rules (Optional[List[Rule]]): Rules to evaluate against scoring results
|
515
527
|
"""
|
528
|
+
|
529
|
+
# Check for enable_param_checking and tools
|
530
|
+
for scorer in scorers:
|
531
|
+
if hasattr(scorer, "kwargs") and scorer.kwargs is not None:
|
532
|
+
if scorer.kwargs.get("enable_param_checking") is True:
|
533
|
+
if not tools:
|
534
|
+
raise ValueError(f"You must provide the 'tools' argument to assert_test when using a scorer with enable_param_checking=True. If you do not want to do param checking, explicitly set enable_param_checking=False for the {scorer.__name__} scorer.")
|
535
|
+
|
516
536
|
# Validate that exactly one of examples or test_file is provided
|
517
537
|
if (examples is None and test_file is None) or (examples is not None and test_file is not None):
|
518
538
|
raise ValueError("Exactly one of 'examples' or 'test_file' must be provided, but not both")
|
@@ -530,10 +550,11 @@ class JudgmentClient(metaclass=SingletonMeta):
|
|
530
550
|
rules=rules,
|
531
551
|
function=function,
|
532
552
|
tracer=tracer,
|
533
|
-
test_file=test_file
|
553
|
+
test_file=test_file,
|
554
|
+
tools=tools
|
534
555
|
)
|
535
556
|
else:
|
536
|
-
results =
|
557
|
+
results = self.run_evaluation(
|
537
558
|
examples=examples,
|
538
559
|
scorers=scorers,
|
539
560
|
model=model,
|
@@ -547,4 +568,10 @@ class JudgmentClient(metaclass=SingletonMeta):
|
|
547
568
|
async_execution=async_execution
|
548
569
|
)
|
549
570
|
|
550
|
-
|
571
|
+
if async_execution:
|
572
|
+
# 'results' is an asyncio.Task here, awaiting it gives List[ScoringResult]
|
573
|
+
actual_results = asyncio.run(results)
|
574
|
+
assert_test(actual_results) # Call the synchronous imported function
|
575
|
+
else:
|
576
|
+
# 'results' is already List[ScoringResult] here (synchronous path)
|
577
|
+
assert_test(results) # Call the synchronous imported function
|