judgeval 0.0.52__py3-none-any.whl → 0.0.53__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. judgeval/common/logger.py +46 -199
  2. judgeval/common/s3_storage.py +2 -6
  3. judgeval/common/tracer.py +182 -262
  4. judgeval/common/utils.py +16 -36
  5. judgeval/constants.py +14 -20
  6. judgeval/data/__init__.py +0 -2
  7. judgeval/data/datasets/dataset.py +6 -10
  8. judgeval/data/datasets/eval_dataset_client.py +25 -27
  9. judgeval/data/example.py +5 -138
  10. judgeval/data/judgment_types.py +214 -0
  11. judgeval/data/result.py +7 -25
  12. judgeval/data/scorer_data.py +28 -40
  13. judgeval/data/scripts/fix_default_factory.py +23 -0
  14. judgeval/data/scripts/openapi_transform.py +123 -0
  15. judgeval/data/tool.py +3 -54
  16. judgeval/data/trace.py +31 -50
  17. judgeval/data/trace_run.py +3 -3
  18. judgeval/evaluation_run.py +16 -23
  19. judgeval/integrations/langgraph.py +11 -12
  20. judgeval/judges/litellm_judge.py +3 -6
  21. judgeval/judges/mixture_of_judges.py +8 -25
  22. judgeval/judges/together_judge.py +3 -6
  23. judgeval/judgment_client.py +22 -24
  24. judgeval/rules.py +7 -19
  25. judgeval/run_evaluation.py +79 -242
  26. judgeval/scorers/__init__.py +4 -20
  27. judgeval/scorers/agent_scorer.py +21 -0
  28. judgeval/scorers/api_scorer.py +28 -38
  29. judgeval/scorers/base_scorer.py +98 -0
  30. judgeval/scorers/example_scorer.py +19 -0
  31. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +0 -20
  32. judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +10 -17
  33. judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +9 -24
  34. judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +16 -68
  35. judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +4 -12
  36. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +4 -4
  37. judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +10 -17
  38. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +4 -4
  39. judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +4 -4
  40. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +4 -4
  41. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +18 -14
  42. judgeval/scorers/score.py +45 -330
  43. judgeval/scorers/utils.py +6 -88
  44. judgeval/utils/file_utils.py +4 -6
  45. judgeval/version_check.py +3 -2
  46. {judgeval-0.0.52.dist-info → judgeval-0.0.53.dist-info}/METADATA +2 -1
  47. judgeval-0.0.53.dist-info/RECORD +65 -0
  48. judgeval/data/custom_example.py +0 -19
  49. judgeval/scorers/judgeval_scorer.py +0 -177
  50. judgeval/scorers/judgeval_scorers/api_scorers/comparison.py +0 -45
  51. judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +0 -29
  52. judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +0 -29
  53. judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +0 -32
  54. judgeval/scorers/judgeval_scorers/api_scorers/groundedness.py +0 -28
  55. judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +0 -38
  56. judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +0 -27
  57. judgeval/scorers/prompt_scorer.py +0 -296
  58. judgeval-0.0.52.dist-info/RECORD +0 -69
  59. {judgeval-0.0.52.dist-info → judgeval-0.0.53.dist-info}/WHEEL +0 -0
  60. {judgeval-0.0.52.dist-info → judgeval-0.0.53.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py CHANGED
@@ -13,7 +13,6 @@ import threading
13
13
  import time
14
14
  import traceback
15
15
  import uuid
16
- import warnings
17
16
  import contextvars
18
17
  import sys
19
18
  import json
@@ -23,7 +22,7 @@ from contextlib import (
23
22
  AbstractContextManager,
24
23
  ) # Import context manager bases
25
24
  from dataclasses import dataclass
26
- from datetime import datetime
25
+ from datetime import datetime, timezone
27
26
  from http import HTTPStatus
28
27
  from typing import (
29
28
  Any,
@@ -53,7 +52,6 @@ from google import genai
53
52
  # Local application/library-specific imports
54
53
  from judgeval.constants import (
55
54
  JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
56
- JUDGMENT_TRACES_SAVE_API_URL,
57
55
  JUDGMENT_TRACES_UPSERT_API_URL,
58
56
  JUDGMENT_TRACES_FETCH_API_URL,
59
57
  JUDGMENT_TRACES_DELETE_API_URL,
@@ -62,11 +60,10 @@ from judgeval.constants import (
62
60
  JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
63
61
  )
64
62
  from judgeval.data import Example, Trace, TraceSpan, TraceUsage
65
- from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
66
- from judgeval.rules import Rule
63
+ from judgeval.scorers import APIScorerConfig, BaseScorer
67
64
  from judgeval.evaluation_run import EvaluationRun
68
65
  from judgeval.common.utils import ExcInfo, validate_api_key
69
- from judgeval.common.exceptions import JudgmentAPIError
66
+ from judgeval.common.logger import judgeval_logger
70
67
 
71
68
  # Standard library imports needed for the new class
72
69
  import concurrent.futures
@@ -157,80 +154,6 @@ class TraceManagerClient:
157
154
 
158
155
  return response.json()
159
156
 
160
- def save_trace(
161
- self, trace_data: dict, offline_mode: bool = False, final_save: bool = True
162
- ):
163
- """
164
- Saves a trace to the Judgment Supabase and optionally to S3 if configured.
165
-
166
- Args:
167
- trace_data: The trace data to save
168
- offline_mode: Whether running in offline mode
169
- final_save: Whether this is the final save (controls S3 saving)
170
- NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
171
-
172
- Returns:
173
- dict: Server response containing UI URL and other metadata
174
- """
175
- # Save to Judgment API
176
-
177
- def fallback_encoder(obj):
178
- """
179
- Custom JSON encoder fallback.
180
- Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
181
- You can choose which one you prefer or try them in sequence.
182
- """
183
- try:
184
- # Option 1: Prefer __repr__ for a more detailed representation
185
- return repr(obj)
186
- except Exception:
187
- # Option 2: Fallback to str() if __repr__ fails or if you prefer str()
188
- try:
189
- return str(obj)
190
- except Exception as e:
191
- # If both fail, you might return a placeholder or re-raise
192
- return f"<Unserializable object of type {type(obj).__name__}: {e}>"
193
-
194
- serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
195
- response = requests.post(
196
- JUDGMENT_TRACES_SAVE_API_URL,
197
- data=serialized_trace_data,
198
- headers={
199
- "Content-Type": "application/json",
200
- "Authorization": f"Bearer {self.judgment_api_key}",
201
- "X-Organization-Id": self.organization_id,
202
- },
203
- verify=True,
204
- )
205
-
206
- if response.status_code == HTTPStatus.BAD_REQUEST:
207
- raise ValueError(
208
- f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}"
209
- )
210
- elif response.status_code != HTTPStatus.OK:
211
- raise ValueError(f"Failed to save trace data: {response.text}")
212
-
213
- # Parse server response
214
- server_response = response.json()
215
-
216
- # If S3 storage is enabled, save to S3 only on final save
217
- if self.tracer and self.tracer.use_s3 and final_save:
218
- try:
219
- s3_key = self.tracer.s3_storage.save_trace(
220
- trace_data=trace_data,
221
- trace_id=trace_data["trace_id"],
222
- project_name=trace_data["project_name"],
223
- )
224
- print(f"Trace also saved to S3 at key: {s3_key}")
225
- except Exception as e:
226
- warnings.warn(f"Failed to save trace to S3: {str(e)}")
227
-
228
- if not offline_mode and "ui_results_url" in server_response:
229
- pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
230
- rprint(pretty_str)
231
-
232
- return server_response
233
-
234
157
  def upsert_trace(
235
158
  self,
236
159
  trace_data: dict,
@@ -291,9 +214,9 @@ class TraceManagerClient:
291
214
  trace_id=trace_data["trace_id"],
292
215
  project_name=trace_data["project_name"],
293
216
  )
294
- print(f"Trace also saved to S3 at key: {s3_key}")
217
+ judgeval_logger.info(f"Trace also saved to S3 at key: {s3_key}")
295
218
  except Exception as e:
296
- warnings.warn(f"Failed to save trace to S3: {str(e)}")
219
+ judgeval_logger.warning(f"Failed to save trace to S3: {str(e)}")
297
220
 
298
221
  if not offline_mode and show_link and "ui_results_url" in server_response:
299
222
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
@@ -401,8 +324,6 @@ class TraceClient:
401
324
  trace_id: Optional[str] = None,
402
325
  name: str = "default",
403
326
  project_name: str | None = None,
404
- overwrite: bool = False,
405
- rules: Optional[List[Rule]] = None,
406
327
  enable_monitoring: bool = True,
407
328
  enable_evaluations: bool = True,
408
329
  parent_trace_id: Optional[str] = None,
@@ -410,10 +331,8 @@ class TraceClient:
410
331
  ):
411
332
  self.name = name
412
333
  self.trace_id = trace_id or str(uuid.uuid4())
413
- self.project_name = project_name or str(uuid.uuid4())
414
- self.overwrite = overwrite
334
+ self.project_name = project_name or "default_project"
415
335
  self.tracer = tracer
416
- self.rules = rules or []
417
336
  self.enable_monitoring = enable_monitoring
418
337
  self.enable_evaluations = enable_evaluations
419
338
  self.parent_trace_id = parent_trace_id
@@ -422,6 +341,7 @@ class TraceClient:
422
341
  self.tags: List[Union[str, set, tuple]] = [] # Added tags attribute
423
342
  self.metadata: Dict[str, Any] = {}
424
343
  self.has_notification: Optional[bool] = False # Initialize has_notification
344
+ self.update_id: int = 1 # Initialize update_id to 1, increments with each save
425
345
  self.trace_spans: List[TraceSpan] = []
426
346
  self.span_id_to_span: Dict[str, TraceSpan] = {}
427
347
  self.evaluation_runs: List[EvaluationRun] = []
@@ -457,15 +377,13 @@ class TraceClient:
457
377
  is_first_span = len(self.trace_spans) == 0
458
378
  if is_first_span:
459
379
  try:
460
- trace_id, server_response = self.save(
461
- overwrite=self.overwrite, final_save=False
462
- )
380
+ trace_id, server_response = self.save(final_save=False)
463
381
  # Set start_time after first successful save
464
- if self.start_time is None:
465
- self.start_time = time.time()
466
382
  # Link will be shown by upsert_trace method
467
383
  except Exception as e:
468
- warnings.warn(f"Failed to save initial trace for live tracking: {e}")
384
+ judgeval_logger.warning(
385
+ f"Failed to save initial trace for live tracking: {e}"
386
+ )
469
387
  start_time = time.time()
470
388
 
471
389
  # Generate a unique ID for *this specific span invocation*
@@ -518,7 +436,7 @@ class TraceClient:
518
436
 
519
437
  def async_evaluate(
520
438
  self,
521
- scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
439
+ scorers: List[Union[APIScorerConfig, BaseScorer]],
522
440
  example: Optional[Example] = None,
523
441
  input: Optional[str] = None,
524
442
  actual_output: Optional[Union[str, List[str]]] = None,
@@ -539,19 +457,11 @@ class TraceClient:
539
457
  try:
540
458
  # Load appropriate implementations for all scorers
541
459
  if not scorers:
542
- warnings.warn("No valid scorers available for evaluation")
460
+ judgeval_logger.warning("No valid scorers available for evaluation")
543
461
  return
544
462
 
545
- # Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
546
- if self.rules and any(
547
- isinstance(scorer, JudgevalScorer) for scorer in scorers
548
- ):
549
- raise ValueError(
550
- "Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types."
551
- )
552
-
553
463
  except Exception as e:
554
- warnings.warn(f"Failed to load scorers: {str(e)}")
464
+ judgeval_logger.warning(f"Failed to load scorers: {str(e)}")
555
465
  return
556
466
 
557
467
  # If example is not provided, create one from the individual parameters
@@ -589,15 +499,8 @@ class TraceClient:
589
499
 
590
500
  # check_examples([example], scorers)
591
501
 
592
- # --- Modification: Capture span_id immediately ---
593
- # span_id_at_eval_call = current_span_var.get()
594
- # print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
595
- # Prioritize explicitly passed span_id, fallback to context var
596
502
  span_id_to_use = span_id if span_id is not None else self.get_current_span()
597
- # print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
598
- # --- End Modification ---
599
503
 
600
- # Combine the trace-level rules with any evaluation-specific rules)
601
504
  eval_run = EvaluationRun(
602
505
  organization_id=self.tracer.organization_id,
603
506
  project_name=self.project_name,
@@ -608,7 +511,6 @@ class TraceClient:
608
511
  scorers=scorers,
609
512
  model=model,
610
513
  judgment_api_key=self.tracer.api_key,
611
- override=self.overwrite,
612
514
  trace_span_id=span_id_to_use,
613
515
  )
614
516
 
@@ -631,7 +533,6 @@ class TraceClient:
631
533
 
632
534
  if current_span_id:
633
535
  span = self.span_id_to_span[current_span_id]
634
- span.evaluation_runs.append(eval_run)
635
536
  span.has_evaluation = True # Set the has_evaluation flag
636
537
  self.evaluation_runs.append(eval_run)
637
538
 
@@ -654,7 +555,7 @@ class TraceClient:
654
555
  if self.background_span_service:
655
556
  self.background_span_service.queue_span(span, span_state="input")
656
557
  except Exception as e:
657
- warnings.warn(f"Failed to queue span with input data: {e}")
558
+ judgeval_logger.warning(f"Failed to queue span with input data: {e}")
658
559
 
659
560
  def record_agent_name(self, agent_name: str):
660
561
  current_span_id = self.get_current_span()
@@ -779,15 +680,12 @@ class TraceClient:
779
680
  return 0.0 # No duration if trace hasn't been saved yet
780
681
  return time.time() - self.start_time
781
682
 
782
- def save(
783
- self, overwrite: bool = False, final_save: bool = False
784
- ) -> Tuple[str, dict]:
683
+ def save(self, final_save: bool = False) -> Tuple[str, dict]:
785
684
  """
786
685
  Save the current trace to the database with rate limiting checks.
787
686
  First checks usage limits, then upserts the trace if allowed.
788
687
 
789
688
  Args:
790
- overwrite: Whether to overwrite existing traces
791
689
  final_save: Whether this is the final save (updates usage counters)
792
690
 
793
691
  Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
@@ -801,17 +699,19 @@ class TraceClient:
801
699
  "trace_id": self.trace_id,
802
700
  "name": self.name,
803
701
  "project_name": self.project_name,
804
- "created_at": datetime.utcfromtimestamp(time.time()).isoformat(),
702
+ "created_at": datetime.fromtimestamp(
703
+ self.start_time or time.time(), timezone.utc
704
+ ).isoformat(),
805
705
  "duration": total_duration,
806
706
  "trace_spans": [span.model_dump() for span in self.trace_spans],
807
707
  "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
808
- "overwrite": overwrite,
809
708
  "offline_mode": self.tracer.offline_mode,
810
709
  "parent_trace_id": self.parent_trace_id,
811
710
  "parent_name": self.parent_name,
812
711
  "customer_id": self.customer_id,
813
712
  "tags": self.tags,
814
713
  "metadata": self.metadata,
714
+ "update_id": self.update_id,
815
715
  }
816
716
 
817
717
  # If usage check passes, upsert the trace
@@ -822,12 +722,15 @@ class TraceClient:
822
722
  final_save=final_save, # Pass final_save to control S3 saving
823
723
  )
824
724
 
725
+ if self.start_time is None:
726
+ self.start_time = time.time()
727
+
728
+ self.update_id += 1
729
+
825
730
  # Upload annotations
826
731
  # TODO: batch to the log endpoint
827
732
  for annotation in self.annotations:
828
733
  self.trace_manager_client.save_annotation(annotation)
829
- if self.start_time is None:
830
- self.start_time = time.time()
831
734
  return self.trace_id, server_response
832
735
 
833
736
  def delete(self):
@@ -844,7 +747,6 @@ class TraceClient:
844
747
  - customer_id: ID of the customer using this trace
845
748
  - tags: List of tags for this trace
846
749
  - has_notification: Whether this trace has a notification
847
- - overwrite: Whether to overwrite existing traces
848
750
  - name: Name of the trace
849
751
  """
850
752
  for k, v in metadata.items():
@@ -872,10 +774,6 @@ class TraceClient:
872
774
  f"has_notification must be a boolean, got {type(v)}"
873
775
  )
874
776
  self.has_notification = v
875
- elif k == "overwrite":
876
- if not isinstance(v, bool):
877
- raise ValueError(f"overwrite must be a boolean, got {type(v)}")
878
- self.overwrite = v
879
777
  elif k == "name":
880
778
  self.name = v
881
779
  else:
@@ -1056,7 +954,7 @@ class BackgroundSpanService:
1056
954
  last_flush_time = current_time
1057
955
 
1058
956
  except Exception as e:
1059
- warnings.warn(f"Error in span service worker loop: {e}")
957
+ judgeval_logger.warning(f"Error in span service worker loop: {e}")
1060
958
  # On error, still need to mark tasks as done to prevent hanging
1061
959
  for _ in range(pending_task_count):
1062
960
  self._span_queue.task_done()
@@ -1100,7 +998,7 @@ class BackgroundSpanService:
1100
998
  self._send_evaluation_runs_batch(evaluation_runs_to_send)
1101
999
 
1102
1000
  except Exception as e:
1103
- warnings.warn(f"Failed to send batch: {e}")
1001
+ judgeval_logger.warning(f"Failed to send batch: {e}")
1104
1002
 
1105
1003
  def _send_spans_batch(self, spans: List[Dict[str, Any]]):
1106
1004
  """Send a batch of spans to the spans endpoint."""
@@ -1133,14 +1031,14 @@ class BackgroundSpanService:
1133
1031
  )
