judgeval 0.0.40__py3-none-any.whl → 0.0.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/s3_storage.py +3 -1
- judgeval/common/tracer.py +1079 -139
- judgeval/common/utils.py +6 -2
- judgeval/constants.py +5 -0
- judgeval/data/datasets/dataset.py +12 -6
- judgeval/data/datasets/eval_dataset_client.py +3 -1
- judgeval/data/trace.py +7 -2
- judgeval/integrations/langgraph.py +218 -34
- judgeval/judgment_client.py +9 -1
- judgeval/rules.py +60 -50
- judgeval/run_evaluation.py +53 -29
- judgeval/scorers/judgeval_scorer.py +4 -1
- judgeval/scorers/prompt_scorer.py +3 -0
- judgeval/utils/alerts.py +8 -0
- {judgeval-0.0.40.dist-info → judgeval-0.0.42.dist-info}/METADATA +48 -50
- {judgeval-0.0.40.dist-info → judgeval-0.0.42.dist-info}/RECORD +18 -18
- {judgeval-0.0.40.dist-info → judgeval-0.0.42.dist-info}/WHEEL +0 -0
- {judgeval-0.0.40.dist-info → judgeval-0.0.42.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/utils.py
CHANGED
@@ -12,9 +12,10 @@ NOTE: any function beginning with 'a', e.g. 'afetch_together_api_response', is a
|
|
12
12
|
import asyncio
|
13
13
|
import concurrent.futures
|
14
14
|
import os
|
15
|
+
from types import TracebackType
|
15
16
|
import requests
|
16
17
|
import pprint
|
17
|
-
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
|
18
|
+
from typing import Any, Dict, List, Literal, Mapping, Optional, TypeAlias, Union
|
18
19
|
|
19
20
|
# Third-party imports
|
20
21
|
import litellm
|
@@ -102,7 +103,7 @@ def validate_api_key(judgment_api_key: str):
|
|
102
103
|
Validates that the user api key is valid
|
103
104
|
"""
|
104
105
|
response = requests.post(
|
105
|
-
f"{ROOT_API}/validate_api_key/",
|
106
|
+
f"{ROOT_API}/auth/validate_api_key/",
|
106
107
|
headers={
|
107
108
|
"Content-Type": "application/json",
|
108
109
|
"Authorization": f"Bearer {judgment_api_key}",
|
@@ -782,3 +783,6 @@ if __name__ == "__main__":
|
|
782
783
|
]
|
783
784
|
]
|
784
785
|
))
|
786
|
+
|
787
|
+
ExcInfo: TypeAlias = tuple[type[BaseException], BaseException, TracebackType]
|
788
|
+
OptExcInfo: TypeAlias = ExcInfo | tuple[None, None, None]
|
judgeval/constants.py
CHANGED
@@ -58,8 +58,13 @@ JUDGMENT_PROJECT_DELETE_API_URL = f"{ROOT_API}/projects/delete/"
|
|
58
58
|
JUDGMENT_PROJECT_CREATE_API_URL = f"{ROOT_API}/projects/add/"
|
59
59
|
JUDGMENT_TRACES_FETCH_API_URL = f"{ROOT_API}/traces/fetch/"
|
60
60
|
JUDGMENT_TRACES_SAVE_API_URL = f"{ROOT_API}/traces/save/"
|
61
|
+
JUDGMENT_TRACES_UPSERT_API_URL = f"{ROOT_API}/traces/upsert/"
|
62
|
+
JUDGMENT_TRACES_USAGE_CHECK_API_URL = f"{ROOT_API}/traces/usage/check/"
|
63
|
+
JUDGMENT_TRACES_USAGE_UPDATE_API_URL = f"{ROOT_API}/traces/usage/update/"
|
61
64
|
JUDGMENT_TRACES_DELETE_API_URL = f"{ROOT_API}/traces/delete/"
|
62
65
|
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL = f"{ROOT_API}/traces/add_annotation/"
|
66
|
+
JUDGMENT_TRACES_SPANS_BATCH_API_URL = f"{ROOT_API}/traces/spans/batch/"
|
67
|
+
JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL = f"{ROOT_API}/traces/evaluation_runs/batch/"
|
63
68
|
JUDGMENT_ADD_TO_RUN_EVAL_QUEUE_API_URL = f"{ROOT_API}/add_to_run_eval_queue/"
|
64
69
|
JUDGMENT_GET_EVAL_STATUS_API_URL = f"{ROOT_API}/get_evaluation_status/"
|
65
70
|
# RabbitMQ
|
@@ -5,14 +5,15 @@ import json
|
|
5
5
|
import os
|
6
6
|
import yaml
|
7
7
|
from dataclasses import dataclass, field
|
8
|
-
from typing import List, Union, Literal
|
8
|
+
from typing import List, Union, Literal, Optional
|
9
9
|
|
10
|
-
from judgeval.data import Example
|
10
|
+
from judgeval.data import Example, Trace
|
11
11
|
from judgeval.common.logger import debug, error, warning, info
|
12
12
|
|
13
13
|
@dataclass
|
14
14
|
class EvalDataset:
|
15
15
|
examples: List[Example]
|
16
|
+
traces: List[Trace]
|
16
17
|
_alias: Union[str, None] = field(default=None)
|
17
18
|
_id: Union[str, None] = field(default=None)
|
18
19
|
judgment_api_key: str = field(default="")
|
@@ -20,12 +21,13 @@ class EvalDataset:
|
|
20
21
|
def __init__(self,
|
21
22
|
judgment_api_key: str = os.getenv("JUDGMENT_API_KEY"),
|
22
23
|
organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
|
23
|
-
examples: List[Example] =
|
24
|
+
examples: Optional[List[Example]] = None,
|
25
|
+
traces: Optional[List[Trace]] = None
|
24
26
|
):
|
25
|
-
debug(f"Initializing EvalDataset with {len(examples)} examples")
|
26
27
|
if not judgment_api_key:
|
27
28
|
warning("No judgment_api_key provided")
|
28
|
-
self.examples = examples
|
29
|
+
self.examples = examples or []
|
30
|
+
self.traces = traces or []
|
29
31
|
self._alias = None
|
30
32
|
self._id = None
|
31
33
|
self.judgment_api_key = judgment_api_key
|
@@ -218,8 +220,11 @@ class EvalDataset:
|
|
218
220
|
self.add_example(e)
|
219
221
|
|
220
222
|
def add_example(self, e: Example) -> None:
|
221
|
-
self.examples
|
223
|
+
self.examples.append(e)
|
222
224
|
# TODO if we need to add rank, then we need to do it here
|
225
|
+
|
226
|
+
def add_trace(self, t: Trace) -> None:
|
227
|
+
self.traces.append(t)
|
223
228
|
|
224
229
|
def save_as(self, file_type: Literal["json", "csv", "yaml"], dir_path: str, save_name: str = None) -> None:
|
225
230
|
"""
|
@@ -307,6 +312,7 @@ class EvalDataset:
|
|
307
312
|
return (
|
308
313
|
f"{self.__class__.__name__}("
|
309
314
|
f"examples={self.examples}, "
|
315
|
+
f"traces={self.traces}, "
|
310
316
|
f"_alias={self._alias}, "
|
311
317
|
f"_id={self._id}"
|
312
318
|
f")"
|
@@ -13,7 +13,7 @@ from judgeval.constants import (
|
|
13
13
|
JUDGMENT_DATASETS_INSERT_API_URL,
|
14
14
|
JUDGMENT_DATASETS_EXPORT_JSONL_API_URL
|
15
15
|
)
|
16
|
-
from judgeval.data import Example
|
16
|
+
from judgeval.data import Example, Trace
|
17
17
|
from judgeval.data.datasets import EvalDataset
|
18
18
|
|
19
19
|
|
@@ -58,6 +58,7 @@ class EvalDatasetClient:
|
|
58
58
|
"dataset_alias": alias,
|
59
59
|
"project_name": project_name,
|
60
60
|
"examples": [e.to_dict() for e in dataset.examples],
|
61
|
+
"traces": [t.model_dump() for t in dataset.traces],
|
61
62
|
"overwrite": overwrite,
|
62
63
|
}
|
63
64
|
try:
|
@@ -202,6 +203,7 @@ class EvalDatasetClient:
|
|
202
203
|
info(f"Successfully pulled dataset with alias '{alias}'")
|
203
204
|
payload = response.json()
|
204
205
|
dataset.examples = [Example(**e) for e in payload.get("examples", [])]
|
206
|
+
dataset.traces = [Trace(**t) for t in payload.get("traces", [])]
|
205
207
|
dataset._alias = payload.get("alias")
|
206
208
|
dataset._id = payload.get("id")
|
207
209
|
progress.update(
|
judgeval/data/trace.py
CHANGED
@@ -33,6 +33,8 @@ class TraceSpan(BaseModel):
|
|
33
33
|
additional_metadata: Optional[Dict[str, Any]] = None
|
34
34
|
has_evaluation: Optional[bool] = False
|
35
35
|
agent_name: Optional[str] = None
|
36
|
+
state_before: Optional[Dict[str, Any]] = None
|
37
|
+
state_after: Optional[Dict[str, Any]] = None
|
36
38
|
|
37
39
|
def model_dump(self, **kwargs):
|
38
40
|
return {
|
@@ -50,7 +52,10 @@ class TraceSpan(BaseModel):
|
|
50
52
|
"span_type": self.span_type,
|
51
53
|
"usage": self.usage.model_dump() if self.usage else None,
|
52
54
|
"has_evaluation": self.has_evaluation,
|
53
|
-
"agent_name": self.agent_name
|
55
|
+
"agent_name": self.agent_name,
|
56
|
+
"state_before": self.state_before,
|
57
|
+
"state_after": self.state_after,
|
58
|
+
"additional_metadata": self._serialize_value(self.additional_metadata)
|
54
59
|
}
|
55
60
|
|
56
61
|
def print_span(self):
|
@@ -113,7 +118,7 @@ class Trace(BaseModel):
|
|
113
118
|
name: str
|
114
119
|
created_at: str
|
115
120
|
duration: float
|
116
|
-
|
121
|
+
trace_spans: List[TraceSpan]
|
117
122
|
overwrite: bool = False
|
118
123
|
offline_mode: bool = False
|
119
124
|
rules: Optional[Dict[str, Any]] = None
|
@@ -3,9 +3,11 @@ from uuid import UUID
|
|
3
3
|
import time
|
4
4
|
import uuid
|
5
5
|
import contextvars # <--- Import contextvars
|
6
|
+
from datetime import datetime
|
6
7
|
|
7
|
-
from judgeval.common.tracer import TraceClient, TraceSpan, Tracer, SpanType, EvaluationConfig
|
8
|
+
from judgeval.common.tracer import TraceClient, TraceSpan, Tracer, SpanType, EvaluationConfig, cost_per_token
|
8
9
|
from judgeval.data import Example # Import Example
|
10
|
+
from judgeval.data.trace import TraceUsage
|
9
11
|
|
10
12
|
from langchain_core.callbacks import BaseCallbackHandler
|
11
13
|
from langchain_core.agents import AgentAction, AgentFinish
|
@@ -36,18 +38,48 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
36
38
|
def __init__(self, tracer: Tracer):
|
37
39
|
|
38
40
|
self.tracer = tracer
|
41
|
+
# Initialize tracking/logging variables (preserved across resets)
|
42
|
+
self.executed_nodes: List[str] = []
|
43
|
+
self.executed_tools: List[str] = []
|
44
|
+
self.executed_node_tools: List[str] = []
|
45
|
+
self.traces: List[Dict[str, Any]] = []
|
46
|
+
# Initialize execution state (reset between runs)
|
47
|
+
self._reset_state()
|
48
|
+
# --- END NEW __init__ ---
|
49
|
+
|
50
|
+
def _reset_state(self):
|
51
|
+
"""Reset only the critical execution state for reuse across multiple executions"""
|
52
|
+
# Reset core execution state that must be cleared between runs
|
39
53
|
self._trace_client: Optional[TraceClient] = None
|
40
54
|
self._run_id_to_span_id: Dict[UUID, str] = {}
|
41
55
|
self._span_id_to_start_time: Dict[str, float] = {}
|
42
56
|
self._span_id_to_depth: Dict[str, int] = {}
|
43
57
|
self._root_run_id: Optional[UUID] = None
|
44
|
-
self._trace_saved: bool = False
|
45
|
-
|
46
|
-
self.
|
58
|
+
self._trace_saved: bool = False
|
59
|
+
self.span_id_to_token: Dict[str, Any] = {}
|
60
|
+
self.trace_id_to_token: Dict[str, Any] = {}
|
61
|
+
|
62
|
+
# Add timestamp to track when we last reset
|
63
|
+
self._last_reset_time: float = time.time()
|
64
|
+
|
65
|
+
# Preserve tracking/logging variables across executions:
|
66
|
+
# - self.executed_nodes: List[str] = [] # Keep as running log
|
67
|
+
# - self.executed_tools: List[str] = [] # Keep as running log
|
68
|
+
# - self.executed_node_tools: List[str] = [] # Keep as running log
|
69
|
+
# - self.traces: List[Dict[str, Any]] = [] # Keep for collecting multiple traces
|
70
|
+
|
71
|
+
def reset(self):
|
72
|
+
"""Public method to manually reset handler execution state for reuse"""
|
73
|
+
self._reset_state()
|
74
|
+
|
75
|
+
def reset_all(self):
|
76
|
+
"""Public method to reset ALL handler state including tracking/logging data"""
|
77
|
+
self._reset_state()
|
78
|
+
# Also reset tracking/logging variables
|
79
|
+
self.executed_nodes: List[str] = []
|
47
80
|
self.executed_tools: List[str] = []
|
48
81
|
self.executed_node_tools: List[str] = []
|
49
82
|
self.traces: List[Dict[str, Any]] = []
|
50
|
-
# --- END NEW __init__ ---
|
51
83
|
|
52
84
|
# --- MODIFIED _ensure_trace_client ---
|
53
85
|
def _ensure_trace_client(self, run_id: UUID, parent_run_id: Optional[UUID], event_name: str) -> Optional[TraceClient]:
|
@@ -57,6 +89,11 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
57
89
|
Returns the client or None.
|
58
90
|
"""
|
59
91
|
|
92
|
+
# If this is a potential new root execution (no parent_run_id) and we had a previous trace saved,
|
93
|
+
# reset state to allow reuse of the handler
|
94
|
+
if parent_run_id is None and self._trace_saved:
|
95
|
+
self._reset_state()
|
96
|
+
|
60
97
|
# If a client already exists, return it.
|
61
98
|
if self._trace_client:
|
62
99
|
return self._trace_client
|
@@ -73,11 +110,25 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
73
110
|
enable_evaluations=self.tracer.enable_evaluations
|
74
111
|
)
|
75
112
|
self._trace_client = client_instance
|
113
|
+
token = self.tracer.set_current_trace(self._trace_client)
|
114
|
+
if token:
|
115
|
+
self.trace_id_to_token[trace_id] = token
|
76
116
|
if self._trace_client:
|
77
117
|
self._root_run_id = run_id # Assign the first run_id encountered as the tentative root
|
78
118
|
self._trace_saved = False # Ensure flag is reset
|
79
119
|
# Set active client on Tracer (important for potential fallbacks)
|
80
120
|
self.tracer._active_trace_client = self._trace_client
|
121
|
+
|
122
|
+
# NEW: Initial save for live tracking (follows the new practice)
|
123
|
+
try:
|
124
|
+
trace_id_saved, server_response = self._trace_client.save_with_rate_limiting(
|
125
|
+
overwrite=self._trace_client.overwrite,
|
126
|
+
final_save=False # Initial save for live tracking
|
127
|
+
)
|
128
|
+
except Exception as e:
|
129
|
+
import warnings
|
130
|
+
warnings.warn(f"Failed to save initial trace for live tracking: {e}")
|
131
|
+
|
81
132
|
return self._trace_client
|
82
133
|
else:
|
83
134
|
return None
|
@@ -112,12 +163,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
112
163
|
self._span_id_to_start_time[span_id] = start_time
|
113
164
|
self._span_id_to_depth[span_id] = current_depth
|
114
165
|
|
115
|
-
|
116
|
-
# --- Set SPAN context variable ONLY for chain (node) spans (Sync version) ---
|
117
|
-
if span_type == "chain":
|
118
|
-
self.tracer.set_current_span(span_id)
|
119
|
-
|
120
|
-
new_trace = TraceSpan(
|
166
|
+
new_span = TraceSpan(
|
121
167
|
span_id=span_id,
|
122
168
|
trace_id=trace_client.trace_id,
|
123
169
|
parent_span_id=parent_span_id,
|
@@ -127,9 +173,36 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
127
173
|
span_type=span_type
|
128
174
|
)
|
129
175
|
|
130
|
-
|
131
|
-
|
132
|
-
|
176
|
+
# Separate metadata from inputs
|
177
|
+
if inputs:
|
178
|
+
metadata = {}
|
179
|
+
clean_inputs = {}
|
180
|
+
|
181
|
+
# Extract metadata fields
|
182
|
+
metadata_fields = ['tags', 'metadata', 'kwargs', 'serialized']
|
183
|
+
for field in metadata_fields:
|
184
|
+
if field in inputs:
|
185
|
+
metadata[field] = inputs.pop(field)
|
186
|
+
|
187
|
+
# Store the remaining inputs
|
188
|
+
clean_inputs = inputs
|
189
|
+
|
190
|
+
# Set both fields on the span
|
191
|
+
new_span.inputs = clean_inputs
|
192
|
+
new_span.additional_metadata = metadata
|
193
|
+
else:
|
194
|
+
new_span.inputs = {}
|
195
|
+
new_span.additional_metadata = {}
|
196
|
+
|
197
|
+
trace_client.add_span(new_span)
|
198
|
+
|
199
|
+
# Queue span with initial state (input phase) through background service
|
200
|
+
if trace_client.background_span_service:
|
201
|
+
trace_client.background_span_service.queue_span(new_span, span_state="input")
|
202
|
+
|
203
|
+
token = self.tracer.set_current_span(span_id)
|
204
|
+
if token:
|
205
|
+
self.span_id_to_token[span_id] = token
|
133
206
|
|
134
207
|
def _end_span_tracking(
|
135
208
|
self,
|
@@ -142,6 +215,8 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
142
215
|
|
143
216
|
# Get span ID and check if it exists
|
144
217
|
span_id = self._run_id_to_span_id.get(run_id)
|
218
|
+
token = self.span_id_to_token.pop(span_id, None)
|
219
|
+
self.tracer.reset_current_span(token, span_id)
|
145
220
|
|
146
221
|
start_time = self._span_id_to_start_time.get(span_id) if span_id else None
|
147
222
|
duration = time.time() - start_time if start_time is not None else None
|
@@ -151,7 +226,38 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
151
226
|
trace_span = trace_client.span_id_to_span.get(span_id)
|
152
227
|
if trace_span:
|
153
228
|
trace_span.duration = duration
|
154
|
-
|
229
|
+
|
230
|
+
# Handle outputs and error
|
231
|
+
if error:
|
232
|
+
trace_span.output = error
|
233
|
+
elif outputs:
|
234
|
+
# Separate metadata from outputs
|
235
|
+
metadata = {}
|
236
|
+
clean_outputs = {}
|
237
|
+
|
238
|
+
# Extract metadata fields
|
239
|
+
metadata_fields = ['tags', 'kwargs']
|
240
|
+
if isinstance(outputs, dict):
|
241
|
+
for field in metadata_fields:
|
242
|
+
if field in outputs:
|
243
|
+
metadata[field] = outputs.pop(field)
|
244
|
+
|
245
|
+
# Store the remaining outputs
|
246
|
+
clean_outputs = outputs
|
247
|
+
else:
|
248
|
+
clean_outputs = outputs
|
249
|
+
|
250
|
+
# Set both fields on the span
|
251
|
+
trace_span.output = clean_outputs
|
252
|
+
if metadata:
|
253
|
+
# Merge with existing metadata
|
254
|
+
existing_metadata = trace_span.additional_metadata or {}
|
255
|
+
trace_span.additional_metadata = {**existing_metadata, **metadata}
|
256
|
+
|
257
|
+
# Queue span with completed state through background service
|
258
|
+
if trace_client.background_span_service:
|
259
|
+
span_state = "error" if error else "completed"
|
260
|
+
trace_client.background_span_service.queue_span(trace_span, span_state=span_state)
|
155
261
|
|
156
262
|
# Clean up dictionaries for this specific span
|
157
263
|
if span_id in self._span_id_to_start_time: del self._span_id_to_start_time[span_id]
|
@@ -165,9 +271,30 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
165
271
|
# Reset input storage for this handler instance
|
166
272
|
|
167
273
|
if self._trace_client and not self._trace_saved: # Check if not already saved
|
168
|
-
#
|
169
|
-
|
170
|
-
|
274
|
+
# Flush background spans before saving the final trace
|
275
|
+
|
276
|
+
complete_trace_data = {
|
277
|
+
"trace_id": self._trace_client.trace_id,
|
278
|
+
"name": self._trace_client.name,
|
279
|
+
"created_at": datetime.utcfromtimestamp(self._trace_client.start_time).isoformat(),
|
280
|
+
"duration": self._trace_client.get_duration(),
|
281
|
+
"trace_spans": [span.model_dump() for span in self._trace_client.trace_spans],
|
282
|
+
"overwrite": self._trace_client.overwrite,
|
283
|
+
"offline_mode": self.tracer.offline_mode,
|
284
|
+
"parent_trace_id": self._trace_client.parent_trace_id,
|
285
|
+
"parent_name": self._trace_client.parent_name
|
286
|
+
}
|
287
|
+
|
288
|
+
# NEW: Use save_with_rate_limiting with final_save=True for final save
|
289
|
+
trace_id, trace_data = self._trace_client.save_with_rate_limiting(
|
290
|
+
overwrite=self._trace_client.overwrite,
|
291
|
+
final_save=True # Final save with usage counter updates
|
292
|
+
)
|
293
|
+
token = self.trace_id_to_token.pop(trace_id, None)
|
294
|
+
self.tracer.reset_current_trace(token, trace_id)
|
295
|
+
|
296
|
+
# Store complete trace data instead of server response
|
297
|
+
self.tracer.traces.append(complete_trace_data)
|
171
298
|
self._trace_saved = True # Set flag only after successful save
|
172
299
|
finally:
|
173
300
|
# --- NEW: Consolidated Cleanup Logic ---
|
@@ -254,10 +381,26 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
254
381
|
# --- Root node cleanup (Existing logic - slightly modified save call) ---
|
255
382
|
if run_id == self._root_run_id:
|
256
383
|
if trace_client and not self._trace_saved:
|
257
|
-
#
|
258
|
-
|
259
|
-
|
260
|
-
|
384
|
+
# Store complete trace data instead of server response
|
385
|
+
complete_trace_data = {
|
386
|
+
"trace_id": trace_client.trace_id,
|
387
|
+
"name": trace_client.name,
|
388
|
+
"created_at": datetime.utcfromtimestamp(trace_client.start_time).isoformat(),
|
389
|
+
"duration": trace_client.get_duration(),
|
390
|
+
"trace_spans": [span.model_dump() for span in trace_client.trace_spans],
|
391
|
+
"overwrite": trace_client.overwrite,
|
392
|
+
"offline_mode": self.tracer.offline_mode,
|
393
|
+
"parent_trace_id": trace_client.parent_trace_id,
|
394
|
+
"parent_name": trace_client.parent_name
|
395
|
+
}
|
396
|
+
# NEW: Use save_with_rate_limiting with final_save=True for final save
|
397
|
+
trace_id_saved, trace_data = trace_client.save_with_rate_limiting(
|
398
|
+
overwrite=trace_client.overwrite,
|
399
|
+
final_save=True # Final save with usage counter updates
|
400
|
+
)
|
401
|
+
|
402
|
+
|
403
|
+
self.tracer.traces.append(complete_trace_data)
|
261
404
|
self._trace_saved = True
|
262
405
|
# Reset tracer's active client *after* successful save
|
263
406
|
if self.tracer._active_trace_client == trace_client:
|
@@ -333,11 +476,23 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
333
476
|
if not trace_client:
|
334
477
|
return
|
335
478
|
outputs = {"response": response, "kwargs": kwargs}
|
336
|
-
|
337
|
-
|
338
|
-
prompt_tokens = None
|
339
|
-
completion_tokens = None
|
479
|
+
|
480
|
+
# --- Token Usage Extraction and Cost Calculation ---
|
481
|
+
prompt_tokens = None
|
482
|
+
completion_tokens = None
|
340
483
|
total_tokens = None
|
484
|
+
model_name = None
|
485
|
+
|
486
|
+
# Extract model name from response if available
|
487
|
+
if hasattr(response, 'llm_output') and response.llm_output and isinstance(response.llm_output, dict):
|
488
|
+
model_name = response.llm_output.get('model_name') or response.llm_output.get('model')
|
489
|
+
|
490
|
+
# Try to get model from the first generation if available
|
491
|
+
if not model_name and response.generations and len(response.generations) > 0:
|
492
|
+
if hasattr(response.generations[0][0], 'generation_info') and response.generations[0][0].generation_info:
|
493
|
+
gen_info = response.generations[0][0].generation_info
|
494
|
+
model_name = gen_info.get('model') or gen_info.get('model_name')
|
495
|
+
|
341
496
|
if response.llm_output and isinstance(response.llm_output, dict):
|
342
497
|
# Check for OpenAI/standard 'token_usage' first
|
343
498
|
if 'token_usage' in response.llm_output:
|
@@ -356,14 +511,43 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
356
511
|
if prompt_tokens is not None and completion_tokens is not None:
|
357
512
|
total_tokens = prompt_tokens + completion_tokens
|
358
513
|
|
359
|
-
# ---
|
514
|
+
# --- Create TraceUsage object and set on span ---
|
360
515
|
if prompt_tokens is not None or completion_tokens is not None:
|
361
|
-
#
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
516
|
+
# Calculate costs if model name is available
|
517
|
+
prompt_cost = None
|
518
|
+
completion_cost = None
|
519
|
+
total_cost_usd = None
|
520
|
+
|
521
|
+
if model_name and prompt_tokens is not None and completion_tokens is not None:
|
522
|
+
try:
|
523
|
+
prompt_cost, completion_cost = cost_per_token(
|
524
|
+
model=model_name,
|
525
|
+
prompt_tokens=prompt_tokens,
|
526
|
+
completion_tokens=completion_tokens
|
527
|
+
)
|
528
|
+
total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
529
|
+
except Exception as e:
|
530
|
+
# If cost calculation fails, continue without costs
|
531
|
+
import warnings
|
532
|
+
warnings.warn(f"Failed to calculate token costs for model {model_name}: {e}")
|
533
|
+
|
534
|
+
# Create TraceUsage object
|
535
|
+
usage = TraceUsage(
|
536
|
+
prompt_tokens=prompt_tokens,
|
537
|
+
completion_tokens=completion_tokens,
|
538
|
+
total_tokens=total_tokens or (prompt_tokens + completion_tokens if prompt_tokens and completion_tokens else None),
|
539
|
+
prompt_tokens_cost_usd=prompt_cost,
|
540
|
+
completion_tokens_cost_usd=completion_cost,
|
541
|
+
total_cost_usd=total_cost_usd,
|
542
|
+
model_name=model_name
|
543
|
+
)
|
544
|
+
|
545
|
+
# Set usage on the actual span (not in outputs)
|
546
|
+
span_id = self._run_id_to_span_id.get(run_id)
|
547
|
+
if span_id and span_id in trace_client.span_id_to_span:
|
548
|
+
trace_span = trace_client.span_id_to_span[span_id]
|
549
|
+
trace_span.usage = usage
|
550
|
+
|
367
551
|
|
368
552
|
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|
369
553
|
# --- End Token Usage ---
|
@@ -416,4 +600,4 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
416
600
|
if not trace_client: return
|
417
601
|
|
418
602
|
outputs = {'return_values': finish.return_values, 'log': finish.log, 'messages': finish.messages, 'kwargs': kwargs}
|
419
|
-
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|
603
|
+
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|
judgeval/judgment_client.py
CHANGED
@@ -63,7 +63,15 @@ class SingletonMeta(type):
|
|
63
63
|
return cls._instances[cls]
|
64
64
|
|
65
65
|
class JudgmentClient(metaclass=SingletonMeta):
|
66
|
-
def __init__(self, judgment_api_key: str = os.getenv("JUDGMENT_API_KEY"), organization_id: str = os.getenv("JUDGMENT_ORG_ID")):
|
66
|
+
def __init__(self, judgment_api_key: Optional[str] = os.getenv("JUDGMENT_API_KEY"), organization_id: Optional[str] = os.getenv("JUDGMENT_ORG_ID")):
|
67
|
+
# Check if API key is None
|
68
|
+
if judgment_api_key is None:
|
69
|
+
raise ValueError("JUDGMENT_API_KEY cannot be None. Please provide a valid API key or set the JUDGMENT_API_KEY environment variable.")
|
70
|
+
|
71
|
+
# Check if organization ID is None
|
72
|
+
if organization_id is None:
|
73
|
+
raise ValueError("JUDGMENT_ORG_ID cannot be None. Please provide a valid organization ID or set the JUDGMENT_ORG_ID environment variable.")
|
74
|
+
|
67
75
|
self.judgment_api_key = judgment_api_key
|
68
76
|
self.organization_id = organization_id
|
69
77
|
self.eval_dataset_client = EvalDatasetClient(judgment_api_key, organization_id)
|