judgeval 0.0.35__py3-none-any.whl → 0.0.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -44,6 +44,7 @@ from openai import OpenAI, AsyncOpenAI
44
44
  from together import Together, AsyncTogether
45
45
  from anthropic import Anthropic, AsyncAnthropic
46
46
  from google import genai
47
+ from judgeval.run_evaluation import check_examples
47
48
 
48
49
  # Local application/library-specific imports
49
50
  from judgeval.constants import (
@@ -75,6 +76,17 @@ in_traced_function_var = contextvars.ContextVar('in_traced_function', default=Fa
75
76
  ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
76
77
  TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
77
78
  SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
79
+
80
+ # --- Evaluation Config Dataclass (Moved from langgraph.py) ---
81
+ @dataclass
82
+ class EvaluationConfig:
83
+ """Configuration for triggering an evaluation from the handler."""
84
+ scorers: List[Union[APIJudgmentScorer, JudgevalScorer]]
85
+ example: Example
86
+ model: Optional[str] = None
87
+ log_results: Optional[bool] = True
88
+ # --- End Evaluation Config Dataclass ---
89
+
78
90
  @dataclass
79
91
  class TraceEntry:
80
92
  """Represents a single trace entry with its visual representation.
@@ -197,29 +209,31 @@ class TraceEntry:
197
209
 
198
210
  Handles special cases:
199
211
  - Pydantic models are converted using model_dump()
212
+ - Dictionaries are processed recursively to handle non-serializable values.
200
213
  - We try to serialize into JSON, then string, then the base representation (__repr__)
201
214
  - Non-serializable objects return None with a warning
202
215
  """
203
-
204
- if isinstance(self.output, BaseModel):
205
- return self.output.model_dump()
206
-
207
- # NEW check: If output is the dict structure from our stream wrapper
208
- if isinstance(self.output, dict) and 'streamed' in self.output:
209
- # Assume it's already JSON-serializable (content is string, usage is dict or None)
210
- return self.output
211
- # NEW check: If output is the placeholder string before stream completes
212
- elif self.output == "<pending stream>":
213
- # Represent this state clearly in the serialized data
214
- return {"status": "pending stream"}
215
216
 
216
- try:
217
- # Try to serialize the output to verify it's JSON compatible
218
- json.dumps(self.output)
219
- return self.output
220
- except (TypeError, OverflowError, ValueError):
221
- return self.safe_stringify(self.output, self.function)
222
-
217
+ def serialize_value(value):
218
+ if isinstance(value, BaseModel):
219
+ return value.model_dump()
220
+ elif isinstance(value, dict):
221
+ # Recursively serialize dictionary values
222
+ return {k: serialize_value(v) for k, v in value.items()}
223
+ elif isinstance(value, (list, tuple)):
224
+ # Recursively serialize list/tuple items
225
+ return [serialize_value(item) for item in value]
226
+ else:
227
+ # Try direct JSON serialization first
228
+ try:
229
+ json.dumps(value)
230
+ return value
231
+ except (TypeError, OverflowError, ValueError):
232
+ # Fallback to safe stringification
233
+ return self.safe_stringify(value, self.function)
234
+
235
+ # Start serialization with the top-level output
236
+ return serialize_value(self.output)
223
237
 
224
238
  class TraceManagerClient:
225
239
  """
@@ -467,32 +481,24 @@ class TraceClient:
467
481
  def async_evaluate(
468
482
  self,
469
483
  scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
484
+ example: Optional[Example] = None,
470
485
  input: Optional[str] = None,
471
- actual_output: Optional[str] = None,
472
- expected_output: Optional[str] = None,
486
+ actual_output: Optional[Union[str, List[str]]] = None,
487
+ expected_output: Optional[Union[str, List[str]]] = None,
473
488
  context: Optional[List[str]] = None,
474
489
  retrieval_context: Optional[List[str]] = None,
475
490
  tools_called: Optional[List[str]] = None,
476
491
  expected_tools: Optional[List[str]] = None,
477
492
  additional_metadata: Optional[Dict[str, Any]] = None,
478
493
  model: Optional[str] = None,
494
+ span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
479
495
  log_results: Optional[bool] = True
480
496
  ):
481
497
  if not self.enable_evaluations:
482
498
  return
483
499
 
484
500
  start_time = time.time() # Record start time
485
- example = Example(
486
- input=input,
487
- actual_output=actual_output,
488
- expected_output=expected_output,
489
- context=context,
490
- retrieval_context=retrieval_context,
491
- tools_called=tools_called,
492
- expected_tools=expected_tools,
493
- additional_metadata=additional_metadata,
494
- trace_id=self.trace_id
495
- )
501
+
496
502
  try:
497
503
  # Load appropriate implementations for all scorers
498
504
  if not scorers:
@@ -507,13 +513,44 @@ class TraceClient:
507
513
  warnings.warn(f"Failed to load scorers: {str(e)}")
508
514
  return
509
515
 
516
+ # If example is not provided, create one from the individual parameters
517
+ if example is None:
518
+ # Check if any of the individual parameters are provided
519
+ if any(param is not None for param in [input, actual_output, expected_output, context,
520
+ retrieval_context, tools_called, expected_tools,
521
+ additional_metadata]):
522
+ example = Example(
523
+ input=input,
524
+ actual_output=actual_output,
525
+ expected_output=expected_output,
526
+ context=context,
527
+ retrieval_context=retrieval_context,
528
+ tools_called=tools_called,
529
+ expected_tools=expected_tools,
530
+ additional_metadata=additional_metadata,
531
+ trace_id=self.trace_id
532
+ )
533
+ else:
534
+ raise ValueError("Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided")
535
+
536
+ # Check examples before creating evaluation run
537
+ check_examples([example], scorers)
538
+
539
+ # --- Modification: Capture span_id immediately ---
540
+ # span_id_at_eval_call = current_span_var.get()
541
+ # print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
542
+ # Prioritize explicitly passed span_id, fallback to context var
543
+ span_id_to_use = span_id if span_id is not None else current_span_var.get()
544
+ # print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
545
+ # --- End Modification ---
546
+
510
547
  # Combine the trace-level rules with any evaluation-specific rules)
511
548
  eval_run = EvaluationRun(
512
549
  organization_id=self.tracer.organization_id,
513
550
  log_results=log_results,
514
551
  project_name=self.project_name,
515
552
  eval_name=f"{self.name.capitalize()}-"
516
- f"{current_span_var.get()}-"
553
+ f"{current_span_var.get()}-" # Keep original eval name format using context var if available
517
554
  f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
518
555
  examples=[example],
519
556
  scorers=scorers,
@@ -521,14 +558,18 @@ class TraceClient:
521
558
  metadata={},
522
559
  judgment_api_key=self.tracer.api_key,
523
560
  override=self.overwrite,
524
- trace_span_id=current_span_var.get(),
561
+ trace_span_id=span_id_to_use, # Pass the determined ID
525
562
  rules=self.rules # Use the combined rules
526
563
  )
527
564
 
528
565
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
529
566
 
530
567
  def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
531
- current_span_id = current_span_var.get()
568
+ # --- Modification: Use span_id from eval_run ---
569
+ current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
570
+ # print(f"[TraceClient.add_eval_run] Using span_id from eval_run: {current_span_id}")
571
+ # --- End Modification ---
572
+
532
573
  if current_span_id:
533
574
  duration = time.time() - start_time
534
575
  prev_entry = self.entries[-1] if self.entries else None
@@ -574,7 +615,7 @@ class TraceClient:
574
615
  self.add_entry(TraceEntry(
575
616
  type="input",
576
617
  function=function_name,
577
- span_id=current_span_id, # Use current span_id
618
+ span_id=current_span_id, # Use current span_id from context
578
619
  trace_id=self.trace_id, # Use the trace_id from the trace client
579
620
  depth=current_depth,
580
621
  message=f"Inputs to {function_name}",
@@ -582,6 +623,7 @@ class TraceClient:
582
623
  inputs=inputs,
583
624
  span_type=entry_span_type,
584
625
  ))
626
+ # Removed else block - original didn't have one
585
627
 
586
628
  async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
587
629
  """Helper method to update the output of a trace entry once the coroutine completes"""
@@ -608,20 +650,22 @@ class TraceClient:
608
650
  entry = TraceEntry(
609
651
  type="output",
610
652
  function=function_name,
611
- span_id=current_span_id, # Use current span_id
653
+ span_id=current_span_id, # Use current span_id from context
612
654
  depth=current_depth,
613
655
  message=f"Output from {function_name}",
614
656
  created_at=time.time(),
615
657
  output="<pending>" if inspect.iscoroutine(output) else output,
616
658
  span_type=entry_span_type,
659
+ trace_id=self.trace_id # Added trace_id for consistency
617
660
  )
618
661
  self.add_entry(entry)
619
662
 
620
663
  if inspect.iscoroutine(output):
621
664
  asyncio.create_task(self._update_coroutine_output(entry, output))
622
665
 
623
- # Return the created entry
624
- return entry
666
+ return entry # Return the created entry
667
+ # Removed else block - original didn't have one
668
+ return None # Return None if no span_id found
625
669
 
626
670
  def add_entry(self, entry: TraceEntry):
627
671
  """Add a trace entry to this trace context"""
@@ -824,78 +868,144 @@ class TraceClient:
824
868
 
825
869
  condensed_entries, evaluation_runs = self.condense_trace(raw_entries)
826
870
 
827
- # Calculate total token counts from LLM API calls
828
- total_prompt_tokens = 0
829
- total_completion_tokens = 0
830
- total_tokens = 0
831
-
832
- total_prompt_tokens_cost = 0.0
833
- total_completion_tokens_cost = 0.0
834
- total_cost = 0.0
835
-
836
871
  # Only count tokens for actual LLM API call spans
837
872
  llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
838
873
  for entry in condensed_entries:
839
- if entry.get("span_type") == "llm" and entry.get("function") in llm_span_names and isinstance(entry.get("output"), dict):
874
+ entry_function_name = entry.get("function", "") # Get function name safely
875
+ # Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
876
+ is_llm_entry = entry.get("span_type") == "llm"
877
+ has_api_suffix = any(suffix in entry_function_name for suffix in llm_span_names)
878
+ output_is_dict = isinstance(entry.get("output"), dict)
879
+
880
+ # --- DEBUG PRINT 1: Check if condition passes ---
881
+ # if is_llm_entry and has_api_suffix and output_is_dict:
882
+ # # print(f"[DEBUG TraceClient.save] Processing entry: {entry.get('span_id')} ({entry_function_name}) - Condition PASSED")
883
+ # elif is_llm_entry:
884
+ # # Print why it failed if it was an LLM entry
885
+ # print(f"[DEBUG TraceClient.save] Skipping LLM entry: {entry.get('span_id')} ({entry_function_name}) - Suffix Match: {has_api_suffix}, Output is Dict: {output_is_dict}")
886
+ # # --- END DEBUG ---
887
+
888
+ if is_llm_entry and has_api_suffix and output_is_dict:
840
889
  output = entry["output"]
841
- usage = output.get("usage", {})
842
- model_name = entry.get("inputs", {}).get("model", "")
890
+ usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
891
+
892
+ # --- DEBUG PRINT 2: Check extracted usage ---
893
+ # print(f"[DEBUG TraceClient.save] Extracted usage dict: {usage}")
894
+ # --- END DEBUG ---
895
+
896
+ # --- NEW: Extract model_name correctly from nested inputs ---
897
+ model_name = None
898
+ entry_inputs = entry.get("inputs", {})
899
+ # print(f"[DEBUG TraceClient.save] Inspecting inputs for span {entry.get('span_id')}: {entry_inputs}") # DEBUG Inputs
900
+ if entry_inputs:
901
+ # Try common locations for model name within the inputs structure
902
+ invocation_params = entry_inputs.get("invocation_params", {})
903
+ serialized_data = entry_inputs.get("serialized", {})
904
+
905
+ # Look in invocation_params (often directly contains model)
906
+ if isinstance(invocation_params, dict):
907
+ model_name = invocation_params.get("model")
908
+
909
+ # Fallback: Check serialized 'repr' if it contains model info
910
+ if not model_name and isinstance(serialized_data, dict):
911
+ serialized_repr = serialized_data.get("repr", "")
912
+ if "model_name=" in serialized_repr:
913
+ try: # Simple parsing attempt
914
+ model_name = serialized_repr.split("model_name='")[1].split("'")[0]
915
+ except IndexError: pass # Ignore parsing errors
916
+
917
+ # Fallback: Check top-level of invocation_params (sometimes passed flat)
918
+ if not model_name and isinstance(invocation_params, dict):
919
+ model_name = invocation_params.get("model") # Redundant check, but safe
920
+
921
+ # Fallback: Check top-level of inputs itself (less likely for callbacks)
922
+ if not model_name:
923
+ model_name = entry_inputs.get("model")
924
+
925
+
926
+ # print(f"[DEBUG TraceClient.save] Determined model_name: {model_name}") # DEBUG Model Name
927
+ # --- END NEW ---
928
+
843
929
  prompt_tokens = 0
844
- completion_tokens = 0
845
-
846
- # Handle OpenAI/Together format
930
+ completion_tokens = 0
931
+
932
+ # Handle OpenAI/Together format (checks within the 'usage' dict)
847
933
  if "prompt_tokens" in usage:
848
934
  prompt_tokens = usage.get("prompt_tokens", 0)
849
935
  completion_tokens = usage.get("completion_tokens", 0)
850
- total_prompt_tokens += prompt_tokens
851
- total_completion_tokens += completion_tokens
852
- # Handle Anthropic format
936
+
937
+ # Handle Anthropic format - MAP values to standard keys
853
938
  elif "input_tokens" in usage:
854
- prompt_tokens = usage.get("input_tokens", 0)
855
- completion_tokens = usage.get("output_tokens", 0)
856
- total_prompt_tokens += prompt_tokens
857
- total_completion_tokens += completion_tokens
858
-
859
- total_tokens += usage.get("total_tokens", 0)
939
+ prompt_tokens = usage.get("input_tokens", 0) # Get value from input_tokens
940
+ completion_tokens = usage.get("output_tokens", 0) # Get value from output_tokens
941
+
942
+ # *** Overwrite the usage dict in the entry to use standard keys ***
943
+ original_total = usage.get("total_tokens", 0)
944
+ original_total_cost = usage.get("total_cost_usd", 0.0) # Preserve if already calculated
945
+ # Recalculate cost just in case it wasn't done correctly before
946
+ temp_prompt_cost, temp_completion_cost = 0.0, 0.0
947
+ if model_name:
948
+ try:
949
+ temp_prompt_cost, temp_completion_cost = cost_per_token(
950
+ model=model_name,
951
+ prompt_tokens=prompt_tokens,
952
+ completion_tokens=completion_tokens
953
+ )
954
+ except Exception:
955
+ pass # Ignore cost calculation errors here, focus on keys
956
+ # Replace the usage dict with one using standard keys but Anthropic values
957
+ output["usage"] = {
958
+ "prompt_tokens": prompt_tokens,
959
+ "completion_tokens": completion_tokens,
960
+ "total_tokens": original_total,
961
+ "prompt_tokens_cost_usd": temp_prompt_cost, # Use standard cost key
962
+ "completion_tokens_cost_usd": temp_completion_cost, # Use standard cost key
963
+ "total_cost_usd": original_total_cost if original_total_cost > 0 else (temp_prompt_cost + temp_completion_cost)
964
+ }
965
+ usage = output["usage"]
966
+
967
+ # Calculate costs if model name is available and ensure they are stored with standard keys
968
+ prompt_tokens = usage.get("prompt_tokens", 0)
969
+ completion_tokens = usage.get("completion_tokens", 0)
860
970
 
861
971
  # Calculate costs if model name is available
862
972
  if model_name:
863
973
  try:
974
+ # Recalculate costs based on potentially mapped tokens
864
975
  prompt_cost, completion_cost = cost_per_token(
865
976
  model=model_name,
866
977
  prompt_tokens=prompt_tokens,
867
978
  completion_tokens=completion_tokens
868
979
  )
869
- total_prompt_tokens_cost += prompt_cost
870
- total_completion_tokens_cost += completion_cost
871
- total_cost += prompt_cost + completion_cost
872
980
 
873
981
  # Add cost information directly to the usage dictionary in the condensed entry
982
+ # Ensure 'usage' exists in the output dict before modifying it
983
+ # Add/Update cost information using standard keys
984
+
874
985
  if "usage" not in output:
875
- output["usage"] = {}
986
+ output["usage"] = {} # Initialize if missing
987
+ elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
988
+ print(f"[WARN TraceClient.save] Output 'usage' for span {entry.get('span_id')} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
989
+ output["usage"] = {} # Reset to dict
990
+
876
991
  output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
877
992
  output["usage"]["completion_tokens_cost_usd"] = completion_cost
878
993
  output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
879
994
  except Exception as e:
880
995
  # If cost calculation fails, continue without adding costs
881
- print(f"Error calculating cost for model '{model_name}': {str(e)}")
996
+ print(f"Error calculating cost for model '{model_name}' (span: {entry.get('span_id')}): {str(e)}")
882
997
  pass
998
+ else:
999
+ print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {entry.get('span_id')}). Inputs: {entry_inputs}")
1000
+
883
1001
 
884
- # Create trace document
1002
+ # Create trace document - Always use standard keys for top-level counts
885
1003
  trace_data = {
886
1004
  "trace_id": self.trace_id,
887
1005
  "name": self.name,
888
1006
  "project_name": self.project_name,
889
1007
  "created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
890
1008
  "duration": total_duration,
891
- "token_counts": {
892
- "prompt_tokens": total_prompt_tokens,
893
- "completion_tokens": total_completion_tokens,
894
- "total_tokens": total_tokens,
895
- "prompt_tokens_cost_usd": total_prompt_tokens_cost,
896
- "completion_tokens_cost_usd": total_completion_tokens_cost,
897
- "total_cost_usd": total_cost
898
- },
899
1009
  "entries": condensed_entries,
900
1010
  "evaluation_runs": evaluation_runs,
901
1011
  "overwrite": overwrite,
@@ -903,12 +1013,6 @@ class TraceClient:
903
1013
  "parent_name": self.parent_name
904
1014
  }
905
1015
  # --- Log trace data before saving ---
906
- try:
907
- rprint(f"[TraceClient.save] Saving trace data for trace_id {self.trace_id}:")
908
- rprint(json.dumps(trace_data, indent=2))
909
- except Exception as log_e:
910
- rprint(f"[TraceClient.save] Error logging trace data: {log_e}")
911
- # --- End logging ---
912
1016
  self.trace_manager_client.save_trace(trace_data)
913
1017
 
914
1018
  return self.trace_id, trace_data
@@ -958,6 +1062,7 @@ class Tracer:
958
1062
  self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
959
1063
  self.organization_id: str = organization_id
960
1064
  self._current_trace: Optional[str] = None
1065
+ self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
961
1066
  self.rules: List[Rule] = rules or [] # Store rules at tracer level
962
1067
  self.initialized: bool = True
963
1068
  self.enable_monitoring: bool = enable_monitoring
@@ -991,10 +1096,29 @@ class Tracer:
991
1096
 
992
1097
  def get_current_trace(self) -> Optional[TraceClient]:
993
1098
  """
994
- Get the current trace context from contextvars
1099
+ Get the current trace context.
1100
+
1101
+ Tries to get the trace client from the context variable first.
1102
+ If not found (e.g., context lost across threads/tasks),
1103
+ it falls back to the active trace client managed by the callback handler.
995
1104
  """
996
- return current_trace_var.get()
1105
+ trace_from_context = current_trace_var.get()
1106
+ if trace_from_context:
1107
+ return trace_from_context
997
1108
 
1109
+ # Fallback: Check the active client potentially set by a callback handler
1110
+ if hasattr(self, '_active_trace_client') and self._active_trace_client:
1111
+ # warnings.warn("Falling back to _active_trace_client in get_current_trace. ContextVar might be lost.", RuntimeWarning)
1112
+ return self._active_trace_client
1113
+
1114
+ # If neither is available
1115
+ # warnings.warn("No current trace found in context variable or active client fallback.", RuntimeWarning)
1116
+ return None
1117
+
1118
+ def get_active_trace_client(self) -> Optional[TraceClient]:
1119
+ """Returns the TraceClient instance currently marked as active by the handler."""
1120
+ return self._active_trace_client
1121
+
998
1122
  def _apply_deep_tracing(self, func, span_type="span"):
999
1123
  """
1000
1124
  Apply deep tracing to all functions in the same module as the given function.
@@ -1314,43 +1438,29 @@ class Tracer:
1314
1438
 
1315
1439
  return wrapper
1316
1440
 
1317
- def score(self, func=None, scorers: List[Union[APIJudgmentScorer, JudgevalScorer]] = None, model: str = None, log_results: bool = True, *, name: str = None, span_type: SpanType = "span"):
1318
- """
1319
- Decorator to trace function execution with detailed entry/exit information.
1320
- """
1321
- if func is None:
1322
- return lambda f: self.score(f, scorers=scorers, model=model, log_results=log_results, name=name, span_type=span_type)
1323
-
1324
- if asyncio.iscoroutinefunction(func):
1325
- @functools.wraps(func)
1326
- async def async_wrapper(*args, **kwargs):
1327
- # Get current trace from contextvars
1328
- current_trace = current_trace_var.get()
1329
- if current_trace and scorers:
1330
- current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
1331
- return await func(*args, **kwargs)
1332
- return async_wrapper
1333
- else:
1334
- @functools.wraps(func)
1335
- def wrapper(*args, **kwargs):
1336
- # Get current trace from contextvars
1337
- current_trace = current_trace_var.get()
1338
- if current_trace and scorers:
1339
- current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
1340
- return func(*args, **kwargs)
1341
- return wrapper
1342
-
1343
1441
  def async_evaluate(self, *args, **kwargs):
1344
1442
  if not self.enable_evaluations:
1345
1443
  return
1346
1444
 
1347
- # Get current trace from context
1445
+ # --- Get trace_id passed explicitly (if any) ---
1446
+ passed_trace_id = kwargs.pop('trace_id', None) # Get and remove trace_id from kwargs
1447
+
1448
+ # --- Get current trace from context FIRST ---
1348
1449
  current_trace = current_trace_var.get()
1349
-
1450
+
1451
+ # --- Fallback Logic: Use active client only if context var is empty ---
1452
+ if not current_trace:
1453
+ current_trace = self._active_trace_client # Use the fallback
1454
+ # --- End Fallback Logic ---
1455
+
1350
1456
  if current_trace:
1457
+ # Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
1458
+ # (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
1459
+ if passed_trace_id:
1460
+ kwargs['trace_id'] = passed_trace_id # Re-add if needed by TraceClient.async_evaluate
1351
1461
  current_trace.async_evaluate(*args, **kwargs)
1352
1462
  else:
1353
- warnings.warn("No trace found, skipping evaluation")
1463
+ warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
1354
1464
 
1355
1465
 
1356
1466
  def wrap(client: Any) -> Any:
@@ -1600,8 +1710,8 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1600
1710
  return {
1601
1711
  "content": response.content[0].text,
1602
1712
  "usage": {
1603
- "input_tokens": response.usage.input_tokens,
1604
- "output_tokens": response.usage.output_tokens,
1713
+ "prompt_tokens": response.usage.input_tokens,
1714
+ "completion_tokens": response.usage.output_tokens,
1605
1715
  "total_tokens": response.usage.input_tokens + response.usage.output_tokens
1606
1716
  }
1607
1717
  }
@@ -1891,8 +2001,8 @@ async def _async_stream_wrapper(
1891
2001
  anthropic_final_usage = None
1892
2002
  if isinstance(client, (AsyncAnthropic, Anthropic)) and (anthropic_input_tokens > 0 or anthropic_output_tokens > 0):
1893
2003
  anthropic_final_usage = {
1894
- "input_tokens": anthropic_input_tokens,
1895
- "output_tokens": anthropic_output_tokens,
2004
+ "prompt_tokens": anthropic_input_tokens,
2005
+ "completion_tokens": anthropic_output_tokens,
1896
2006
  "total_tokens": anthropic_input_tokens + anthropic_output_tokens
1897
2007
  }
1898
2008
 
@@ -2080,3 +2190,127 @@ class _TracedSyncStreamManagerWrapper(AbstractContextManager):
2080
2190
  if hasattr(self._original_manager, "__exit__"):
2081
2191
  return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
2082
2192
  return None
2193
+
2194
+ # --- NEW Generalized Helper Function (Moved from demo) ---
2195
+ def prepare_evaluation_for_state(
2196
+ scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
2197
+ example: Optional[Example] = None,
2198
+ # --- Individual components (alternative to 'example') ---
2199
+ input: Optional[str] = None,
2200
+ actual_output: Optional[Union[str, List[str]]] = None,
2201
+ expected_output: Optional[Union[str, List[str]]] = None,
2202
+ context: Optional[List[str]] = None,
2203
+ retrieval_context: Optional[List[str]] = None,
2204
+ tools_called: Optional[List[str]] = None,
2205
+ expected_tools: Optional[List[str]] = None,
2206
+ additional_metadata: Optional[Dict[str, Any]] = None,
2207
+ # --- Other eval parameters ---
2208
+ model: Optional[str] = None,
2209
+ log_results: Optional[bool] = True
2210
+ ) -> Optional[EvaluationConfig]:
2211
+ """
2212
+ Prepares an EvaluationConfig object, similar to TraceClient.async_evaluate.
2213
+
2214
+ Accepts either a pre-made Example object or individual components to construct one.
2215
+ Returns the EvaluationConfig object ready to be placed in the state, or None.
2216
+ """
2217
+ final_example = example
2218
+
2219
+ # If example is not provided, try to construct one from individual parts
2220
+ if final_example is None:
2221
+ # Basic validation: Ensure at least actual_output is present for most scorers
2222
+ if actual_output is None:
2223
+ # print("[prepare_evaluation_for_state] Warning: 'actual_output' is required when 'example' is not provided. Skipping evaluation setup.")
2224
+ return None
2225
+ try:
2226
+ final_example = Example(
2227
+ input=input,
2228
+ actual_output=actual_output,
2229
+ expected_output=expected_output,
2230
+ context=context,
2231
+ retrieval_context=retrieval_context,
2232
+ tools_called=tools_called,
2233
+ expected_tools=expected_tools,
2234
+ additional_metadata=additional_metadata,
2235
+ # trace_id will be set by the handler later if needed
2236
+ )
2237
+ # print("[prepare_evaluation_for_state] Constructed Example from individual components.")
2238
+ except Exception as e:
2239
+ # print(f"[prepare_evaluation_for_state] Error constructing Example: {e}. Skipping evaluation setup.")
2240
+ return None
2241
+
2242
+ # If we have a valid example (provided or constructed) and scorers
2243
+ if final_example and scorers:
2244
+ # TODO: Add validation like check_examples if needed here,
2245
+ # although the handler might implicitly handle some checks via TraceClient.
2246
+ return EvaluationConfig(
2247
+ scorers=scorers,
2248
+ example=final_example,
2249
+ model=model,
2250
+ log_results=log_results
2251
+ )
2252
+ elif not scorers:
2253
+ # print("[prepare_evaluation_for_state] No scorers provided. Skipping evaluation setup.")
2254
+ return None
2255
+ else: # No valid example
2256
+ # print("[prepare_evaluation_for_state] No valid Example available. Skipping evaluation setup.")
2257
+ return None
2258
+ # --- End NEW Helper Function ---
2259
+
2260
+ # --- NEW: Helper function to simplify adding eval config to state ---
2261
+ def add_evaluation_to_state(
2262
+ state: Dict[str, Any], # The LangGraph state dictionary
2263
+ scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
2264
+ # --- Evaluation components (same as prepare_evaluation_for_state) ---
2265
+ input: Optional[str] = None,
2266
+ actual_output: Optional[Union[str, List[str]]] = None,
2267
+ expected_output: Optional[Union[str, List[str]]] = None,
2268
+ context: Optional[List[str]] = None,
2269
+ retrieval_context: Optional[List[str]] = None,
2270
+ tools_called: Optional[List[str]] = None,
2271
+ expected_tools: Optional[List[str]] = None,
2272
+ additional_metadata: Optional[Dict[str, Any]] = None,
2273
+ # --- Other eval parameters ---
2274
+ model: Optional[str] = None,
2275
+ log_results: Optional[bool] = True
2276
+ ) -> None:
2277
+ """
2278
+ Prepares an EvaluationConfig and adds it to the state dictionary
2279
+ under the '_judgeval_eval' key if successful.
2280
+
2281
+ This simplifies the process of setting up evaluations within LangGraph nodes.
2282
+
2283
+ Args:
2284
+ state: The LangGraph state dictionary to modify.
2285
+ scorers: List of scorer instances.
2286
+ input: Input for the evaluation example.
2287
+ actual_output: Actual output for the evaluation example.
2288
+ expected_output: Expected output for the evaluation example.
2289
+ context: Context for the evaluation example.
2290
+ retrieval_context: Retrieval context for the evaluation example.
2291
+ tools_called: Tools called for the evaluation example.
2292
+ expected_tools: Expected tools for the evaluation example.
2293
+ additional_metadata: Additional metadata for the evaluation example.
2294
+ model: Model name used for generation (optional).
2295
+ log_results: Whether to log evaluation results (optional, defaults to True).
2296
+ """
2297
+ eval_config = prepare_evaluation_for_state(
2298
+ scorers=scorers,
2299
+ input=input,
2300
+ actual_output=actual_output,
2301
+ expected_output=expected_output,
2302
+ context=context,
2303
+ retrieval_context=retrieval_context,
2304
+ tools_called=tools_called,
2305
+ expected_tools=expected_tools,
2306
+ additional_metadata=additional_metadata,
2307
+ model=model,
2308
+ log_results=log_results
2309
+ )
2310
+
2311
+ if eval_config:
2312
+ state["_judgeval_eval"] = eval_config
2313
+ # print(f"[_judgeval_eval added to state for node]") # Optional: Log confirmation
2314
+
2315
+ # print("[Skipped adding _judgeval_eval to state: prepare_evaluation_for_state failed]")
2316
+ # --- End NEW Helper ---