judgeval 0.0.40__py3-none-any.whl → 0.0.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/utils.py CHANGED
@@ -12,9 +12,10 @@ NOTE: any function beginning with 'a', e.g. 'afetch_together_api_response', is a
12
12
  import asyncio
13
13
  import concurrent.futures
14
14
  import os
15
+ from types import TracebackType
15
16
  import requests
16
17
  import pprint
17
- from typing import Any, Dict, List, Literal, Mapping, Optional, Union
18
+ from typing import Any, Dict, List, Literal, Mapping, Optional, TypeAlias, Union
18
19
 
19
20
  # Third-party imports
20
21
  import litellm
@@ -102,7 +103,7 @@ def validate_api_key(judgment_api_key: str):
102
103
  Validates that the user api key is valid
103
104
  """
104
105
  response = requests.post(
105
- f"{ROOT_API}/validate_api_key/",
106
+ f"{ROOT_API}/auth/validate_api_key/",
106
107
  headers={
107
108
  "Content-Type": "application/json",
108
109
  "Authorization": f"Bearer {judgment_api_key}",
@@ -782,3 +783,6 @@ if __name__ == "__main__":
782
783
  ]
783
784
  ]
784
785
  ))
786
+
787
+ ExcInfo: TypeAlias = tuple[type[BaseException], BaseException, TracebackType]
788
+ OptExcInfo: TypeAlias = ExcInfo | tuple[None, None, None]
judgeval/constants.py CHANGED
@@ -58,8 +58,13 @@ JUDGMENT_PROJECT_DELETE_API_URL = f"{ROOT_API}/projects/delete/"
58
58
  JUDGMENT_PROJECT_CREATE_API_URL = f"{ROOT_API}/projects/add/"
59
59
  JUDGMENT_TRACES_FETCH_API_URL = f"{ROOT_API}/traces/fetch/"
60
60
  JUDGMENT_TRACES_SAVE_API_URL = f"{ROOT_API}/traces/save/"
61
+ JUDGMENT_TRACES_UPSERT_API_URL = f"{ROOT_API}/traces/upsert/"
62
+ JUDGMENT_TRACES_USAGE_CHECK_API_URL = f"{ROOT_API}/traces/usage/check/"
63
+ JUDGMENT_TRACES_USAGE_UPDATE_API_URL = f"{ROOT_API}/traces/usage/update/"
61
64
  JUDGMENT_TRACES_DELETE_API_URL = f"{ROOT_API}/traces/delete/"
62
65
  JUDGMENT_TRACES_ADD_ANNOTATION_API_URL = f"{ROOT_API}/traces/add_annotation/"
66
+ JUDGMENT_TRACES_SPANS_BATCH_API_URL = f"{ROOT_API}/traces/spans/batch/"
67
+ JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL = f"{ROOT_API}/traces/evaluation_runs/batch/"
63
68
  JUDGMENT_ADD_TO_RUN_EVAL_QUEUE_API_URL = f"{ROOT_API}/add_to_run_eval_queue/"
64
69
  JUDGMENT_GET_EVAL_STATUS_API_URL = f"{ROOT_API}/get_evaluation_status/"
65
70
  # RabbitMQ
@@ -5,14 +5,15 @@ import json
5
5
  import os
6
6
  import yaml
7
7
  from dataclasses import dataclass, field
8
- from typing import List, Union, Literal
8
+ from typing import List, Union, Literal, Optional
9
9
 
10
- from judgeval.data import Example
10
+ from judgeval.data import Example, Trace
11
11
  from judgeval.common.logger import debug, error, warning, info
12
12
 
13
13
  @dataclass
14
14
  class EvalDataset:
15
15
  examples: List[Example]
16
+ traces: List[Trace]
16
17
  _alias: Union[str, None] = field(default=None)
17
18
  _id: Union[str, None] = field(default=None)
18
19
  judgment_api_key: str = field(default="")
@@ -20,12 +21,13 @@ class EvalDataset:
20
21
  def __init__(self,
21
22
  judgment_api_key: str = os.getenv("JUDGMENT_API_KEY"),
22
23
  organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
23
- examples: List[Example] = [],
24
+ examples: Optional[List[Example]] = None,
25
+ traces: Optional[List[Trace]] = None
24
26
  ):
25
- debug(f"Initializing EvalDataset with {len(examples)} examples")
26
27
  if not judgment_api_key:
27
28
  warning("No judgment_api_key provided")
28
- self.examples = examples
29
+ self.examples = examples or []
30
+ self.traces = traces or []
29
31
  self._alias = None
30
32
  self._id = None
31
33
  self.judgment_api_key = judgment_api_key
@@ -218,8 +220,11 @@ class EvalDataset:
218
220
  self.add_example(e)
219
221
 
220
222
  def add_example(self, e: Example) -> None:
221
- self.examples = self.examples + [e]
223
+ self.examples.append(e)
222
224
  # TODO if we need to add rank, then we need to do it here
225
+
226
+ def add_trace(self, t: Trace) -> None:
227
+ self.traces.append(t)
223
228
 
224
229
  def save_as(self, file_type: Literal["json", "csv", "yaml"], dir_path: str, save_name: str = None) -> None:
225
230
  """
