judgeval 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -2,13 +2,26 @@
2
2
  Tracing system for judgeval that allows for function tracing using decorators.
3
3
  """
4
4
 
5
+ import os
5
6
  import time
6
7
  import functools
7
8
  import requests
8
9
  import uuid
9
10
  from contextlib import contextmanager
10
- from typing import Optional, Any, List, Literal, Tuple, Generator, TypeAlias, Union
11
- from dataclasses import dataclass, field
11
+ from typing import (
12
+ Optional,
13
+ Any,
14
+ List,
15
+ Literal,
16
+ Tuple,
17
+ Generator,
18
+ TypeAlias,
19
+ Union
20
+ )
21
+ from dataclasses import (
22
+ dataclass,
23
+ field
24
+ )
12
25
  from datetime import datetime
13
26
  from openai import OpenAI
14
27
  from together import Together
@@ -21,16 +34,25 @@ import warnings
21
34
  from pydantic import BaseModel
22
35
  from http import HTTPStatus
23
36
 
24
- from judgeval.constants import JUDGMENT_TRACES_SAVE_API_URL
37
+ import pika
38
+ import os
39
+
40
+ from judgeval.constants import JUDGMENT_TRACES_SAVE_API_URL, JUDGMENT_TRACES_FETCH_API_URL, RABBITMQ_HOST, RABBITMQ_PORT, RABBITMQ_QUEUE, JUDGMENT_TRACES_DELETE_API_URL
25
41
  from judgeval.judgment_client import JudgmentClient
26
42
  from judgeval.data import Example
27
- from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
43
+ from judgeval.scorers import APIJudgmentScorer, JudgevalScorer, ScorerWrapper
44
+
45
+ from rich import print as rprint
46
+
28
47
  from judgeval.data.result import ScoringResult
48
+ from judgeval.evaluation_run import EvaluationRun
29
49
 
30
50
  # Define type aliases for better code readability and maintainability
31
51
  ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
32
52
  TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
33
53
  SpanType = Literal['span', 'tool', 'llm', 'evaluation']
54
+
55
+
34
56
  @dataclass
35
57
  class TraceEntry:
36
58
  """Represents a single trace entry with its visual representation.
@@ -52,7 +74,7 @@ class TraceEntry:
52
74
  # Use field() for mutable defaults to avoid shared state issues
53
75
  inputs: dict = field(default_factory=dict)
54
76
  span_type: SpanType = "span"
55
- evaluation_result: Optional[List[ScoringResult]] = field(default=None)
77
+ evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
56
78
 
57
79
  def print_entry(self):
58
80
  indent = " " * self.depth
@@ -65,7 +87,8 @@ class TraceEntry:
65
87
  elif self.type == "input":
66
88
  print(f"{indent}Input: {self.inputs}")
67
89
  elif self.type == "evaluation":
68
- print(f"{indent}Evaluation: {self.evaluation_result} ({self.duration:.3f}s)")
90
+ for evaluation_run in self.evaluation_runs:
91
+ print(f"{indent}Evaluation: {evaluation_run.model_dump()}")
69
92
 
70
93
  def _serialize_inputs(self) -> dict:
71
94
  """Helper method to serialize input data safely.
@@ -112,7 +135,7 @@ class TraceEntry:
112
135
  "duration": self.duration,
113
136
  "output": self._serialize_output(),
114
137
  "inputs": self._serialize_inputs(),
115
- "evaluation_result": [result.to_dict() for result in self.evaluation_result] if self.evaluation_result else None,
138
+ "evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
116
139
  "span_type": self.span_type
117
140
  }
118
141
 
@@ -121,8 +144,29 @@ class TraceEntry:
121
144
 
122
145
  Handles special cases:
123
146
  - Pydantic models are converted using model_dump()
147
+ - We try to serialize into JSON, then string, then the base representation (__repr__)
124
148
  - Non-serializable objects return None with a warning
125
149
  """
150
+
151
+ def safe_stringify(output, function_name):
152
+ """
153
+ Safely converts an object to a string or repr, handling serialization issues gracefully.
154
+ """
155
+ try:
156
+ return str(output)
157
+ except (TypeError, OverflowError, ValueError):
158
+ pass
159
+
160
+ try:
161
+ return repr(output)
162
+ except (TypeError, OverflowError, ValueError):
163
+ pass
164
+
165
+ warnings.warn(
166
+ f"Output for function {function_name} is not JSON serializable and could not be converted to string. Setting to None."
167
+ )
168
+ return None
169
+
126
170
  if isinstance(self.output, BaseModel):
