judgeval 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/tracer.py +229 -44
- judgeval/constants.py +15 -3
- judgeval/data/datasets/__init__.py +2 -1
- judgeval/data/datasets/dataset.py +1 -122
- judgeval/data/datasets/eval_dataset_client.py +193 -0
- judgeval/data/result.py +16 -1
- judgeval/evaluation_run.py +2 -1
- judgeval/judges/utils.py +14 -2
- judgeval/judgment_client.py +64 -7
- judgeval/run_evaluation.py +19 -0
- judgeval/scorers/judgeval_scorer.py +8 -8
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/faithfulness_scorer.py +3 -1
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/prompts.py +6 -3
- judgeval/scorers/prompt_scorer.py +2 -2
- judgeval/scorers/score.py +11 -11
- judgeval/scorers/utils.py +3 -3
- judgeval/tracer/__init__.py +3 -0
- {judgeval-0.0.9.dist-info → judgeval-0.0.11.dist-info}/METADATA +5 -4
- {judgeval-0.0.9.dist-info → judgeval-0.0.11.dist-info}/RECORD +21 -19
- {judgeval-0.0.9.dist-info → judgeval-0.0.11.dist-info}/WHEEL +0 -0
- {judgeval-0.0.9.dist-info → judgeval-0.0.11.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -2,13 +2,26 @@
|
|
2
2
|
Tracing system for judgeval that allows for function tracing using decorators.
|
3
3
|
"""
|
4
4
|
|
5
|
+
import os
|
5
6
|
import time
|
6
7
|
import functools
|
7
8
|
import requests
|
8
9
|
import uuid
|
9
10
|
from contextlib import contextmanager
|
10
|
-
from typing import
|
11
|
-
|
11
|
+
from typing import (
|
12
|
+
Optional,
|
13
|
+
Any,
|
14
|
+
List,
|
15
|
+
Literal,
|
16
|
+
Tuple,
|
17
|
+
Generator,
|
18
|
+
TypeAlias,
|
19
|
+
Union
|
20
|
+
)
|
21
|
+
from dataclasses import (
|
22
|
+
dataclass,
|
23
|
+
field
|
24
|
+
)
|
12
25
|
from datetime import datetime
|
13
26
|
from openai import OpenAI
|
14
27
|
from together import Together
|
@@ -21,16 +34,25 @@ import warnings
|
|
21
34
|
from pydantic import BaseModel
|
22
35
|
from http import HTTPStatus
|
23
36
|
|
24
|
-
|
37
|
+
import pika
|
38
|
+
import os
|
39
|
+
|
40
|
+
from judgeval.constants import JUDGMENT_TRACES_SAVE_API_URL, JUDGMENT_TRACES_FETCH_API_URL, RABBITMQ_HOST, RABBITMQ_PORT, RABBITMQ_QUEUE, JUDGMENT_TRACES_DELETE_API_URL
|
25
41
|
from judgeval.judgment_client import JudgmentClient
|
26
42
|
from judgeval.data import Example
|
27
|
-
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
43
|
+
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer, ScorerWrapper
|
44
|
+
|
45
|
+
from rich import print as rprint
|
46
|
+
|
28
47
|
from judgeval.data.result import ScoringResult
|
48
|
+
from judgeval.evaluation_run import EvaluationRun
|
29
49
|
|
30
50
|
# Define type aliases for better code readability and maintainability
|
31
51
|
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
|
32
52
|
TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
|
33
53
|
SpanType = Literal['span', 'tool', 'llm', 'evaluation']
|
54
|
+
|
55
|
+
|
34
56
|
@dataclass
|
35
57
|
class TraceEntry:
|
36
58
|
"""Represents a single trace entry with its visual representation.
|
@@ -52,7 +74,7 @@ class TraceEntry:
|
|
52
74
|
# Use field() for mutable defaults to avoid shared state issues
|
53
75
|
inputs: dict = field(default_factory=dict)
|
54
76
|
span_type: SpanType = "span"
|
55
|
-
|
77
|
+
evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
|
56
78
|
|
57
79
|
def print_entry(self):
|
58
80
|
indent = " " * self.depth
|
@@ -65,7 +87,8 @@ class TraceEntry:
|
|
65
87
|
elif self.type == "input":
|
66
88
|
print(f"{indent}Input: {self.inputs}")
|
67
89
|
elif self.type == "evaluation":
|
68
|
-
|
90
|
+
for evaluation_run in self.evaluation_runs:
|
91
|
+
print(f"{indent}Evaluation: {evaluation_run.model_dump()}")
|
69
92
|
|
70
93
|
def _serialize_inputs(self) -> dict:
|
71
94
|
"""Helper method to serialize input data safely.
|
@@ -112,7 +135,7 @@ class TraceEntry:
|
|
112
135
|
"duration": self.duration,
|
113
136
|
"output": self._serialize_output(),
|
114
137
|
"inputs": self._serialize_inputs(),
|
115
|
-
"
|
138
|
+
"evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
|
116
139
|
"span_type": self.span_type
|
117
140
|
}
|
118
141
|
|
@@ -121,8 +144,29 @@ class TraceEntry:
|
|
121
144
|
|
122
145
|
Handles special cases:
|
123
146
|
- Pydantic models are converted using model_dump()
|
147
|
+
- We try to serialize into JSON, then string, then the base representation (__repr__)
|
124
148
|
- Non-serializable objects return None with a warning
|
125
149
|
"""
|
150
|
+
|
151
|
+
def safe_stringify(output, function_name):
|
152
|
+
"""
|
153
|
+
Safely converts an object to a string or repr, handling serialization issues gracefully.
|
154
|
+
"""
|
155
|
+
try:
|
156
|
+
return str(output)
|
157
|
+
except (TypeError, OverflowError, ValueError):
|
158
|
+
pass
|
159
|
+
|
160
|
+
try:
|
161
|
+
return repr(output)
|
162
|
+
except (TypeError, OverflowError, ValueError):
|
163
|
+
pass
|
164
|
+
|
165
|
+
warnings.warn(
|
166
|
+
f"Output for function {function_name} is not JSON serializable and could not be converted to string. Setting to None."
|
167
|
+
)
|
168
|
+
return None
|
169
|
+
|
126
170
|
if isinstance(self.output, BaseModel):
|
127
171
|
return self.output.model_dump()
|
128
172
|
|
@@ -131,8 +175,107 @@ class TraceEntry:
|
|
131
175
|
json.dumps(self.output)
|
132
176
|
return self.output
|
133
177
|
except (TypeError, OverflowError, ValueError):
|
134
|
-
|
135
|
-
|
178
|
+
return safe_stringify(self.output, self.function)
|
179
|
+
|
180
|
+
|
181
|
+
class TraceManagerClient:
|
182
|
+
"""
|
183
|
+
Client for handling trace endpoints with the Judgment API
|
184
|
+
|
185
|
+
|
186
|
+
Operations include:
|
187
|
+
- Fetching a trace by id
|
188
|
+
- Saving a trace
|
189
|
+
- Deleting a trace
|
190
|
+
"""
|
191
|
+
def __init__(self, judgment_api_key: str):
|
192
|
+
self.judgment_api_key = judgment_api_key
|
193
|
+
|
194
|
+
def fetch_trace(self, trace_id: str):
|
195
|
+
"""
|
196
|
+
Fetch a trace by its id
|
197
|
+
"""
|
198
|
+
response = requests.post(
|
199
|
+
JUDGMENT_TRACES_FETCH_API_URL,
|
200
|
+
json={
|
201
|
+
"trace_id": trace_id,
|
202
|
+
"judgment_api_key": self.judgment_api_key,
|
203
|
+
},
|
204
|
+
headers={
|
205
|
+
"Content-Type": "application/json",
|
206
|
+
}
|
207
|
+
)
|
208
|
+
|
209
|
+
if response.status_code != HTTPStatus.OK:
|
210
|
+
raise ValueError(f"Failed to fetch traces: {response.text}")
|
211
|
+
|
212
|
+
return response.json()
|
213
|
+
|
214
|
+
def save_trace(self, trace_data: dict, empty_save: bool):
|
215
|
+
"""
|
216
|
+
Saves a trace to the database
|
217
|
+
|
218
|
+
Args:
|
219
|
+
trace_data: The trace data to save
|
220
|
+
empty_save: Whether to save an empty trace
|
221
|
+
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
222
|
+
"""
|
223
|
+
response = requests.post(
|
224
|
+
JUDGMENT_TRACES_SAVE_API_URL,
|
225
|
+
json=trace_data,
|
226
|
+
headers={
|
227
|
+
"Content-Type": "application/json",
|
228
|
+
}
|
229
|
+
)
|
230
|
+
|
231
|
+
if response.status_code == HTTPStatus.BAD_REQUEST:
|
232
|
+
raise ValueError(f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}")
|
233
|
+
elif response.status_code != HTTPStatus.OK:
|
234
|
+
raise ValueError(f"Failed to save trace data: {response.text}")
|
235
|
+
|
236
|
+
if not empty_save and "ui_results_url" in response.json():
|
237
|
+
rprint(f"\n🔍 You can view your trace data here: [rgb(106,0,255)]{response.json()['ui_results_url']}[/]\n")
|
238
|
+
|
239
|
+
def delete_trace(self, trace_id: str):
|
240
|
+
"""
|
241
|
+
Delete a trace from the database.
|
242
|
+
"""
|
243
|
+
response = requests.delete(
|
244
|
+
JUDGMENT_TRACES_DELETE_API_URL,
|
245
|
+
json={
|
246
|
+
"judgment_api_key": self.judgment_api_key,
|
247
|
+
"trace_ids": [trace_id],
|
248
|
+
},
|
249
|
+
headers={
|
250
|
+
"Content-Type": "application/json",
|
251
|
+
}
|
252
|
+
)
|
253
|
+
|
254
|
+
if response.status_code != HTTPStatus.OK:
|
255
|
+
raise ValueError(f"Failed to delete trace: {response.text}")
|
256
|
+
|
257
|
+
return response.json()
|
258
|
+
|
259
|
+
def delete_traces(self, trace_ids: List[str]):
|
260
|
+
"""
|
261
|
+
Delete a batch of traces from the database.
|
262
|
+
"""
|
263
|
+
response = requests.delete(
|
264
|
+
JUDGMENT_TRACES_DELETE_API_URL,
|
265
|
+
json={
|
266
|
+
"judgment_api_key": self.judgment_api_key,
|
267
|
+
"trace_ids": trace_ids,
|
268
|
+
},
|
269
|
+
headers={
|
270
|
+
"Content-Type": "application/json",
|
271
|
+
}
|
272
|
+
)
|
273
|
+
|
274
|
+
if response.status_code != HTTPStatus.OK:
|
275
|
+
raise ValueError(f"Failed to delete trace: {response.text}")
|
276
|
+
|
277
|
+
return response.json()
|
278
|
+
|
136
279
|
|
137
280
|
class TraceClient:
|
138
281
|
"""Client for managing a single trace context"""
|
@@ -147,6 +290,7 @@ class TraceClient:
|
|
147
290
|
self.span_type = None
|
148
291
|
self._current_span: Optional[TraceEntry] = None
|
149
292
|
self.overwrite = overwrite
|
293
|
+
self.trace_manager_client = TraceManagerClient(tracer.api_key) # Manages DB operations for trace data
|
150
294
|
|
151
295
|
@contextmanager
|
152
296
|
def span(self, name: str, span_type: SpanType = "span"):
|
@@ -163,6 +307,7 @@ class TraceClient:
|
|
163
307
|
span_type=span_type
|
164
308
|
))
|
165
309
|
|
310
|
+
# Increment nested depth and set current span
|
166
311
|
self.tracer.depth += 1
|
167
312
|
prev_span = self._current_span
|
168
313
|
self._current_span = name
|
@@ -185,7 +330,7 @@ class TraceClient:
|
|
185
330
|
))
|
186
331
|
self._current_span = prev_span
|
187
332
|
|
188
|
-
|
333
|
+
def async_evaluate(
|
189
334
|
self,
|
190
335
|
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
191
336
|
input: Optional[str] = None,
|
@@ -211,25 +356,40 @@ class TraceClient:
|
|
211
356
|
additional_metadata=additional_metadata,
|
212
357
|
trace_id=self.trace_id
|
213
358
|
)
|
214
|
-
|
215
|
-
|
216
|
-
scorers
|
217
|
-
|
218
|
-
|
359
|
+
|
360
|
+
try:
|
361
|
+
# Load appropriate implementations for all scorers
|
362
|
+
loaded_scorers: List[Union[JudgevalScorer, APIJudgmentScorer]] = [
|
363
|
+
scorer.load_implementation(use_judgment=True) if isinstance(scorer, ScorerWrapper) else scorer
|
364
|
+
for scorer in scorers
|
365
|
+
]
|
366
|
+
except Exception as e:
|
367
|
+
raise ValueError(f"Failed to load scorers: {str(e)}")
|
368
|
+
|
369
|
+
eval_run = EvaluationRun(
|
219
370
|
log_results=log_results,
|
220
371
|
project_name=self.project_name,
|
221
|
-
|
222
|
-
f"{self.name.capitalize()}-"
|
372
|
+
eval_name=f"{self.name.capitalize()}-"
|
223
373
|
f"{self._current_span}-"
|
224
|
-
f"[{','.join(scorer.load_implementation().score_type.capitalize() for scorer in scorers)}]"
|
225
|
-
|
374
|
+
f"[{','.join(scorer.load_implementation().score_type.capitalize() for scorer in scorers)}]",
|
375
|
+
examples=[example],
|
376
|
+
scorers=loaded_scorers,
|
377
|
+
model=model,
|
378
|
+
metadata={},
|
379
|
+
judgment_api_key=self.tracer.api_key,
|
226
380
|
override=self.overwrite
|
227
381
|
)
|
228
382
|
|
229
|
-
self.
|
383
|
+
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
230
384
|
|
231
|
-
def
|
232
|
-
"""
|
385
|
+
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
386
|
+
"""
|
387
|
+
Add evaluation run data to the trace
|
388
|
+
|
389
|
+
Args:
|
390
|
+
eval_run (EvaluationRun): The evaluation run to add to the trace
|
391
|
+
start_time (float): The start time of the evaluation run
|
392
|
+
"""
|
233
393
|
if self._current_span:
|
234
394
|
duration = time.time() - start_time # Calculate duration from start_time
|
235
395
|
|
@@ -239,7 +399,7 @@ class TraceClient:
|
|
239
399
|
depth=self.tracer.depth,
|
240
400
|
message=f"Evaluation results for {self._current_span}",
|
241
401
|
timestamp=time.time(),
|
242
|
-
|
402
|
+
evaluation_runs=[eval_run],
|
243
403
|
duration=duration,
|
244
404
|
span_type="evaluation"
|
245
405
|
))
|
@@ -320,7 +480,7 @@ class TraceClient:
|
|
320
480
|
"timestamp": entry["timestamp"],
|
321
481
|
"inputs": None,
|
322
482
|
"output": None,
|
323
|
-
"
|
483
|
+
"evaluation_runs": [],
|
324
484
|
"span_type": entry.get("span_type", "span")
|
325
485
|
}
|
326
486
|
active_functions.append(function)
|
@@ -343,8 +503,8 @@ class TraceClient:
|
|
343
503
|
if entry["type"] == "output" and entry["output"]:
|
344
504
|
current_entry["output"] = entry["output"]
|
345
505
|
|
346
|
-
if entry["type"] == "evaluation" and entry["
|
347
|
-
current_entry["
|
506
|
+
if entry["type"] == "evaluation" and entry["evaluation_runs"]:
|
507
|
+
current_entry["evaluation_runs"] = entry["evaluation_runs"]
|
348
508
|
|
349
509
|
# Sort by timestamp
|
350
510
|
condensed.sort(key=lambda x: x["timestamp"])
|
@@ -361,6 +521,24 @@ class TraceClient:
|
|
361
521
|
raw_entries = [entry.to_dict() for entry in self.entries]
|
362
522
|
condensed_entries = self.condense_trace(raw_entries)
|
363
523
|
|
524
|
+
# Calculate total token counts from LLM API calls
|
525
|
+
total_prompt_tokens = 0
|
526
|
+
total_completion_tokens = 0
|
527
|
+
total_tokens = 0
|
528
|
+
|
529
|
+
for entry in condensed_entries:
|
530
|
+
if entry.get("span_type") == "llm" and isinstance(entry.get("output"), dict):
|
531
|
+
usage = entry["output"].get("usage", {})
|
532
|
+
# Handle OpenAI/Together format
|
533
|
+
if "prompt_tokens" in usage:
|
534
|
+
total_prompt_tokens += usage.get("prompt_tokens", 0)
|
535
|
+
total_completion_tokens += usage.get("completion_tokens", 0)
|
536
|
+
# Handle Anthropic format
|
537
|
+
elif "input_tokens" in usage:
|
538
|
+
total_prompt_tokens += usage.get("input_tokens", 0)
|
539
|
+
total_completion_tokens += usage.get("output_tokens", 0)
|
540
|
+
total_tokens += usage.get("total_tokens", 0)
|
541
|
+
|
364
542
|
# Create trace document
|
365
543
|
trace_data = {
|
366
544
|
"trace_id": self.trace_id,
|
@@ -370,31 +548,38 @@ class TraceClient:
|
|
370
548
|
"created_at": datetime.fromtimestamp(self.start_time).isoformat(),
|
371
549
|
"duration": total_duration,
|
372
550
|
"token_counts": {
|
373
|
-
"prompt_tokens":
|
374
|
-
"completion_tokens":
|
375
|
-
"total_tokens":
|
376
|
-
},
|
551
|
+
"prompt_tokens": total_prompt_tokens,
|
552
|
+
"completion_tokens": total_completion_tokens,
|
553
|
+
"total_tokens": total_tokens,
|
554
|
+
},
|
377
555
|
"entries": condensed_entries,
|
378
556
|
"empty_save": empty_save,
|
379
557
|
"overwrite": overwrite
|
380
558
|
}
|
381
|
-
|
382
|
-
# Save trace data by making POST request to API
|
383
|
-
response = requests.post(
|
384
|
-
JUDGMENT_TRACES_SAVE_API_URL,
|
385
|
-
json=trace_data,
|
386
|
-
headers={
|
387
|
-
"Content-Type": "application/json",
|
388
|
-
}
|
389
|
-
)
|
390
559
|
|
391
|
-
if
|
392
|
-
|
393
|
-
|
394
|
-
|
560
|
+
if not empty_save:
|
561
|
+
connection = pika.BlockingConnection(
|
562
|
+
pika.ConnectionParameters(host=RABBITMQ_HOST, port=RABBITMQ_PORT))
|
563
|
+
channel = connection.channel()
|
564
|
+
|
565
|
+
channel.queue_declare(queue=RABBITMQ_QUEUE, durable=True)
|
566
|
+
|
567
|
+
channel.basic_publish(
|
568
|
+
exchange='',
|
569
|
+
routing_key=RABBITMQ_QUEUE,
|
570
|
+
body=json.dumps(trace_data),
|
571
|
+
properties=pika.BasicProperties(
|
572
|
+
delivery_mode=pika.DeliveryMode.Transient # Changed from Persistent to Transient
|
573
|
+
))
|
574
|
+
connection.close()
|
395
575
|
|
576
|
+
self.trace_manager_client.save_trace(trace_data, empty_save)
|
577
|
+
|
396
578
|
return self.trace_id, trace_data
|
397
579
|
|
580
|
+
def delete(self):
|
581
|
+
return self.trace_manager_client.delete_trace(self.trace_id)
|
582
|
+
|
398
583
|
class Tracer:
|
399
584
|
_instance = None
|
400
585
|
|
@@ -403,7 +588,7 @@ class Tracer:
|
|
403
588
|
cls._instance = super(Tracer, cls).__new__(cls)
|
404
589
|
return cls._instance
|
405
590
|
|
406
|
-
def __init__(self, api_key: str):
|
591
|
+
def __init__(self, api_key: str = os.getenv("JUDGMENT_API_KEY")):
|
407
592
|
if not hasattr(self, 'initialized'):
|
408
593
|
|
409
594
|
if not api_key:
|
judgeval/constants.py
CHANGED
@@ -32,15 +32,25 @@ class APIScorer(str, Enum):
|
|
32
32
|
return member
|
33
33
|
|
34
34
|
ROOT_API = os.getenv("JUDGMENT_API_URL", "https://api.judgmentlabs.ai")
|
35
|
-
|
35
|
+
# API URLs
|
36
36
|
JUDGMENT_EVAL_API_URL = f"{ROOT_API}/evaluate/"
|
37
37
|
JUDGMENT_DATASETS_PUSH_API_URL = f"{ROOT_API}/datasets/push/"
|
38
38
|
JUDGMENT_DATASETS_PULL_API_URL = f"{ROOT_API}/datasets/pull/"
|
39
|
+
JUDGMENT_DATASETS_PULL_ALL_API_URL = f"{ROOT_API}/datasets/get_all_stats/"
|
39
40
|
JUDGMENT_EVAL_LOG_API_URL = f"{ROOT_API}/log_eval_results/"
|
40
41
|
JUDGMENT_EVAL_FETCH_API_URL = f"{ROOT_API}/fetch_eval_results/"
|
42
|
+
JUDGMENT_EVAL_DELETE_API_URL = f"{ROOT_API}/delete_eval_results_by_project_and_run_name/"
|
43
|
+
JUDGMENT_EVAL_DELETE_PROJECT_API_URL = f"{ROOT_API}/delete_eval_results_by_project/"
|
44
|
+
JUDGMENT_TRACES_FETCH_API_URL = f"{ROOT_API}/traces/fetch/"
|
41
45
|
JUDGMENT_TRACES_SAVE_API_URL = f"{ROOT_API}/traces/save/"
|
46
|
+
JUDGMENT_TRACES_DELETE_API_URL = f"{ROOT_API}/traces/delete/"
|
42
47
|
|
43
|
-
|
48
|
+
# RabbitMQ
|
49
|
+
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq-networklb-faa155df16ec9085.elb.us-west-1.amazonaws.com")
|
50
|
+
RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", 5672)
|
51
|
+
RABBITMQ_QUEUE = os.getenv("RABBITMQ_QUEUE", "task_queue")
|
52
|
+
|
53
|
+
# Models
|
44
54
|
TOGETHER_SUPPORTED_MODELS = {
|
45
55
|
"QWEN": "Qwen/Qwen2-72B-Instruct",
|
46
56
|
"LLAMA3_70B_INSTRUCT_TURBO": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
@@ -50,7 +60,9 @@ TOGETHER_SUPPORTED_MODELS = {
|
|
50
60
|
"MISTRAL_8x7B_INSTRUCT": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
51
61
|
}
|
52
62
|
|
53
|
-
|
63
|
+
JUDGMENT_SUPPORTED_MODELS = {"osiris-large", "osiris-mini"}
|
64
|
+
|
65
|
+
ACCEPTABLE_MODELS = set(litellm.model_list) | set(TOGETHER_SUPPORTED_MODELS.keys()) | JUDGMENT_SUPPORTED_MODELS
|
54
66
|
|
55
67
|
## System settings
|
56
68
|
MAX_WORKER_THREADS = 10
|
@@ -1,4 +1,5 @@
|
|
1
1
|
from judgeval.data.datasets.dataset import EvalDataset
|
2
2
|
from judgeval.data.datasets.ground_truth import GroundTruthExample
|
3
|
+
from judgeval.data.datasets.eval_dataset_client import EvalDatasetClient
|
3
4
|
|
4
|
-
__all__ = ["EvalDataset", "GroundTruthExample"]
|
5
|
+
__all__ = ["EvalDataset", "EvalDatasetClient", "GroundTruthExample"]
|
@@ -2,16 +2,11 @@ import ast
|
|
2
2
|
import csv
|
3
3
|
import datetime
|
4
4
|
import json
|
5
|
-
from rich.console import Console
|
6
|
-
from rich.progress import Progress, SpinnerColumn, TextColumn
|
7
|
-
import requests
|
8
5
|
from dataclasses import dataclass, field
|
9
6
|
import os
|
10
7
|
from typing import List, Optional, Union, Literal
|
11
8
|
|
12
|
-
from judgeval.constants import JUDGMENT_DATASETS_PUSH_API_URL, JUDGMENT_DATASETS_PULL_API_URL
|
13
9
|
from judgeval.data.datasets.ground_truth import GroundTruthExample
|
14
|
-
from judgeval.data.datasets.utils import ground_truths_to_examples, examples_to_ground_truths
|
15
10
|
from judgeval.data import Example
|
16
11
|
from judgeval.common.logger import debug, error, warning, info
|
17
12
|
|
@@ -37,120 +32,6 @@ class EvalDataset:
|
|
37
32
|
self._id = None
|
38
33
|
self.judgment_api_key = judgment_api_key
|
39
34
|
|
40
|
-
def push(self, alias: str, overwrite: Optional[bool] = False) -> bool:
|
41
|
-
debug(f"Pushing dataset with alias '{alias}' (overwrite={overwrite})")
|
42
|
-
if overwrite:
|
43
|
-
warning(f"Overwrite enabled for alias '{alias}'")
|
44
|
-
"""
|
45
|
-
Pushes the dataset to Judgment platform
|
46
|
-
|
47
|
-
Mock request:
|
48
|
-
{
|
49
|
-
"alias": alias,
|
50
|
-
"ground_truths": [...],
|
51
|
-
"examples": [...],
|
52
|
-
"overwrite": overwrite
|
53
|
-
} ==>
|
54
|
-
{
|
55
|
-
"_alias": alias,
|
56
|
-
"_id": "..." # ID of the dataset
|
57
|
-
}
|
58
|
-
"""
|
59
|
-
with Progress(
|
60
|
-
SpinnerColumn(style="rgb(106,0,255)"),
|
61
|
-
TextColumn("[progress.description]{task.description}"),
|
62
|
-
transient=False,
|
63
|
-
) as progress:
|
64
|
-
task_id = progress.add_task(
|
65
|
-
f"Pushing [rgb(106,0,255)]'{alias}' to Judgment...",
|
66
|
-
total=100,
|
67
|
-
)
|
68
|
-
content = {
|
69
|
-
"alias": alias,
|
70
|
-
"ground_truths": [g.to_dict() for g in self.ground_truths],
|
71
|
-
"examples": [e.to_dict() for e in self.examples],
|
72
|
-
"overwrite": overwrite,
|
73
|
-
"judgment_api_key": self.judgment_api_key
|
74
|
-
}
|
75
|
-
try:
|
76
|
-
response = requests.post(
|
77
|
-
JUDGMENT_DATASETS_PUSH_API_URL,
|
78
|
-
json=content
|
79
|
-
)
|
80
|
-
if response.status_code == 500:
|
81
|
-
error(f"Server error during push: {content.get('message')}")
|
82
|
-
return False
|
83
|
-
response.raise_for_status()
|
84
|
-
except requests.exceptions.HTTPError as err:
|
85
|
-
if response.status_code == 422:
|
86
|
-
error(f"Validation error during push: {err.response.json()}")
|
87
|
-
else:
|
88
|
-
error(f"HTTP error during push: {err}")
|
89
|
-
|
90
|
-
info(f"Successfully pushed dataset with alias '{alias}'")
|
91
|
-
payload = response.json()
|
92
|
-
self._alias = payload.get("_alias")
|
93
|
-
self._id = payload.get("_id")
|
94
|
-
progress.update(
|
95
|
-
task_id,
|
96
|
-
description=f"{progress.tasks[task_id].description} [rgb(25,227,160)]Done!)",
|
97
|
-
)
|
98
|
-
return True
|
99
|
-
|
100
|
-
def pull(self, alias: str):
|
101
|
-
debug(f"Pulling dataset with alias '{alias}'")
|
102
|
-
"""
|
103
|
-
Pulls the dataset from Judgment platform
|
104
|
-
|
105
|
-
Mock request:
|
106
|
-
{
|
107
|
-
"alias": alias,
|
108
|
-
"user_id": user_id
|
109
|
-
}
|
110
|
-
==>
|
111
|
-
{
|
112
|
-
"ground_truths": [...],
|
113
|
-
"examples": [...],
|
114
|
-
"_alias": alias,
|
115
|
-
"_id": "..." # ID of the dataset
|
116
|
-
}
|
117
|
-
"""
|
118
|
-
# Make a POST request to the Judgment API to get the dataset
|
119
|
-
|
120
|
-
with Progress(
|
121
|
-
SpinnerColumn(style="rgb(106,0,255)"),
|
122
|
-
TextColumn("[progress.description]{task.description}"),
|
123
|
-
transient=False,
|
124
|
-
) as progress:
|
125
|
-
task_id = progress.add_task(
|
126
|
-
f"Pulling [rgb(106,0,255)]'{alias}'[/rgb(106,0,255)] from Judgment...",
|
127
|
-
total=100,
|
128
|
-
)
|
129
|
-
request_body = {
|
130
|
-
"alias": alias,
|
131
|
-
"judgment_api_key": self.judgment_api_key
|
132
|
-
}
|
133
|
-
|
134
|
-
try:
|
135
|
-
response = requests.post(
|
136
|
-
JUDGMENT_DATASETS_PULL_API_URL,
|
137
|
-
json=request_body
|
138
|
-
)
|
139
|
-
response.raise_for_status()
|
140
|
-
except requests.exceptions.RequestException as e:
|
141
|
-
error(f"Error pulling dataset: {str(e)}")
|
142
|
-
raise
|
143
|
-
|
144
|
-
info(f"Successfully pulled dataset with alias '{alias}'")
|
145
|
-
payload = response.json()
|
146
|
-
self.ground_truths = [GroundTruthExample(**g) for g in payload.get("ground_truths", [])]
|
147
|
-
self.examples = [Example(**e) for e in payload.get("examples", [])]
|
148
|
-
self._alias = payload.get("_alias")
|
149
|
-
self._id = payload.get("_id")
|
150
|
-
progress.update(
|
151
|
-
task_id,
|
152
|
-
description=f"{progress.tasks[task_id].description} [rgb(25,227,160)]Done!)",
|
153
|
-
)
|
154
35
|
|
155
36
|
def add_from_json(self, file_path: str) -> None:
|
156
37
|
debug(f"Loading dataset from JSON file: {file_path}")
|
@@ -402,6 +283,4 @@ class EvalDataset:
|
|
402
283
|
f"_alias={self._alias}, "
|
403
284
|
f"_id={self._id}"
|
404
285
|
f")"
|
405
|
-
)
|
406
|
-
|
407
|
-
|
286
|
+
)
|