judgeval 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +83 -0
- judgeval/clients.py +19 -0
- judgeval/common/__init__.py +8 -0
- judgeval/common/exceptions.py +28 -0
- judgeval/common/logger.py +189 -0
- judgeval/common/tracer.py +587 -0
- judgeval/common/utils.py +763 -0
- judgeval/constants.py +55 -0
- judgeval/data/__init__.py +14 -0
- judgeval/data/api_example.py +111 -0
- judgeval/data/datasets/__init__.py +4 -0
- judgeval/data/datasets/dataset.py +407 -0
- judgeval/data/datasets/ground_truth.py +54 -0
- judgeval/data/datasets/utils.py +74 -0
- judgeval/data/example.py +76 -0
- judgeval/data/result.py +83 -0
- judgeval/data/scorer_data.py +86 -0
- judgeval/evaluation_run.py +130 -0
- judgeval/judges/__init__.py +7 -0
- judgeval/judges/base_judge.py +44 -0
- judgeval/judges/litellm_judge.py +49 -0
- judgeval/judges/mixture_of_judges.py +248 -0
- judgeval/judges/together_judge.py +55 -0
- judgeval/judges/utils.py +45 -0
- judgeval/judgment_client.py +244 -0
- judgeval/run_evaluation.py +355 -0
- judgeval/scorers/__init__.py +30 -0
- judgeval/scorers/base_scorer.py +51 -0
- judgeval/scorers/custom_scorer.py +134 -0
- judgeval/scorers/judgeval_scorers/__init__.py +21 -0
- judgeval/scorers/judgeval_scorers/answer_relevancy.py +19 -0
- judgeval/scorers/judgeval_scorers/contextual_precision.py +19 -0
- judgeval/scorers/judgeval_scorers/contextual_recall.py +19 -0
- judgeval/scorers/judgeval_scorers/contextual_relevancy.py +22 -0
- judgeval/scorers/judgeval_scorers/faithfulness.py +19 -0
- judgeval/scorers/judgeval_scorers/hallucination.py +19 -0
- judgeval/scorers/judgeval_scorers/json_correctness.py +32 -0
- judgeval/scorers/judgeval_scorers/summarization.py +20 -0
- judgeval/scorers/judgeval_scorers/tool_correctness.py +19 -0
- judgeval/scorers/prompt_scorer.py +439 -0
- judgeval/scorers/score.py +427 -0
- judgeval/scorers/utils.py +175 -0
- judgeval-0.0.1.dist-info/METADATA +40 -0
- judgeval-0.0.1.dist-info/RECORD +46 -0
- judgeval-0.0.1.dist-info/WHEEL +4 -0
- judgeval-0.0.1.dist-info/licenses/LICENSE.md +202 -0
judgeval/constants.py
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
"""
|
2
|
+
Constant variables used throughout source code
|
3
|
+
"""
|
4
|
+
|
5
|
+
from enum import Enum
|
6
|
+
import litellm
|
7
|
+
import os
|
8
|
+
|
9
|
+
class APIScorer(str, Enum):
|
10
|
+
"""
|
11
|
+
Collection of proprietary scorers implemented by Judgment.
|
12
|
+
|
13
|
+
These are ready-made evaluation scorers that can be used to evaluate
|
14
|
+
Examples via the Judgment API.
|
15
|
+
"""
|
16
|
+
FAITHFULNESS = "faithfulness"
|
17
|
+
ANSWER_RELEVANCY = "answer_relevancy"
|
18
|
+
HALLUCINATION = "hallucination"
|
19
|
+
SUMMARIZATION = "summarization"
|
20
|
+
CONTEXTUAL_RECALL = "contextual_recall"
|
21
|
+
CONTEXTUAL_RELEVANCY = "contextual_relevancy"
|
22
|
+
CONTEXTUAL_PRECISION = "contextual_precision"
|
23
|
+
TOOL_CORRECTNESS = "tool_correctness"
|
24
|
+
JSON_CORRECTNESS = "json_correctness"
|
25
|
+
|
26
|
+
@classmethod
|
27
|
+
def _missing_(cls, value):
|
28
|
+
# Handle case-insensitive lookup
|
29
|
+
for member in cls:
|
30
|
+
if member.value == value.lower():
|
31
|
+
return member
|
32
|
+
|
33
|
+
ROOT_API = os.getenv("JUDGMENT_API_URL", "https://api.judgmentlabs.ai")
|
34
|
+
## API URLs
|
35
|
+
JUDGMENT_EVAL_API_URL = f"{ROOT_API}/evaluate/"
|
36
|
+
JUDGMENT_DATASETS_PUSH_API_URL = f"{ROOT_API}/datasets/push/"
|
37
|
+
JUDGMENT_DATASETS_PULL_API_URL = f"{ROOT_API}/datasets/pull/"
|
38
|
+
JUDGMENT_EVAL_LOG_API_URL = f"{ROOT_API}/log_eval_results/"
|
39
|
+
JUDGMENT_EVAL_FETCH_API_URL = f"{ROOT_API}/fetch_eval_results/"
|
40
|
+
JUDGMENT_TRACES_SAVE_API_URL = f"{ROOT_API}/traces/save/"
|
41
|
+
|
42
|
+
## Models
|
43
|
+
TOGETHER_SUPPORTED_MODELS = {
|
44
|
+
"QWEN": "Qwen/Qwen2-72B-Instruct",
|
45
|
+
"LLAMA3_70B_INSTRUCT_TURBO": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
46
|
+
"LLAMA3_405B_INSTRUCT_TURBO": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
47
|
+
"LLAMA3_8B_INSTRUCT_TURBO": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
48
|
+
"MISTRAL_8x22B_INSTRUCT": "mistralai/Mixtral-8x22B-Instruct-v0.1",
|
49
|
+
"MISTRAL_8x7B_INSTRUCT": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
50
|
+
}
|
51
|
+
|
52
|
+
ACCEPTABLE_MODELS = set(litellm.model_list) | set(TOGETHER_SUPPORTED_MODELS.keys())
|
53
|
+
|
54
|
+
## System settings
|
55
|
+
MAX_WORKER_THREADS = 10
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from judgeval.data.example import Example
|
2
|
+
from judgeval.data.api_example import ProcessExample, create_process_example
|
3
|
+
from judgeval.data.scorer_data import ScorerData, create_scorer_data
|
4
|
+
from judgeval.data.result import ScoringResult, generate_scoring_result
|
5
|
+
|
6
|
+
__all__ = [
|
7
|
+
"Example",
|
8
|
+
"ProcessExample",
|
9
|
+
"create_process_example",
|
10
|
+
"ScorerData",
|
11
|
+
"create_scorer_data",
|
12
|
+
"ScoringResult",
|
13
|
+
"generate_scoring_result",
|
14
|
+
]
|
@@ -0,0 +1,111 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, Union
|
2
|
+
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
3
|
+
|
4
|
+
from judgeval.data.example import Example
|
5
|
+
from judgeval.data.scorer_data import ScorerData
|
6
|
+
from judgeval.common.logger import debug, error
|
7
|
+
|
8
|
+
class ProcessExample(BaseModel):
|
9
|
+
"""
|
10
|
+
ProcessExample is an `Example` object that contains intermediate information
|
11
|
+
about an undergoing evaluation on the original `Example`. It is used purely for
|
12
|
+
internal operations and keeping track of the evaluation process.
|
13
|
+
"""
|
14
|
+
name: str
|
15
|
+
input: Optional[str] = None
|
16
|
+
actual_output: Optional[str] = None
|
17
|
+
expected_output: Optional[str] = None
|
18
|
+
context: Optional[list] = None
|
19
|
+
retrieval_context: Optional[list] = None
|
20
|
+
tools_called: Optional[list] = None
|
21
|
+
expected_tools: Optional[list] = None
|
22
|
+
|
23
|
+
# make these optional, not all test cases in a conversation will be evaluated
|
24
|
+
success: Optional[bool] = None
|
25
|
+
scorers_data: Optional[List[ScorerData]] = None
|
26
|
+
run_duration: Optional[float] = None
|
27
|
+
evaluation_cost: Optional[float] = None
|
28
|
+
|
29
|
+
order: Optional[int] = None
|
30
|
+
# These should map 1 to 1 from golden
|
31
|
+
additional_metadata: Optional[Dict] = None
|
32
|
+
comments: Optional[str] = None
|
33
|
+
trace_id: Optional[str] = None
|
34
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
35
|
+
|
36
|
+
def update_scorer_data(self, scorer_data: ScorerData):
|
37
|
+
"""
|
38
|
+
Updates scorer data field of test case after the scorers have been
|
39
|
+
evaluated on this test case.
|
40
|
+
"""
|
41
|
+
debug(f"Updating scorer data for example '{self.name}' with scorer: {scorer_data}")
|
42
|
+
# self.scorers_data is a list of ScorerData objects that contain the
|
43
|
+
# evaluation results of each scorer on this test case
|
44
|
+
if self.scorers_data is None:
|
45
|
+
self.scorers_data = [scorer_data]
|
46
|
+
else:
|
47
|
+
self.scorers_data.append(scorer_data)
|
48
|
+
|
49
|
+
if self.success is None:
|
50
|
+
# self.success will be None when it is a message
|
51
|
+
# in that case we will be setting success for the first time
|
52
|
+
self.success = scorer_data.success
|
53
|
+
else:
|
54
|
+
if scorer_data.success is False:
|
55
|
+
debug(f"Example '{self.name}' marked as failed due to scorer: {scorer_data}")
|
56
|
+
self.success = False
|
57
|
+
|
58
|
+
def update_run_duration(self, run_duration: float):
|
59
|
+
self.run_duration = run_duration
|
60
|
+
|
61
|
+
@model_validator(mode="before")
|
62
|
+
def check_input(cls, values: Dict[str, Any]):
|
63
|
+
input = values.get("input")
|
64
|
+
actual_output = values.get("actual_output")
|
65
|
+
|
66
|
+
if (input is None or actual_output is None):
|
67
|
+
error(f"Validation error: Required fields missing. input={input}, actual_output={actual_output}")
|
68
|
+
raise ValueError(
|
69
|
+
"'input' and 'actual_output' must be provided."
|
70
|
+
)
|
71
|
+
|
72
|
+
return values
|
73
|
+
|
74
|
+
|
75
|
+
def create_process_example(
|
76
|
+
example: Example,
|
77
|
+
) -> ProcessExample:
|
78
|
+
"""
|
79
|
+
When an LLM Test Case is executed, we track its progress using an ProcessExample.
|
80
|
+
|
81
|
+
This will track things like the success of the test case, as well as the metadata (such as verdicts and claims in Faithfulness).
|
82
|
+
"""
|
83
|
+
success = True
|
84
|
+
if example.name is not None:
|
85
|
+
name = example.name
|
86
|
+
else:
|
87
|
+
name = "Test Case Placeholder"
|
88
|
+
debug(f"No name provided for example, using default name: {name}")
|
89
|
+
order = None
|
90
|
+
scorers_data = []
|
91
|
+
|
92
|
+
debug(f"Creating ProcessExample for: {name}")
|
93
|
+
process_ex = ProcessExample(
|
94
|
+
name=name,
|
95
|
+
input=example.input,
|
96
|
+
actual_output=example.actual_output,
|
97
|
+
expected_output=example.expected_output,
|
98
|
+
context=example.context,
|
99
|
+
retrieval_context=example.retrieval_context,
|
100
|
+
tools_called=example.tools_called,
|
101
|
+
expected_tools=example.expected_tools,
|
102
|
+
success=success,
|
103
|
+
scorers_data=scorers_data,
|
104
|
+
run_duration=None,
|
105
|
+
evaluation_cost=None,
|
106
|
+
order=order,
|
107
|
+
additional_metadata=example.additional_metadata,
|
108
|
+
trace_id=example.trace_id
|
109
|
+
)
|
110
|
+
return process_ex
|
111
|
+
|
@@ -0,0 +1,407 @@
|
|
1
|
+
import ast
|
2
|
+
import csv
|
3
|
+
import datetime
|
4
|
+
import json
|
5
|
+
from rich.console import Console
|
6
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
7
|
+
import requests
|
8
|
+
from dataclasses import dataclass, field
|
9
|
+
import os
|
10
|
+
from typing import List, Optional, Union, Literal
|
11
|
+
|
12
|
+
from judgeval.constants import JUDGMENT_DATASETS_PUSH_API_URL, JUDGMENT_DATASETS_PULL_API_URL
|
13
|
+
from judgeval.data.datasets.ground_truth import GroundTruthExample
|
14
|
+
from judgeval.data.datasets.utils import ground_truths_to_examples, examples_to_ground_truths
|
15
|
+
from judgeval.data import Example
|
16
|
+
from judgeval.common.logger import debug, error, warning, info
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class EvalDataset:
|
20
|
+
ground_truths: List[GroundTruthExample]
|
21
|
+
examples: List[Example]
|
22
|
+
_alias: Union[str, None] = field(default=None)
|
23
|
+
_id: Union[str, None] = field(default=None)
|
24
|
+
judgment_api_key: str = field(default="")
|
25
|
+
|
26
|
+
def __init__(self,
|
27
|
+
judgment_api_key: str = os.getenv("JUDGMENT_API_KEY"),
|
28
|
+
ground_truths: List[GroundTruthExample] = [],
|
29
|
+
examples: List[Example] = [],
|
30
|
+
):
|
31
|
+
debug(f"Initializing EvalDataset with {len(ground_truths)} ground truths and {len(examples)} examples")
|
32
|
+
if not judgment_api_key:
|
33
|
+
warning("No judgment_api_key provided")
|
34
|
+
self.ground_truths = ground_truths
|
35
|
+
self.examples = examples
|
36
|
+
self._alias = None
|
37
|
+
self._id = None
|
38
|
+
self.judgment_api_key = judgment_api_key
|
39
|
+
|
40
|
+
def push(self, alias: str, overwrite: Optional[bool] = False) -> bool:
|
41
|
+
debug(f"Pushing dataset with alias '{alias}' (overwrite={overwrite})")
|
42
|
+
if overwrite:
|
43
|
+
warning(f"Overwrite enabled for alias '{alias}'")
|
44
|
+
"""
|
45
|
+
Pushes the dataset to Judgment platform
|
46
|
+
|
47
|
+
Mock request:
|
48
|
+
{
|
49
|
+
"alias": alias,
|
50
|
+
"ground_truths": [...],
|
51
|
+
"examples": [...],
|
52
|
+
"overwrite": overwrite
|
53
|
+
} ==>
|
54
|
+
{
|
55
|
+
"_alias": alias,
|
56
|
+
"_id": "..." # ID of the dataset
|
57
|
+
}
|
58
|
+
"""
|
59
|
+
with Progress(
|
60
|
+
SpinnerColumn(style="rgb(106,0,255)"),
|
61
|
+
TextColumn("[progress.description]{task.description}"),
|
62
|
+
transient=False,
|
63
|
+
) as progress:
|
64
|
+
task_id = progress.add_task(
|
65
|
+
f"Pushing [rgb(106,0,255)]'{alias}' to Judgment...",
|
66
|
+
total=100,
|
67
|
+
)
|
68
|
+
content = {
|
69
|
+
"alias": alias,
|
70
|
+
"ground_truths": [g.to_dict() for g in self.ground_truths],
|
71
|
+
"examples": [e.to_dict() for e in self.examples],
|
72
|
+
"overwrite": overwrite,
|
73
|
+
"judgment_api_key": self.judgment_api_key
|
74
|
+
}
|
75
|
+
try:
|
76
|
+
response = requests.post(
|
77
|
+
JUDGMENT_DATASETS_PUSH_API_URL,
|
78
|
+
json=content
|
79
|
+
)
|
80
|
+
if response.status_code == 500:
|
81
|
+
error(f"Server error during push: {content.get('message')}")
|
82
|
+
return False
|
83
|
+
response.raise_for_status()
|
84
|
+
except requests.exceptions.HTTPError as err:
|
85
|
+
if response.status_code == 422:
|
86
|
+
error(f"Validation error during push: {err.response.json()}")
|
87
|
+
else:
|
88
|
+
error(f"HTTP error during push: {err}")
|
89
|
+
|
90
|
+
info(f"Successfully pushed dataset with alias '{alias}'")
|
91
|
+
payload = response.json()
|
92
|
+
self._alias = payload.get("_alias")
|
93
|
+
self._id = payload.get("_id")
|
94
|
+
progress.update(
|
95
|
+
task_id,
|
96
|
+
description=f"{progress.tasks[task_id].description} [rgb(25,227,160)]Done!)",
|
97
|
+
)
|
98
|
+
return True
|
99
|
+
|
100
|
+
def pull(self, alias: str):
|
101
|
+
debug(f"Pulling dataset with alias '{alias}'")
|
102
|
+
"""
|
103
|
+
Pulls the dataset from Judgment platform
|
104
|
+
|
105
|
+
Mock request:
|
106
|
+
{
|
107
|
+
"alias": alias,
|
108
|
+
"user_id": user_id
|
109
|
+
}
|
110
|
+
==>
|
111
|
+
{
|
112
|
+
"ground_truths": [...],
|
113
|
+
"examples": [...],
|
114
|
+
"_alias": alias,
|
115
|
+
"_id": "..." # ID of the dataset
|
116
|
+
}
|
117
|
+
"""
|
118
|
+
# Make a POST request to the Judgment API to get the dataset
|
119
|
+
|
120
|
+
with Progress(
|
121
|
+
SpinnerColumn(style="rgb(106,0,255)"),
|
122
|
+
TextColumn("[progress.description]{task.description}"),
|
123
|
+
transient=False,
|
124
|
+
) as progress:
|
125
|
+
task_id = progress.add_task(
|
126
|
+
f"Pulling [rgb(106,0,255)]'{alias}'[/rgb(106,0,255)] from Judgment...",
|
127
|
+
total=100,
|
128
|
+
)
|
129
|
+
request_body = {
|
130
|
+
"alias": alias,
|
131
|
+
"judgment_api_key": self.judgment_api_key
|
132
|
+
}
|
133
|
+
|
134
|
+
try:
|
135
|
+
response = requests.post(
|
136
|
+
JUDGMENT_DATASETS_PULL_API_URL,
|
137
|
+
json=request_body
|
138
|
+
)
|
139
|
+
response.raise_for_status()
|
140
|
+
except requests.exceptions.RequestException as e:
|
141
|
+
error(f"Error pulling dataset: {str(e)}")
|
142
|
+
raise
|
143
|
+
|
144
|
+
info(f"Successfully pulled dataset with alias '{alias}'")
|
145
|
+
payload = response.json()
|
146
|
+
self.ground_truths = [GroundTruthExample(**g) for g in payload.get("ground_truths", [])]
|
147
|
+
self.examples = [Example(**e) for e in payload.get("examples", [])]
|
148
|
+
self._alias = payload.get("_alias")
|
149
|
+
self._id = payload.get("_id")
|
150
|
+
progress.update(
|
151
|
+
task_id,
|
152
|
+
description=f"{progress.tasks[task_id].description} [rgb(25,227,160)]Done!)",
|
153
|
+
)
|
154
|
+
|
155
|
+
def add_from_json(self, file_path: str) -> None:
|
156
|
+
debug(f"Loading dataset from JSON file: {file_path}")
|
157
|
+
"""
|
158
|
+
Adds examples and ground truths from a JSON file.
|
159
|
+
|
160
|
+
The format of the JSON file is expected to be a dictionary with two keys: "examples" and "ground_truths".
|
161
|
+
The value of each key is a list of dictionaries, where each dictionary represents an example or ground truth.
|
162
|
+
|
163
|
+
The JSON file is expected to have the following format:
|
164
|
+
{
|
165
|
+
"ground_truths": [
|
166
|
+
{
|
167
|
+
"input": "test input",
|
168
|
+
"actual_output": null,
|
169
|
+
"expected_output": "expected output",
|
170
|
+
"context": [
|
171
|
+
"context1"
|
172
|
+
],
|
173
|
+
"retrieval_context": [
|
174
|
+
"retrieval1"
|
175
|
+
],
|
176
|
+
"additional_metadata": {
|
177
|
+
"key": "value"
|
178
|
+
},
|
179
|
+
"comments": "test comment",
|
180
|
+
"tools_called": [
|
181
|
+
"tool1"
|
182
|
+
],
|
183
|
+
"expected_tools": [
|
184
|
+
"tool1"
|
185
|
+
],
|
186
|
+
"source_file": "test.py",
|
187
|
+
"trace_id": "094121"
|
188
|
+
}
|
189
|
+
],
|
190
|
+
"examples": [
|
191
|
+
{
|
192
|
+
"input": "test input",
|
193
|
+
"actual_output": "test output",
|
194
|
+
"expected_output": "expected output",
|
195
|
+
"context": [
|
196
|
+
"context1",
|
197
|
+
"context2"
|
198
|
+
],
|
199
|
+
"retrieval_context": [
|
200
|
+
"retrieval1"
|
201
|
+
],
|
202
|
+
"additional_metadata": {
|
203
|
+
"key": "value"
|
204
|
+
},
|
205
|
+
"tools_called": [
|
206
|
+
"tool1"
|
207
|
+
],
|
208
|
+
"expected_tools": [
|
209
|
+
"tool1",
|
210
|
+
"tool2"
|
211
|
+
],
|
212
|
+
"name": "test example",
|
213
|
+
"example_id": null,
|
214
|
+
"timestamp": "20241230_160117",
|
215
|
+
"trace_id": "123"
|
216
|
+
}
|
217
|
+
]
|
218
|
+
}
|
219
|
+
"""
|
220
|
+
try:
|
221
|
+
with open(file_path, "r") as file:
|
222
|
+
payload = json.load(file)
|
223
|
+
examples = payload.get("examples", [])
|
224
|
+
ground_truths = payload.get("ground_truths", [])
|
225
|
+
except FileNotFoundError:
|
226
|
+
error(f"JSON file not found: {file_path}")
|
227
|
+
raise FileNotFoundError(f"The file {file_path} was not found.")
|
228
|
+
except json.JSONDecodeError:
|
229
|
+
error(f"Invalid JSON file: {file_path}")
|
230
|
+
raise ValueError(f"The file {file_path} is not a valid JSON file.")
|
231
|
+
|
232
|
+
info(f"Added {len(examples)} examples and {len(ground_truths)} ground truths from JSON")
|
233
|
+
new_examples = [Example(**e) for e in examples]
|
234
|
+
for e in new_examples:
|
235
|
+
self.add_example(e)
|
236
|
+
|
237
|
+
new_ground_truths = [GroundTruthExample(**g) for g in ground_truths]
|
238
|
+
for g in new_ground_truths:
|
239
|
+
self.add_ground_truth(g)
|
240
|
+
|
241
|
+
def add_from_csv(
|
242
|
+
self,
|
243
|
+
file_path: str,
|
244
|
+
) -> None:
|
245
|
+
"""
|
246
|
+
Add Examples and GroundTruthExamples from a CSV file.
|
247
|
+
"""
|
248
|
+
try:
|
249
|
+
import pandas as pd
|
250
|
+
except ModuleNotFoundError:
|
251
|
+
raise ModuleNotFoundError(
|
252
|
+
"Please install pandas to use this method. 'pip install pandas'"
|
253
|
+
)
|
254
|
+
|
255
|
+
# Pandas naturally reads numbers in data files as ints, not strings (can lead to unexpected behavior)
|
256
|
+
df = pd.read_csv(file_path, dtype={'trace_id': str})
|
257
|
+
"""
|
258
|
+
Expect the CSV to have headers
|
259
|
+
|
260
|
+
"input", "actual_output", "expected_output", "context", \
|
261
|
+
"retrieval_context", "additional_metadata", "tools_called", \
|
262
|
+
"expected_tools", "name", "comments", "source_file", "example", \
|
263
|
+
"trace_id"
|
264
|
+
|
265
|
+
We want to collect the examples and ground truths separately which can
|
266
|
+
be determined by the "example" column. If the value is True, then it is an
|
267
|
+
example, otherwise it is a ground truth.
|
268
|
+
|
269
|
+
We also assume that if there are multiple retrieval contexts or contexts, they are separated by semicolons.
|
270
|
+
This can be adjusted using the `context_delimiter` and `retrieval_context_delimiter` parameters.
|
271
|
+
"""
|
272
|
+
examples, ground_truths = [], []
|
273
|
+
|
274
|
+
for _, row in df.iterrows():
|
275
|
+
data = {
|
276
|
+
"input": row["input"],
|
277
|
+
"actual_output": row["actual_output"] if pd.notna(row["actual_output"]) else None,
|
278
|
+
"expected_output": row["expected_output"] if pd.notna(row["expected_output"]) else None,
|
279
|
+
"context": row["context"].split(";") if pd.notna(row["context"]) else [],
|
280
|
+
"retrieval_context": row["retrieval_context"].split(";") if pd.notna(row["retrieval_context"]) else [],
|
281
|
+
"additional_metadata": ast.literal_eval(row["additional_metadata"]) if pd.notna(row["additional_metadata"]) else dict(),
|
282
|
+
"tools_called": row["tools_called"].split(";") if pd.notna(row["tools_called"]) else [],
|
283
|
+
"expected_tools": row["expected_tools"].split(";") if pd.notna(row["expected_tools"]) else [],
|
284
|
+
"trace_id": row["trace_id"] if pd.notna(row["trace_id"]) else None
|
285
|
+
}
|
286
|
+
if row["example"]:
|
287
|
+
data["name"] = row["name"] if pd.notna(row["name"]) else None
|
288
|
+
# every Example has `input` and `actual_output` fields
|
289
|
+
if data["input"] is not None and data["actual_output"] is not None:
|
290
|
+
e = Example(**data)
|
291
|
+
examples.append(e)
|
292
|
+
else:
|
293
|
+
raise ValueError("Every example must have an 'input' and 'actual_output' field.")
|
294
|
+
else:
|
295
|
+
# GroundTruthExample has `comments` and `source_file` fields
|
296
|
+
data["comments"] = row["comments"] if pd.notna(row["comments"]) else None
|
297
|
+
data["source_file"] = row["source_file"] if pd.notna(row["source_file"]) else None
|
298
|
+
# every GroundTruthExample has `input` field
|
299
|
+
if data["input"] is not None:
|
300
|
+
g = GroundTruthExample(**data)
|
301
|
+
ground_truths.append(g)
|
302
|
+
else:
|
303
|
+
raise ValueError("Every ground truth must have an 'input' field.")
|
304
|
+
|
305
|
+
for e in examples:
|
306
|
+
self.add_example(e)
|
307
|
+
|
308
|
+
for g in ground_truths:
|
309
|
+
self.add_ground_truth(g)
|
310
|
+
|
311
|
+
def add_example(self, e: Example) -> None:
|
312
|
+
self.examples = self.examples + [e]
|
313
|
+
# TODO if we need to add rank, then we need to do it here
|
314
|
+
|
315
|
+
def add_ground_truth(self, g: GroundTruthExample) -> None:
|
316
|
+
self.ground_truths = self.ground_truths + [g]
|
317
|
+
|
318
|
+
def save_as(self, file_type: Literal["json", "csv"], dir_path: str, save_name: str = None) -> None:
|
319
|
+
"""
|
320
|
+
Saves the dataset as a file. Save both the ground truths and examples.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
file_type (Literal["json", "csv"]): The file type to save the dataset as.
|
324
|
+
dir_path (str): The directory path to save the file to.
|
325
|
+
save_name (str, optional): The name of the file to save. Defaults to None.
|
326
|
+
"""
|
327
|
+
if not os.path.exists(dir_path):
|
328
|
+
os.makedirs(dir_path)
|
329
|
+
file_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") if save_name is None else save_name
|
330
|
+
complete_path = os.path.join(dir_path, f"{file_name}.{file_type}")
|
331
|
+
if file_type == "json":
|
332
|
+
with open(complete_path, "w") as file:
|
333
|
+
json.dump(
|
334
|
+
{
|
335
|
+
"ground_truths": [g.to_dict() for g in self.ground_truths],
|
336
|
+
"examples": [e.to_dict() for e in self.examples],
|
337
|
+
},
|
338
|
+
file,
|
339
|
+
indent=4,
|
340
|
+
)
|
341
|
+
elif file_type == "csv":
|
342
|
+
with open(complete_path, "w", newline="") as file:
|
343
|
+
writer = csv.writer(file)
|
344
|
+
writer.writerow([
|
345
|
+
"input", "actual_output", "expected_output", "context", \
|
346
|
+
"retrieval_context", "additional_metadata", "tools_called", \
|
347
|
+
"expected_tools", "name", "comments", "source_file", "example", \
|
348
|
+
"trace_id"
|
349
|
+
])
|
350
|
+
for e in self.examples:
|
351
|
+
writer.writerow(
|
352
|
+
[
|
353
|
+
e.input,
|
354
|
+
e.actual_output,
|
355
|
+
e.expected_output,
|
356
|
+
";".join(e.context),
|
357
|
+
";".join(e.retrieval_context),
|
358
|
+
e.additional_metadata,
|
359
|
+
";".join(e.tools_called),
|
360
|
+
";".join(e.expected_tools),
|
361
|
+
e.name,
|
362
|
+
None, # Example does not have comments
|
363
|
+
None, # Example does not have source file
|
364
|
+
True, # Adding an Example
|
365
|
+
e.trace_id
|
366
|
+
]
|
367
|
+
)
|
368
|
+
|
369
|
+
for g in self.ground_truths:
|
370
|
+
writer.writerow(
|
371
|
+
[
|
372
|
+
g.input,
|
373
|
+
g.actual_output,
|
374
|
+
g.expected_output,
|
375
|
+
";".join(g.context),
|
376
|
+
";".join(g.retrieval_context),
|
377
|
+
g.additional_metadata,
|
378
|
+
";".join(g.tools_called),
|
379
|
+
";".join(g.expected_tools),
|
380
|
+
None, # GroundTruthExample does not have name
|
381
|
+
g.comments,
|
382
|
+
g.source_file,
|
383
|
+
False, # Adding a GroundTruthExample, not an Example
|
384
|
+
g.trace_id
|
385
|
+
]
|
386
|
+
)
|
387
|
+
else:
|
388
|
+
ACCEPTABLE_FILE_TYPES = ["json", "csv"]
|
389
|
+
raise TypeError(f"Invalid file type: {file_type}. Please choose from {ACCEPTABLE_FILE_TYPES}")
|
390
|
+
|
391
|
+
def __iter__(self):
|
392
|
+
return iter(self.examples)
|
393
|
+
|
394
|
+
def __len__(self):
|
395
|
+
return len(self.examples)
|
396
|
+
|
397
|
+
def __str__(self):
|
398
|
+
return (
|
399
|
+
f"{self.__class__.__name__}("
|
400
|
+
f"ground_truths={self.ground_truths}, "
|
401
|
+
f"examples={self.examples}, "
|
402
|
+
f"_alias={self._alias}, "
|
403
|
+
f"_id={self._id}"
|
404
|
+
f")"
|
405
|
+
)
|
406
|
+
|
407
|
+
|
@@ -0,0 +1,54 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
from typing import Optional, Dict, List
|
3
|
+
|
4
|
+
|
5
|
+
class GroundTruthExample(BaseModel):
|
6
|
+
"""
|
7
|
+
GroundTruthExample is the atomic unit of a `Dataset`. It is essentially the same
|
8
|
+
as an `Example`, but the `actual_output` field is optional to enable users to
|
9
|
+
run their workflow on the `input` field at test-time to evaluate their current
|
10
|
+
workflow's performance.
|
11
|
+
"""
|
12
|
+
input: str
|
13
|
+
actual_output: Optional[str] = None
|
14
|
+
expected_output: Optional[str] = None
|
15
|
+
context: Optional[List[str]] = None
|
16
|
+
retrieval_context: Optional[List[str]] = None
|
17
|
+
additional_metadata: Optional[Dict] = None
|
18
|
+
comments: Optional[str] = None
|
19
|
+
tools_called: Optional[List[str]] = None
|
20
|
+
expected_tools: Optional[List[str]] = None
|
21
|
+
source_file: Optional[str] = None
|
22
|
+
trace_id: Optional[str] = None
|
23
|
+
|
24
|
+
def to_dict(self):
|
25
|
+
return {
|
26
|
+
"input": self.input,
|
27
|
+
"actual_output": self.actual_output,
|
28
|
+
"expected_output": self.expected_output,
|
29
|
+
"context": self.context,
|
30
|
+
"retrieval_context": self.retrieval_context,
|
31
|
+
"additional_metadata": self.additional_metadata,
|
32
|
+
"comments": self.comments,
|
33
|
+
"tools_called": self.tools_called,
|
34
|
+
"expected_tools": self.expected_tools,
|
35
|
+
"source_file": self.source_file,
|
36
|
+
"trace_id": self.trace_id,
|
37
|
+
}
|
38
|
+
|
39
|
+
def __str__(self):
|
40
|
+
return (
|
41
|
+
f"{self.__class__.__name__}("
|
42
|
+
f"input={self.input}, "
|
43
|
+
f"actual_output={self.actual_output}, "
|
44
|
+
f"expected_output={self.expected_output}, "
|
45
|
+
f"context={self.context}, "
|
46
|
+
f"retrieval_context={self.retrieval_context}, "
|
47
|
+
f"additional_metadata={self.additional_metadata}, "
|
48
|
+
f"comments={self.comments}, "
|
49
|
+
f"tools_called={self.tools_called}, "
|
50
|
+
f"expected_tools={self.expected_tools}, "
|
51
|
+
f"source_file={self.source_file}, "
|
52
|
+
f"trace_id={self.trace_id}"
|
53
|
+
f")"
|
54
|
+
)
|