127
171
  return self.output.model_dump()
128
172
 
@@ -131,8 +175,107 @@ class TraceEntry:
131
175
  json.dumps(self.output)
132
176
  return self.output
133
177
  except (TypeError, OverflowError, ValueError):
134
- warnings.warn(f"Output for function {self.function} is not JSON serializable. Setting to None.")
135
- return None
178
+ return safe_stringify(self.output, self.function)
179
+
180
+
181
+ class TraceManagerClient:
182
+ """
183
+ Client for handling trace endpoints with the Judgment API
184
+
185
+
186
+ Operations include:
187
+ - Fetching a trace by id
188
+ - Saving a trace
189
+ - Deleting a trace
190
+ """
191
+ def __init__(self, judgment_api_key: str):
192
+ self.judgment_api_key = judgment_api_key
193
+
194
+ def fetch_trace(self, trace_id: str):
195
+ """
196
+ Fetch a trace by its id
197
+ """
198
+ response = requests.post(
199
+ JUDGMENT_TRACES_FETCH_API_URL,
200
+ json={
201
+ "trace_id": trace_id,
202
+ "judgment_api_key": self.judgment_api_key,
203
+ },
204
+ headers={
205
+ "Content-Type": "application/json",
206
+ }
207
+ )
208
+
209
+ if response.status_code != HTTPStatus.OK:
210
+ raise ValueError(f"Failed to fetch traces: {response.text}")
211
+
212
+ return response.json()
213
+
214
+ def save_trace(self, trace_data: dict, empty_save: bool):
215
+ """
216
+ Saves a trace to the database
217
+
218
+ Args:
219
+ trace_data: The trace data to save
220
+ empty_save: Whether to save an empty trace
221
+ NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
222
+ """
223
+ response = requests.post(
224
+ JUDGMENT_TRACES_SAVE_API_URL,
225
+ json=trace_data,
226
+ headers={
227
+ "Content-Type": "application/json",
228
+ }
229
+ )
230
+
231
+ if response.status_code == HTTPStatus.BAD_REQUEST:
232
+ raise ValueError(f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}")
233
+ elif response.status_code != HTTPStatus.OK:
234
+ raise ValueError(f"Failed to save trace data: {response.text}")
235
+
236
+ if not empty_save and "ui_results_url" in response.json():
237
+ rprint(f"\n🔍 You can view your trace data here: [rgb(106,0,255)]{response.json()['ui_results_url']}[/]\n")
238
+
239
+ def delete_trace(self, trace_id: str):
240
+ """
241
+ Delete a trace from the database.
242
+ """
243
+ response = requests.delete(
244
+ JUDGMENT_TRACES_DELETE_API_URL,
245
+ json={
246
+ "judgment_api_key": self.judgment_api_key,
247
+ "trace_ids": [trace_id],
248
+ },
249
+ headers={
250
+ "Content-Type": "application/json",
251
+ }
252
+ )
253
+
254
+ if response.status_code != HTTPStatus.OK:
255
+ raise ValueError(f"Failed to delete trace: {response.text}")
256
+
257
+ return response.json()
258
+
259
+ def delete_traces(self, trace_ids: List[str]):
260
+ """
261
+ Delete a batch of traces from the database.
262
+ """
263
+ response = requests.delete(
264
+ JUDGMENT_TRACES_DELETE_API_URL,
265
+ json={
266
+ "judgment_api_key": self.judgment_api_key,
267
+ "trace_ids": trace_ids,
268
+ },
269
+ headers={
270
+ "Content-Type": "application/json",
271
+ }
272
+ )
273
+
274
+ if response.status_code != HTTPStatus.OK:
275
+ raise ValueError(f"Failed to delete trace: {response.text}")
276
+
277
+ return response.json()
278
+
136
279
 
137
280
  class TraceClient:
138
281
  """Client for managing a single trace context"""
@@ -147,6 +290,7 @@ class TraceClient:
147
290
  self.span_type = None
148
291
  self._current_span: Optional[TraceEntry] = None
149
292
  self.overwrite = overwrite