1134
1032
 
1135
1033
  if response.status_code != HTTPStatus.OK:
1136
- warnings.warn(
1034
+ judgeval_logger.warning(
1137
1035
  f"Failed to send spans batch: HTTP {response.status_code} - {response.text}"
1138
1036
  )
1139
1037
 
1140
1038
  except RequestException as e:
1141
- warnings.warn(f"Network error sending spans batch: {e}")
1039
+ judgeval_logger.warning(f"Network error sending spans batch: {e}")
1142
1040
  except Exception as e:
1143
- warnings.warn(f"Failed to serialize or send spans batch: {e}")
1041
+ judgeval_logger.warning(f"Failed to serialize or send spans batch: {e}")
1144
1042
 
1145
1043
  def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
1146
1044
  """Send a batch of evaluation runs with their associated span data to the endpoint."""
@@ -1195,14 +1093,14 @@ class BackgroundSpanService:
1195
1093
  )
1196
1094
 
1197
1095
  if response.status_code != HTTPStatus.OK:
1198
- warnings.warn(
1096
+ judgeval_logger.warning(
1199
1097
  f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}"
1200
1098
  )
1201
1099
 
1202
1100
  except RequestException as e:
1203
- warnings.warn(f"Network error sending evaluation runs batch: {e}")
1101
+ judgeval_logger.warning(f"Network error sending evaluation runs batch: {e}")
1204
1102
  except Exception as e:
1205
- warnings.warn(f"Failed to send evaluation runs batch: {e}")
1103
+ judgeval_logger.warning(f"Failed to send evaluation runs batch: {e}")
1206
1104
 
1207
1105
  def queue_span(self, span: TraceSpan, span_state: str = "input"):
1208
1106
  """
@@ -1213,6 +1111,8 @@ class BackgroundSpanService:
1213
1111
  span_state: State of the span ("input", "output", "completed")
1214
1112
  """
1215
1113
  if not self._shutdown_event.is_set():
1114
+ # Increment update_id when queueing the span
1115
+ span.increment_update_id()
1216
1116
  span_data = {
1217
1117
  "type": "span",
1218
1118
  "data": {
@@ -1252,7 +1152,7 @@ class BackgroundSpanService:
1252
1152
  # Wait for the queue to be processed
1253
1153
  self._span_queue.join()
1254
1154
  except Exception as e:
1255
- warnings.warn(f"Error during flush: {e}")
1155
+ judgeval_logger.warning(f"Error during flush: {e}")
1256
1156
 
1257
1157
  def shutdown(self):
1258
1158
  """Shutdown the background service and flush remaining spans."""
@@ -1267,9 +1167,9 @@ class BackgroundSpanService:
1267
1167
  try:
1268
1168
  self.flush()
1269
1169
  except Exception as e:
1270
- warnings.warn(f"Error during final flush: {e}")
1170
+ judgeval_logger.warning(f"Error during final flush: {e}")
1271
1171
  except Exception as e:
1272
- warnings.warn(f"Error during BackgroundSpanService shutdown: {e}")
1172
+ judgeval_logger.warning(f"Error during BackgroundSpanService shutdown: {e}")
1273
1173
  finally:
1274
1174
  # Clear the worker threads list (daemon threads will be killed automatically)
1275
1175
  self._worker_threads.clear()
@@ -1601,9 +1501,9 @@ class Tracer:
1601
1501
  def __init__(
1602
1502
  self,
1603
1503
  api_key: str | None = os.getenv("JUDGMENT_API_KEY"),
1604
- project_name: str | None = None,
1605
- rules: Optional[List[Rule]] = None, # Added rules parameter
1606
1504
  organization_id: str | None = os.getenv("JUDGMENT_ORG_ID"),
1505
+ project_name: str | None = None,
1506
+ deep_tracing: bool = False, # Deep tracing is disabled by default
1607
1507
  enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
1608
1508
  == "true",
1609
1509
  enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
@@ -1614,73 +1514,77 @@ class Tracer:
1614
1514
  s3_aws_access_key_id: Optional[str] = None,
1615
1515
  s3_aws_secret_access_key: Optional[str] = None,
1616
1516
  s3_region_name: Optional[str] = None,
1617
- offline_mode: bool = False,
1618
- deep_tracing: bool = False, # Deep tracing is disabled by default
1619
1517
  trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
1620
1518
  # Background span service configuration
1621
- enable_background_spans: bool = True, # Enable background span service by default
1622
1519
  span_batch_size: int = 50, # Number of spans to batch before sending
1623
1520
  span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
1624
1521
  span_num_workers: int = 10, # Number of worker threads for span processing
1625
1522
  ):
1626
- if not api_key:
1627
- raise ValueError("Tracer must be configured with a Judgment API key")
1628
-
1629
1523
  try:
1630
- result, response = validate_api_key(api_key)
1631
- except Exception as e:
1632
- print(f"Issue with verifying API key, disabling monitoring: {e}")
1633
- enable_monitoring = False
1634
- result = True
1635
-
1636
- if not result:
1637
- raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
1638
-
1639
- if not organization_id:
1640
- raise ValueError("Tracer must be configured with an Organization ID")
1641
- if use_s3 and not s3_bucket_name:
1642
- raise ValueError("S3 bucket name must be provided when use_s3 is True")
1643
-
1644
- self.api_key: str = api_key
1645
- self.project_name: str = project_name or str(uuid.uuid4())
1646
- self.organization_id: str = organization_id
1647
- self.rules: List[Rule] = rules or [] # Store rules at tracer level
1648
- self.traces: List[Trace] = []
1649
- self.enable_monitoring: bool = enable_monitoring
1650
- self.enable_evaluations: bool = enable_evaluations
1651
- self.class_identifiers: Dict[
1652
- str, str
1653
- ] = {} # Dictionary to store class identifiers
1654
- self.span_id_to_previous_span_id: Dict[str, str | None] = {}
1655
- self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
1656
- self.current_span_id: Optional[str] = None
1657
- self.current_trace: Optional[TraceClient] = None
1658
- self.trace_across_async_contexts: bool = trace_across_async_contexts
1659
- Tracer.trace_across_async_contexts = trace_across_async_contexts
1660
-
1661
- # Initialize S3 storage if enabled
1662
- self.use_s3 = use_s3
1663
- if use_s3:
1664
- from judgeval.common.s3_storage import S3Storage
1524
+ if not api_key:
1525
+ raise ValueError(
1526
+ "api_key parameter must be provided. Please provide a valid API key value or set the JUDGMENT_API_KEY environment variable"
1527
+ )
1665
1528
 
1666
- try:
1667
- self.s3_storage = S3Storage(
1668
- bucket_name=s3_bucket_name,
1669
- aws_access_key_id=s3_aws_access_key_id,
1670
- aws_secret_access_key=s3_aws_secret_access_key,
1671
- region_name=s3_region_name,
1529
+ if not organization_id:
1530
+ raise ValueError(
1531
+ "organization_id parameter must be provided. Please provide a valid organization ID value or set the JUDGMENT_ORG_ID environment variable"
1672
1532
  )
1533
+
1534
+ try:
1535
+ result, response = validate_api_key(api_key)
1673
1536
  except Exception as e:
1674
- print(f"Issue with initializing S3 storage, disabling S3: {e}")
1675
- self.use_s3 = False
1537
+ judgeval_logger.error(
1538
+ f"Issue with verifying API key, disabling monitoring: {e}"
1539
+ )
1540
+ enable_monitoring = False
1541
+ result = True
1542
+
1543
+ if not result:
1544
+ raise ValueError(f"Issue with passed in Judgment API key: {response}")
1545
+
1546
+ if use_s3 and not s3_bucket_name:
1547
+ raise ValueError("S3 bucket name must be provided when use_s3 is True")
1548
+
1549
+ self.api_key: str = api_key
1550
+ self.project_name: str = project_name or "default_project"
1551
+ self.organization_id: str = organization_id
1552
+ self.traces: List[Trace] = []
1553
+ self.enable_monitoring: bool = enable_monitoring
1554
+ self.enable_evaluations: bool = enable_evaluations
1555
+ self.class_identifiers: Dict[
1556
+ str, str
1557
+ ] = {} # Dictionary to store class identifiers
1558
+ self.span_id_to_previous_span_id: Dict[str, str | None] = {}
1559
+ self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
1560
+ self.current_span_id: Optional[str] = None
1561
+ self.current_trace: Optional[TraceClient] = None
1562
+ self.trace_across_async_contexts: bool = trace_across_async_contexts
1563
+ Tracer.trace_across_async_contexts = trace_across_async_contexts
1564
+
1565
+ # Initialize S3 storage if enabled
1566
+ self.use_s3 = use_s3
1567
+ if use_s3:
1568
+ from judgeval.common.s3_storage import S3Storage
1676
1569
 
1677
- self.offline_mode: bool = offline_mode
1678
- self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
1570
+ try:
1571
+ self.s3_storage = S3Storage(
1572
+ bucket_name=s3_bucket_name,
1573
+ aws_access_key_id=s3_aws_access_key_id,
1574
+ aws_secret_access_key=s3_aws_secret_access_key,
1575
+ region_name=s3_region_name,
1576
+ )
1577
+ except Exception as e:
1578
+ judgeval_logger.error(
1579
+ f"Issue with initializing S3 storage, disabling S3: {e}"
1580
+ )
1581
+ self.use_s3 = False
1582
+
1583
+ self.offline_mode = False # This is used to differentiate traces between online and offline (IE experiments vs monitoring page)
1584
+ self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
1679
1585
 
1680
- # Initialize background span service
1681
- self.enable_background_spans: bool = enable_background_spans
1682
- self.background_span_service: Optional[BackgroundSpanService] = None
1683
- if enable_background_spans and not offline_mode:
1586
+ # Initialize background span service
1587
+ self.background_span_service: Optional[BackgroundSpanService] = None
1684
1588
  self.background_span_service = BackgroundSpanService(
1685
1589
  judgment_api_key=api_key,
1686
1590
  organization_id=organization_id,
@@ -1688,6 +1592,12 @@ class Tracer:
1688
1592
  flush_interval=span_flush_interval,
1689
1593
  num_workers=span_num_workers,
1690
1594
  )
1595
+ except Exception as e:
1596
+ judgeval_logger.error(
1597
+ f"Issue with initializing Tracer: {e}. Disabling monitoring and evaluations."
1598
+ )
1599
+ self.enable_monitoring = False
1600
+ self.enable_evaluations = False
1691
1601
 
1692
1602
  def set_current_span(self, span_id: str) -> Optional[contextvars.Token[str | None]]:
1693
1603
  self.span_id_to_previous_span_id[span_id] = self.current_span_id
@@ -1777,11 +1687,7 @@ class Tracer:
1777
1687
 
1778
1688
  @contextmanager
1779
1689
  def trace(
1780
- self,
1781
- name: str,
1782
- project_name: str | None = None,
1783
- overwrite: bool = False,
1784
- rules: Optional[List[Rule]] = None, # Added rules parameter
1690
+ self, name: str, project_name: str | None = None
1785
1691
  ) -> Generator[TraceClient, None, None]:
1786
1692
  """Start a new trace context using a context manager"""
1787
1693
  trace_id = str(uuid.uuid4())
@@ -1801,8 +1707,6 @@ class Tracer:
1801
1707
  trace_id,
1802
1708
  name,
1803
1709
  project_name=project,
1804
- overwrite=overwrite,
1805
- rules=self.rules, # Pass combined rules to the trace client
1806
1710
  enable_monitoring=self.enable_monitoring,
1807
1711
  enable_evaluations=self.enable_evaluations,
1808
1712
  parent_trace_id=parent_trace_id,
@@ -1939,9 +1843,6 @@ class Tracer:
1939
1843
  *,
1940
1844
  name=None,
1941
1845
  span_type: SpanType = "span",
1942
- project_name: str | None = None,
1943
- overwrite: bool = False,
1944
- deep_tracing: bool | None = None,
1945
1846
  ):
1946
1847
  """
1947
1848
  Decorator to trace function execution with detailed entry/exit information.
@@ -1949,11 +1850,7 @@ class Tracer:
1949
1850
  Args:
1950
1851
  func: The function to decorate
1951
1852
  name: Optional custom name for the span (defaults to function name)
1952
- span_type: Type of span (default "span")
1953
- project_name: Optional project name override
1954
- overwrite: Whether to overwrite existing traces
1955
- deep_tracing: Whether to enable deep tracing for this function and all nested calls.
1956
- If None, uses the tracer's default setting.
1853
+ span_type: Type of span (default "span").
1957
1854
  """
1958
1855
  # If monitoring is disabled, return the function as is
1959
1856
  try:
@@ -1965,9 +1862,6 @@ class Tracer:
1965
1862
  f,
1966
1863
  name=name,
1967
1864
  span_type=span_type,
1968
- project_name=project_name,
1969
- overwrite=overwrite,
1970
- deep_tracing=deep_tracing,
1971
1865
  )
1972
1866
 
1973
1867
  # Use provided name or fall back to function name
@@ -1977,10 +1871,6 @@ class Tracer:
1977
1871
  func._judgment_span_name = original_span_name
1978
1872
  func._judgment_span_type = span_type
1979
1873
 
1980
- # Use the provided deep_tracing value or fall back to the tracer's default
1981
- use_deep_tracing = (
1982
- deep_tracing if deep_tracing is not None else self.deep_tracing
1983
- )
1984
1874
  except Exception:
1985
1875
  return func
1986
1876
 
@@ -2005,9 +1895,7 @@ class Tracer:
2005
1895
  # If there's no current trace, create a root trace
2006
1896
  if not current_trace:
2007
1897
  trace_id = str(uuid.uuid4())
2008
- project = (
2009
- project_name if project_name is not None else self.project_name
2010
- )
1898
+ project = self.project_name
2011
1899
 
2012
1900
  # Create a new trace client to serve as the root
2013
1901
  current_trace = TraceClient(
@@ -2015,14 +1903,10 @@ class Tracer:
2015
1903
  trace_id,
2016
1904
  span_name, # MODIFIED: Use span_name directly
2017
1905
  project_name=project,
2018
- overwrite=overwrite,
2019
- rules=self.rules,
2020
1906
  enable_monitoring=self.enable_monitoring,
2021
1907
  enable_evaluations=self.enable_evaluations,
2022
1908
  )
2023
1909
 
2024
- # Save empty trace and set trace context
2025
- # current_trace.save(empty_save=True, overwrite=overwrite)
2026
1910
  trace_token = self.set_current_trace(current_trace)
2027
1911
 
2028
1912
  try:
@@ -2043,7 +1927,7 @@ class Tracer:
2043
1927
  )
