judgeval 0.0.35__py3-none-any.whl → 0.0.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/tracer.py +352 -118
- judgeval/constants.py +3 -2
- judgeval/data/datasets/dataset.py +3 -0
- judgeval/data/datasets/eval_dataset_client.py +63 -3
- judgeval/integrations/langgraph.py +1961 -299
- judgeval/judgment_client.py +8 -2
- judgeval/run_evaluation.py +67 -18
- judgeval/scorers/score.py +1 -0
- {judgeval-0.0.35.dist-info → judgeval-0.0.36.dist-info}/METADATA +1 -2
- {judgeval-0.0.35.dist-info → judgeval-0.0.36.dist-info}/RECORD +12 -12
- {judgeval-0.0.35.dist-info → judgeval-0.0.36.dist-info}/WHEEL +0 -0
- {judgeval-0.0.35.dist-info → judgeval-0.0.36.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -44,6 +44,7 @@ from openai import OpenAI, AsyncOpenAI
|
|
44
44
|
from together import Together, AsyncTogether
|
45
45
|
from anthropic import Anthropic, AsyncAnthropic
|
46
46
|
from google import genai
|
47
|
+
from judgeval.run_evaluation import check_examples
|
47
48
|
|
48
49
|
# Local application/library-specific imports
|
49
50
|
from judgeval.constants import (
|
@@ -75,6 +76,17 @@ in_traced_function_var = contextvars.ContextVar('in_traced_function', default=Fa
|
|
75
76
|
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
|
76
77
|
TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
|
77
78
|
SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
|
79
|
+
|
80
|
+
# --- Evaluation Config Dataclass (Moved from langgraph.py) ---
|
81
|
+
@dataclass
|
82
|
+
class EvaluationConfig:
|
83
|
+
"""Configuration for triggering an evaluation from the handler."""
|
84
|
+
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]]
|
85
|
+
example: Example
|
86
|
+
model: Optional[str] = None
|
87
|
+
log_results: Optional[bool] = True
|
88
|
+
# --- End Evaluation Config Dataclass ---
|
89
|
+
|
78
90
|
@dataclass
|
79
91
|
class TraceEntry:
|
80
92
|
"""Represents a single trace entry with its visual representation.
|
@@ -197,29 +209,31 @@ class TraceEntry:
|
|
197
209
|
|
198
210
|
Handles special cases:
|
199
211
|
- Pydantic models are converted using model_dump()
|
212
|
+
- Dictionaries are processed recursively to handle non-serializable values.
|
200
213
|
- We try to serialize into JSON, then string, then the base representation (__repr__)
|
201
214
|
- Non-serializable objects return None with a warning
|
202
215
|
"""
|
203
|
-
|
204
|
-
if isinstance(self.output, BaseModel):
|
205
|
-
return self.output.model_dump()
|
206
|
-
|
207
|
-
# NEW check: If output is the dict structure from our stream wrapper
|
208
|
-
if isinstance(self.output, dict) and 'streamed' in self.output:
|
209
|
-
# Assume it's already JSON-serializable (content is string, usage is dict or None)
|
210
|
-
return self.output
|
211
|
-
# NEW check: If output is the placeholder string before stream completes
|
212
|
-
elif self.output == "<pending stream>":
|
213
|
-
# Represent this state clearly in the serialized data
|
214
|
-
return {"status": "pending stream"}
|
215
216
|
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
217
|
+
def serialize_value(value):
|
218
|
+
if isinstance(value, BaseModel):
|
219
|
+
return value.model_dump()
|
220
|
+
elif isinstance(value, dict):
|
221
|
+
# Recursively serialize dictionary values
|
222
|
+
return {k: serialize_value(v) for k, v in value.items()}
|
223
|
+
elif isinstance(value, (list, tuple)):
|
224
|
+
# Recursively serialize list/tuple items
|
225
|
+
return [serialize_value(item) for item in value]
|
226
|
+
else:
|
227
|
+
# Try direct JSON serialization first
|
228
|
+
try:
|
229
|
+
json.dumps(value)
|
230
|
+
return value
|
231
|
+
except (TypeError, OverflowError, ValueError):
|
232
|
+
# Fallback to safe stringification
|
233
|
+
return self.safe_stringify(value, self.function)
|
234
|
+
|
235
|
+
# Start serialization with the top-level output
|
236
|
+
return serialize_value(self.output)
|
223
237
|
|
224
238
|
class TraceManagerClient:
|
225
239
|
"""
|
@@ -467,32 +481,24 @@ class TraceClient:
|
|
467
481
|
def async_evaluate(
|
468
482
|
self,
|
469
483
|
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
484
|
+
example: Optional[Example] = None,
|
470
485
|
input: Optional[str] = None,
|
471
|
-
actual_output: Optional[str] = None,
|
472
|
-
expected_output: Optional[str] = None,
|
486
|
+
actual_output: Optional[Union[str, List[str]]] = None,
|
487
|
+
expected_output: Optional[Union[str, List[str]]] = None,
|
473
488
|
context: Optional[List[str]] = None,
|
474
489
|
retrieval_context: Optional[List[str]] = None,
|
475
490
|
tools_called: Optional[List[str]] = None,
|
476
491
|
expected_tools: Optional[List[str]] = None,
|
477
492
|
additional_metadata: Optional[Dict[str, Any]] = None,
|
478
493
|
model: Optional[str] = None,
|
494
|
+
span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
|
479
495
|
log_results: Optional[bool] = True
|
480
496
|
):
|
481
497
|
if not self.enable_evaluations:
|
482
498
|
return
|
483
499
|
|
484
500
|
start_time = time.time() # Record start time
|
485
|
-
|
486
|
-
input=input,
|
487
|
-
actual_output=actual_output,
|
488
|
-
expected_output=expected_output,
|
489
|
-
context=context,
|
490
|
-
retrieval_context=retrieval_context,
|
491
|
-
tools_called=tools_called,
|
492
|
-
expected_tools=expected_tools,
|
493
|
-
additional_metadata=additional_metadata,
|
494
|
-
trace_id=self.trace_id
|
495
|
-
)
|
501
|
+
|
496
502
|
try:
|
497
503
|
# Load appropriate implementations for all scorers
|
498
504
|
if not scorers:
|
@@ -507,13 +513,44 @@ class TraceClient:
|
|
507
513
|
warnings.warn(f"Failed to load scorers: {str(e)}")
|
508
514
|
return
|
509
515
|
|
516
|
+
# If example is not provided, create one from the individual parameters
|
517
|
+
if example is None:
|
518
|
+
# Check if any of the individual parameters are provided
|
519
|
+
if any(param is not None for param in [input, actual_output, expected_output, context,
|
520
|
+
retrieval_context, tools_called, expected_tools,
|
521
|
+
additional_metadata]):
|
522
|
+
example = Example(
|
523
|
+
input=input,
|
524
|
+
actual_output=actual_output,
|
525
|
+
expected_output=expected_output,
|
526
|
+
context=context,
|
527
|
+
retrieval_context=retrieval_context,
|
528
|
+
tools_called=tools_called,
|
529
|
+
expected_tools=expected_tools,
|
530
|
+
additional_metadata=additional_metadata,
|
531
|
+
trace_id=self.trace_id
|
532
|
+
)
|
533
|
+
else:
|
534
|
+
raise ValueError("Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided")
|
535
|
+
|
536
|
+
# Check examples before creating evaluation run
|
537
|
+
check_examples([example], scorers)
|
538
|
+
|
539
|
+
# --- Modification: Capture span_id immediately ---
|
540
|
+
# span_id_at_eval_call = current_span_var.get()
|
541
|
+
# print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
|
542
|
+
# Prioritize explicitly passed span_id, fallback to context var
|
543
|
+
span_id_to_use = span_id if span_id is not None else current_span_var.get()
|
544
|
+
# print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
|
545
|
+
# --- End Modification ---
|
546
|
+
|
510
547
|
# Combine the trace-level rules with any evaluation-specific rules)
|
511
548
|
eval_run = EvaluationRun(
|
512
549
|
organization_id=self.tracer.organization_id,
|
513
550
|
log_results=log_results,
|
514
551
|
project_name=self.project_name,
|
515
552
|
eval_name=f"{self.name.capitalize()}-"
|
516
|
-
f"{current_span_var.get()}-"
|
553
|
+
f"{current_span_var.get()}-" # Keep original eval name format using context var if available
|
517
554
|
f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
|
518
555
|
examples=[example],
|
519
556
|
scorers=scorers,
|
@@ -521,14 +558,18 @@ class TraceClient:
|
|
521
558
|
metadata={},
|
522
559
|
judgment_api_key=self.tracer.api_key,
|
523
560
|
override=self.overwrite,
|
524
|
-
trace_span_id=
|
561
|
+
trace_span_id=span_id_to_use, # Pass the determined ID
|
525
562
|
rules=self.rules # Use the combined rules
|
526
563
|
)
|
527
564
|
|
528
565
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
529
566
|
|
530
567
|
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
531
|
-
|
568
|
+
# --- Modification: Use span_id from eval_run ---
|
569
|
+
current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
|
570
|
+
# print(f"[TraceClient.add_eval_run] Using span_id from eval_run: {current_span_id}")
|
571
|
+
# --- End Modification ---
|
572
|
+
|
532
573
|
if current_span_id:
|
533
574
|
duration = time.time() - start_time
|
534
575
|
prev_entry = self.entries[-1] if self.entries else None
|
@@ -574,7 +615,7 @@ class TraceClient:
|
|
574
615
|
self.add_entry(TraceEntry(
|
575
616
|
type="input",
|
576
617
|
function=function_name,
|
577
|
-
span_id=current_span_id, # Use current span_id
|
618
|
+
span_id=current_span_id, # Use current span_id from context
|
578
619
|
trace_id=self.trace_id, # Use the trace_id from the trace client
|
579
620
|
depth=current_depth,
|
580
621
|
message=f"Inputs to {function_name}",
|
@@ -582,6 +623,7 @@ class TraceClient:
|
|
582
623
|
inputs=inputs,
|
583
624
|
span_type=entry_span_type,
|
584
625
|
))
|
626
|
+
# Removed else block - original didn't have one
|
585
627
|
|
586
628
|
async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
|
587
629
|
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
@@ -608,20 +650,22 @@ class TraceClient:
|
|
608
650
|
entry = TraceEntry(
|
609
651
|
type="output",
|
610
652
|
function=function_name,
|
611
|
-
span_id=current_span_id, # Use current span_id
|
653
|
+
span_id=current_span_id, # Use current span_id from context
|
612
654
|
depth=current_depth,
|
613
655
|
message=f"Output from {function_name}",
|
614
656
|
created_at=time.time(),
|
615
657
|
output="<pending>" if inspect.iscoroutine(output) else output,
|
616
658
|
span_type=entry_span_type,
|
659
|
+
trace_id=self.trace_id # Added trace_id for consistency
|
617
660
|
)
|
618
661
|
self.add_entry(entry)
|
619
662
|
|
620
663
|
if inspect.iscoroutine(output):
|
621
664
|
asyncio.create_task(self._update_coroutine_output(entry, output))
|
622
665
|
|
623
|
-
# Return the created entry
|
624
|
-
|
666
|
+
return entry # Return the created entry
|
667
|
+
# Removed else block - original didn't have one
|
668
|
+
return None # Return None if no span_id found
|
625
669
|
|
626
670
|
def add_entry(self, entry: TraceEntry):
|
627
671
|
"""Add a trace entry to this trace context"""
|
@@ -824,78 +868,144 @@ class TraceClient:
|
|
824
868
|
|
825
869
|
condensed_entries, evaluation_runs = self.condense_trace(raw_entries)
|
826
870
|
|
827
|
-
# Calculate total token counts from LLM API calls
|
828
|
-
total_prompt_tokens = 0
|
829
|
-
total_completion_tokens = 0
|
830
|
-
total_tokens = 0
|
831
|
-
|
832
|
-
total_prompt_tokens_cost = 0.0
|
833
|
-
total_completion_tokens_cost = 0.0
|
834
|
-
total_cost = 0.0
|
835
|
-
|
836
871
|
# Only count tokens for actual LLM API call spans
|
837
872
|
llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
|
838
873
|
for entry in condensed_entries:
|
839
|
-
|
874
|
+
entry_function_name = entry.get("function", "") # Get function name safely
|
875
|
+
# Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
|
876
|
+
is_llm_entry = entry.get("span_type") == "llm"
|
877
|
+
has_api_suffix = any(suffix in entry_function_name for suffix in llm_span_names)
|
878
|
+
output_is_dict = isinstance(entry.get("output"), dict)
|
879
|
+
|
880
|
+
# --- DEBUG PRINT 1: Check if condition passes ---
|
881
|
+
# if is_llm_entry and has_api_suffix and output_is_dict:
|
882
|
+
# # print(f"[DEBUG TraceClient.save] Processing entry: {entry.get('span_id')} ({entry_function_name}) - Condition PASSED")
|
883
|
+
# elif is_llm_entry:
|
884
|
+
# # Print why it failed if it was an LLM entry
|
885
|
+
# print(f"[DEBUG TraceClient.save] Skipping LLM entry: {entry.get('span_id')} ({entry_function_name}) - Suffix Match: {has_api_suffix}, Output is Dict: {output_is_dict}")
|
886
|
+
# # --- END DEBUG ---
|
887
|
+
|
888
|
+
if is_llm_entry and has_api_suffix and output_is_dict:
|
840
889
|
output = entry["output"]
|
841
|
-
usage = output.get("usage", {})
|
842
|
-
|
890
|
+
usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
|
891
|
+
|
892
|
+
# --- DEBUG PRINT 2: Check extracted usage ---
|
893
|
+
# print(f"[DEBUG TraceClient.save] Extracted usage dict: {usage}")
|
894
|
+
# --- END DEBUG ---
|
895
|
+
|
896
|
+
# --- NEW: Extract model_name correctly from nested inputs ---
|
897
|
+
model_name = None
|
898
|
+
entry_inputs = entry.get("inputs", {})
|
899
|
+
# print(f"[DEBUG TraceClient.save] Inspecting inputs for span {entry.get('span_id')}: {entry_inputs}") # DEBUG Inputs
|
900
|
+
if entry_inputs:
|
901
|
+
# Try common locations for model name within the inputs structure
|
902
|
+
invocation_params = entry_inputs.get("invocation_params", {})
|
903
|
+
serialized_data = entry_inputs.get("serialized", {})
|
904
|
+
|
905
|
+
# Look in invocation_params (often directly contains model)
|
906
|
+
if isinstance(invocation_params, dict):
|
907
|
+
model_name = invocation_params.get("model")
|
908
|
+
|
909
|
+
# Fallback: Check serialized 'repr' if it contains model info
|
910
|
+
if not model_name and isinstance(serialized_data, dict):
|
911
|
+
serialized_repr = serialized_data.get("repr", "")
|
912
|
+
if "model_name=" in serialized_repr:
|
913
|
+
try: # Simple parsing attempt
|
914
|
+
model_name = serialized_repr.split("model_name='")[1].split("'")[0]
|
915
|
+
except IndexError: pass # Ignore parsing errors
|
916
|
+
|
917
|
+
# Fallback: Check top-level of invocation_params (sometimes passed flat)
|
918
|
+
if not model_name and isinstance(invocation_params, dict):
|
919
|
+
model_name = invocation_params.get("model") # Redundant check, but safe
|
920
|
+
|
921
|
+
# Fallback: Check top-level of inputs itself (less likely for callbacks)
|
922
|
+
if not model_name:
|
923
|
+
model_name = entry_inputs.get("model")
|
924
|
+
|
925
|
+
|
926
|
+
# print(f"[DEBUG TraceClient.save] Determined model_name: {model_name}") # DEBUG Model Name
|
927
|
+
# --- END NEW ---
|
928
|
+
|
843
929
|
prompt_tokens = 0
|
844
|
-
completion_tokens = 0
|
845
|
-
|
846
|
-
# Handle OpenAI/Together format
|
930
|
+
completion_tokens = 0
|
931
|
+
|
932
|
+
# Handle OpenAI/Together format (checks within the 'usage' dict)
|
847
933
|
if "prompt_tokens" in usage:
|
848
934
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
849
935
|
completion_tokens = usage.get("completion_tokens", 0)
|
850
|
-
|
851
|
-
|
852
|
-
# Handle Anthropic format
|
936
|
+
|
937
|
+
# Handle Anthropic format - MAP values to standard keys
|
853
938
|
elif "input_tokens" in usage:
|
854
|
-
prompt_tokens = usage.get("input_tokens", 0)
|
855
|
-
completion_tokens = usage.get("output_tokens", 0)
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
939
|
+
prompt_tokens = usage.get("input_tokens", 0) # Get value from input_tokens
|
940
|
+
completion_tokens = usage.get("output_tokens", 0) # Get value from output_tokens
|
941
|
+
|
942
|
+
# *** Overwrite the usage dict in the entry to use standard keys ***
|
943
|
+
original_total = usage.get("total_tokens", 0)
|
944
|
+
original_total_cost = usage.get("total_cost_usd", 0.0) # Preserve if already calculated
|
945
|
+
# Recalculate cost just in case it wasn't done correctly before
|
946
|
+
temp_prompt_cost, temp_completion_cost = 0.0, 0.0
|
947
|
+
if model_name:
|
948
|
+
try:
|
949
|
+
temp_prompt_cost, temp_completion_cost = cost_per_token(
|
950
|
+
model=model_name,
|
951
|
+
prompt_tokens=prompt_tokens,
|
952
|
+
completion_tokens=completion_tokens
|
953
|
+
)
|
954
|
+
except Exception:
|
955
|
+
pass # Ignore cost calculation errors here, focus on keys
|
956
|
+
# Replace the usage dict with one using standard keys but Anthropic values
|
957
|
+
output["usage"] = {
|
958
|
+
"prompt_tokens": prompt_tokens,
|
959
|
+
"completion_tokens": completion_tokens,
|
960
|
+
"total_tokens": original_total,
|
961
|
+
"prompt_tokens_cost_usd": temp_prompt_cost, # Use standard cost key
|
962
|
+
"completion_tokens_cost_usd": temp_completion_cost, # Use standard cost key
|
963
|
+
"total_cost_usd": original_total_cost if original_total_cost > 0 else (temp_prompt_cost + temp_completion_cost)
|
964
|
+
}
|
965
|
+
usage = output["usage"]
|
966
|
+
|
967
|
+
# Calculate costs if model name is available and ensure they are stored with standard keys
|
968
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
969
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
860
970
|
|
861
971
|
# Calculate costs if model name is available
|
862
972
|
if model_name:
|
863
973
|
try:
|
974
|
+
# Recalculate costs based on potentially mapped tokens
|
864
975
|
prompt_cost, completion_cost = cost_per_token(
|
865
976
|
model=model_name,
|
866
977
|
prompt_tokens=prompt_tokens,
|
867
978
|
completion_tokens=completion_tokens
|
868
979
|
)
|
869
|
-
total_prompt_tokens_cost += prompt_cost
|
870
|
-
total_completion_tokens_cost += completion_cost
|
871
|
-
total_cost += prompt_cost + completion_cost
|
872
980
|
|
873
981
|
# Add cost information directly to the usage dictionary in the condensed entry
|
982
|
+
# Ensure 'usage' exists in the output dict before modifying it
|
983
|
+
# Add/Update cost information using standard keys
|
984
|
+
|
874
985
|
if "usage" not in output:
|
875
|
-
output["usage"] = {}
|
986
|
+
output["usage"] = {} # Initialize if missing
|
987
|
+
elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
|
988
|
+
print(f"[WARN TraceClient.save] Output 'usage' for span {entry.get('span_id')} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
|
989
|
+
output["usage"] = {} # Reset to dict
|
990
|
+
|
876
991
|
output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
|
877
992
|
output["usage"]["completion_tokens_cost_usd"] = completion_cost
|
878
993
|
output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
|
879
994
|
except Exception as e:
|
880
995
|
# If cost calculation fails, continue without adding costs
|
881
|
-
print(f"Error calculating cost for model '{model_name}': {str(e)}")
|
996
|
+
print(f"Error calculating cost for model '{model_name}' (span: {entry.get('span_id')}): {str(e)}")
|
882
997
|
pass
|
998
|
+
else:
|
999
|
+
print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {entry.get('span_id')}). Inputs: {entry_inputs}")
|
1000
|
+
|
883
1001
|
|
884
|
-
# Create trace document
|
1002
|
+
# Create trace document - Always use standard keys for top-level counts
|
885
1003
|
trace_data = {
|
886
1004
|
"trace_id": self.trace_id,
|
887
1005
|
"name": self.name,
|
888
1006
|
"project_name": self.project_name,
|
889
1007
|
"created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
|
890
1008
|
"duration": total_duration,
|
891
|
-
"token_counts": {
|
892
|
-
"prompt_tokens": total_prompt_tokens,
|
893
|
-
"completion_tokens": total_completion_tokens,
|
894
|
-
"total_tokens": total_tokens,
|
895
|
-
"prompt_tokens_cost_usd": total_prompt_tokens_cost,
|
896
|
-
"completion_tokens_cost_usd": total_completion_tokens_cost,
|
897
|
-
"total_cost_usd": total_cost
|
898
|
-
},
|
899
1009
|
"entries": condensed_entries,
|
900
1010
|
"evaluation_runs": evaluation_runs,
|
901
1011
|
"overwrite": overwrite,
|
@@ -903,12 +1013,6 @@ class TraceClient:
|
|
903
1013
|
"parent_name": self.parent_name
|
904
1014
|
}
|
905
1015
|
# --- Log trace data before saving ---
|
906
|
-
try:
|
907
|
-
rprint(f"[TraceClient.save] Saving trace data for trace_id {self.trace_id}:")
|
908
|
-
rprint(json.dumps(trace_data, indent=2))
|
909
|
-
except Exception as log_e:
|
910
|
-
rprint(f"[TraceClient.save] Error logging trace data: {log_e}")
|
911
|
-
# --- End logging ---
|
912
1016
|
self.trace_manager_client.save_trace(trace_data)
|
913
1017
|
|
914
1018
|
return self.trace_id, trace_data
|
@@ -958,6 +1062,7 @@ class Tracer:
|
|
958
1062
|
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
959
1063
|
self.organization_id: str = organization_id
|
960
1064
|
self._current_trace: Optional[str] = None
|
1065
|
+
self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
|
961
1066
|
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
962
1067
|
self.initialized: bool = True
|
963
1068
|
self.enable_monitoring: bool = enable_monitoring
|
@@ -991,10 +1096,29 @@ class Tracer:
|
|
991
1096
|
|
992
1097
|
def get_current_trace(self) -> Optional[TraceClient]:
|
993
1098
|
"""
|
994
|
-
Get the current trace context
|
1099
|
+
Get the current trace context.
|
1100
|
+
|
1101
|
+
Tries to get the trace client from the context variable first.
|
1102
|
+
If not found (e.g., context lost across threads/tasks),
|
1103
|
+
it falls back to the active trace client managed by the callback handler.
|
995
1104
|
"""
|
996
|
-
|
1105
|
+
trace_from_context = current_trace_var.get()
|
1106
|
+
if trace_from_context:
|
1107
|
+
return trace_from_context
|
997
1108
|
|
1109
|
+
# Fallback: Check the active client potentially set by a callback handler
|
1110
|
+
if hasattr(self, '_active_trace_client') and self._active_trace_client:
|
1111
|
+
# warnings.warn("Falling back to _active_trace_client in get_current_trace. ContextVar might be lost.", RuntimeWarning)
|
1112
|
+
return self._active_trace_client
|
1113
|
+
|
1114
|
+
# If neither is available
|
1115
|
+
# warnings.warn("No current trace found in context variable or active client fallback.", RuntimeWarning)
|
1116
|
+
return None
|
1117
|
+
|
1118
|
+
def get_active_trace_client(self) -> Optional[TraceClient]:
|
1119
|
+
"""Returns the TraceClient instance currently marked as active by the handler."""
|
1120
|
+
return self._active_trace_client
|
1121
|
+
|
998
1122
|
def _apply_deep_tracing(self, func, span_type="span"):
|
999
1123
|
"""
|
1000
1124
|
Apply deep tracing to all functions in the same module as the given function.
|
@@ -1314,43 +1438,29 @@ class Tracer:
|
|
1314
1438
|
|
1315
1439
|
return wrapper
|
1316
1440
|
|
1317
|
-
def score(self, func=None, scorers: List[Union[APIJudgmentScorer, JudgevalScorer]] = None, model: str = None, log_results: bool = True, *, name: str = None, span_type: SpanType = "span"):
|
1318
|
-
"""
|
1319
|
-
Decorator to trace function execution with detailed entry/exit information.
|
1320
|
-
"""
|
1321
|
-
if func is None:
|
1322
|
-
return lambda f: self.score(f, scorers=scorers, model=model, log_results=log_results, name=name, span_type=span_type)
|
1323
|
-
|
1324
|
-
if asyncio.iscoroutinefunction(func):
|
1325
|
-
@functools.wraps(func)
|
1326
|
-
async def async_wrapper(*args, **kwargs):
|
1327
|
-
# Get current trace from contextvars
|
1328
|
-
current_trace = current_trace_var.get()
|
1329
|
-
if current_trace and scorers:
|
1330
|
-
current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
|
1331
|
-
return await func(*args, **kwargs)
|
1332
|
-
return async_wrapper
|
1333
|
-
else:
|
1334
|
-
@functools.wraps(func)
|
1335
|
-
def wrapper(*args, **kwargs):
|
1336
|
-
# Get current trace from contextvars
|
1337
|
-
current_trace = current_trace_var.get()
|
1338
|
-
if current_trace and scorers:
|
1339
|
-
current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
|
1340
|
-
return func(*args, **kwargs)
|
1341
|
-
return wrapper
|
1342
|
-
|
1343
1441
|
def async_evaluate(self, *args, **kwargs):
|
1344
1442
|
if not self.enable_evaluations:
|
1345
1443
|
return
|
1346
1444
|
|
1347
|
-
# Get
|
1445
|
+
# --- Get trace_id passed explicitly (if any) ---
|
1446
|
+
passed_trace_id = kwargs.pop('trace_id', None) # Get and remove trace_id from kwargs
|
1447
|
+
|
1448
|
+
# --- Get current trace from context FIRST ---
|
1348
1449
|
current_trace = current_trace_var.get()
|
1349
|
-
|
1450
|
+
|
1451
|
+
# --- Fallback Logic: Use active client only if context var is empty ---
|
1452
|
+
if not current_trace:
|
1453
|
+
current_trace = self._active_trace_client # Use the fallback
|
1454
|
+
# --- End Fallback Logic ---
|
1455
|
+
|
1350
1456
|
if current_trace:
|
1457
|
+
# Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
|
1458
|
+
# (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
|
1459
|
+
if passed_trace_id:
|
1460
|
+
kwargs['trace_id'] = passed_trace_id # Re-add if needed by TraceClient.async_evaluate
|
1351
1461
|
current_trace.async_evaluate(*args, **kwargs)
|
1352
1462
|
else:
|
1353
|
-
warnings.warn("No trace found, skipping evaluation")
|
1463
|
+
warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
|
1354
1464
|
|
1355
1465
|
|
1356
1466
|
def wrap(client: Any) -> Any:
|
@@ -1600,8 +1710,8 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1600
1710
|
return {
|
1601
1711
|
"content": response.content[0].text,
|
1602
1712
|
"usage": {
|
1603
|
-
"
|
1604
|
-
"
|
1713
|
+
"prompt_tokens": response.usage.input_tokens,
|
1714
|
+
"completion_tokens": response.usage.output_tokens,
|
1605
1715
|
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
|
1606
1716
|
}
|
1607
1717
|
}
|
@@ -1891,8 +2001,8 @@ async def _async_stream_wrapper(
|
|
1891
2001
|
anthropic_final_usage = None
|
1892
2002
|
if isinstance(client, (AsyncAnthropic, Anthropic)) and (anthropic_input_tokens > 0 or anthropic_output_tokens > 0):
|
1893
2003
|
anthropic_final_usage = {
|
1894
|
-
"
|
1895
|
-
"
|
2004
|
+
"prompt_tokens": anthropic_input_tokens,
|
2005
|
+
"completion_tokens": anthropic_output_tokens,
|
1896
2006
|
"total_tokens": anthropic_input_tokens + anthropic_output_tokens
|
1897
2007
|
}
|
1898
2008
|
|
@@ -2080,3 +2190,127 @@ class _TracedSyncStreamManagerWrapper(AbstractContextManager):
|
|
2080
2190
|
if hasattr(self._original_manager, "__exit__"):
|
2081
2191
|
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
2082
2192
|
return None
|
2193
|
+
|
2194
|
+
# --- NEW Generalized Helper Function (Moved from demo) ---
|
2195
|
+
def prepare_evaluation_for_state(
|
2196
|
+
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
2197
|
+
example: Optional[Example] = None,
|
2198
|
+
# --- Individual components (alternative to 'example') ---
|
2199
|
+
input: Optional[str] = None,
|
2200
|
+
actual_output: Optional[Union[str, List[str]]] = None,
|
2201
|
+
expected_output: Optional[Union[str, List[str]]] = None,
|
2202
|
+
context: Optional[List[str]] = None,
|
2203
|
+
retrieval_context: Optional[List[str]] = None,
|
2204
|
+
tools_called: Optional[List[str]] = None,
|
2205
|
+
expected_tools: Optional[List[str]] = None,
|
2206
|
+
additional_metadata: Optional[Dict[str, Any]] = None,
|
2207
|
+
# --- Other eval parameters ---
|
2208
|
+
model: Optional[str] = None,
|
2209
|
+
log_results: Optional[bool] = True
|
2210
|
+
) -> Optional[EvaluationConfig]:
|
2211
|
+
"""
|
2212
|
+
Prepares an EvaluationConfig object, similar to TraceClient.async_evaluate.
|
2213
|
+
|
2214
|
+
Accepts either a pre-made Example object or individual components to construct one.
|
2215
|
+
Returns the EvaluationConfig object ready to be placed in the state, or None.
|
2216
|
+
"""
|
2217
|
+
final_example = example
|
2218
|
+
|
2219
|
+
# If example is not provided, try to construct one from individual parts
|
2220
|
+
if final_example is None:
|
2221
|
+
# Basic validation: Ensure at least actual_output is present for most scorers
|
2222
|
+
if actual_output is None:
|
2223
|
+
# print("[prepare_evaluation_for_state] Warning: 'actual_output' is required when 'example' is not provided. Skipping evaluation setup.")
|
2224
|
+
return None
|
2225
|
+
try:
|
2226
|
+
final_example = Example(
|
2227
|
+
input=input,
|
2228
|
+
actual_output=actual_output,
|
2229
|
+
expected_output=expected_output,
|
2230
|
+
context=context,
|
2231
|
+
retrieval_context=retrieval_context,
|
2232
|
+
tools_called=tools_called,
|
2233
|
+
expected_tools=expected_tools,
|
2234
|
+
additional_metadata=additional_metadata,
|
2235
|
+
# trace_id will be set by the handler later if needed
|
2236
|
+
)
|
2237
|
+
# print("[prepare_evaluation_for_state] Constructed Example from individual components.")
|
2238
|
+
except Exception as e:
|
2239
|
+
# print(f"[prepare_evaluation_for_state] Error constructing Example: {e}. Skipping evaluation setup.")
|
2240
|
+
return None
|
2241
|
+
|
2242
|
+
# If we have a valid example (provided or constructed) and scorers
|
2243
|
+
if final_example and scorers:
|
2244
|
+
# TODO: Add validation like check_examples if needed here,
|
2245
|
+
# although the handler might implicitly handle some checks via TraceClient.
|
2246
|
+
return EvaluationConfig(
|
2247
|
+
scorers=scorers,
|
2248
|
+
example=final_example,
|
2249
|
+
model=model,
|
2250
|
+
log_results=log_results
|
2251
|
+
)
|
2252
|
+
elif not scorers:
|
2253
|
+
# print("[prepare_evaluation_for_state] No scorers provided. Skipping evaluation setup.")
|
2254
|
+
return None
|
2255
|
+
else: # No valid example
|
2256
|
+
# print("[prepare_evaluation_for_state] No valid Example available. Skipping evaluation setup.")
|
2257
|
+
return None
|
2258
|
+
# --- End NEW Helper Function ---
|
2259
|
+
|
2260
|
+
# --- NEW: Helper function to simplify adding eval config to state ---
|
2261
|
+
def add_evaluation_to_state(
|
2262
|
+
state: Dict[str, Any], # The LangGraph state dictionary
|
2263
|
+
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
2264
|
+
# --- Evaluation components (same as prepare_evaluation_for_state) ---
|
2265
|
+
input: Optional[str] = None,
|
2266
|
+
actual_output: Optional[Union[str, List[str]]] = None,
|
2267
|
+
expected_output: Optional[Union[str, List[str]]] = None,
|
2268
|
+
context: Optional[List[str]] = None,
|
2269
|
+
retrieval_context: Optional[List[str]] = None,
|
2270
|
+
tools_called: Optional[List[str]] = None,
|
2271
|
+
expected_tools: Optional[List[str]] = None,
|
2272
|
+
additional_metadata: Optional[Dict[str, Any]] = None,
|
2273
|
+
# --- Other eval parameters ---
|
2274
|
+
model: Optional[str] = None,
|
2275
|
+
log_results: Optional[bool] = True
|
2276
|
+
) -> None:
|
2277
|
+
"""
|
2278
|
+
Prepares an EvaluationConfig and adds it to the state dictionary
|
2279
|
+
under the '_judgeval_eval' key if successful.
|
2280
|
+
|
2281
|
+
This simplifies the process of setting up evaluations within LangGraph nodes.
|
2282
|
+
|
2283
|
+
Args:
|
2284
|
+
state: The LangGraph state dictionary to modify.
|
2285
|
+
scorers: List of scorer instances.
|
2286
|
+
input: Input for the evaluation example.
|
2287
|
+
actual_output: Actual output for the evaluation example.
|
2288
|
+
expected_output: Expected output for the evaluation example.
|
2289
|
+
context: Context for the evaluation example.
|
2290
|
+
retrieval_context: Retrieval context for the evaluation example.
|
2291
|
+
tools_called: Tools called for the evaluation example.
|
2292
|
+
expected_tools: Expected tools for the evaluation example.
|
2293
|
+
additional_metadata: Additional metadata for the evaluation example.
|
2294
|
+
model: Model name used for generation (optional).
|
2295
|
+
log_results: Whether to log evaluation results (optional, defaults to True).
|
2296
|
+
"""
|
2297
|
+
eval_config = prepare_evaluation_for_state(
|
2298
|
+
scorers=scorers,
|
2299
|
+
input=input,
|
2300
|
+
actual_output=actual_output,
|
2301
|
+
expected_output=expected_output,
|
2302
|
+
context=context,
|
2303
|
+
retrieval_context=retrieval_context,
|
2304
|
+
tools_called=tools_called,
|
2305
|
+
expected_tools=expected_tools,
|
2306
|
+
additional_metadata=additional_metadata,
|
2307
|
+
model=model,
|
2308
|
+
log_results=log_results
|
2309
|
+
)
|
2310
|
+
|
2311
|
+
if eval_config:
|
2312
|
+
state["_judgeval_eval"] = eval_config
|
2313
|
+
# print(f"[_judgeval_eval added to state for node]") # Optional: Log confirmation
|
2314
|
+
|
2315
|
+
# print("[Skipped adding _judgeval_eval to state: prepare_evaluation_for_state failed]")
|
2316
|
+
# --- End NEW Helper ---
|