293
+ self.trace_manager_client = TraceManagerClient(tracer.api_key) # Manages DB operations for trace data
150
294
 
151
295
  @contextmanager
152
296
  def span(self, name: str, span_type: SpanType = "span"):
@@ -163,6 +307,7 @@ class TraceClient:
163
307
  span_type=span_type
164
308
  ))
165
309
 
310
+ # Increment nested depth and set current span
166
311
  self.tracer.depth += 1
167
312
  prev_span = self._current_span
168
313
  self._current_span = name
@@ -185,7 +330,7 @@ class TraceClient:
185
330
  ))
186
331
  self._current_span = prev_span
187
332
 
188
- async def async_evaluate(
333
+ def async_evaluate(
189
334
  self,
190
335
  scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
191
336
  input: Optional[str] = None,
@@ -211,25 +356,40 @@ class TraceClient:
211
356
  additional_metadata=additional_metadata,
212
357
  trace_id=self.trace_id
213
358
  )
214
- scoring_results = self.client.run_evaluation(
215
- examples=[example],
216
- scorers=scorers,
217
- model=model,
218
- metadata={},
359
+
360
+ try:
361
+ # Load appropriate implementations for all scorers
362
+ loaded_scorers: List[Union[JudgevalScorer, APIJudgmentScorer]] = [
363
+ scorer.load_implementation(use_judgment=True) if isinstance(scorer, ScorerWrapper) else scorer
364
+ for scorer in scorers
365
+ ]
366
+ except Exception as e:
367
+ raise ValueError(f"Failed to load scorers: {str(e)}")
368
+
369
+ eval_run = EvaluationRun(
219
370
  log_results=log_results,
220
371
  project_name=self.project_name,
221
- eval_run_name=(
222
- f"{self.name.capitalize()}-"
372
+ eval_name=f"{self.name.capitalize()}-"
223
373
  f"{self._current_span}-"
224
- f"[{','.join(scorer.load_implementation().score_type.capitalize() for scorer in scorers)}]"
225
- ),
374
+ f"[{','.join(scorer.load_implementation().score_type.capitalize() for scorer in scorers)}]",
375
+ examples=[example],
376
+ scorers=loaded_scorers,
377
+ model=model,
378
+ metadata={},
379
+ judgment_api_key=self.tracer.api_key,
226
380
  override=self.overwrite
227
381
  )
228
382
 
229
- self.record_evaluation(scoring_results, start_time) # Pass start_time to record_evaluation
383
+ self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
230
384
 
231
- def record_evaluation(self, results: List[ScoringResult], start_time: float):
232
- """Record evaluation results for the current span"""
385
+ def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
386
+ """
387
+ Add evaluation run data to the trace
388
+
389
+ Args:
390
+ eval_run (EvaluationRun): The evaluation run to add to the trace
391
+ start_time (float): The start time of the evaluation run
392
+ """
233
393
  if self._current_span:
234
394
  duration = time.time() - start_time # Calculate duration from start_time
235
395
 
@@ -239,7 +399,7 @@ class TraceClient:
239
399
  depth=self.tracer.depth,
240
400
  message=f"Evaluation results for {self._current_span}",
241
401
  timestamp=time.time(),
242
- evaluation_result=results,
402
+ evaluation_runs=[eval_run],
243
403
  duration=duration,
244
404
  span_type="evaluation"
245
405
  ))
@@ -320,7 +480,7 @@ class TraceClient:
320
480
  "timestamp": entry["timestamp"],
321
481
  "inputs": None,
322
482
  "output": None,
323
- "evaluation_result": None,
483
+ "evaluation_runs": [],
324
484
  "span_type": entry.get("span_type", "span")
325
485
  }
326
486
  active_functions.append(function)
@@ -343,8 +503,8 @@ class TraceClient:
343
503
  if entry["type"] == "output" and entry["output"]:
344
504
  current_entry["output"] = entry["output"]
345
505
 
346
- if entry["type"] == "evaluation" and entry["evaluation_result"]:
347
- current_entry["evaluation_result"] = entry["evaluation_result"]
506
+ if entry["type"] == "evaluation" and entry["evaluation_runs"]:
507
+ current_entry["evaluation_runs"] = entry["evaluation_runs"]
348
508
 
349
509
  # Sort by timestamp
350
510
  condensed.sort(key=lambda x: x["timestamp"])