2044
1928
 
2045
1929
  try:
2046
- if use_deep_tracing:
1930
+ if self.deep_tracing:
2047
1931
  with _DeepTracer(self):
2048
1932
  result = await func(*args, **kwargs)
2049
1933
  else:
@@ -2068,22 +1952,22 @@ class Tracer:
2068
1952
  complete_trace_data = {
2069
1953
  "trace_id": current_trace.trace_id,
2070
1954
  "name": current_trace.name,
2071
- "created_at": datetime.utcfromtimestamp(
2072
- current_trace.start_time
1955
+ "created_at": datetime.fromtimestamp(
1956
+ current_trace.start_time or time.time(),
1957
+ timezone.utc,
2073
1958
  ).isoformat(),
2074
1959
  "duration": current_trace.get_duration(),
2075
1960
  "trace_spans": [
2076
1961
  span.model_dump()
2077
1962
  for span in current_trace.trace_spans
2078
1963
  ],
2079
- "overwrite": overwrite,
2080
1964
  "offline_mode": self.offline_mode,
2081
1965
  "parent_trace_id": current_trace.parent_trace_id,
2082
1966
  "parent_name": current_trace.parent_name,
2083
1967
  }
2084
1968
  # Save the completed trace
2085
1969
  trace_id, server_response = current_trace.save(
2086
- overwrite=overwrite, final_save=True
1970
+ final_save=True
2087
1971
  )
2088
1972
 
2089
1973
  # Store the complete trace data instead of just server response
@@ -2096,7 +1980,7 @@ class Tracer:
2096
1980
  # Reset trace context (span context resets automatically)
2097
1981
  self.reset_current_trace(trace_token)
2098
1982
  except Exception as e:
2099
- warnings.warn(f"Issue with async_wrapper: {e}")
1983
+ judgeval_logger.warning(f"Issue with async_wrapper: {e}")
2100
1984
  return
2101
1985
  else:
2102
1986
  with current_trace.span(span_name, span_type=span_type) as span:
@@ -2111,7 +1995,7 @@ class Tracer:
2111
1995
  )