@@ -307,6 +312,7 @@ class EvalDataset:
307
312
  return (
308
313
  f"{self.__class__.__name__}("
309
314
  f"examples={self.examples}, "
315
+ f"traces={self.traces}, "
310
316
  f"_alias={self._alias}, "
311
317
  f"_id={self._id}"
312
318
  f")"
@@ -13,7 +13,7 @@ from judgeval.constants import (
13
13
  JUDGMENT_DATASETS_INSERT_API_URL,
14
14
  JUDGMENT_DATASETS_EXPORT_JSONL_API_URL
15
15
  )
16
- from judgeval.data import Example
16
+ from judgeval.data import Example, Trace
17
17
  from judgeval.data.datasets import EvalDataset
18
18
 
19
19
 
@@ -58,6 +58,7 @@ class EvalDatasetClient:
58
58
  "dataset_alias": alias,
59
59
  "project_name": project_name,
60
60
  "examples": [e.to_dict() for e in dataset.examples],
61
+ "traces": [t.model_dump() for t in dataset.traces],
61
62
  "overwrite": overwrite,
62
63
  }
63
64
  try:
@@ -202,6 +203,7 @@ class EvalDatasetClient:
202
203
  info(f"Successfully pulled dataset with alias '{alias}'")
203
204
  payload = response.json()
204
205
  dataset.examples = [Example(**e) for e in payload.get("examples", [])]
206
+ dataset.traces = [Trace(**t) for t in payload.get("traces", [])]
205
207
  dataset._alias = payload.get("alias")
206
208
  dataset._id = payload.get("id")