@@ -361,6 +521,24 @@ class TraceClient:
361
521
  raw_entries = [entry.to_dict() for entry in self.entries]
362
522
  condensed_entries = self.condense_trace(raw_entries)
363
523
 
524
+ # Calculate total token counts from LLM API calls
525
+ total_prompt_tokens = 0
526
+ total_completion_tokens = 0
527
+ total_tokens = 0
528
+
529
+ for entry in condensed_entries:
530
+ if entry.get("span_type") == "llm" and isinstance(entry.get("output"), dict):
531
+ usage = entry["output"].get("usage", {})
532
+ # Handle OpenAI/Together format
533
+ if "prompt_tokens" in usage:
534
+ total_prompt_tokens += usage.get("prompt_tokens", 0)
535
+ total_completion_tokens += usage.get("completion_tokens", 0)
536
+ # Handle Anthropic format
537
+ elif "input_tokens" in usage:
538
+ total_prompt_tokens += usage.get("input_tokens", 0)
539
+ total_completion_tokens += usage.get("output_tokens", 0)
540
+ total_tokens += usage.get("total_tokens", 0)
541
+
364
542
  # Create trace document
365
543
  trace_data = {
366
544
  "trace_id": self.trace_id,
@@ -370,31 +548,38 @@ class TraceClient:
370
548
  "created_at": datetime.fromtimestamp(self.start_time).isoformat(),
371
549
  "duration": total_duration,
372
550
  "token_counts": {
373
- "prompt_tokens": 0, # Dummy value
374
- "completion_tokens": 0, # Dummy value
375
- "total_tokens": 0, # Dummy value
376
- }, # TODO: Add token counts
551
+ "prompt_tokens": total_prompt_tokens,
552
+ "completion_tokens": total_completion_tokens,
553
+ "total_tokens": total_tokens,
554
+ },
377
555
  "entries": condensed_entries,
378
556
  "empty_save": empty_save,
379
557
  "overwrite": overwrite
380
558
  }
381
-
382
- # Save trace data by making POST request to API
383
- response = requests.post(
384
- JUDGMENT_TRACES_SAVE_API_URL,
385
- json=trace_data,
386
- headers={
387
- "Content-Type": "application/json",
388
- }
389
- )
390
559
 
391
- if response.status_code == HTTPStatus.BAD_REQUEST:
392
- raise ValueError(f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}")
393
- elif response.status_code != HTTPStatus.OK:
394
- raise ValueError(f"Failed to save trace data: {response.text}")
560
+ if not empty_save:
561
+ connection = pika.BlockingConnection(
562
+ pika.ConnectionParameters(host=RABBITMQ_HOST, port=RABBITMQ_PORT))
563
+ channel = connection.channel()
564
+
565
+ channel.queue_declare(queue=RABBITMQ_QUEUE, durable=True)
566
+
567
+ channel.basic_publish(
568
+ exchange='',
569
+ routing_key=RABBITMQ_QUEUE,
570
+ body=json.dumps(trace_data),
571
+ properties=pika.BasicProperties(
572
+ delivery_mode=pika.DeliveryMode.Transient # Changed from Persistent to Transient
573
+ ))
574
+ connection.close()
395
575
 
576
+ self.trace_manager_client.save_trace(trace_data, empty_save)
577
+
396
578
  return self.trace_id, trace_data
397
579
 
580
+ def delete(self):
581
+ return self.trace_manager_client.delete_trace(self.trace_id)
582
+
398
583
  class Tracer:
399
584
  _instance = None
400
585
 
@@ -403,7 +588,7 @@ class Tracer:
403
588
  cls._instance = super(Tracer, cls).__new__(cls)
404
589
  return cls._instance
405
590
 
406
- def __init__(self, api_key: str):
591
+ def __init__(self, api_key: str = os.getenv("JUDGMENT_API_KEY")):
407
592
  if not hasattr(self, 'initialized'):
408
593
 
409
594
  if not api_key:
judgeval/constants.py CHANGED
@@ -32,15 +32,25 @@ class APIScorer(str, Enum):
32
32
  return member
33
33
 
34
34
  ROOT_API = os.getenv("JUDGMENT_API_URL", "https://api.judgmentlabs.ai")
35
- ## API URLs
35
+ # API URLs
36
36
  JUDGMENT_EVAL_API_URL = f"{ROOT_API}/evaluate/"