2112
1996
 
2113
1997
  try:
2114
- if use_deep_tracing:
1998
+ if self.deep_tracing:
2115
1999
  with _DeepTracer(self):
2116
2000
  result = await func(*args, **kwargs)
2117
2001
  else:
@@ -2148,9 +2032,7 @@ class Tracer:
2148
2032
  # If there's no current trace, create a root trace
2149
2033
  if not current_trace:
2150
2034
  trace_id = str(uuid.uuid4())
2151
- project = (
2152
- project_name if project_name is not None else self.project_name
2153
- )
2035
+ project = self.project_name
2154
2036
 
2155
2037
  # Create a new trace client to serve as the root
2156
2038
  current_trace = TraceClient(
@@ -2158,14 +2040,10 @@ class Tracer:
2158
2040
  trace_id,
2159
2041
  span_name, # MODIFIED: Use span_name directly
2160
2042
  project_name=project,
2161
- overwrite=overwrite,
2162
- rules=self.rules,
2163
2043
  enable_monitoring=self.enable_monitoring,
2164
2044
  enable_evaluations=self.enable_evaluations,
2165
2045
  )
2166
2046
 
2167
- # Save empty trace and set trace context
2168
- # current_trace.save(empty_save=True, overwrite=overwrite)
2169
2047
  trace_token = self.set_current_trace(current_trace)
