judgeval 0.0.52__py3-none-any.whl → 0.0.54__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/logger.py +46 -199
- judgeval/common/s3_storage.py +2 -6
- judgeval/common/tracer.py +182 -262
- judgeval/common/utils.py +16 -36
- judgeval/constants.py +14 -20
- judgeval/data/__init__.py +0 -2
- judgeval/data/datasets/dataset.py +6 -10
- judgeval/data/datasets/eval_dataset_client.py +25 -27
- judgeval/data/example.py +5 -138
- judgeval/data/judgment_types.py +214 -0
- judgeval/data/result.py +7 -25
- judgeval/data/scorer_data.py +28 -40
- judgeval/data/scripts/fix_default_factory.py +23 -0
- judgeval/data/scripts/openapi_transform.py +123 -0
- judgeval/data/tool.py +3 -54
- judgeval/data/trace.py +31 -50
- judgeval/data/trace_run.py +3 -3
- judgeval/evaluation_run.py +16 -23
- judgeval/integrations/langgraph.py +11 -12
- judgeval/judges/litellm_judge.py +3 -6
- judgeval/judges/mixture_of_judges.py +8 -25
- judgeval/judges/together_judge.py +3 -6
- judgeval/judgment_client.py +22 -24
- judgeval/rules.py +7 -19
- judgeval/run_evaluation.py +79 -242
- judgeval/scorers/__init__.py +4 -20
- judgeval/scorers/agent_scorer.py +21 -0
- judgeval/scorers/api_scorer.py +28 -38
- judgeval/scorers/base_scorer.py +98 -0
- judgeval/scorers/example_scorer.py +19 -0
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +0 -20
- judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +10 -17
- judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +9 -24
- judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +16 -68
- judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +4 -12
- judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +4 -4
- judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +10 -17
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +4 -4
- judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +4 -4
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +4 -4
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +18 -14
- judgeval/scorers/score.py +45 -330
- judgeval/scorers/utils.py +6 -88
- judgeval/utils/file_utils.py +4 -6
- judgeval/version_check.py +3 -2
- {judgeval-0.0.52.dist-info → judgeval-0.0.54.dist-info}/METADATA +6 -5
- judgeval-0.0.54.dist-info/RECORD +65 -0
- judgeval/data/custom_example.py +0 -19
- judgeval/scorers/judgeval_scorer.py +0 -177
- judgeval/scorers/judgeval_scorers/api_scorers/comparison.py +0 -45
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +0 -29
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +0 -29
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +0 -32
- judgeval/scorers/judgeval_scorers/api_scorers/groundedness.py +0 -28
- judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +0 -38
- judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +0 -27
- judgeval/scorers/prompt_scorer.py +0 -296
- judgeval-0.0.52.dist-info/RECORD +0 -69
- {judgeval-0.0.52.dist-info → judgeval-0.0.54.dist-info}/WHEEL +0 -0
- {judgeval-0.0.52.dist-info → judgeval-0.0.54.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -13,7 +13,6 @@ import threading
|
|
13
13
|
import time
|
14
14
|
import traceback
|
15
15
|
import uuid
|
16
|
-
import warnings
|
17
16
|
import contextvars
|
18
17
|
import sys
|
19
18
|
import json
|
@@ -23,7 +22,7 @@ from contextlib import (
|
|
23
22
|
AbstractContextManager,
|
24
23
|
) # Import context manager bases
|
25
24
|
from dataclasses import dataclass
|
26
|
-
from datetime import datetime
|
25
|
+
from datetime import datetime, timezone
|
27
26
|
from http import HTTPStatus
|
28
27
|
from typing import (
|
29
28
|
Any,
|
@@ -53,7 +52,6 @@ from google import genai
|
|
53
52
|
# Local application/library-specific imports
|
54
53
|
from judgeval.constants import (
|
55
54
|
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
56
|
-
JUDGMENT_TRACES_SAVE_API_URL,
|
57
55
|
JUDGMENT_TRACES_UPSERT_API_URL,
|
58
56
|
JUDGMENT_TRACES_FETCH_API_URL,
|
59
57
|
JUDGMENT_TRACES_DELETE_API_URL,
|
@@ -62,11 +60,10 @@ from judgeval.constants import (
|
|
62
60
|
JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
|
63
61
|
)
|
64
62
|
from judgeval.data import Example, Trace, TraceSpan, TraceUsage
|
65
|
-
from judgeval.scorers import
|
66
|
-
from judgeval.rules import Rule
|
63
|
+
from judgeval.scorers import APIScorerConfig, BaseScorer
|
67
64
|
from judgeval.evaluation_run import EvaluationRun
|
68
65
|
from judgeval.common.utils import ExcInfo, validate_api_key
|
69
|
-
from judgeval.common.
|
66
|
+
from judgeval.common.logger import judgeval_logger
|
70
67
|
|
71
68
|
# Standard library imports needed for the new class
|
72
69
|
import concurrent.futures
|
@@ -157,80 +154,6 @@ class TraceManagerClient:
|
|
157
154
|
|
158
155
|
return response.json()
|
159
156
|
|
160
|
-
def save_trace(
|
161
|
-
self, trace_data: dict, offline_mode: bool = False, final_save: bool = True
|
162
|
-
):
|
163
|
-
"""
|
164
|
-
Saves a trace to the Judgment Supabase and optionally to S3 if configured.
|
165
|
-
|
166
|
-
Args:
|
167
|
-
trace_data: The trace data to save
|
168
|
-
offline_mode: Whether running in offline mode
|
169
|
-
final_save: Whether this is the final save (controls S3 saving)
|
170
|
-
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
171
|
-
|
172
|
-
Returns:
|
173
|
-
dict: Server response containing UI URL and other metadata
|
174
|
-
"""
|
175
|
-
# Save to Judgment API
|
176
|
-
|
177
|
-
def fallback_encoder(obj):
|
178
|
-
"""
|
179
|
-
Custom JSON encoder fallback.
|
180
|
-
Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
|
181
|
-
You can choose which one you prefer or try them in sequence.
|
182
|
-
"""
|
183
|
-
try:
|
184
|
-
# Option 1: Prefer __repr__ for a more detailed representation
|
185
|
-
return repr(obj)
|
186
|
-
except Exception:
|
187
|
-
# Option 2: Fallback to str() if __repr__ fails or if you prefer str()
|
188
|
-
try:
|
189
|
-
return str(obj)
|
190
|
-
except Exception as e:
|
191
|
-
# If both fail, you might return a placeholder or re-raise
|
192
|
-
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
193
|
-
|
194
|
-
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
195
|
-
response = requests.post(
|
196
|
-
JUDGMENT_TRACES_SAVE_API_URL,
|
197
|
-
data=serialized_trace_data,
|
198
|
-
headers={
|
199
|
-
"Content-Type": "application/json",
|
200
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
201
|
-
"X-Organization-Id": self.organization_id,
|
202
|
-
},
|
203
|
-
verify=True,
|
204
|
-
)
|
205
|
-
|
206
|
-
if response.status_code == HTTPStatus.BAD_REQUEST:
|
207
|
-
raise ValueError(
|
208
|
-
f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}"
|
209
|
-
)
|
210
|
-
elif response.status_code != HTTPStatus.OK:
|
211
|
-
raise ValueError(f"Failed to save trace data: {response.text}")
|
212
|
-
|
213
|
-
# Parse server response
|
214
|
-
server_response = response.json()
|
215
|
-
|
216
|
-
# If S3 storage is enabled, save to S3 only on final save
|
217
|
-
if self.tracer and self.tracer.use_s3 and final_save:
|
218
|
-
try:
|
219
|
-
s3_key = self.tracer.s3_storage.save_trace(
|
220
|
-
trace_data=trace_data,
|
221
|
-
trace_id=trace_data["trace_id"],
|
222
|
-
project_name=trace_data["project_name"],
|
223
|
-
)
|
224
|
-
print(f"Trace also saved to S3 at key: {s3_key}")
|
225
|
-
except Exception as e:
|
226
|
-
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
227
|
-
|
228
|
-
if not offline_mode and "ui_results_url" in server_response:
|
229
|
-
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
230
|
-
rprint(pretty_str)
|
231
|
-
|
232
|
-
return server_response
|
233
|
-
|
234
157
|
def upsert_trace(
|
235
158
|
self,
|
236
159
|
trace_data: dict,
|
@@ -291,9 +214,9 @@ class TraceManagerClient:
|
|
291
214
|
trace_id=trace_data["trace_id"],
|
292
215
|
project_name=trace_data["project_name"],
|
293
216
|
)
|
294
|
-
|
217
|
+
judgeval_logger.info(f"Trace also saved to S3 at key: {s3_key}")
|
295
218
|
except Exception as e:
|
296
|
-
|
219
|
+
judgeval_logger.warning(f"Failed to save trace to S3: {str(e)}")
|
297
220
|
|
298
221
|
if not offline_mode and show_link and "ui_results_url" in server_response:
|
299
222
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
@@ -401,8 +324,6 @@ class TraceClient:
|
|
401
324
|
trace_id: Optional[str] = None,
|
402
325
|
name: str = "default",
|
403
326
|
project_name: str | None = None,
|
404
|
-
overwrite: bool = False,
|
405
|
-
rules: Optional[List[Rule]] = None,
|
406
327
|
enable_monitoring: bool = True,
|
407
328
|
enable_evaluations: bool = True,
|
408
329
|
parent_trace_id: Optional[str] = None,
|
@@ -410,10 +331,8 @@ class TraceClient:
|
|
410
331
|
):
|
411
332
|
self.name = name
|
412
333
|
self.trace_id = trace_id or str(uuid.uuid4())
|
413
|
-
self.project_name = project_name or
|
414
|
-
self.overwrite = overwrite
|
334
|
+
self.project_name = project_name or "default_project"
|
415
335
|
self.tracer = tracer
|
416
|
-
self.rules = rules or []
|
417
336
|
self.enable_monitoring = enable_monitoring
|
418
337
|
self.enable_evaluations = enable_evaluations
|
419
338
|
self.parent_trace_id = parent_trace_id
|
@@ -422,6 +341,7 @@ class TraceClient:
|
|
422
341
|
self.tags: List[Union[str, set, tuple]] = [] # Added tags attribute
|
423
342
|
self.metadata: Dict[str, Any] = {}
|
424
343
|
self.has_notification: Optional[bool] = False # Initialize has_notification
|
344
|
+
self.update_id: int = 1 # Initialize update_id to 1, increments with each save
|
425
345
|
self.trace_spans: List[TraceSpan] = []
|
426
346
|
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
427
347
|
self.evaluation_runs: List[EvaluationRun] = []
|
@@ -457,15 +377,13 @@ class TraceClient:
|
|
457
377
|
is_first_span = len(self.trace_spans) == 0
|
458
378
|
if is_first_span:
|
459
379
|
try:
|
460
|
-
trace_id, server_response = self.save(
|
461
|
-
overwrite=self.overwrite, final_save=False
|
462
|
-
)
|
380
|
+
trace_id, server_response = self.save(final_save=False)
|
463
381
|
# Set start_time after first successful save
|
464
|
-
if self.start_time is None:
|
465
|
-
self.start_time = time.time()
|
466
382
|
# Link will be shown by upsert_trace method
|
467
383
|
except Exception as e:
|
468
|
-
|
384
|
+
judgeval_logger.warning(
|
385
|
+
f"Failed to save initial trace for live tracking: {e}"
|
386
|
+
)
|
469
387
|
start_time = time.time()
|
470
388
|
|
471
389
|
# Generate a unique ID for *this specific span invocation*
|
@@ -518,7 +436,7 @@ class TraceClient:
|
|
518
436
|
|
519
437
|
def async_evaluate(
|
520
438
|
self,
|
521
|
-
scorers: List[Union[
|
439
|
+
scorers: List[Union[APIScorerConfig, BaseScorer]],
|
522
440
|
example: Optional[Example] = None,
|
523
441
|
input: Optional[str] = None,
|
524
442
|
actual_output: Optional[Union[str, List[str]]] = None,
|
@@ -539,19 +457,11 @@ class TraceClient:
|
|
539
457
|
try:
|
540
458
|
# Load appropriate implementations for all scorers
|
541
459
|
if not scorers:
|
542
|
-
|
460
|
+
judgeval_logger.warning("No valid scorers available for evaluation")
|
543
461
|
return
|
544
462
|
|
545
|
-
# Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
|
546
|
-
if self.rules and any(
|
547
|
-
isinstance(scorer, JudgevalScorer) for scorer in scorers
|
548
|
-
):
|
549
|
-
raise ValueError(
|
550
|
-
"Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types."
|
551
|
-
)
|
552
|
-
|
553
463
|
except Exception as e:
|
554
|
-
|
464
|
+
judgeval_logger.warning(f"Failed to load scorers: {str(e)}")
|
555
465
|
return
|
556
466
|
|
557
467
|
# If example is not provided, create one from the individual parameters
|
@@ -589,15 +499,8 @@ class TraceClient:
|
|
589
499
|
|
590
500
|
# check_examples([example], scorers)
|
591
501
|
|
592
|
-
# --- Modification: Capture span_id immediately ---
|
593
|
-
# span_id_at_eval_call = current_span_var.get()
|
594
|
-
# print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
|
595
|
-
# Prioritize explicitly passed span_id, fallback to context var
|
596
502
|
span_id_to_use = span_id if span_id is not None else self.get_current_span()
|
597
|
-
# print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
|
598
|
-
# --- End Modification ---
|
599
503
|
|
600
|
-
# Combine the trace-level rules with any evaluation-specific rules)
|
601
504
|
eval_run = EvaluationRun(
|
602
505
|
organization_id=self.tracer.organization_id,
|
603
506
|
project_name=self.project_name,
|
@@ -608,7 +511,6 @@ class TraceClient:
|
|
608
511
|
scorers=scorers,
|
609
512
|
model=model,
|
610
513
|
judgment_api_key=self.tracer.api_key,
|
611
|
-
override=self.overwrite,
|
612
514
|
trace_span_id=span_id_to_use,
|
613
515
|
)
|
614
516
|
|
@@ -631,7 +533,6 @@ class TraceClient:
|
|
631
533
|
|
632
534
|
if current_span_id:
|
633
535
|
span = self.span_id_to_span[current_span_id]
|
634
|
-
span.evaluation_runs.append(eval_run)
|
635
536
|
span.has_evaluation = True # Set the has_evaluation flag
|
636
537
|
self.evaluation_runs.append(eval_run)
|
637
538
|
|
@@ -654,7 +555,7 @@ class TraceClient:
|
|
654
555
|
if self.background_span_service:
|
655
556
|
self.background_span_service.queue_span(span, span_state="input")
|
656
557
|
except Exception as e:
|
657
|
-
|
558
|
+
judgeval_logger.warning(f"Failed to queue span with input data: {e}")
|
658
559
|
|
659
560
|
def record_agent_name(self, agent_name: str):
|
660
561
|
current_span_id = self.get_current_span()
|
@@ -779,15 +680,12 @@ class TraceClient:
|
|
779
680
|
return 0.0 # No duration if trace hasn't been saved yet
|
780
681
|
return time.time() - self.start_time
|
781
682
|
|
782
|
-
def save(
|
783
|
-
self, overwrite: bool = False, final_save: bool = False
|
784
|
-
) -> Tuple[str, dict]:
|
683
|
+
def save(self, final_save: bool = False) -> Tuple[str, dict]:
|
785
684
|
"""
|
786
685
|
Save the current trace to the database with rate limiting checks.
|
787
686
|
First checks usage limits, then upserts the trace if allowed.
|
788
687
|
|
789
688
|
Args:
|
790
|
-
overwrite: Whether to overwrite existing traces
|
791
689
|
final_save: Whether this is the final save (updates usage counters)
|
792
690
|
|
793
691
|
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
@@ -801,17 +699,19 @@ class TraceClient:
|
|
801
699
|
"trace_id": self.trace_id,
|
802
700
|
"name": self.name,
|
803
701
|
"project_name": self.project_name,
|
804
|
-
"created_at": datetime.
|
702
|
+
"created_at": datetime.fromtimestamp(
|
703
|
+
self.start_time or time.time(), timezone.utc
|
704
|
+
).isoformat(),
|
805
705
|
"duration": total_duration,
|
806
706
|
"trace_spans": [span.model_dump() for span in self.trace_spans],
|
807
707
|
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
808
|
-
"overwrite": overwrite,
|
809
708
|
"offline_mode": self.tracer.offline_mode,
|
810
709
|
"parent_trace_id": self.parent_trace_id,
|
811
710
|
"parent_name": self.parent_name,
|
812
711
|
"customer_id": self.customer_id,
|
813
712
|
"tags": self.tags,
|
814
713
|
"metadata": self.metadata,
|
714
|
+
"update_id": self.update_id,
|
815
715
|
}
|
816
716
|
|
817
717
|
# If usage check passes, upsert the trace
|
@@ -822,12 +722,15 @@ class TraceClient:
|
|
822
722
|
final_save=final_save, # Pass final_save to control S3 saving
|
823
723
|
)
|
824
724
|
|
725
|
+
if self.start_time is None:
|
726
|
+
self.start_time = time.time()
|
727
|
+
|
728
|
+
self.update_id += 1
|
729
|
+
|
825
730
|
# Upload annotations
|
826
731
|
# TODO: batch to the log endpoint
|
827
732
|
for annotation in self.annotations:
|
828
733
|
self.trace_manager_client.save_annotation(annotation)
|
829
|
-
if self.start_time is None:
|
830
|
-
self.start_time = time.time()
|
831
734
|
return self.trace_id, server_response
|
832
735
|
|
833
736
|
def delete(self):
|
@@ -844,7 +747,6 @@ class TraceClient:
|
|
844
747
|
- customer_id: ID of the customer using this trace
|
845
748
|
- tags: List of tags for this trace
|
846
749
|
- has_notification: Whether this trace has a notification
|
847
|
-
- overwrite: Whether to overwrite existing traces
|
848
750
|
- name: Name of the trace
|
849
751
|
"""
|
850
752
|
for k, v in metadata.items():
|
@@ -872,10 +774,6 @@ class TraceClient:
|
|
872
774
|
f"has_notification must be a boolean, got {type(v)}"
|
873
775
|
)
|
874
776
|
self.has_notification = v
|
875
|
-
elif k == "overwrite":
|
876
|
-
if not isinstance(v, bool):
|
877
|
-
raise ValueError(f"overwrite must be a boolean, got {type(v)}")
|
878
|
-
self.overwrite = v
|
879
777
|
elif k == "name":
|
880
778
|
self.name = v
|
881
779
|
else:
|
@@ -1056,7 +954,7 @@ class BackgroundSpanService:
|
|
1056
954
|
last_flush_time = current_time
|
1057
955
|
|
1058
956
|
except Exception as e:
|
1059
|
-
|
957
|
+
judgeval_logger.warning(f"Error in span service worker loop: {e}")
|
1060
958
|
# On error, still need to mark tasks as done to prevent hanging
|
1061
959
|
for _ in range(pending_task_count):
|
1062
960
|
self._span_queue.task_done()
|
@@ -1100,7 +998,7 @@ class BackgroundSpanService:
|
|
1100
998
|
self._send_evaluation_runs_batch(evaluation_runs_to_send)
|
1101
999
|
|
1102
1000
|
except Exception as e:
|
1103
|
-
|
1001
|
+
judgeval_logger.warning(f"Failed to send batch: {e}")
|
1104
1002
|
|
1105
1003
|
def _send_spans_batch(self, spans: List[Dict[str, Any]]):
|
1106
1004
|
"""Send a batch of spans to the spans endpoint."""
|
@@ -1133,14 +1031,14 @@ class BackgroundSpanService:
|
|
1133
1031
|
)
|
1134
1032
|
|
1135
1033
|
if response.status_code != HTTPStatus.OK:
|
1136
|
-
|
1034
|
+
judgeval_logger.warning(
|
1137
1035
|
f"Failed to send spans batch: HTTP {response.status_code} - {response.text}"
|
1138
1036
|
)
|
1139
1037
|
|
1140
1038
|
except RequestException as e:
|
1141
|
-
|
1039
|
+
judgeval_logger.warning(f"Network error sending spans batch: {e}")
|
1142
1040
|
except Exception as e:
|
1143
|
-
|
1041
|
+
judgeval_logger.warning(f"Failed to serialize or send spans batch: {e}")
|
1144
1042
|
|
1145
1043
|
def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
|
1146
1044
|
"""Send a batch of evaluation runs with their associated span data to the endpoint."""
|
@@ -1195,14 +1093,14 @@ class BackgroundSpanService:
|
|
1195
1093
|
)
|
1196
1094
|
|
1197
1095
|
if response.status_code != HTTPStatus.OK:
|
1198
|
-
|
1096
|
+
judgeval_logger.warning(
|
1199
1097
|
f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}"
|
1200
1098
|
)
|
1201
1099
|
|
1202
1100
|
except RequestException as e:
|
1203
|
-
|
1101
|
+
judgeval_logger.warning(f"Network error sending evaluation runs batch: {e}")
|
1204
1102
|
except Exception as e:
|
1205
|
-
|
1103
|
+
judgeval_logger.warning(f"Failed to send evaluation runs batch: {e}")
|
1206
1104
|
|
1207
1105
|
def queue_span(self, span: TraceSpan, span_state: str = "input"):
|
1208
1106
|
"""
|
@@ -1213,6 +1111,8 @@ class BackgroundSpanService:
|
|
1213
1111
|
span_state: State of the span ("input", "output", "completed")
|
1214
1112
|
"""
|
1215
1113
|
if not self._shutdown_event.is_set():
|
1114
|
+
# Increment update_id when queueing the span
|
1115
|
+
span.increment_update_id()
|
1216
1116
|
span_data = {
|
1217
1117
|
"type": "span",
|
1218
1118
|
"data": {
|
@@ -1252,7 +1152,7 @@ class BackgroundSpanService:
|
|
1252
1152
|
# Wait for the queue to be processed
|
1253
1153
|
self._span_queue.join()
|
1254
1154
|
except Exception as e:
|
1255
|
-
|
1155
|
+
judgeval_logger.warning(f"Error during flush: {e}")
|
1256
1156
|
|
1257
1157
|
def shutdown(self):
|
1258
1158
|
"""Shutdown the background service and flush remaining spans."""
|
@@ -1267,9 +1167,9 @@ class BackgroundSpanService:
|
|
1267
1167
|
try:
|
1268
1168
|
self.flush()
|
1269
1169
|
except Exception as e:
|
1270
|
-
|
1170
|
+
judgeval_logger.warning(f"Error during final flush: {e}")
|
1271
1171
|
except Exception as e:
|
1272
|
-
|
1172
|
+
judgeval_logger.warning(f"Error during BackgroundSpanService shutdown: {e}")
|
1273
1173
|
finally:
|
1274
1174
|
# Clear the worker threads list (daemon threads will be killed automatically)
|
1275
1175
|
self._worker_threads.clear()
|
@@ -1601,9 +1501,9 @@ class Tracer:
|
|
1601
1501
|
def __init__(
|
1602
1502
|
self,
|
1603
1503
|
api_key: str | None = os.getenv("JUDGMENT_API_KEY"),
|
1604
|
-
project_name: str | None = None,
|
1605
|
-
rules: Optional[List[Rule]] = None, # Added rules parameter
|
1606
1504
|
organization_id: str | None = os.getenv("JUDGMENT_ORG_ID"),
|
1505
|
+
project_name: str | None = None,
|
1506
|
+
deep_tracing: bool = False, # Deep tracing is disabled by default
|
1607
1507
|
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
|
1608
1508
|
== "true",
|
1609
1509
|
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
|
@@ -1614,73 +1514,77 @@ class Tracer:
|
|
1614
1514
|
s3_aws_access_key_id: Optional[str] = None,
|
1615
1515
|
s3_aws_secret_access_key: Optional[str] = None,
|
1616
1516
|
s3_region_name: Optional[str] = None,
|
1617
|
-
offline_mode: bool = False,
|
1618
|
-
deep_tracing: bool = False, # Deep tracing is disabled by default
|
1619
1517
|
trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
|
1620
1518
|
# Background span service configuration
|
1621
|
-
enable_background_spans: bool = True, # Enable background span service by default
|
1622
1519
|
span_batch_size: int = 50, # Number of spans to batch before sending
|
1623
1520
|
span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
|
1624
1521
|
span_num_workers: int = 10, # Number of worker threads for span processing
|
1625
1522
|
):
|
1626
|
-
if not api_key:
|
1627
|
-
raise ValueError("Tracer must be configured with a Judgment API key")
|
1628
|
-
|
1629
1523
|
try:
|
1630
|
-
|
1631
|
-
|
1632
|
-
|
1633
|
-
|
1634
|
-
result = True
|
1635
|
-
|
1636
|
-
if not result:
|
1637
|
-
raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
|
1638
|
-
|
1639
|
-
if not organization_id:
|
1640
|
-
raise ValueError("Tracer must be configured with an Organization ID")
|
1641
|
-
if use_s3 and not s3_bucket_name:
|
1642
|
-
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
1643
|
-
|
1644
|
-
self.api_key: str = api_key
|
1645
|
-
self.project_name: str = project_name or str(uuid.uuid4())
|
1646
|
-
self.organization_id: str = organization_id
|
1647
|
-
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
1648
|
-
self.traces: List[Trace] = []
|
1649
|
-
self.enable_monitoring: bool = enable_monitoring
|
1650
|
-
self.enable_evaluations: bool = enable_evaluations
|
1651
|
-
self.class_identifiers: Dict[
|
1652
|
-
str, str
|
1653
|
-
] = {} # Dictionary to store class identifiers
|
1654
|
-
self.span_id_to_previous_span_id: Dict[str, str | None] = {}
|
1655
|
-
self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
|
1656
|
-
self.current_span_id: Optional[str] = None
|
1657
|
-
self.current_trace: Optional[TraceClient] = None
|
1658
|
-
self.trace_across_async_contexts: bool = trace_across_async_contexts
|
1659
|
-
Tracer.trace_across_async_contexts = trace_across_async_contexts
|
1660
|
-
|
1661
|
-
# Initialize S3 storage if enabled
|
1662
|
-
self.use_s3 = use_s3
|
1663
|
-
if use_s3:
|
1664
|
-
from judgeval.common.s3_storage import S3Storage
|
1524
|
+
if not api_key:
|
1525
|
+
raise ValueError(
|
1526
|
+
"api_key parameter must be provided. Please provide a valid API key value or set the JUDGMENT_API_KEY environment variable"
|
1527
|
+
)
|
1665
1528
|
|
1666
|
-
|
1667
|
-
|
1668
|
-
|
1669
|
-
aws_access_key_id=s3_aws_access_key_id,
|
1670
|
-
aws_secret_access_key=s3_aws_secret_access_key,
|
1671
|
-
region_name=s3_region_name,
|
1529
|
+
if not organization_id:
|
1530
|
+
raise ValueError(
|
1531
|
+
"organization_id parameter must be provided. Please provide a valid organization ID value or set the JUDGMENT_ORG_ID environment variable"
|
1672
1532
|
)
|
1533
|
+
|
1534
|
+
try:
|
1535
|
+
result, response = validate_api_key(api_key)
|
1673
1536
|
except Exception as e:
|
1674
|
-
|
1675
|
-
|
1537
|
+
judgeval_logger.error(
|
1538
|
+
f"Issue with verifying API key, disabling monitoring: {e}"
|
1539
|
+
)
|
1540
|
+
enable_monitoring = False
|
1541
|
+
result = True
|
1542
|
+
|
1543
|
+
if not result:
|
1544
|
+
raise ValueError(f"Issue with passed in Judgment API key: {response}")
|
1545
|
+
|
1546
|
+
if use_s3 and not s3_bucket_name:
|
1547
|
+
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
1548
|
+
|
1549
|
+
self.api_key: str = api_key
|
1550
|
+
self.project_name: str = project_name or "default_project"
|
1551
|
+
self.organization_id: str = organization_id
|
1552
|
+
self.traces: List[Trace] = []
|
1553
|
+
self.enable_monitoring: bool = enable_monitoring
|
1554
|
+
self.enable_evaluations: bool = enable_evaluations
|
1555
|
+
self.class_identifiers: Dict[
|
1556
|
+
str, str
|
1557
|
+
] = {} # Dictionary to store class identifiers
|
1558
|
+
self.span_id_to_previous_span_id: Dict[str, str | None] = {}
|
1559
|
+
self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
|
1560
|
+
self.current_span_id: Optional[str] = None
|
1561
|
+
self.current_trace: Optional[TraceClient] = None
|
1562
|
+
self.trace_across_async_contexts: bool = trace_across_async_contexts
|
1563
|
+
Tracer.trace_across_async_contexts = trace_across_async_contexts
|
1564
|
+
|
1565
|
+
# Initialize S3 storage if enabled
|
1566
|
+
self.use_s3 = use_s3
|
1567
|
+
if use_s3:
|
1568
|
+
from judgeval.common.s3_storage import S3Storage
|
1676
1569
|
|
1677
|
-
|
1678
|
-
|
1570
|
+
try:
|
1571
|
+
self.s3_storage = S3Storage(
|
1572
|
+
bucket_name=s3_bucket_name,
|
1573
|
+
aws_access_key_id=s3_aws_access_key_id,
|
1574
|
+
aws_secret_access_key=s3_aws_secret_access_key,
|
1575
|
+
region_name=s3_region_name,
|
1576
|
+
)
|
1577
|
+
except Exception as e:
|
1578
|
+
judgeval_logger.error(
|
1579
|
+
f"Issue with initializing S3 storage, disabling S3: {e}"
|
1580
|
+
)
|
1581
|
+
self.use_s3 = False
|
1582
|
+
|
1583
|
+
self.offline_mode = False # This is used to differentiate traces between online and offline (IE experiments vs monitoring page)
|
1584
|
+
self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
|
1679
1585
|
|
1680
|
-
|
1681
|
-
|
1682
|
-
self.background_span_service: Optional[BackgroundSpanService] = None
|
1683
|
-
if enable_background_spans and not offline_mode:
|
1586
|
+
# Initialize background span service
|
1587
|
+
self.background_span_service: Optional[BackgroundSpanService] = None
|
1684
1588
|
self.background_span_service = BackgroundSpanService(
|
1685
1589
|
judgment_api_key=api_key,
|
1686
1590
|
organization_id=organization_id,
|
@@ -1688,6 +1592,12 @@ class Tracer:
|
|
1688
1592
|
flush_interval=span_flush_interval,
|
1689
1593
|
num_workers=span_num_workers,
|
1690
1594
|
)
|
1595
|
+
except Exception as e:
|
1596
|
+
judgeval_logger.error(
|
1597
|
+
f"Issue with initializing Tracer: {e}. Disabling monitoring and evaluations."
|
1598
|
+
)
|
1599
|
+
self.enable_monitoring = False
|
1600
|
+
self.enable_evaluations = False
|
1691
1601
|
|
1692
1602
|
def set_current_span(self, span_id: str) -> Optional[contextvars.Token[str | None]]:
|
1693
1603
|
self.span_id_to_previous_span_id[span_id] = self.current_span_id
|
@@ -1777,11 +1687,7 @@ class Tracer:
|
|
1777
1687
|
|
1778
1688
|
@contextmanager
|
1779
1689
|
def trace(
|
1780
|
-
self,
|
1781
|
-
name: str,
|
1782
|
-
project_name: str | None = None,
|
1783
|
-
overwrite: bool = False,
|
1784
|
-
rules: Optional[List[Rule]] = None, # Added rules parameter
|
1690
|
+
self, name: str, project_name: str | None = None
|
1785
1691
|
) -> Generator[TraceClient, None, None]:
|
1786
1692
|
"""Start a new trace context using a context manager"""
|
1787
1693
|
trace_id = str(uuid.uuid4())
|
@@ -1801,8 +1707,6 @@ class Tracer:
|
|
1801
1707
|
trace_id,
|
1802
1708
|
name,
|
1803
1709
|
project_name=project,
|
1804
|
-
overwrite=overwrite,
|
1805
|
-
rules=self.rules, # Pass combined rules to the trace client
|
1806
1710
|
enable_monitoring=self.enable_monitoring,
|
1807
1711
|
enable_evaluations=self.enable_evaluations,
|
1808
1712
|
parent_trace_id=parent_trace_id,
|
@@ -1939,9 +1843,6 @@ class Tracer:
|
|
1939
1843
|
*,
|
1940
1844
|
name=None,
|
1941
1845
|
span_type: SpanType = "span",
|
1942
|
-
project_name: str | None = None,
|
1943
|
-
overwrite: bool = False,
|
1944
|
-
deep_tracing: bool | None = None,
|
1945
1846
|
):
|
1946
1847
|
"""
|
1947
1848
|
Decorator to trace function execution with detailed entry/exit information.
|
@@ -1949,11 +1850,7 @@ class Tracer:
|
|
1949
1850
|
Args:
|
1950
1851
|
func: The function to decorate
|
1951
1852
|
name: Optional custom name for the span (defaults to function name)
|
1952
|
-
span_type: Type of span (default "span")
|
1953
|
-
project_name: Optional project name override
|
1954
|
-
overwrite: Whether to overwrite existing traces
|
1955
|
-
deep_tracing: Whether to enable deep tracing for this function and all nested calls.
|
1956
|
-
If None, uses the tracer's default setting.
|
1853
|
+
span_type: Type of span (default "span").
|
1957
1854
|
"""
|
1958
1855
|
# If monitoring is disabled, return the function as is
|
1959
1856
|
try:
|
@@ -1965,9 +1862,6 @@ class Tracer:
|
|
1965
1862
|
f,
|
1966
1863
|
name=name,
|
1967
1864
|
span_type=span_type,
|
1968
|
-
project_name=project_name,
|
1969
|
-
overwrite=overwrite,
|
1970
|
-
deep_tracing=deep_tracing,
|
1971
1865
|
)
|
1972
1866
|
|
1973
1867
|
# Use provided name or fall back to function name
|
@@ -1977,10 +1871,6 @@ class Tracer:
|
|
1977
1871
|
func._judgment_span_name = original_span_name
|
1978
1872
|
func._judgment_span_type = span_type
|
1979
1873
|
|
1980
|
-
# Use the provided deep_tracing value or fall back to the tracer's default
|
1981
|
-
use_deep_tracing = (
|
1982
|
-
deep_tracing if deep_tracing is not None else self.deep_tracing
|
1983
|
-
)
|
1984
1874
|
except Exception:
|
1985
1875
|
return func
|
1986
1876
|
|
@@ -2005,9 +1895,7 @@ class Tracer:
|
|
2005
1895
|
# If there's no current trace, create a root trace
|
2006
1896
|
if not current_trace:
|
2007
1897
|
trace_id = str(uuid.uuid4())
|
2008
|
-
project =
|
2009
|
-
project_name if project_name is not None else self.project_name
|
2010
|
-
)
|
1898
|
+
project = self.project_name
|
2011
1899
|
|
2012
1900
|
# Create a new trace client to serve as the root
|
2013
1901
|
current_trace = TraceClient(
|
@@ -2015,14 +1903,10 @@ class Tracer:
|
|
2015
1903
|
trace_id,
|
2016
1904
|
span_name, # MODIFIED: Use span_name directly
|
2017
1905
|
project_name=project,
|
2018
|
-
overwrite=overwrite,
|
2019
|
-
rules=self.rules,
|
2020
1906
|
enable_monitoring=self.enable_monitoring,
|
2021
1907
|
enable_evaluations=self.enable_evaluations,
|
2022
1908
|
)
|
2023
1909
|
|
2024
|
-
# Save empty trace and set trace context
|
2025
|
-
# current_trace.save(empty_save=True, overwrite=overwrite)
|
2026
1910
|
trace_token = self.set_current_trace(current_trace)
|
2027
1911
|
|
2028
1912
|
try:
|
@@ -2043,7 +1927,7 @@ class Tracer:
|
|
2043
1927
|
)
|
2044
1928
|
|
2045
1929
|
try:
|
2046
|
-
if
|
1930
|
+
if self.deep_tracing:
|
2047
1931
|
with _DeepTracer(self):
|
2048
1932
|
result = await func(*args, **kwargs)
|
2049
1933
|
else:
|
@@ -2068,22 +1952,22 @@ class Tracer:
|
|
2068
1952
|
complete_trace_data = {
|
2069
1953
|
"trace_id": current_trace.trace_id,
|
2070
1954
|
"name": current_trace.name,
|
2071
|
-
"created_at": datetime.
|
2072
|
-
current_trace.start_time
|
1955
|
+
"created_at": datetime.fromtimestamp(
|
1956
|
+
current_trace.start_time or time.time(),
|
1957
|
+
timezone.utc,
|
2073
1958
|
).isoformat(),
|
2074
1959
|
"duration": current_trace.get_duration(),
|
2075
1960
|
"trace_spans": [
|
2076
1961
|
span.model_dump()
|
2077
1962
|
for span in current_trace.trace_spans
|
2078
1963
|
],
|
2079
|
-
"overwrite": overwrite,
|
2080
1964
|
"offline_mode": self.offline_mode,
|
2081
1965
|
"parent_trace_id": current_trace.parent_trace_id,
|
2082
1966
|
"parent_name": current_trace.parent_name,
|
2083
1967
|
}
|
2084
1968
|
# Save the completed trace
|
2085
1969
|
trace_id, server_response = current_trace.save(
|
2086
|
-
|
1970
|
+
final_save=True
|
2087
1971
|
)
|
2088
1972
|
|
2089
1973
|
# Store the complete trace data instead of just server response
|
@@ -2096,7 +1980,7 @@ class Tracer:
|
|
2096
1980
|
# Reset trace context (span context resets automatically)
|
2097
1981
|
self.reset_current_trace(trace_token)
|
2098
1982
|
except Exception as e:
|
2099
|
-
|
1983
|
+
judgeval_logger.warning(f"Issue with async_wrapper: {e}")
|
2100
1984
|
return
|
2101
1985
|
else:
|
2102
1986
|
with current_trace.span(span_name, span_type=span_type) as span:
|
@@ -2111,7 +1995,7 @@ class Tracer:
|
|
2111
1995
|
)
|
2112
1996
|
|
2113
1997
|
try:
|
2114
|
-
if
|
1998
|
+
if self.deep_tracing:
|
2115
1999
|
with _DeepTracer(self):
|
2116
2000
|
result = await func(*args, **kwargs)
|
2117
2001
|
else:
|
@@ -2148,9 +2032,7 @@ class Tracer:
|
|
2148
2032
|
# If there's no current trace, create a root trace
|
2149
2033
|
if not current_trace:
|
2150
2034
|
trace_id = str(uuid.uuid4())
|
2151
|
-
project =
|
2152
|
-
project_name if project_name is not None else self.project_name
|
2153
|
-
)
|
2035
|
+
project = self.project_name
|
2154
2036
|
|
2155
2037
|
# Create a new trace client to serve as the root
|
2156
2038
|
current_trace = TraceClient(
|
@@ -2158,14 +2040,10 @@ class Tracer:
|
|
2158
2040
|
trace_id,
|
2159
2041
|
span_name, # MODIFIED: Use span_name directly
|
2160
2042
|
project_name=project,
|
2161
|
-
overwrite=overwrite,
|
2162
|
-
rules=self.rules,
|
2163
2043
|
enable_monitoring=self.enable_monitoring,
|
2164
2044
|
enable_evaluations=self.enable_evaluations,
|
2165
2045
|
)
|
2166
2046
|
|
2167
|
-
# Save empty trace and set trace context
|
2168
|
-
# current_trace.save(empty_save=True, overwrite=overwrite)
|
2169
2047
|
trace_token = self.set_current_trace(current_trace)
|
2170
2048
|
|
2171
2049
|
try:
|
@@ -2185,7 +2063,7 @@ class Tracer:
|
|
2185
2063
|
)
|
2186
2064
|
|
2187
2065
|
try:
|
2188
|
-
if
|
2066
|
+
if self.deep_tracing:
|
2189
2067
|
with _DeepTracer(self):
|
2190
2068
|
result = func(*args, **kwargs)
|
2191
2069
|
else:
|
@@ -2209,22 +2087,22 @@ class Tracer:
|
|
2209
2087
|
try:
|
2210
2088
|
# Save the completed trace
|
2211
2089
|
trace_id, server_response = current_trace.save(
|
2212
|
-
|
2090
|
+
final_save=True
|
2213
2091
|
)
|
2214
2092
|
|
2215
2093
|
# Store the complete trace data instead of just server response
|
2216
2094
|
complete_trace_data = {
|
2217
2095
|
"trace_id": current_trace.trace_id,
|
2218
2096
|
"name": current_trace.name,
|
2219
|
-
"created_at": datetime.
|
2220
|
-
current_trace.start_time
|
2097
|
+
"created_at": datetime.fromtimestamp(
|
2098
|
+
current_trace.start_time or time.time(),
|
2099
|
+
timezone.utc,
|
2221
2100
|
).isoformat(),
|
2222
2101
|
"duration": current_trace.get_duration(),
|
2223
2102
|
"trace_spans": [
|
2224
2103
|
span.model_dump()
|
2225
2104
|
for span in current_trace.trace_spans
|
2226
2105
|
],
|
2227
|
-
"overwrite": overwrite,
|
2228
2106
|
"offline_mode": self.offline_mode,
|
2229
2107
|
"parent_trace_id": current_trace.parent_trace_id,
|
2230
2108
|
"parent_name": current_trace.parent_name,
|
@@ -2233,7 +2111,7 @@ class Tracer:
|
|
2233
2111
|
# Reset trace context (span context resets automatically)
|
2234
2112
|
self.reset_current_trace(trace_token)
|
2235
2113
|
except Exception as e:
|
2236
|
-
|
2114
|
+
judgeval_logger.warning(f"Issue with save: {e}")
|
2237
2115
|
return
|
2238
2116
|
else:
|
2239
2117
|
with current_trace.span(span_name, span_type=span_type) as span:
|
@@ -2248,7 +2126,7 @@ class Tracer:
|
|
2248
2126
|
)
|
2249
2127
|
|
2250
2128
|
try:
|
2251
|
-
if
|
2129
|
+
if self.deep_tracing:
|
2252
2130
|
with _DeepTracer(self):
|
2253
2131
|
result = func(*args, **kwargs)
|
2254
2132
|
else:
|
@@ -2309,8 +2187,8 @@ class Tracer:
|
|
2309
2187
|
if hasattr(method, "_judgment_span_name"):
|
2310
2188
|
skipped.append(name)
|
2311
2189
|
if warn_on_double_decoration:
|
2312
|
-
|
2313
|
-
f"
|
2190
|
+
judgeval_logger.info(
|
2191
|
+
f"{cls.__name__}.{name} already decorated, skipping"
|
2314
2192
|
)
|
2315
2193
|
continue
|
2316
2194
|
|
@@ -2320,7 +2198,9 @@ class Tracer:
|
|
2320
2198
|
decorated.append(name)
|
2321
2199
|
except Exception as e:
|
2322
2200
|
if warn_on_double_decoration:
|
2323
|
-
|
2201
|
+
judgeval_logger.warning(
|
2202
|
+
f"Failed to decorate {cls.__name__}.{name}: {e}"
|
2203
|
+
)
|
2324
2204
|
|
2325
2205
|
return cls
|
2326
2206
|
|
@@ -2347,11 +2227,11 @@ class Tracer:
|
|
2347
2227
|
)
|
2348
2228
|
current_trace.async_evaluate(*args, **kwargs)
|
2349
2229
|
else:
|
2350
|
-
|
2230
|
+
judgeval_logger.warning(
|
2351
2231
|
"No trace found (context var or fallback), skipping evaluation"
|
2352
2232
|
) # Modified warning
|
2353
2233
|
except Exception as e:
|
2354
|
-
|
2234
|
+
judgeval_logger.warning(f"Issue with async_evaluate: {e}")
|
2355
2235
|
|
2356
2236
|
def update_metadata(self, metadata: dict):
|
2357
2237
|
"""
|
@@ -2364,7 +2244,7 @@ class Tracer:
|
|
2364
2244
|
if current_trace:
|
2365
2245
|
current_trace.update_metadata(metadata)
|
2366
2246
|
else:
|
2367
|
-
|
2247
|
+
judgeval_logger.warning("No current trace found, cannot set metadata")
|
2368
2248
|
|
2369
2249
|
def set_customer_id(self, customer_id: str):
|
2370
2250
|
"""
|
@@ -2377,7 +2257,7 @@ class Tracer:
|
|
2377
2257
|
if current_trace:
|
2378
2258
|
current_trace.set_customer_id(customer_id)
|
2379
2259
|
else:
|
2380
|
-
|
2260
|
+
judgeval_logger.warning("No current trace found, cannot set customer ID")
|
2381
2261
|
|
2382
2262
|
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
2383
2263
|
"""
|
@@ -2390,7 +2270,7 @@ class Tracer:
|
|
2390
2270
|
if current_trace:
|
2391
2271
|
current_trace.set_tags(tags)
|
2392
2272
|
else:
|
2393
|
-
|
2273
|
+
judgeval_logger.warning("No current trace found, cannot set tags")
|
2394
2274
|
|
2395
2275
|
def get_background_span_service(self) -> Optional[BackgroundSpanService]:
|
2396
2276
|
"""Get the background span service instance."""
|
@@ -2447,7 +2327,7 @@ def wrap(
|
|
2447
2327
|
# Warn about token counting limitations with streaming
|
2448
2328
|
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
2449
2329
|
if not kwargs.get("stream_options", {}).get("include_usage"):
|
2450
|
-
|
2330
|
+
judgeval_logger.warning(
|
2451
2331
|
"OpenAI streaming calls don't include token counts by default. "
|
2452
2332
|
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
2453
2333
|
"in your API call arguments.",
|
@@ -2746,19 +2626,25 @@ def _format_response_output_data(client: ApiClient, response: Any) -> tuple:
|
|
2746
2626
|
prompt_tokens = 0
|
2747
2627
|
completion_tokens = 0
|
2748
2628
|
model_name = None
|
2749
|
-
|
2629
|
+
cache_read_input_tokens = 0
|
2630
|
+
cache_creation_input_tokens = 0
|
2631
|
+
|
2632
|
+
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
2750
2633
|
model_name = response.model
|
2751
2634
|
prompt_tokens = response.usage.input_tokens
|
2752
2635
|
completion_tokens = response.usage.output_tokens
|
2753
|
-
|
2636
|
+
cache_read_input_tokens = response.usage.input_tokens_details.cached_tokens
|
2637
|
+
message_content = "".join(seg.text for seg in response.output[0].content)
|
2754
2638
|
else:
|
2755
|
-
|
2639
|
+
judgeval_logger.warning(f"Unsupported client type: {type(client)}")
|
2756
2640
|
return None, None
|
2757
2641
|
|
2758
2642
|
prompt_cost, completion_cost = cost_per_token(
|
2759
2643
|
model=model_name,
|
2760
2644
|
prompt_tokens=prompt_tokens,
|
2761
2645
|
completion_tokens=completion_tokens,
|
2646
|
+
cache_read_input_tokens=cache_read_input_tokens,
|
2647
|
+
cache_creation_input_tokens=cache_creation_input_tokens,
|
2762
2648
|
)
|
2763
2649
|
total_cost_usd = (
|
2764
2650
|
(prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
@@ -2767,6 +2653,8 @@ def _format_response_output_data(client: ApiClient, response: Any) -> tuple:
|
|
2767
2653
|
prompt_tokens=prompt_tokens,
|
2768
2654
|
completion_tokens=completion_tokens,
|
2769
2655
|
total_tokens=prompt_tokens + completion_tokens,
|
2656
|
+
cache_read_input_tokens=cache_read_input_tokens,
|
2657
|
+
cache_creation_input_tokens=cache_creation_input_tokens,
|
2770
2658
|
prompt_tokens_cost_usd=prompt_cost,
|
2771
2659
|
completion_tokens_cost_usd=completion_cost,
|
2772
2660
|
total_cost_usd=total_cost_usd,
|
@@ -2790,13 +2678,28 @@ def _format_output_data(
|
|
2790
2678
|
"""
|
2791
2679
|
prompt_tokens = 0
|
2792
2680
|
completion_tokens = 0
|
2681
|
+
cache_read_input_tokens = 0
|
2682
|
+
cache_creation_input_tokens = 0
|
2793
2683
|
model_name = None
|
2794
2684
|
message_content = None
|
2795
2685
|
|
2796
|
-
if isinstance(client, (OpenAI,
|
2686
|
+
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
2797
2687
|
model_name = response.model
|
2798
2688
|
prompt_tokens = response.usage.prompt_tokens
|
2799
2689
|
completion_tokens = response.usage.completion_tokens
|
2690
|
+
cache_read_input_tokens = response.usage.prompt_tokens_details.cached_tokens
|
2691
|
+
if (
|
2692
|
+
hasattr(response.choices[0].message, "parsed")
|
2693
|
+
and response.choices[0].message.parsed
|
2694
|
+
):
|
2695
|
+
message_content = response.choices[0].message.parsed
|
2696
|
+
else:
|
2697
|
+
message_content = response.choices[0].message.content
|
2698
|
+
|
2699
|
+
elif isinstance(client, (Together, AsyncTogether)):
|
2700
|
+
model_name = "together_ai/" + response.model
|
2701
|
+
prompt_tokens = response.usage.prompt_tokens
|
2702
|
+
completion_tokens = response.usage.completion_tokens
|
2800
2703
|
if (
|
2801
2704
|
hasattr(response.choices[0].message, "parsed")
|
2802
2705
|
and response.choices[0].message.parsed
|
@@ -2813,15 +2716,19 @@ def _format_output_data(
|
|
2813
2716
|
model_name = response.model
|
2814
2717
|
prompt_tokens = response.usage.input_tokens
|
2815
2718
|
completion_tokens = response.usage.output_tokens
|
2719
|
+
cache_read_input_tokens = response.usage.cache_read_input_tokens
|
2720
|
+
cache_creation_input_tokens = response.usage.cache_creation_input_tokens
|
2816
2721
|
message_content = response.content[0].text
|
2817
2722
|
else:
|
2818
|
-
|
2723
|
+
judgeval_logger.warning(f"Unsupported client type: {type(client)}")
|
2819
2724
|
return None, None
|
2820
2725
|
|
2821
2726
|
prompt_cost, completion_cost = cost_per_token(
|
2822
2727
|
model=model_name,
|
2823
2728
|
prompt_tokens=prompt_tokens,
|
2824
2729
|
completion_tokens=completion_tokens,
|
2730
|
+
cache_read_input_tokens=cache_read_input_tokens,
|
2731
|
+
cache_creation_input_tokens=cache_creation_input_tokens,
|
2825
2732
|
)
|
2826
2733
|
total_cost_usd = (
|
2827
2734
|
(prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
@@ -2830,6 +2737,8 @@ def _format_output_data(
|
|
2830
2737
|
prompt_tokens=prompt_tokens,
|
2831
2738
|
completion_tokens=completion_tokens,
|
2832
2739
|
total_tokens=prompt_tokens + completion_tokens,
|
2740
|
+
cache_read_input_tokens=cache_read_input_tokens,
|
2741
|
+
cache_creation_input_tokens=cache_creation_input_tokens,
|
2833
2742
|
prompt_tokens_cost_usd=prompt_cost,
|
2834
2743
|
completion_tokens_cost_usd=completion_cost,
|
2835
2744
|
total_cost_usd=total_cost_usd,
|
@@ -2969,8 +2878,11 @@ def _extract_usage_from_final_chunk(
|
|
2969
2878
|
prompt_tokens = chunk.choices[0].usage.prompt_tokens
|
2970
2879
|
completion_tokens = chunk.choices[0].usage.completion_tokens
|
2971
2880
|
|
2881
|
+
if isinstance(client, (Together, AsyncTogether)):
|
2882
|
+
model_name = "together_ai/" + chunk.model
|
2883
|
+
|
2972
2884
|
prompt_cost, completion_cost = cost_per_token(
|
2973
|
-
model=
|
2885
|
+
model=model_name,
|
2974
2886
|
prompt_tokens=prompt_tokens,
|
2975
2887
|
completion_tokens=completion_tokens,
|
2976
2888
|
)
|
@@ -3161,9 +3073,17 @@ async def _async_stream_wrapper(
|
|
3161
3073
|
|
3162
3074
|
def cost_per_token(*args, **kwargs):
|
3163
3075
|
try:
|
3164
|
-
|
3076
|
+
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = (
|
3077
|
+
_original_cost_per_token(*args, **kwargs)
|
3078
|
+
)
|
3079
|
+
if (
|
3080
|
+
prompt_tokens_cost_usd_dollar == 0
|
3081
|
+
and completion_tokens_cost_usd_dollar == 0
|
3082
|
+
):
|
3083
|
+
judgeval_logger.warning("LiteLLM returned a total of 0 for cost per token")
|
3084
|
+
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
3165
3085
|
except Exception as e:
|
3166
|
-
|
3086
|
+
judgeval_logger.warning(f"Error calculating cost per token: {e}")
|
3167
3087
|
return None, None
|
3168
3088
|
|
3169
3089
|
|