37
37
  JUDGMENT_DATASETS_PUSH_API_URL = f"{ROOT_API}/datasets/push/"
38
38
  JUDGMENT_DATASETS_PULL_API_URL = f"{ROOT_API}/datasets/pull/"
39
+ JUDGMENT_DATASETS_PULL_ALL_API_URL = f"{ROOT_API}/datasets/get_all_stats/"
39
40
  JUDGMENT_EVAL_LOG_API_URL = f"{ROOT_API}/log_eval_results/"
40
41
  JUDGMENT_EVAL_FETCH_API_URL = f"{ROOT_API}/fetch_eval_results/"
42
+ JUDGMENT_EVAL_DELETE_API_URL = f"{ROOT_API}/delete_eval_results_by_project_and_run_name/"
43
+ JUDGMENT_EVAL_DELETE_PROJECT_API_URL = f"{ROOT_API}/delete_eval_results_by_project/"
44
+ JUDGMENT_TRACES_FETCH_API_URL = f"{ROOT_API}/traces/fetch/"
41
45
  JUDGMENT_TRACES_SAVE_API_URL = f"{ROOT_API}/traces/save/"
46
+ JUDGMENT_TRACES_DELETE_API_URL = f"{ROOT_API}/traces/delete/"
42
47
 
43
- ## Models
48
+ # RabbitMQ
49
+ RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", "rabbitmq-networklb-faa155df16ec9085.elb.us-west-1.amazonaws.com")
50
+ RABBITMQ_PORT = os.getenv("RABBITMQ_PORT", 5672)
51
+ RABBITMQ_QUEUE = os.getenv("RABBITMQ_QUEUE", "task_queue")
52
+
53
+ # Models
44
54
  TOGETHER_SUPPORTED_MODELS = {
45
55
  "QWEN": "Qwen/Qwen2-72B-Instruct",
46
56
  "LLAMA3_70B_INSTRUCT_TURBO": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
@@ -50,7 +60,9 @@ TOGETHER_SUPPORTED_MODELS = {
50
60
  "MISTRAL_8x7B_INSTRUCT": "mistralai/Mixtral-8x7B-Instruct-v0.1",
51
61
  }
52
62
 
53
- ACCEPTABLE_MODELS = set(litellm.model_list) | set(TOGETHER_SUPPORTED_MODELS.keys())
63
+ JUDGMENT_SUPPORTED_MODELS = {"osiris-large", "osiris-mini"}
64
+
65
+ ACCEPTABLE_MODELS = set(litellm.model_list) | set(TOGETHER_SUPPORTED_MODELS.keys()) | JUDGMENT_SUPPORTED_MODELS
54
66
 
55
67
  ## System settings
56
68
  MAX_WORKER_THREADS = 10
@@ -1,4 +1,5 @@
1
1
  from judgeval.data.datasets.dataset import EvalDataset
2
2
  from judgeval.data.datasets.ground_truth import GroundTruthExample
3
+ from judgeval.data.datasets.eval_dataset_client import EvalDatasetClient
3
4
 
4
- __all__ = ["EvalDataset", "GroundTruthExample"]
5
+ __all__ = ["EvalDataset", "EvalDatasetClient", "GroundTruthExample"]
@@ -2,16 +2,11 @@ import ast
2
2
  import csv
3
3
  import datetime
4
4
  import json
5
- from rich.console import Console
6
- from rich.progress import Progress, SpinnerColumn, TextColumn
7
- import requests
8
5
  from dataclasses import dataclass, field
9
6
  import os
10
7
  from typing import List, Optional, Union, Literal
11
8
 
12
- from judgeval.constants import JUDGMENT_DATASETS_PUSH_API_URL, JUDGMENT_DATASETS_PULL_API_URL
13
9
  from judgeval.data.datasets.ground_truth import GroundTruthExample
14
- from judgeval.data.datasets.utils import ground_truths_to_examples, examples_to_ground_truths
15
10
  from judgeval.data import Example
16
11
  from judgeval.common.logger import debug, error, warning, info
17
12
 
@@ -37,120 +32,6 @@ class EvalDataset:
37
32
  self._id = None
38
33
  self.judgment_api_key = judgment_api_key
39
34
 
40
- def push(self, alias: str, overwrite: Optional[bool] = False) -> bool:
41
- debug(f"Pushing dataset with alias '{alias}' (overwrite={overwrite})")
42
- if overwrite:
43
- warning(f"Overwrite enabled for alias '{alias}'")
44
- """
45
- Pushes the dataset to Judgment platform
46
-
47
- Mock request:
48
- {
49
- "alias": alias,
50
- "ground_truths": [...],
51
- "examples": [...],
52
- "overwrite": overwrite
53
- } ==>
54
- {
55
- "_alias": alias,
56
- "_id": "..." # ID of the dataset
57
- }
58
- """
59
- with Progress(
60
- SpinnerColumn(style="rgb(106,0,255)"),
61
- TextColumn("[progress.description]{task.description}"),
62
- transient=False,
63
- ) as progress:
64
- task_id = progress.add_task(
65
- f"Pushing [rgb(106,0,255)]'{alias}' to Judgment...",
66
- total=100,
67
- )
68
- content = {
69
- "alias": alias,
70
- "ground_truths": [g.to_dict() for g in self.ground_truths],
71
- "examples": [e.to_dict() for e in self.examples],
72
- "overwrite": overwrite,
73
- "judgment_api_key": self.judgment_api_key
74
- }
75
- try:
76
- response = requests.post(
77
- JUDGMENT_DATASETS_PUSH_API_URL,
78
- json=content
79
- )
80
- if response.status_code == 500:
81
- error(f"Server error during push: {content.get('message')}")
82
- return False
83
- response.raise_for_status()
84
- except requests.exceptions.HTTPError as err:
85
- if response.status_code == 422:
86
- error(f"Validation error during push: {err.response.json()}")
87
- else:
88
- error(f"HTTP error during push: {err}")
89
-
90
- info(f"Successfully pushed dataset with alias '{alias}'")
91
- payload = response.json()
92
- self._alias = payload.get("_alias")
93
- self._id = payload.get("_id")
94
- progress.update(
95
- task_id,
96
- description=f"{progress.tasks[task_id].description} [rgb(25,227,160)]Done!)",
97
- )
98
- return True
99
-
100
- def pull(self, alias: str):
101
- debug(f"Pulling dataset with alias '{alias}'")
102
- """
103
- Pulls the dataset from Judgment platform
104
-
105
- Mock request:
106
- {
107
- "alias": alias,
108
- "user_id": user_id
109
- }
110
- ==>
111
- {
112
- "ground_truths": [...],
113
- "examples": [...],
114
- "_alias": alias,
115
- "_id": "..." # ID of the dataset
116
- }
117
- """
118
- # Make a POST request to the Judgment API to get the dataset
119
-
120
- with Progress(
121
- SpinnerColumn(style="rgb(106,0,255)"),
122
- TextColumn("[progress.description]{task.description}"),
123
- transient=False,
124
- ) as progress:
125
- task_id = progress.add_task(
126
- f"Pulling [rgb(106,0,255)]'{alias}'[/rgb(106,0,255)] from Judgment...",
127
- total=100,
128
- )
129
- request_body = {
130
- "alias": alias,
131
- "judgment_api_key": self.judgment_api_key
132
- }
133
-
134
- try:
135
- response = requests.post(
136
- JUDGMENT_DATASETS_PULL_API_URL,
137
- json=request_body
138
- )
139
- response.raise_for_status()
140
- except requests.exceptions.RequestException as e:
141
- error(f"Error pulling dataset: {str(e)}")
142
- raise
143
-
144
- info(f"Successfully pulled dataset with alias '{alias}'")
145
- payload = response.json()
146
- self.ground_truths = [GroundTruthExample(**g) for g in payload.get("ground_truths", [])]
147
- self.examples = [Example(**e) for e in payload.get("examples", [])]
148
- self._alias = payload.get("_alias")
149
- self._id = payload.get("_id")
150
- progress.update(
151
- task_id,
152
- description=f"{progress.tasks[task_id].description} [rgb(25,227,160)]Done!)",
153
- )
154
35
 
155
36
  def add_from_json(self, file_path: str) -> None:
156
37
  debug(f"Loading dataset from JSON file: {file_path}")
@@ -402,6 +283,4 @@ class EvalDataset:
402
283
  f"_alias={self._alias}, "
403
284
  f"_id={self._id}"
404
285
  f")"
405
- )
406
-
407
-
286
+ )