2170
2048
 
2171
2049
  try:
@@ -2185,7 +2063,7 @@ class Tracer:
2185
2063
  )
2186
2064
 
2187
2065
  try:
2188
- if use_deep_tracing:
2066
+ if self.deep_tracing:
2189
2067
  with _DeepTracer(self):
2190
2068
  result = func(*args, **kwargs)
2191
2069
  else:
@@ -2209,22 +2087,22 @@ class Tracer:
2209
2087
  try:
2210
2088
  # Save the completed trace
2211
2089
  trace_id, server_response = current_trace.save(
2212
- overwrite=overwrite, final_save=True
2090
+ final_save=True
2213
2091
  )
2214
2092
 
2215
2093
  # Store the complete trace data instead of just server response
2216
2094
  complete_trace_data = {
2217
2095
  "trace_id": current_trace.trace_id,
2218
2096
  "name": current_trace.name,
2219
- "created_at": datetime.utcfromtimestamp(
2220
- current_trace.start_time
2097
+ "created_at": datetime.fromtimestamp(
2098
+ current_trace.start_time or time.time(),
2099
+ timezone.utc,
2221
2100
  ).isoformat(),
2222
2101
  "duration": current_trace.get_duration(),
2223
2102
  "trace_spans": [
2224
2103
  span.model_dump()
2225
2104
  for span in current_trace.trace_spans
2226
2105
  ],
2227
- "overwrite": overwrite,
2228
2106
  "offline_mode": self.offline_mode,
2229
2107
  "parent_trace_id": current_trace.parent_trace_id,
2230
2108
  "parent_name": current_trace.parent_name,
@@ -2233,7 +2111,7 @@ class Tracer:
2233
2111
  # Reset trace context (span context resets automatically)
2234
2112
  self.reset_current_trace(trace_token)
2235
2113
  except Exception as e:
2236
- warnings.warn(f"Issue with save: {e}")
2114
+ judgeval_logger.warning(f"Issue with save: {e}")
2237
2115
  return
2238
2116
  else:
2239
2117
  with current_trace.span(span_name, span_type=span_type) as span:
@@ -2248,7 +2126,7 @@ class Tracer:
2248
2126
  )
2249
2127
 
2250
2128
  try:
2251
- if use_deep_tracing:
2129
+ if self.deep_tracing:
2252
2130
  with _DeepTracer(self):
2253
2131
  result = func(*args, **kwargs)
2254
2132
  else:
@@ -2309,8 +2187,8 @@ class Tracer:
2309
2187
  if hasattr(method, "_judgment_span_name"):
2310
2188
  skipped.append(name)
2311
2189
  if warn_on_double_decoration:
2312
- print(
2313
- f"Warning: {cls.__name__}.{name} already decorated, skipping"
2190
+ judgeval_logger.info(
2191
+ f"{cls.__name__}.{name} already decorated, skipping"
2314
2192
  )
2315
2193
  continue
2316
2194
 
@@ -2320,7 +2198,9 @@ class Tracer:
2320
2198
  decorated.append(name)
2321
2199
  except Exception as e:
2322
2200
  if warn_on_double_decoration:
2323
- print(f"Warning: Failed to decorate {cls.__name__}.{name}: {e}")
2201
+ judgeval_logger.warning(
2202
+ f"Failed to decorate {cls.__name__}.{name}: {e}"
2203
+ )
2324
2204
 
2325
2205
  return cls
2326
2206
 
@@ -2347,11 +2227,11 @@ class Tracer:
2347
2227
  )
2348
2228
  current_trace.async_evaluate(*args, **kwargs)
2349
2229
  else:
2350
- warnings.warn(
2230
+ judgeval_logger.warning(
2351
2231
  "No trace found (context var or fallback), skipping evaluation"
2352
2232
  ) # Modified warning
2353
2233
  except Exception as e:
2354
- warnings.warn(f"Issue with async_evaluate: {e}")
2234
+ judgeval_logger.warning(f"Issue with async_evaluate: {e}")
2355
2235
 
2356
2236
  def update_metadata(self, metadata: dict):
2357
2237
  """
@@ -2364,7 +2244,7 @@ class Tracer:
2364
2244
  if current_trace:
2365
2245
  current_trace.update_metadata(metadata)
2366
2246
  else:
2367
- warnings.warn("No current trace found, cannot set metadata")
2247
+ judgeval_logger.warning("No current trace found, cannot set metadata")
2368
2248
 
2369
2249
  def set_customer_id(self, customer_id: str):
2370
2250
  """
@@ -2377,7 +2257,7 @@ class Tracer:
2377
2257
  if current_trace:
2378
2258
  current_trace.set_customer_id(customer_id)
2379
2259
  else:
2380
- warnings.warn("No current trace found, cannot set customer ID")
2260
+ judgeval_logger.warning("No current trace found, cannot set customer ID")
2381
2261
 
2382
2262
  def set_tags(self, tags: List[Union[str, set, tuple]]):
2383
2263
  """
@@ -2390,7 +2270,7 @@ class Tracer:
2390
2270
  if current_trace:
2391
2271
  current_trace.set_tags(tags)
2392
2272
  else:
2393
- warnings.warn("No current trace found, cannot set tags")
2273
+ judgeval_logger.warning("No current trace found, cannot set tags")
2394
2274
 
2395
2275
  def get_background_span_service(self) -> Optional[BackgroundSpanService]:
2396
2276
  """Get the background span service instance."""
@@ -2447,7 +2327,7 @@ def wrap(
2447
2327
  # Warn about token counting limitations with streaming
2448
2328
  if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
2449
2329
  if not kwargs.get("stream_options", {}).get("include_usage"):
2450
- warnings.warn(
2330
+ judgeval_logger.warning(
2451
2331
  "OpenAI streaming calls don't include token counts by default. "
2452
2332
  "To enable token counting with streams, set stream_options={'include_usage': True} "
2453
2333
  "in your API call arguments.",
@@ -2746,19 +2626,25 @@ def _format_response_output_data(client: ApiClient, response: Any) -> tuple:
2746
2626
  prompt_tokens = 0
2747
2627
  completion_tokens = 0
2748
2628
  model_name = None
2749
- if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
2629
+ cache_read_input_tokens = 0
2630
+ cache_creation_input_tokens = 0
2631
+
2632
+ if isinstance(client, (OpenAI, AsyncOpenAI)):
2750
2633
  model_name = response.model
2751
2634
  prompt_tokens = response.usage.input_tokens
2752
2635
  completion_tokens = response.usage.output_tokens
2753
- message_content = response.output
2636
+ cache_read_input_tokens = response.usage.input_tokens_details.cached_tokens
2637
+ message_content = "".join(seg.text for seg in response.output[0].content)
2754
2638
  else:
2755
- warnings.warn(f"Unsupported client type: {type(client)}")
2639
+ judgeval_logger.warning(f"Unsupported client type: {type(client)}")
2756
2640
  return None, None
2757
2641
 
2758
2642
  prompt_cost, completion_cost = cost_per_token(
2759
2643
  model=model_name,
2760
2644
  prompt_tokens=prompt_tokens,
2761
2645
  completion_tokens=completion_tokens,
2646
+ cache_read_input_tokens=cache_read_input_tokens,
2647
+ cache_creation_input_tokens=cache_creation_input_tokens,
2762
2648
  )
2763
2649
  total_cost_usd = (
2764
2650
  (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
@@ -2767,6 +2653,8 @@ def _format_response_output_data(client: ApiClient, response: Any) -> tuple:
2767
2653
  prompt_tokens=prompt_tokens,
2768
2654
  completion_tokens=completion_tokens,
2769
2655
  total_tokens=prompt_tokens + completion_tokens,
2656
+ cache_read_input_tokens=cache_read_input_tokens,
2657
+ cache_creation_input_tokens=cache_creation_input_tokens,
2770
2658
  prompt_tokens_cost_usd=prompt_cost,
2771
2659
  completion_tokens_cost_usd=completion_cost,
2772
2660
  total_cost_usd=total_cost_usd,
@@ -2790,13 +2678,28 @@ def _format_output_data(
2790
2678
  """
2791
2679
  prompt_tokens = 0
2792
2680
  completion_tokens = 0
2681
+ cache_read_input_tokens = 0
2682
+ cache_creation_input_tokens = 0
2793
2683
  model_name = None
2794
2684
  message_content = None
2795
2685
 
2796
- if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
2686
+ if isinstance(client, (OpenAI, AsyncOpenAI)):
2797
2687
  model_name = response.model
2798
2688
  prompt_tokens = response.usage.prompt_tokens
2799
2689
  completion_tokens = response.usage.completion_tokens
2690
+ cache_read_input_tokens = response.usage.prompt_tokens_details.cached_tokens
2691
+ if (
2692
+ hasattr(response.choices[0].message, "parsed")
2693
+ and response.choices[0].message.parsed
2694
+ ):
2695
+ message_content = response.choices[0].message.parsed
2696
+ else:
2697
+ message_content = response.choices[0].message.content
2698
+
2699
+ elif isinstance(client, (Together, AsyncTogether)):
2700
+ model_name = "together_ai/" + response.model
2701
+ prompt_tokens = response.usage.prompt_tokens
2702
+ completion_tokens = response.usage.completion_tokens
2800
2703
  if (
2801
2704
  hasattr(response.choices[0].message, "parsed")
2802
2705
  and response.choices[0].message.parsed
@@ -2813,15 +2716,19 @@ def _format_output_data(
2813
2716
  model_name = response.model
2814
2717
  prompt_tokens = response.usage.input_tokens
2815
2718
  completion_tokens = response.usage.output_tokens
2719
+ cache_read_input_tokens = response.usage.cache_read_input_tokens
2720
+ cache_creation_input_tokens = response.usage.cache_creation_input_tokens
2816
2721
  message_content = response.content[0].text
2817
2722
  else:
2818
- warnings.warn(f"Unsupported client type: {type(client)}")
2723
+ judgeval_logger.warning(f"Unsupported client type: {type(client)}")
2819
2724
  return None, None
2820
2725
 
2821
2726
  prompt_cost, completion_cost = cost_per_token(
2822
2727
  model=model_name,
2823
2728
  prompt_tokens=prompt_tokens,
2824
2729
  completion_tokens=completion_tokens,
2730
+ cache_read_input_tokens=cache_read_input_tokens,
2731
+ cache_creation_input_tokens=cache_creation_input_tokens,
2825
2732
  )
2826
2733
  total_cost_usd = (
2827
2734
  (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
@@ -2830,6 +2737,8 @@ def _format_output_data(
2830
2737
  prompt_tokens=prompt_tokens,
2831
2738
  completion_tokens=completion_tokens,
2832
2739
  total_tokens=prompt_tokens + completion_tokens,
2740
+ cache_read_input_tokens=cache_read_input_tokens,
2741
+ cache_creation_input_tokens=cache_creation_input_tokens,
2833
2742
  prompt_tokens_cost_usd=prompt_cost,
2834
2743
  completion_tokens_cost_usd=completion_cost,
2835
2744
  total_cost_usd=total_cost_usd,
@@ -2969,8 +2878,11 @@ def _extract_usage_from_final_chunk(
2969
2878
  prompt_tokens = chunk.choices[0].usage.prompt_tokens
2970
2879
  completion_tokens = chunk.choices[0].usage.completion_tokens
2971
2880
 
2881
+ if isinstance(client, (Together, AsyncTogether)):
2882
+ model_name = "together_ai/" + chunk.model
2883
+
2972
2884
  prompt_cost, completion_cost = cost_per_token(
2973
- model=chunk.model,
2885
+ model=model_name,
2974
2886
  prompt_tokens=prompt_tokens,
2975
2887
  completion_tokens=completion_tokens,
2976
2888
  )
@@ -3161,9 +3073,17 @@ async def _async_stream_wrapper(
3161
3073
 
3162
3074
  def cost_per_token(*args, **kwargs):
3163
3075
  try:
3164
- return _original_cost_per_token(*args, **kwargs)
3076
+ prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = (
3077
+ _original_cost_per_token(*args, **kwargs)
3078
+ )
3079
+ if (
3080
+ prompt_tokens_cost_usd_dollar == 0
3081
+ and completion_tokens_cost_usd_dollar == 0
3082
+ ):
3083
+ judgeval_logger.warning("LiteLLM returned a total of 0 for cost per token")
3084
+ return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
3165
3085
  except Exception as e:
3166
- warnings.warn(f"Error calculating cost per token: {e}")
3086
+ judgeval_logger.warning(f"Error calculating cost per token: {e}")
3167
3087
  return None, None
3168
3088
 
3169
3089