207
209
  progress.update(
judgeval/data/trace.py CHANGED
@@ -33,6 +33,8 @@ class TraceSpan(BaseModel):
33
33
  additional_metadata: Optional[Dict[str, Any]] = None
34
34
  has_evaluation: Optional[bool] = False
35
35
  agent_name: Optional[str] = None
36
+ state_before: Optional[Dict[str, Any]] = None
37
+ state_after: Optional[Dict[str, Any]] = None
36
38
 
37
39
  def model_dump(self, **kwargs):
38
40
  return {
@@ -50,7 +52,10 @@ class TraceSpan(BaseModel):
50
52
  "span_type": self.span_type,
51
53
  "usage": self.usage.model_dump() if self.usage else None,
52
54
  "has_evaluation": self.has_evaluation,
53
- "agent_name": self.agent_name
55
+ "agent_name": self.agent_name,
56
+ "state_before": self.state_before,
57
+ "state_after": self.state_after,
58
+ "additional_metadata": self._serialize_value(self.additional_metadata)
54
59
  }
55
60
 
56
61
  def print_span(self):
@@ -113,7 +118,7 @@ class Trace(BaseModel):
113
118
  name: str
114
119
  created_at: str
115
120
  duration: float
116
- entries: List[TraceSpan]
121
+ trace_spans: List[TraceSpan]
117
122
  overwrite: bool = False
118
123
  offline_mode: bool = False
119
124
  rules: Optional[Dict[str, Any]] = None
@@ -3,9 +3,11 @@ from uuid import UUID
3
3
  import time
4
4
  import uuid
5
5
  import contextvars # <--- Import contextvars
6
+ from datetime import datetime
6
7
 
7
- from judgeval.common.tracer import TraceClient, TraceSpan, Tracer, SpanType, EvaluationConfig
8
+ from judgeval.common.tracer import TraceClient, TraceSpan, Tracer, SpanType, EvaluationConfig, cost_per_token
8
9
  from judgeval.data import Example # Import Example
10
+ from judgeval.data.trace import TraceUsage
9
11
 
10
12
  from langchain_core.callbacks import BaseCallbackHandler
11
13
  from langchain_core.agents import AgentAction, AgentFinish
@@ -36,18 +38,48 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
36
38
  def __init__(self, tracer: Tracer):
37
39
 
38
40
  self.tracer = tracer
41
+ # Initialize tracking/logging variables (preserved across resets)
42
+ self.executed_nodes: List[str] = []
43
+ self.executed_tools: List[str] = []
44
+ self.executed_node_tools: List[str] = []
45
+ self.traces: List[Dict[str, Any]] = []
46
+ # Initialize execution state (reset between runs)
47
+ self._reset_state()
48
+ # --- END NEW __init__ ---
49
+
50
+ def _reset_state(self):
51
+ """Reset only the critical execution state for reuse across multiple executions"""
52
+ # Reset core execution state that must be cleared between runs
39
53
  self._trace_client: Optional[TraceClient] = None
40
54
  self._run_id_to_span_id: Dict[UUID, str] = {}
41
55
  self._span_id_to_start_time: Dict[str, float] = {}
42
56
  self._span_id_to_depth: Dict[str, int] = {}
43
57
  self._root_run_id: Optional[UUID] = None
44
- self._trace_saved: bool = False # Flag to prevent actions after trace is saved
45
-
46
- self.executed_nodes: List[str] = [] # These last four members are only appended to and never accessed; can probably be removed but still might be useful for future reference?
58
+ self._trace_saved: bool = False
59
+ self.span_id_to_token: Dict[str, Any] = {}
60
+ self.trace_id_to_token: Dict[str, Any] = {}
61
+
62
+ # Add timestamp to track when we last reset
63
+ self._last_reset_time: float = time.time()
64
+
65
+ # Preserve tracking/logging variables across executions:
66
+ # - self.executed_nodes: List[str] = [] # Keep as running log
67
+ # - self.executed_tools: List[str] = [] # Keep as running log
68
+ # - self.executed_node_tools: List[str] = [] # Keep as running log
69
+ # - self.traces: List[Dict[str, Any]] = [] # Keep for collecting multiple traces
70
+
71
+ def reset(self):
72
+ """Public method to manually reset handler execution state for reuse"""
73
+ self._reset_state()
74
+
75
+ def reset_all(self):
76
+ """Public method to reset ALL handler state including tracking/logging data"""
77
+ self._reset_state()
78
+ # Also reset tracking/logging variables
79
+ self.executed_nodes: List[str] = []
47
80
  self.executed_tools: List[str] = []
48
81
  self.executed_node_tools: List[str] = []
49
82
  self.traces: List[Dict[str, Any]] = []
50
- # --- END NEW __init__ ---
51
83
 
52
84
  # --- MODIFIED _ensure_trace_client ---
53
85
  def _ensure_trace_client(self, run_id: UUID, parent_run_id: Optional[UUID], event_name: str) -> Optional[TraceClient]:
@@ -57,6 +89,11 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
57
89
  Returns the client or None.
58
90
  """
59
91
 
92
+ # If this is a potential new root execution (no parent_run_id) and we had a previous trace saved,
93
+ # reset state to allow reuse of the handler
94
+ if parent_run_id is None and self._trace_saved:
95
+ self._reset_state()
96
+
60
97
  # If a client already exists, return it.
61
98
  if self._trace_client:
62
99
  return self._trace_client
@@ -73,11 +110,25 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
73
110
  enable_evaluations=self.tracer.enable_evaluations
74
111
  )
75
112
  self._trace_client = client_instance
113
+ token = self.tracer.set_current_trace(self._trace_client)
114
+ if token:
115
+ self.trace_id_to_token[trace_id] = token
76
116
  if self._trace_client:
77
117
  self._root_run_id = run_id # Assign the first run_id encountered as the tentative root
78
118
  self._trace_saved = False # Ensure flag is reset
79
119
  # Set active client on Tracer (important for potential fallbacks)
80
120
  self.tracer._active_trace_client = self._trace_client
121
+
122
+ # NEW: Initial save for live tracking (follows the new practice)
123
+ try:
124
+ trace_id_saved, server_response = self._trace_client.save_with_rate_limiting(
125
+ overwrite=self._trace_client.overwrite,
126
+ final_save=False # Initial save for live tracking
127
+ )
128
+ except Exception as e:
129
+ import warnings
130
+ warnings.warn(f"Failed to save initial trace for live tracking: {e}")
131
+
81
132
  return self._trace_client
82
133
  else:
83
134
  return None
@@ -112,12 +163,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
112
163
  self._span_id_to_start_time[span_id] = start_time
113
164
  self._span_id_to_depth[span_id] = current_depth
114
165
 
115
-
116
- # --- Set SPAN context variable ONLY for chain (node) spans (Sync version) ---
117
- if span_type == "chain":
118
- self.tracer.set_current_span(span_id)
119
-
120
- new_trace = TraceSpan(
166
+ new_span = TraceSpan(
121
167
  span_id=span_id,
122
168
  trace_id=trace_client.trace_id,
123
169
  parent_span_id=parent_span_id,
@@ -127,9 +173,36 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
127
173
  span_type=span_type
128
174
  )
129
175
 
130
- new_trace.inputs = inputs
131
-
132
- trace_client.add_span(new_trace)
176
+ # Separate metadata from inputs
177
+ if inputs:
178
+ metadata = {}
179
+ clean_inputs = {}
180
+
181
+ # Extract metadata fields
182
+ metadata_fields = ['tags', 'metadata', 'kwargs', 'serialized']
183
+ for field in metadata_fields:
184
+ if field in inputs:
185
+ metadata[field] = inputs.pop(field)
186
+
187
+ # Store the remaining inputs
188
+ clean_inputs = inputs
189
+
190
+ # Set both fields on the span
191
+ new_span.inputs = clean_inputs
192
+ new_span.additional_metadata = metadata
193
+ else:
194
+ new_span.inputs = {}
195
+ new_span.additional_metadata = {}
196
+
197
+ trace_client.add_span(new_span)
198
+
199
+ # Queue span with initial state (input phase) through background service
200
+ if trace_client.background_span_service:
201
+ trace_client.background_span_service.queue_span(new_span, span_state="input")
202
+
203
+ token = self.tracer.set_current_span(span_id)
204
+ if token:
205
+ self.span_id_to_token[span_id] = token
133
206
 
134
207
  def _end_span_tracking(
135
208
  self,
@@ -142,6 +215,8 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
142
215
 
143
216
  # Get span ID and check if it exists
144
217
  span_id = self._run_id_to_span_id.get(run_id)
218
+ token = self.span_id_to_token.pop(span_id, None)
219
+ self.tracer.reset_current_span(token, span_id)
145
220
 
146
221
  start_time = self._span_id_to_start_time.get(span_id) if span_id else None
147
222
  duration = time.time() - start_time if start_time is not None else None
@@ -151,7 +226,38 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
151
226
  trace_span = trace_client.span_id_to_span.get(span_id)
152
227
  if trace_span:
153
228
  trace_span.duration = duration
154
- trace_span.output = error if error else outputs
229
+
230
+ # Handle outputs and error
231
+ if error:
232
+ trace_span.output = error
233
+ elif outputs:
234
+ # Separate metadata from outputs
235
+ metadata = {}
236
+ clean_outputs = {}
237
+
238
+ # Extract metadata fields
239
+ metadata_fields = ['tags', 'kwargs']
240
+ if isinstance(outputs, dict):
241
+ for field in metadata_fields:
242
+ if field in outputs:
243
+ metadata[field] = outputs.pop(field)
244
+
245
+ # Store the remaining outputs
246
+ clean_outputs = outputs
247
+ else:
248
+ clean_outputs = outputs
249
+
250
+ # Set both fields on the span
251
+ trace_span.output = clean_outputs
252
+ if metadata:
253
+ # Merge with existing metadata
254
+ existing_metadata = trace_span.additional_metadata or {}
255
+ trace_span.additional_metadata = {**existing_metadata, **metadata}
256
+
257
+ # Queue span with completed state through background service
258
+ if trace_client.background_span_service:
259
+ span_state = "error" if error else "completed"
260
+ trace_client.background_span_service.queue_span(trace_span, span_state=span_state)
155
261
 
156
262
  # Clean up dictionaries for this specific span
157
263
  if span_id in self._span_id_to_start_time: del self._span_id_to_start_time[span_id]
@@ -165,9 +271,30 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
165
271
  # Reset input storage for this handler instance
166
272
 
167
273
  if self._trace_client and not self._trace_saved: # Check if not already saved
168
- # TODO: Check if trace_client.save needs await if TraceClient becomes async
169
- trace_id, trace_data = self._trace_client.save(overwrite=self._trace_client.overwrite) # Use client's overwrite setting
170
- self.traces.append(trace_data) # Leaving this in for now but can probably be removed
274
+ # Flush background spans before saving the final trace
275
+
276
+ complete_trace_data = {
277
+ "trace_id": self._trace_client.trace_id,
278
+ "name": self._trace_client.name,
279
+ "created_at": datetime.utcfromtimestamp(self._trace_client.start_time).isoformat(),
280
+ "duration": self._trace_client.get_duration(),
281
+ "trace_spans": [span.model_dump() for span in self._trace_client.trace_spans],
282
+ "overwrite": self._trace_client.overwrite,
283
+ "offline_mode": self.tracer.offline_mode,
284
+ "parent_trace_id": self._trace_client.parent_trace_id,
285
+ "parent_name": self._trace_client.parent_name
286
+ }
287
+
288
+ # NEW: Use save_with_rate_limiting with final_save=True for final save
289
+ trace_id, trace_data = self._trace_client.save_with_rate_limiting(
290
+ overwrite=self._trace_client.overwrite,
291
+ final_save=True # Final save with usage counter updates
292
+ )
293
+ token = self.trace_id_to_token.pop(trace_id, None)
294
+ self.tracer.reset_current_trace(token, trace_id)
295
+
296
+ # Store complete trace data instead of server response
297
+ self.tracer.traces.append(complete_trace_data)
171
298
  self._trace_saved = True # Set flag only after successful save
172
299
  finally:
173
300
  # --- NEW: Consolidated Cleanup Logic ---
@@ -254,10 +381,26 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
254
381
  # --- Root node cleanup (Existing logic - slightly modified save call) ---
255
382
  if run_id == self._root_run_id:
256
383
  if trace_client and not self._trace_saved:
257
- # Save might need to be async if TraceClient methods become async
258
- # Pass overwrite=True based on client's setting
259
- trace_id_saved, trace_data = trace_client.save(overwrite=trace_client.overwrite)
260
- self.traces.append(trace_data) # Leaving this in for now but can probably be removed
384
+ # Store complete trace data instead of server response
385
+ complete_trace_data = {
386
+ "trace_id": trace_client.trace_id,
387
+ "name": trace_client.name,
388
+ "created_at": datetime.utcfromtimestamp(trace_client.start_time).isoformat(),
389
+ "duration": trace_client.get_duration(),
390
+ "trace_spans": [span.model_dump() for span in trace_client.trace_spans],
391
+ "overwrite": trace_client.overwrite,
392
+ "offline_mode": self.tracer.offline_mode,
393
+ "parent_trace_id": trace_client.parent_trace_id,
394
+ "parent_name": trace_client.parent_name
395
+ }
396
+ # NEW: Use save_with_rate_limiting with final_save=True for final save
397
+ trace_id_saved, trace_data = trace_client.save_with_rate_limiting(
398
+ overwrite=trace_client.overwrite,
399
+ final_save=True # Final save with usage counter updates
400
+ )
401
+
402
+
403
+ self.tracer.traces.append(complete_trace_data)
261
404
  self._trace_saved = True
262
405
  # Reset tracer's active client *after* successful save
263
406
  if self.tracer._active_trace_client == trace_client:
@@ -333,11 +476,23 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
333
476
  if not trace_client:
334
477
  return
335
478
  outputs = {"response": response, "kwargs": kwargs}
336
- # --- Token Usage Extraction and Accumulation ---
337
- token_usage = None
338
- prompt_tokens = None # Use standard name
339
- completion_tokens = None # Use standard name
479
+
480
+ # --- Token Usage Extraction and Cost Calculation ---
481
+ prompt_tokens = None
482
+ completion_tokens = None
340
483
  total_tokens = None
484
+ model_name = None
485
+
486
+ # Extract model name from response if available
487
+ if hasattr(response, 'llm_output') and response.llm_output and isinstance(response.llm_output, dict):
488
+ model_name = response.llm_output.get('model_name') or response.llm_output.get('model')
489
+
490
+ # Try to get model from the first generation if available
491
+ if not model_name and response.generations and len(response.generations) > 0:
492
+ if hasattr(response.generations[0][0], 'generation_info') and response.generations[0][0].generation_info:
493
+ gen_info = response.generations[0][0].generation_info
494
+ model_name = gen_info.get('model') or gen_info.get('model_name')
495
+
341
496
  if response.llm_output and isinstance(response.llm_output, dict):
342
497
  # Check for OpenAI/standard 'token_usage' first
343
498
  if 'token_usage' in response.llm_output:
@@ -356,14 +511,43 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
356
511
  if prompt_tokens is not None and completion_tokens is not None:
357
512
  total_tokens = prompt_tokens + completion_tokens
358
513
 
359
- # --- Store individual usage in span output and Accumulate ---
514
+ # --- Create TraceUsage object and set on span ---
360
515
  if prompt_tokens is not None or completion_tokens is not None:
361
- # Store individual usage for this span
362
- outputs['usage'] = {
363
- 'prompt_tokens': prompt_tokens,
364
- 'completion_tokens': completion_tokens,
365
- 'total_tokens': total_tokens
366
- }
516
+ # Calculate costs if model name is available
517
+ prompt_cost = None
518
+ completion_cost = None
519
+ total_cost_usd = None
520
+
521
+ if model_name and prompt_tokens is not None and completion_tokens is not None:
522
+ try:
523
+ prompt_cost, completion_cost = cost_per_token(
524
+ model=model_name,
525
+ prompt_tokens=prompt_tokens,
526
+ completion_tokens=completion_tokens
527
+ )
528
+ total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
529
+ except Exception as e:
530
+ # If cost calculation fails, continue without costs
531
+ import warnings
532
+ warnings.warn(f"Failed to calculate token costs for model {model_name}: {e}")
533
+
534
+ # Create TraceUsage object
535
+ usage = TraceUsage(
536
+ prompt_tokens=prompt_tokens,
537
+ completion_tokens=completion_tokens,
538
+ total_tokens=total_tokens or (prompt_tokens + completion_tokens if prompt_tokens and completion_tokens else None),
539
+ prompt_tokens_cost_usd=prompt_cost,
540
+ completion_tokens_cost_usd=completion_cost,
541
+ total_cost_usd=total_cost_usd,
542
+ model_name=model_name
543
+ )
544
+
545
+ # Set usage on the actual span (not in outputs)
546
+ span_id = self._run_id_to_span_id.get(run_id)
547
+ if span_id and span_id in trace_client.span_id_to_span:
548
+ trace_span = trace_client.span_id_to_span[span_id]
549
+ trace_span.usage = usage
550
+
367
551
 
368
552
  self._end_span_tracking(trace_client, run_id, outputs=outputs)
369
553
  # --- End Token Usage ---
@@ -416,4 +600,4 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
416
600
  if not trace_client: return
417
601
 
418
602
  outputs = {'return_values': finish.return_values, 'log': finish.log, 'messages': finish.messages, 'kwargs': kwargs}
419
- self._end_span_tracking(trace_client, run_id, outputs=outputs)
603
+ self._end_span_tracking(trace_client, run_id, outputs=outputs)
@@ -63,7 +63,15 @@ class SingletonMeta(type):
63
63
  return cls._instances[cls]
64
64
 
65
65
  class JudgmentClient(metaclass=SingletonMeta):
66
- def __init__(self, judgment_api_key: str = os.getenv("JUDGMENT_API_KEY"), organization_id: str = os.getenv("JUDGMENT_ORG_ID")):
66
+ def __init__(self, judgment_api_key: Optional[str] = os.getenv("JUDGMENT_API_KEY"), organization_id: Optional[str] = os.getenv("JUDGMENT_ORG_ID")):
67
+ # Check if API key is None
68
+ if judgment_api_key is None:
69
+ raise ValueError("JUDGMENT_API_KEY cannot be None. Please provide a valid API key or set the JUDGMENT_API_KEY environment variable.")
70
+
71
+ # Check if organization ID is None
72
+ if organization_id is None:
73
+ raise ValueError("JUDGMENT_ORG_ID cannot be None. Please provide a valid organization ID or set the JUDGMENT_ORG_ID environment variable.")
74
+
67
75
  self.judgment_api_key = judgment_api_key
68
76
  self.organization_id = organization_id
69
77
  self.eval_dataset_client = EvalDatasetClient(judgment_api_key, organization_id)