judgeval 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +3 -1
- judgeval/common/s3_storage.py +93 -0
- judgeval/common/tracer.py +869 -183
- judgeval/constants.py +1 -1
- judgeval/data/datasets/dataset.py +5 -1
- judgeval/data/datasets/eval_dataset_client.py +2 -2
- judgeval/data/sequence.py +16 -26
- judgeval/data/sequence_run.py +2 -0
- judgeval/judgment_client.py +44 -166
- judgeval/rules.py +4 -7
- judgeval/run_evaluation.py +2 -2
- judgeval/scorers/__init__.py +4 -4
- judgeval/scorers/judgeval_scorers/__init__.py +0 -176
- judgeval/version_check.py +22 -0
- {judgeval-0.0.31.dist-info → judgeval-0.0.33.dist-info}/METADATA +15 -2
- judgeval-0.0.33.dist-info/RECORD +63 -0
- judgeval/scorers/base_scorer.py +0 -58
- judgeval/scorers/judgeval_scorers/local_implementations/__init__.py +0 -27
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/__init__.py +0 -4
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/answer_correctness_scorer.py +0 -276
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/prompts.py +0 -169
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/__init__.py +0 -4
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/answer_relevancy_scorer.py +0 -298
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/prompts.py +0 -174
- judgeval/scorers/judgeval_scorers/local_implementations/comparison/__init__.py +0 -0
- judgeval/scorers/judgeval_scorers/local_implementations/comparison/comparison_scorer.py +0 -161
- judgeval/scorers/judgeval_scorers/local_implementations/comparison/prompts.py +0 -222
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/contextual_precision_scorer.py +0 -264
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/prompts.py +0 -106
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/contextual_recall_scorer.py +0 -254
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/prompts.py +0 -142
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/contextual_relevancy_scorer.py +0 -245
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/prompts.py +0 -121
- judgeval/scorers/judgeval_scorers/local_implementations/execution_order/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/execution_order/execution_order.py +0 -156
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/faithfulness_scorer.py +0 -318
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/prompts.py +0 -268
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/hallucination_scorer.py +0 -264
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/prompts.py +0 -104
- judgeval/scorers/judgeval_scorers/local_implementations/instruction_adherence/instruction_adherence.py +0 -232
- judgeval/scorers/judgeval_scorers/local_implementations/instruction_adherence/prompt.py +0 -102
- judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/__init__.py +0 -5
- judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/json_correctness_scorer.py +0 -134
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/prompts.py +0 -247
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/summarization_scorer.py +0 -551
- judgeval-0.0.31.dist-info/RECORD +0 -96
- {judgeval-0.0.31.dist-info → judgeval-0.0.33.dist-info}/WHEEL +0 -0
- {judgeval-0.0.31.dist-info → judgeval-0.0.33.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -11,12 +11,28 @@ import time
|
|
11
11
|
import uuid
|
12
12
|
import warnings
|
13
13
|
import contextvars
|
14
|
-
|
14
|
+
import sys
|
15
|
+
from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
|
15
16
|
from dataclasses import dataclass, field
|
16
17
|
from datetime import datetime
|
17
18
|
from http import HTTPStatus
|
18
|
-
from typing import
|
19
|
+
from typing import (
|
20
|
+
Any,
|
21
|
+
Callable,
|
22
|
+
Dict,
|
23
|
+
Generator,
|
24
|
+
List,
|
25
|
+
Literal,
|
26
|
+
Optional,
|
27
|
+
Tuple,
|
28
|
+
Type,
|
29
|
+
TypeVar,
|
30
|
+
Union,
|
31
|
+
AsyncGenerator,
|
32
|
+
TypeAlias,
|
33
|
+
)
|
19
34
|
from rich import print as rprint
|
35
|
+
import types # <--- Add this import
|
20
36
|
|
21
37
|
# Third-party imports
|
22
38
|
import pika
|
@@ -27,6 +43,7 @@ from rich import print as rprint
|
|
27
43
|
from openai import OpenAI, AsyncOpenAI
|
28
44
|
from together import Together, AsyncTogether
|
29
45
|
from anthropic import Anthropic, AsyncAnthropic
|
46
|
+
from google import genai
|
30
47
|
|
31
48
|
# Local application/library-specific imports
|
32
49
|
from judgeval.constants import (
|
@@ -40,20 +57,22 @@ from judgeval.constants import (
|
|
40
57
|
)
|
41
58
|
from judgeval.judgment_client import JudgmentClient
|
42
59
|
from judgeval.data import Example
|
43
|
-
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
60
|
+
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
44
61
|
from judgeval.rules import Rule
|
45
62
|
from judgeval.evaluation_run import EvaluationRun
|
46
63
|
from judgeval.data.result import ScoringResult
|
47
64
|
|
48
65
|
# Standard library imports needed for the new class
|
49
66
|
import concurrent.futures
|
67
|
+
from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
|
50
68
|
|
51
69
|
# Define context variables for tracking the current trace and the current span within a trace
|
52
70
|
current_trace_var = contextvars.ContextVar('current_trace', default=None)
|
53
|
-
current_span_var = contextvars.ContextVar('current_span', default=None) #
|
71
|
+
current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
|
72
|
+
in_traced_function_var = contextvars.ContextVar('in_traced_function', default=False) # Track if we're in a traced function
|
54
73
|
|
55
74
|
# Define type aliases for better code readability and maintainability
|
56
|
-
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether] # Supported API clients
|
75
|
+
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
|
57
76
|
TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
|
58
77
|
SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
|
59
78
|
@dataclass
|
@@ -170,7 +189,7 @@ class TraceEntry:
|
|
170
189
|
"inputs": self._serialize_inputs(),
|
171
190
|
"evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
|
172
191
|
"span_type": self.span_type,
|
173
|
-
"parent_span_id": self.parent_span_id
|
192
|
+
"parent_span_id": self.parent_span_id,
|
174
193
|
}
|
175
194
|
|
176
195
|
def _serialize_output(self) -> Any:
|
@@ -185,6 +204,15 @@ class TraceEntry:
|
|
185
204
|
if isinstance(self.output, BaseModel):
|
186
205
|
return self.output.model_dump()
|
187
206
|
|
207
|
+
# NEW check: If output is the dict structure from our stream wrapper
|
208
|
+
if isinstance(self.output, dict) and 'streamed' in self.output:
|
209
|
+
# Assume it's already JSON-serializable (content is string, usage is dict or None)
|
210
|
+
return self.output
|
211
|
+
# NEW check: If output is the placeholder string before stream completes
|
212
|
+
elif self.output == "<pending stream>":
|
213
|
+
# Represent this state clearly in the serialized data
|
214
|
+
return {"status": "pending stream"}
|
215
|
+
|
188
216
|
try:
|
189
217
|
# Try to serialize the output to verify it's JSON compatible
|
190
218
|
json.dumps(self.output)
|
@@ -203,9 +231,10 @@ class TraceManagerClient:
|
|
203
231
|
- Saving a trace
|
204
232
|
- Deleting a trace
|
205
233
|
"""
|
206
|
-
def __init__(self, judgment_api_key: str, organization_id: str):
|
234
|
+
def __init__(self, judgment_api_key: str, organization_id: str, tracer: Optional["Tracer"] = None):
|
207
235
|
self.judgment_api_key = judgment_api_key
|
208
236
|
self.organization_id = organization_id
|
237
|
+
self.tracer = tracer
|
209
238
|
|
210
239
|
def fetch_trace(self, trace_id: str):
|
211
240
|
"""
|
@@ -233,12 +262,13 @@ class TraceManagerClient:
|
|
233
262
|
|
234
263
|
def save_trace(self, trace_data: dict):
|
235
264
|
"""
|
236
|
-
Saves a trace to the
|
265
|
+
Saves a trace to the Judgment Supabase and optionally to S3 if configured.
|
237
266
|
|
238
267
|
Args:
|
239
268
|
trace_data: The trace data to save
|
240
269
|
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
241
270
|
"""
|
271
|
+
# Save to Judgment API
|
242
272
|
response = requests.post(
|
243
273
|
JUDGMENT_TRACES_SAVE_API_URL,
|
244
274
|
json=trace_data,
|
@@ -255,6 +285,18 @@ class TraceManagerClient:
|
|
255
285
|
elif response.status_code != HTTPStatus.OK:
|
256
286
|
raise ValueError(f"Failed to save trace data: {response.text}")
|
257
287
|
|
288
|
+
# If S3 storage is enabled, save to S3 as well
|
289
|
+
if self.tracer and self.tracer.use_s3:
|
290
|
+
try:
|
291
|
+
s3_key = self.tracer.s3_storage.save_trace(
|
292
|
+
trace_data=trace_data,
|
293
|
+
trace_id=trace_data["trace_id"],
|
294
|
+
project_name=trace_data["project_name"]
|
295
|
+
)
|
296
|
+
print(f"Trace also saved to S3 at key: {s3_key}")
|
297
|
+
except Exception as e:
|
298
|
+
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
299
|
+
|
258
300
|
if "ui_results_url" in response.json():
|
259
301
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
|
260
302
|
rprint(pretty_str)
|
@@ -352,7 +394,7 @@ class TraceClient:
|
|
352
394
|
self.client: JudgmentClient = tracer.client
|
353
395
|
self.entries: List[TraceEntry] = []
|
354
396
|
self.start_time = time.time()
|
355
|
-
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id)
|
397
|
+
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
|
356
398
|
self.visited_nodes = []
|
357
399
|
self.executed_tools = []
|
358
400
|
self.executed_node_tools = []
|
@@ -390,13 +432,13 @@ class TraceClient:
|
|
390
432
|
entry = TraceEntry(
|
391
433
|
type="enter",
|
392
434
|
function=name,
|
393
|
-
span_id=span_id,
|
394
|
-
trace_id=self.trace_id,
|
435
|
+
span_id=span_id,
|
436
|
+
trace_id=self.trace_id,
|
395
437
|
depth=current_depth,
|
396
438
|
message=name,
|
397
439
|
created_at=start_time,
|
398
440
|
span_type=span_type,
|
399
|
-
parent_span_id=parent_span_id
|
441
|
+
parent_span_id=parent_span_id,
|
400
442
|
)
|
401
443
|
self.add_entry(entry)
|
402
444
|
|
@@ -414,7 +456,7 @@ class TraceClient:
|
|
414
456
|
message=f"← {name}",
|
415
457
|
created_at=time.time(),
|
416
458
|
duration=duration,
|
417
|
-
span_type=span_type
|
459
|
+
span_type=span_type,
|
418
460
|
))
|
419
461
|
# Clean up depth tracking for this span_id
|
420
462
|
if span_id in self._span_depths:
|
@@ -451,47 +493,14 @@ class TraceClient:
|
|
451
493
|
additional_metadata=additional_metadata,
|
452
494
|
trace_id=self.trace_id
|
453
495
|
)
|
454
|
-
loaded_rules = None
|
455
|
-
if self.rules:
|
456
|
-
loaded_rules = []
|
457
|
-
for rule in self.rules:
|
458
|
-
processed_conditions = []
|
459
|
-
for condition in rule.conditions:
|
460
|
-
# Convert metric if it's a ScorerWrapper
|
461
|
-
try:
|
462
|
-
if isinstance(condition.metric, ScorerWrapper):
|
463
|
-
condition_copy = condition.model_copy()
|
464
|
-
condition_copy.metric = condition.metric.load_implementation(use_judgment=True)
|
465
|
-
processed_conditions.append(condition_copy)
|
466
|
-
else:
|
467
|
-
processed_conditions.append(condition)
|
468
|
-
except Exception as e:
|
469
|
-
warnings.warn(f"Failed to convert ScorerWrapper in rule '{rule.name}', condition metric '{condition.metric_name}': {str(e)}")
|
470
|
-
processed_conditions.append(condition) # Keep original condition as fallback
|
471
|
-
|
472
|
-
# Create new rule with processed conditions
|
473
|
-
new_rule = rule.model_copy()
|
474
|
-
new_rule.conditions = processed_conditions
|
475
|
-
loaded_rules.append(new_rule)
|
476
496
|
try:
|
477
497
|
# Load appropriate implementations for all scorers
|
478
|
-
|
479
|
-
for scorer in scorers:
|
480
|
-
try:
|
481
|
-
if isinstance(scorer, ScorerWrapper):
|
482
|
-
loaded_scorers.append(scorer.load_implementation(use_judgment=True))
|
483
|
-
else:
|
484
|
-
loaded_scorers.append(scorer)
|
485
|
-
except Exception as e:
|
486
|
-
warnings.warn(f"Failed to load implementation for scorer {scorer}: {str(e)}")
|
487
|
-
# Skip this scorer
|
488
|
-
|
489
|
-
if not loaded_scorers:
|
498
|
+
if not scorers:
|
490
499
|
warnings.warn("No valid scorers available for evaluation")
|
491
500
|
return
|
492
501
|
|
493
502
|
# Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
|
494
|
-
if
|
503
|
+
if self.rules and any(isinstance(scorer, JudgevalScorer) for scorer in scorers):
|
495
504
|
raise ValueError("Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types.")
|
496
505
|
|
497
506
|
except Exception as e:
|
@@ -505,15 +514,15 @@ class TraceClient:
|
|
505
514
|
project_name=self.project_name,
|
506
515
|
eval_name=f"{self.name.capitalize()}-"
|
507
516
|
f"{current_span_var.get()}-"
|
508
|
-
f"[{','.join(scorer.score_type.capitalize() for scorer in
|
517
|
+
f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
|
509
518
|
examples=[example],
|
510
|
-
scorers=
|
519
|
+
scorers=scorers,
|
511
520
|
model=model,
|
512
521
|
metadata={},
|
513
522
|
judgment_api_key=self.tracer.api_key,
|
514
523
|
override=self.overwrite,
|
515
524
|
trace_span_id=current_span_var.get(),
|
516
|
-
rules=
|
525
|
+
rules=self.rules # Use the combined rules
|
517
526
|
)
|
518
527
|
|
519
528
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
@@ -571,7 +580,7 @@ class TraceClient:
|
|
571
580
|
message=f"Inputs to {function_name}",
|
572
581
|
created_at=time.time(),
|
573
582
|
inputs=inputs,
|
574
|
-
span_type=entry_span_type
|
583
|
+
span_type=entry_span_type,
|
575
584
|
))
|
576
585
|
|
577
586
|
async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
|
@@ -604,12 +613,15 @@ class TraceClient:
|
|
604
613
|
message=f"Output from {function_name}",
|
605
614
|
created_at=time.time(),
|
606
615
|
output="<pending>" if inspect.iscoroutine(output) else output,
|
607
|
-
span_type=entry_span_type
|
616
|
+
span_type=entry_span_type,
|
608
617
|
)
|
609
618
|
self.add_entry(entry)
|
610
619
|
|
611
620
|
if inspect.iscoroutine(output):
|
612
621
|
asyncio.create_task(self._update_coroutine_output(entry, output))
|
622
|
+
|
623
|
+
# Return the created entry
|
624
|
+
return entry
|
613
625
|
|
614
626
|
def add_entry(self, entry: TraceEntry):
|
615
627
|
"""Add a trace entry to this trace context"""
|
@@ -821,8 +833,10 @@ class TraceClient:
|
|
821
833
|
total_completion_tokens_cost = 0.0
|
822
834
|
total_cost = 0.0
|
823
835
|
|
836
|
+
# Only count tokens for actual LLM API call spans
|
837
|
+
llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
|
824
838
|
for entry in condensed_entries:
|
825
|
-
if entry.get("span_type") == "llm" and isinstance(entry.get("output"), dict):
|
839
|
+
if entry.get("span_type") == "llm" and entry.get("function") in llm_span_names and isinstance(entry.get("output"), dict):
|
826
840
|
output = entry["output"]
|
827
841
|
usage = output.get("usage", {})
|
828
842
|
model_name = entry.get("inputs", {}).get("model", "")
|
@@ -888,6 +902,13 @@ class TraceClient:
|
|
888
902
|
"parent_trace_id": self.parent_trace_id,
|
889
903
|
"parent_name": self.parent_name
|
890
904
|
}
|
905
|
+
# --- Log trace data before saving ---
|
906
|
+
try:
|
907
|
+
rprint(f"[TraceClient.save] Saving trace data for trace_id {self.trace_id}:")
|
908
|
+
rprint(json.dumps(trace_data, indent=2))
|
909
|
+
except Exception as log_e:
|
910
|
+
rprint(f"[TraceClient.save] Error logging trace data: {log_e}")
|
911
|
+
# --- End logging ---
|
891
912
|
self.trace_manager_client.save_trace(trace_data)
|
892
913
|
|
893
914
|
return self.trace_id, trace_data
|
@@ -910,7 +931,14 @@ class Tracer:
|
|
910
931
|
rules: Optional[List[Rule]] = None, # Added rules parameter
|
911
932
|
organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
|
912
933
|
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
|
913
|
-
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true"
|
934
|
+
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true",
|
935
|
+
# S3 configuration
|
936
|
+
use_s3: bool = False,
|
937
|
+
s3_bucket_name: Optional[str] = None,
|
938
|
+
s3_aws_access_key_id: Optional[str] = None,
|
939
|
+
s3_aws_secret_access_key: Optional[str] = None,
|
940
|
+
s3_region_name: Optional[str] = None,
|
941
|
+
deep_tracing: bool = True # NEW: Enable deep tracing by default
|
914
942
|
):
|
915
943
|
if not hasattr(self, 'initialized'):
|
916
944
|
if not api_key:
|
@@ -918,6 +946,13 @@ class Tracer:
|
|
918
946
|
|
919
947
|
if not organization_id:
|
920
948
|
raise ValueError("Tracer must be configured with an Organization ID")
|
949
|
+
if use_s3 and not s3_bucket_name:
|
950
|
+
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
951
|
+
if use_s3 and not (s3_aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")):
|
952
|
+
raise ValueError("AWS Access Key ID must be provided when use_s3 is True")
|
953
|
+
if use_s3 and not (s3_aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")):
|
954
|
+
raise ValueError("AWS Secret Access Key must be provided when use_s3 is True")
|
955
|
+
|
921
956
|
self.api_key: str = api_key
|
922
957
|
self.project_name: str = project_name
|
923
958
|
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
@@ -927,6 +962,19 @@ class Tracer:
|
|
927
962
|
self.initialized: bool = True
|
928
963
|
self.enable_monitoring: bool = enable_monitoring
|
929
964
|
self.enable_evaluations: bool = enable_evaluations
|
965
|
+
|
966
|
+
# Initialize S3 storage if enabled
|
967
|
+
self.use_s3 = use_s3
|
968
|
+
if use_s3:
|
969
|
+
from judgeval.common.s3_storage import S3Storage
|
970
|
+
self.s3_storage = S3Storage(
|
971
|
+
bucket_name=s3_bucket_name,
|
972
|
+
aws_access_key_id=s3_aws_access_key_id,
|
973
|
+
aws_secret_access_key=s3_aws_secret_access_key,
|
974
|
+
region_name=s3_region_name
|
975
|
+
)
|
976
|
+
self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
|
977
|
+
|
930
978
|
elif hasattr(self, 'project_name') and self.project_name != project_name:
|
931
979
|
warnings.warn(
|
932
980
|
f"Attempting to initialize Tracer with project_name='{project_name}' but it was already initialized with "
|
@@ -941,12 +989,52 @@ class Tracer:
|
|
941
989
|
"""
|
942
990
|
current_trace_var.set(trace)
|
943
991
|
|
944
|
-
def get_current_trace(self):
|
992
|
+
def get_current_trace(self) -> Optional[TraceClient]:
|
945
993
|
"""
|
946
994
|
Get the current trace context from contextvars
|
947
995
|
"""
|
948
996
|
return current_trace_var.get()
|
949
997
|
|
998
|
+
def _apply_deep_tracing(self, func, span_type="span"):
|
999
|
+
"""
|
1000
|
+
Apply deep tracing to all functions in the same module as the given function.
|
1001
|
+
|
1002
|
+
Args:
|
1003
|
+
func: The function being traced
|
1004
|
+
span_type: Type of span to use for traced functions
|
1005
|
+
|
1006
|
+
Returns:
|
1007
|
+
A tuple of (module, original_functions_dict) where original_functions_dict
|
1008
|
+
contains the original functions that were replaced with traced versions.
|
1009
|
+
"""
|
1010
|
+
module = inspect.getmodule(func)
|
1011
|
+
if not module:
|
1012
|
+
return None, {}
|
1013
|
+
|
1014
|
+
# Save original functions
|
1015
|
+
original_functions = {}
|
1016
|
+
|
1017
|
+
# Find all functions in the module
|
1018
|
+
for name, obj in inspect.getmembers(module, inspect.isfunction):
|
1019
|
+
# Skip already wrapped functions
|
1020
|
+
if hasattr(obj, '_judgment_traced'):
|
1021
|
+
continue
|
1022
|
+
|
1023
|
+
# Create a traced version of the function
|
1024
|
+
# Always use default span type "span" for child functions
|
1025
|
+
traced_func = _create_deep_tracing_wrapper(obj, self, "span")
|
1026
|
+
|
1027
|
+
# Mark the function as traced to avoid double wrapping
|
1028
|
+
traced_func._judgment_traced = True
|
1029
|
+
|
1030
|
+
# Save the original function
|
1031
|
+
original_functions[name] = obj
|
1032
|
+
|
1033
|
+
# Replace with traced version
|
1034
|
+
setattr(module, name, traced_func)
|
1035
|
+
|
1036
|
+
return module, original_functions
|
1037
|
+
|
950
1038
|
@contextmanager
|
951
1039
|
def trace(
|
952
1040
|
self,
|
@@ -992,14 +1080,8 @@ class Tracer:
|
|
992
1080
|
finally:
|
993
1081
|
# Reset the context variable
|
994
1082
|
current_trace_var.reset(token)
|
995
|
-
|
996
|
-
def
|
997
|
-
"""
|
998
|
-
Get the current trace context from contextvars
|
999
|
-
"""
|
1000
|
-
return current_trace_var.get()
|
1001
|
-
|
1002
|
-
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False):
|
1083
|
+
|
1084
|
+
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
|
1003
1085
|
"""
|
1004
1086
|
Decorator to trace function execution with detailed entry/exit information.
|
1005
1087
|
|
@@ -1009,20 +1091,37 @@ class Tracer:
|
|
1009
1091
|
span_type: Type of span (default "span")
|
1010
1092
|
project_name: Optional project name override
|
1011
1093
|
overwrite: Whether to overwrite existing traces
|
1094
|
+
deep_tracing: Whether to enable deep tracing for this function and all nested calls.
|
1095
|
+
If None, uses the tracer's default setting.
|
1012
1096
|
"""
|
1013
1097
|
# If monitoring is disabled, return the function as is
|
1014
1098
|
if not self.enable_monitoring:
|
1015
1099
|
return func if func else lambda f: f
|
1016
1100
|
|
1017
1101
|
if func is None:
|
1018
|
-
return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name,
|
1102
|
+
return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name,
|
1103
|
+
overwrite=overwrite, deep_tracing=deep_tracing)
|
1019
1104
|
|
1020
1105
|
# Use provided name or fall back to function name
|
1021
1106
|
span_name = name or func.__name__
|
1022
1107
|
|
1108
|
+
# Store custom attributes on the function object
|
1109
|
+
func._judgment_span_name = span_name
|
1110
|
+
func._judgment_span_type = span_type
|
1111
|
+
|
1112
|
+
# Use the provided deep_tracing value or fall back to the tracer's default
|
1113
|
+
use_deep_tracing = deep_tracing if deep_tracing is not None else self.deep_tracing
|
1114
|
+
|
1023
1115
|
if asyncio.iscoroutinefunction(func):
|
1024
1116
|
@functools.wraps(func)
|
1025
1117
|
async def async_wrapper(*args, **kwargs):
|
1118
|
+
# Check if we're already in a traced function
|
1119
|
+
if in_traced_function_var.get():
|
1120
|
+
return await func(*args, **kwargs)
|
1121
|
+
|
1122
|
+
# Set in_traced_function_var to True
|
1123
|
+
token = in_traced_function_var.set(True)
|
1124
|
+
|
1026
1125
|
# Get current trace from context
|
1027
1126
|
current_trace = current_trace_var.get()
|
1028
1127
|
|
@@ -1057,9 +1156,18 @@ class Tracer:
|
|
1057
1156
|
'kwargs': kwargs
|
1058
1157
|
})
|
1059
1158
|
|
1159
|
+
# If deep tracing is enabled, apply monkey patching
|
1160
|
+
if use_deep_tracing:
|
1161
|
+
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1162
|
+
|
1060
1163
|
# Execute function
|
1061
1164
|
result = await func(*args, **kwargs)
|
1062
1165
|
|
1166
|
+
# Restore original functions if deep tracing was enabled
|
1167
|
+
if use_deep_tracing and module and 'original_functions' in locals():
|
1168
|
+
for name, obj in original_functions.items():
|
1169
|
+
setattr(module, name, obj)
|
1170
|
+
|
1063
1171
|
# Record output
|
1064
1172
|
span.record_output(result)
|
1065
1173
|
|
@@ -1069,29 +1177,52 @@ class Tracer:
|
|
1069
1177
|
finally:
|
1070
1178
|
# Reset trace context (span context resets automatically)
|
1071
1179
|
current_trace_var.reset(trace_token)
|
1180
|
+
# Reset in_traced_function_var
|
1181
|
+
in_traced_function_var.reset(token)
|
1072
1182
|
else:
|
1073
1183
|
# Already have a trace context, just create a span in it
|
1074
1184
|
# The span method handles current_span_var
|
1075
|
-
|
1076
|
-
|
1077
|
-
span
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1185
|
+
|
1186
|
+
try:
|
1187
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1188
|
+
# Record inputs
|
1189
|
+
span.record_input({
|
1190
|
+
'args': str(args),
|
1191
|
+
'kwargs': kwargs
|
1192
|
+
})
|
1193
|
+
|
1194
|
+
# If deep tracing is enabled, apply monkey patching
|
1195
|
+
if use_deep_tracing:
|
1196
|
+
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1197
|
+
|
1198
|
+
# Execute function
|
1199
|
+
result = await func(*args, **kwargs)
|
1200
|
+
|
1201
|
+
# Restore original functions if deep tracing was enabled
|
1202
|
+
if use_deep_tracing and module and 'original_functions' in locals():
|
1203
|
+
for name, obj in original_functions.items():
|
1204
|
+
setattr(module, name, obj)
|
1205
|
+
|
1206
|
+
# Record output
|
1207
|
+
span.record_output(result)
|
1087
1208
|
|
1088
1209
|
return result
|
1089
|
-
|
1210
|
+
finally:
|
1211
|
+
# Reset in_traced_function_var
|
1212
|
+
in_traced_function_var.reset(token)
|
1213
|
+
|
1090
1214
|
return async_wrapper
|
1091
1215
|
else:
|
1092
|
-
# Non-async function implementation
|
1216
|
+
# Non-async function implementation with deep tracing
|
1093
1217
|
@functools.wraps(func)
|
1094
1218
|
def wrapper(*args, **kwargs):
|
1219
|
+
# Check if we're already in a traced function
|
1220
|
+
if in_traced_function_var.get():
|
1221
|
+
return func(*args, **kwargs)
|
1222
|
+
|
1223
|
+
# Set in_traced_function_var to True
|
1224
|
+
token = in_traced_function_var.set(True)
|
1225
|
+
|
1095
1226
|
# Get current trace from context
|
1096
1227
|
current_trace = current_trace_var.get()
|
1097
1228
|
|
@@ -1126,9 +1257,18 @@ class Tracer:
|
|
1126
1257
|
'kwargs': kwargs
|
1127
1258
|
})
|
1128
1259
|
|
1260
|
+
# If deep tracing is enabled, apply monkey patching
|
1261
|
+
if use_deep_tracing:
|
1262
|
+
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1263
|
+
|
1129
1264
|
# Execute function
|
1130
1265
|
result = func(*args, **kwargs)
|
1131
1266
|
|
1267
|
+
# Restore original functions if deep tracing was enabled
|
1268
|
+
if use_deep_tracing and module and 'original_functions' in locals():
|
1269
|
+
for name, obj in original_functions.items():
|
1270
|
+
setattr(module, name, obj)
|
1271
|
+
|
1132
1272
|
# Record output
|
1133
1273
|
span.record_output(result)
|
1134
1274
|
|
@@ -1138,24 +1278,40 @@ class Tracer:
|
|
1138
1278
|
finally:
|
1139
1279
|
# Reset trace context (span context resets automatically)
|
1140
1280
|
current_trace_var.reset(trace_token)
|
1281
|
+
# Reset in_traced_function_var
|
1282
|
+
in_traced_function_var.reset(token)
|
1141
1283
|
else:
|
1142
1284
|
# Already have a trace context, just create a span in it
|
1143
1285
|
# The span method handles current_span_var
|
1144
|
-
|
1145
|
-
|
1146
|
-
span
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1286
|
+
|
1287
|
+
try:
|
1288
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1289
|
+
# Record inputs
|
1290
|
+
span.record_input({
|
1291
|
+
'args': str(args),
|
1292
|
+
'kwargs': kwargs
|
1293
|
+
})
|
1294
|
+
|
1295
|
+
# If deep tracing is enabled, apply monkey patching
|
1296
|
+
if use_deep_tracing:
|
1297
|
+
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1298
|
+
|
1299
|
+
# Execute function
|
1300
|
+
result = func(*args, **kwargs)
|
1301
|
+
|
1302
|
+
# Restore original functions if deep tracing was enabled
|
1303
|
+
if use_deep_tracing and module and 'original_functions' in locals():
|
1304
|
+
for name, obj in original_functions.items():
|
1305
|
+
setattr(module, name, obj)
|
1306
|
+
|
1307
|
+
# Record output
|
1308
|
+
span.record_output(result)
|
1156
1309
|
|
1157
1310
|
return result
|
1158
|
-
|
1311
|
+
finally:
|
1312
|
+
# Reset in_traced_function_var
|
1313
|
+
in_traced_function_var.reset(token)
|
1314
|
+
|
1159
1315
|
return wrapper
|
1160
1316
|
|
1161
1317
|
def score(self, func=None, scorers: List[Union[APIJudgmentScorer, JudgevalScorer]] = None, model: str = None, log_results: bool = True, *, name: str = None, span_type: SpanType = "span"):
|
@@ -1200,96 +1356,192 @@ class Tracer:
|
|
1200
1356
|
def wrap(client: Any) -> Any:
|
1201
1357
|
"""
|
1202
1358
|
Wraps an API client to add tracing capabilities.
|
1203
|
-
Supports OpenAI, Together, and
|
1359
|
+
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
1360
|
+
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1204
1361
|
"""
|
1205
|
-
|
1206
|
-
span_name, original_create = _get_client_config(client)
|
1362
|
+
span_name, original_create, original_stream = _get_client_config(client)
|
1207
1363
|
|
1208
|
-
#
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1364
|
+
# --- Define Traced Async Functions ---
|
1365
|
+
async def traced_create_async(*args, **kwargs):
|
1366
|
+
# [Existing logic - unchanged]
|
1367
|
+
current_trace = current_trace_var.get()
|
1368
|
+
if not current_trace:
|
1369
|
+
if asyncio.iscoroutinefunction(original_create):
|
1370
|
+
return await original_create(*args, **kwargs)
|
1371
|
+
else:
|
1372
|
+
return original_create(*args, **kwargs)
|
1373
|
+
|
1374
|
+
is_streaming = kwargs.get("stream", False)
|
1375
|
+
|
1376
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1377
|
+
input_data = _format_input_data(client, **kwargs)
|
1378
|
+
span.record_input(input_data)
|
1379
|
+
|
1380
|
+
# Warn about token counting limitations with streaming
|
1381
|
+
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
1382
|
+
if not kwargs.get("stream_options", {}).get("include_usage"):
|
1383
|
+
warnings.warn(
|
1384
|
+
"OpenAI streaming calls don't include token counts by default. "
|
1385
|
+
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
1386
|
+
"in your API call arguments.",
|
1387
|
+
UserWarning
|
1388
|
+
)
|
1217
1389
|
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1390
|
+
try:
|
1391
|
+
if is_streaming:
|
1392
|
+
stream_iterator = await original_create(*args, **kwargs)
|
1393
|
+
output_entry = span.record_output("<pending stream>")
|
1394
|
+
return _async_stream_wrapper(stream_iterator, client, output_entry)
|
1395
|
+
else:
|
1396
|
+
awaited_response = await original_create(*args, **kwargs)
|
1397
|
+
output_data = _format_output_data(client, awaited_response)
|
1398
|
+
span.record_output(output_data)
|
1399
|
+
return awaited_response
|
1400
|
+
except Exception as e:
|
1401
|
+
print(f"Error during wrapped async API call ({span_name}): {e}")
|
1402
|
+
span.record_output({"error": str(e)})
|
1403
|
+
raise
|
1404
|
+
|
1405
|
+
|
1406
|
+
# Function replacing .stream() - NOW returns the wrapper class instance
|
1407
|
+
def traced_stream_async(*args, **kwargs):
|
1408
|
+
current_trace = current_trace_var.get()
|
1409
|
+
if not current_trace or not original_stream:
|
1410
|
+
return original_stream(*args, **kwargs)
|
1411
|
+
original_manager = original_stream(*args, **kwargs)
|
1412
|
+
wrapper_manager = _TracedAsyncStreamManagerWrapper(
|
1413
|
+
original_manager=original_manager,
|
1414
|
+
client=client,
|
1415
|
+
span_name=span_name,
|
1416
|
+
trace_client=current_trace,
|
1417
|
+
stream_wrapper_func=_async_stream_wrapper,
|
1418
|
+
input_kwargs=kwargs
|
1419
|
+
)
|
1420
|
+
return wrapper_manager
|
1421
|
+
|
1422
|
+
# --- Define Traced Sync Functions ---
|
1423
|
+
def traced_create_sync(*args, **kwargs):
|
1424
|
+
# [Existing logic - unchanged]
|
1425
|
+
current_trace = current_trace_var.get()
|
1426
|
+
if not current_trace:
|
1427
|
+
return original_create(*args, **kwargs)
|
1428
|
+
|
1429
|
+
is_streaming = kwargs.get("stream", False)
|
1430
|
+
|
1431
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1432
|
+
input_data = _format_input_data(client, **kwargs)
|
1433
|
+
span.record_input(input_data)
|
1434
|
+
|
1435
|
+
# Warn about token counting limitations with streaming
|
1436
|
+
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
1437
|
+
if not kwargs.get("stream_options", {}).get("include_usage"):
|
1438
|
+
warnings.warn(
|
1439
|
+
"OpenAI streaming calls don't include token counts by default. "
|
1440
|
+
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
1441
|
+
"in your API call arguments.",
|
1442
|
+
UserWarning
|
1443
|
+
)
|
1444
|
+
|
1445
|
+
try:
|
1446
|
+
response_or_iterator = original_create(*args, **kwargs)
|
1447
|
+
except Exception as e:
|
1448
|
+
print(f"Error during wrapped sync API call ({span_name}): {e}")
|
1449
|
+
span.record_output({"error": str(e)})
|
1450
|
+
raise
|
1451
|
+
|
1452
|
+
if is_streaming:
|
1453
|
+
output_entry = span.record_output("<pending stream>")
|
1454
|
+
return _sync_stream_wrapper(response_or_iterator, client, output_entry)
|
1455
|
+
else:
|
1456
|
+
output_data = _format_output_data(client, response_or_iterator)
|
1457
|
+
span.record_output(output_data)
|
1458
|
+
return response_or_iterator
|
1459
|
+
|
1460
|
+
|
1461
|
+
# Function replacing sync .stream()
|
1462
|
+
def traced_stream_sync(*args, **kwargs):
|
1463
|
+
current_trace = current_trace_var.get()
|
1464
|
+
if not current_trace or not original_stream:
|
1465
|
+
return original_stream(*args, **kwargs)
|
1466
|
+
original_manager = original_stream(*args, **kwargs)
|
1467
|
+
wrapper_manager = _TracedSyncStreamManagerWrapper(
|
1468
|
+
original_manager=original_manager,
|
1469
|
+
client=client,
|
1470
|
+
span_name=span_name,
|
1471
|
+
trace_client=current_trace,
|
1472
|
+
stream_wrapper_func=_sync_stream_wrapper,
|
1473
|
+
input_kwargs=kwargs
|
1474
|
+
)
|
1475
|
+
return wrapper_manager
|
1476
|
+
|
1477
|
+
|
1478
|
+
# --- Assign Traced Methods to Client Instance ---
|
1479
|
+
# [Assignment logic remains the same]
|
1480
|
+
if isinstance(client, (AsyncOpenAI, AsyncTogether)):
|
1481
|
+
client.chat.completions.create = traced_create_async
|
1482
|
+
# Wrap the Responses API endpoint for OpenAI clients
|
1483
|
+
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
1484
|
+
# Capture the original responses.create
|
1485
|
+
original_responses_create = client.responses.create
|
1486
|
+
def traced_responses(*args, **kwargs):
|
1487
|
+
# Get the current trace from contextvars
|
1488
|
+
current_trace = current_trace_var.get()
|
1489
|
+
# If no active trace, call the original
|
1490
|
+
if not current_trace:
|
1491
|
+
return original_responses_create(*args, **kwargs)
|
1492
|
+
# Trace this responses.create call
|
1493
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1494
|
+
# Record raw input kwargs
|
1495
|
+
span.record_input(kwargs)
|
1496
|
+
# Make the actual API call
|
1497
|
+
response = original_responses_create(*args, **kwargs)
|
1498
|
+
# Record the output object
|
1499
|
+
span.record_output(response)
|
1500
|
+
return response
|
1501
|
+
# Assign the traced wrapper
|
1502
|
+
client.responses.create = traced_responses
|
1503
|
+
elif isinstance(client, AsyncAnthropic):
|
1504
|
+
client.messages.create = traced_create_async
|
1505
|
+
if original_stream:
|
1506
|
+
client.messages.stream = traced_stream_async
|
1507
|
+
elif isinstance(client, genai.client.AsyncClient):
|
1508
|
+
client.generate_content = traced_create_async
|
1509
|
+
elif isinstance(client, (OpenAI, Together)):
|
1510
|
+
client.chat.completions.create = traced_create_sync
|
1511
|
+
elif isinstance(client, Anthropic):
|
1512
|
+
client.messages.create = traced_create_sync
|
1513
|
+
if original_stream:
|
1514
|
+
client.messages.stream = traced_stream_sync
|
1515
|
+
elif isinstance(client, genai.Client):
|
1516
|
+
client.generate_content = traced_create_sync
|
1243
1517
|
|
1244
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
1245
|
-
# Format and record the input parameters
|
1246
|
-
input_data = _format_input_data(client, **kwargs)
|
1247
|
-
span.record_input(input_data)
|
1248
|
-
|
1249
|
-
# Make the actual API call
|
1250
|
-
try:
|
1251
|
-
response = original_create(*args, **kwargs)
|
1252
|
-
except Exception as e:
|
1253
|
-
print(f"Error during API call: {e}")
|
1254
|
-
raise
|
1255
|
-
|
1256
|
-
# Format and record the output
|
1257
|
-
output_data = _format_output_data(client, response)
|
1258
|
-
span.record_output(output_data)
|
1259
|
-
|
1260
|
-
return response
|
1261
|
-
|
1262
|
-
|
1263
|
-
# Replace the original method with our traced version
|
1264
|
-
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1265
|
-
client.chat.completions.create = traced_create
|
1266
|
-
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1267
|
-
client.messages.create = traced_create
|
1268
|
-
|
1269
1518
|
return client
|
1270
1519
|
|
1271
1520
|
# Helper functions for client-specific operations
|
1272
1521
|
|
1273
|
-
def _get_client_config(client: ApiClient) -> tuple[str, callable]:
|
1522
|
+
def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[callable]]:
|
1274
1523
|
"""Returns configuration tuple for the given API client.
|
1275
1524
|
|
1276
1525
|
Args:
|
1277
1526
|
client: An instance of OpenAI, Together, or Anthropic client
|
1278
1527
|
|
1279
1528
|
Returns:
|
1280
|
-
tuple: (span_name, create_method)
|
1529
|
+
tuple: (span_name, create_method, stream_method)
|
1281
1530
|
- span_name: String identifier for tracing
|
1282
1531
|
- create_method: Reference to the client's creation method
|
1532
|
+
- stream_method: Reference to the client's stream method (if applicable)
|
1283
1533
|
|
1284
1534
|
Raises:
|
1285
1535
|
ValueError: If client type is not supported
|
1286
1536
|
"""
|
1287
1537
|
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
1288
|
-
return "OPENAI_API_CALL", client.chat.completions.create
|
1538
|
+
return "OPENAI_API_CALL", client.chat.completions.create, None
|
1289
1539
|
elif isinstance(client, (Together, AsyncTogether)):
|
1290
|
-
return "TOGETHER_API_CALL", client.chat.completions.create
|
1540
|
+
return "TOGETHER_API_CALL", client.chat.completions.create, None
|
1291
1541
|
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1292
|
-
return "ANTHROPIC_API_CALL", client.messages.create
|
1542
|
+
return "ANTHROPIC_API_CALL", client.messages.create, client.messages.stream
|
1543
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1544
|
+
return "GOOGLE_API_CALL", client.models.generate_content, None
|
1293
1545
|
raise ValueError(f"Unsupported client type: {type(client)}")
|
1294
1546
|
|
1295
1547
|
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
@@ -1303,6 +1555,11 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
|
1303
1555
|
"model": kwargs.get("model"),
|
1304
1556
|
"messages": kwargs.get("messages"),
|
1305
1557
|
}
|
1558
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1559
|
+
return {
|
1560
|
+
"model": kwargs.get("model"),
|
1561
|
+
"contents": kwargs.get("contents")
|
1562
|
+
}
|
1306
1563
|
# Anthropic requires additional max_tokens parameter
|
1307
1564
|
return {
|
1308
1565
|
"model": kwargs.get("model"),
|
@@ -1330,6 +1587,15 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1330
1587
|
"total_tokens": response.usage.total_tokens
|
1331
1588
|
}
|
1332
1589
|
}
|
1590
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1591
|
+
return {
|
1592
|
+
"content": response.candidates[0].content.parts[0].text,
|
1593
|
+
"usage": {
|
1594
|
+
"prompt_tokens": response.usage_metadata.prompt_token_count,
|
1595
|
+
"completion_tokens": response.usage_metadata.candidates_token_count,
|
1596
|
+
"total_tokens": response.usage_metadata.total_token_count
|
1597
|
+
}
|
1598
|
+
}
|
1333
1599
|
# Anthropic has a different response structure
|
1334
1600
|
return {
|
1335
1601
|
"content": response.content[0].text,
|
@@ -1340,29 +1606,117 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1340
1606
|
}
|
1341
1607
|
}
|
1342
1608
|
|
1343
|
-
#
|
1344
|
-
#
|
1345
|
-
|
1346
|
-
#
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
#
|
1351
|
-
|
1352
|
-
#
|
1353
|
-
|
1354
|
-
#
|
1355
|
-
|
1356
|
-
#
|
1357
|
-
|
1358
|
-
#
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
#
|
1364
|
-
|
1365
|
-
|
1609
|
+
# Define a blocklist of functions that should not be traced
|
1610
|
+
# These are typically utility functions, print statements, logging, etc.
|
1611
|
+
_TRACE_BLOCKLIST = {
|
1612
|
+
# Built-in functions
|
1613
|
+
'print', 'str', 'int', 'float', 'bool', 'list', 'dict', 'set', 'tuple',
|
1614
|
+
'len', 'range', 'enumerate', 'zip', 'map', 'filter', 'sorted', 'reversed',
|
1615
|
+
'min', 'max', 'sum', 'any', 'all', 'abs', 'round', 'format',
|
1616
|
+
# Logging functions
|
1617
|
+
'debug', 'info', 'warning', 'error', 'critical', 'exception', 'log',
|
1618
|
+
# Common utility functions
|
1619
|
+
'sleep', 'time', 'datetime', 'json', 'dumps', 'loads',
|
1620
|
+
# String operations
|
1621
|
+
'join', 'split', 'strip', 'lstrip', 'rstrip', 'replace', 'lower', 'upper',
|
1622
|
+
# Dict operations
|
1623
|
+
'get', 'items', 'keys', 'values', 'update',
|
1624
|
+
# List operations
|
1625
|
+
'append', 'extend', 'insert', 'remove', 'pop', 'clear', 'index', 'count', 'sort',
|
1626
|
+
}
|
1627
|
+
|
1628
|
+
|
1629
|
+
# Add a new function for deep tracing at the module level
|
1630
|
+
def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
|
1631
|
+
"""
|
1632
|
+
Creates a wrapper for a function that automatically traces it when called within a traced function.
|
1633
|
+
This enables deep tracing without requiring explicit @observe decorators on every function.
|
1634
|
+
|
1635
|
+
Args:
|
1636
|
+
func: The function to wrap
|
1637
|
+
tracer: The Tracer instance
|
1638
|
+
span_type: Type of span (default "span")
|
1639
|
+
|
1640
|
+
Returns:
|
1641
|
+
A wrapped function that will be traced when called
|
1642
|
+
"""
|
1643
|
+
# Skip wrapping if the function is not callable or is a built-in
|
1644
|
+
if not callable(func) or isinstance(func, type) or func.__module__ == 'builtins':
|
1645
|
+
return func
|
1646
|
+
|
1647
|
+
# Skip functions in the blocklist
|
1648
|
+
if func.__name__ in _TRACE_BLOCKLIST:
|
1649
|
+
return func
|
1650
|
+
|
1651
|
+
# Skip functions from certain modules (logging, sys, etc.)
|
1652
|
+
if func.__module__ and any(func.__module__.startswith(m) for m in ['logging', 'sys', 'os', 'json', 'time', 'datetime']):
|
1653
|
+
return func
|
1654
|
+
|
1655
|
+
|
1656
|
+
# Get function name for the span - check for custom name set by @observe
|
1657
|
+
func_name = getattr(func, '_judgment_span_name', func.__name__)
|
1658
|
+
|
1659
|
+
# Check for custom span_type set by @observe
|
1660
|
+
func_span_type = getattr(func, '_judgment_span_type', "span")
|
1661
|
+
|
1662
|
+
# Store original function to prevent losing reference
|
1663
|
+
original_func = func
|
1664
|
+
|
1665
|
+
# Create appropriate wrapper based on whether the function is async or not
|
1666
|
+
if asyncio.iscoroutinefunction(func):
|
1667
|
+
@functools.wraps(func)
|
1668
|
+
async def async_deep_wrapper(*args, **kwargs):
|
1669
|
+
# Get current trace from context
|
1670
|
+
current_trace = current_trace_var.get()
|
1671
|
+
|
1672
|
+
# If no trace context, just call the function
|
1673
|
+
if not current_trace:
|
1674
|
+
return await original_func(*args, **kwargs)
|
1675
|
+
|
1676
|
+
# Create a span for this function call - use custom span_type if available
|
1677
|
+
with current_trace.span(func_name, span_type=func_span_type) as span:
|
1678
|
+
# Record inputs
|
1679
|
+
span.record_input({
|
1680
|
+
'args': str(args),
|
1681
|
+
'kwargs': kwargs
|
1682
|
+
})
|
1683
|
+
|
1684
|
+
# Execute function
|
1685
|
+
result = await original_func(*args, **kwargs)
|
1686
|
+
|
1687
|
+
# Record output
|
1688
|
+
span.record_output(result)
|
1689
|
+
|
1690
|
+
return result
|
1691
|
+
|
1692
|
+
return async_deep_wrapper
|
1693
|
+
else:
|
1694
|
+
@functools.wraps(func)
|
1695
|
+
def deep_wrapper(*args, **kwargs):
|
1696
|
+
# Get current trace from context
|
1697
|
+
current_trace = current_trace_var.get()
|
1698
|
+
|
1699
|
+
# If no trace context, just call the function
|
1700
|
+
if not current_trace:
|
1701
|
+
return original_func(*args, **kwargs)
|
1702
|
+
|
1703
|
+
# Create a span for this function call - use custom span_type if available
|
1704
|
+
with current_trace.span(func_name, span_type=func_span_type) as span:
|
1705
|
+
# Record inputs
|
1706
|
+
span.record_input({
|
1707
|
+
'args': str(args),
|
1708
|
+
'kwargs': kwargs
|
1709
|
+
})
|
1710
|
+
|
1711
|
+
# Execute function
|
1712
|
+
result = original_func(*args, **kwargs)
|
1713
|
+
|
1714
|
+
# Record output
|
1715
|
+
span.record_output(result)
|
1716
|
+
|
1717
|
+
return result
|
1718
|
+
|
1719
|
+
return deep_wrapper
|
1366
1720
|
|
1367
1721
|
# Add the new TraceThreadPoolExecutor class
|
1368
1722
|
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
@@ -1393,4 +1747,336 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
|
1393
1747
|
return super().submit(ctx.run, func_with_bound_args)
|
1394
1748
|
|
1395
1749
|
# Note: The `map` method would also need to be overridden for full context
|
1396
|
-
# propagation if users rely on it, but `submit` is the most common use case.
|
1750
|
+
# propagation if users rely on it, but `submit` is the most common use case.
|
1751
|
+
|
1752
|
+
# Helper functions for stream processing
|
1753
|
+
# ---------------------------------------
|
1754
|
+
|
1755
|
+
def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
|
1756
|
+
"""Extracts the text content from a stream chunk based on the client type."""
|
1757
|
+
try:
|
1758
|
+
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1759
|
+
return chunk.choices[0].delta.content
|
1760
|
+
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1761
|
+
# Anthropic streams various event types, we only care for content blocks
|
1762
|
+
if chunk.type == "content_block_delta":
|
1763
|
+
return chunk.delta.text
|
1764
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1765
|
+
# Google streams Candidate objects
|
1766
|
+
if chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts:
|
1767
|
+
return chunk.candidates[0].content.parts[0].text
|
1768
|
+
except (AttributeError, IndexError, KeyError):
|
1769
|
+
# Handle cases where chunk structure is unexpected or doesn't contain content
|
1770
|
+
pass # Return None
|
1771
|
+
return None
|
1772
|
+
|
1773
|
+
def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[Dict[str, int]]:
|
1774
|
+
"""Extracts usage data if present in the *final* chunk (client-specific)."""
|
1775
|
+
try:
|
1776
|
+
# OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
|
1777
|
+
# This typically requires specific API versions or settings. Often usage is *not* streamed.
|
1778
|
+
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1779
|
+
# Check if usage is directly on the chunk (some models might do this)
|
1780
|
+
if hasattr(chunk, 'usage') and chunk.usage:
|
1781
|
+
return {
|
1782
|
+
"prompt_tokens": chunk.usage.prompt_tokens,
|
1783
|
+
"completion_tokens": chunk.usage.completion_tokens,
|
1784
|
+
"total_tokens": chunk.usage.total_tokens
|
1785
|
+
}
|
1786
|
+
# Check if usage is nested within choices (less common for final chunk, but check)
|
1787
|
+
elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
|
1788
|
+
usage = chunk.choices[0].usage
|
1789
|
+
return {
|
1790
|
+
"prompt_tokens": usage.prompt_tokens,
|
1791
|
+
"completion_tokens": usage.completion_tokens,
|
1792
|
+
"total_tokens": usage.total_tokens
|
1793
|
+
}
|
1794
|
+
# Anthropic includes usage in the 'message_stop' event type
|
1795
|
+
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1796
|
+
if chunk.type == "message_stop":
|
1797
|
+
# Anthropic final usage is often attached to the *message* object, not the chunk directly
|
1798
|
+
# The API might provide a way to get the final message object, but typically not in the stream itself.
|
1799
|
+
# Let's assume for now usage might appear in the final *chunk* metadata if supported.
|
1800
|
+
# This is a placeholder - Anthropic usage typically needs a separate call or context.
|
1801
|
+
pass
|
1802
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1803
|
+
# Google provides usage metadata on the full response object, not typically streamed per chunk.
|
1804
|
+
# It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
|
1805
|
+
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
1806
|
+
return {
|
1807
|
+
"prompt_tokens": chunk.usage_metadata.prompt_token_count,
|
1808
|
+
"completion_tokens": chunk.usage_metadata.candidates_token_count,
|
1809
|
+
"total_tokens": chunk.usage_metadata.total_token_count
|
1810
|
+
}
|
1811
|
+
|
1812
|
+
except (AttributeError, IndexError, KeyError, TypeError):
|
1813
|
+
# Handle cases where usage data is missing or malformed
|
1814
|
+
pass # Return None
|
1815
|
+
return None
|
1816
|
+
|
1817
|
+
|
1818
|
+
# --- Sync Stream Wrapper ---
|
1819
|
+
def _sync_stream_wrapper(
|
1820
|
+
original_stream: Iterator,
|
1821
|
+
client: ApiClient,
|
1822
|
+
output_entry: TraceEntry
|
1823
|
+
) -> Generator[Any, None, None]:
|
1824
|
+
"""Wraps a synchronous stream iterator to capture content and update the trace."""
|
1825
|
+
content_parts = [] # Use a list instead of string concatenation
|
1826
|
+
final_usage = None
|
1827
|
+
last_chunk = None
|
1828
|
+
try:
|
1829
|
+
for chunk in original_stream:
|
1830
|
+
content_part = _extract_content_from_chunk(client, chunk)
|
1831
|
+
if content_part:
|
1832
|
+
content_parts.append(content_part) # Append to list instead of concatenating
|
1833
|
+
last_chunk = chunk # Keep track of the last chunk for potential usage data
|
1834
|
+
yield chunk # Pass the chunk to the caller
|
1835
|
+
finally:
|
1836
|
+
# Attempt to extract usage from the last chunk received
|
1837
|
+
if last_chunk:
|
1838
|
+
final_usage = _extract_usage_from_final_chunk(client, last_chunk)
|
1839
|
+
|
1840
|
+
# Update the trace entry with the accumulated content and usage
|
1841
|
+
output_entry.output = {
|
1842
|
+
"content": "".join(content_parts), # Join list at the end
|
1843
|
+
"usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
|
1844
|
+
"streamed": True
|
1845
|
+
}
|
1846
|
+
# Note: We might need to adjust _serialize_output if this dict causes issues,
|
1847
|
+
# but Pydantic's model_dump should handle dicts.
|
1848
|
+
|
1849
|
+
# --- Async Stream Wrapper ---
|
1850
|
+
async def _async_stream_wrapper(
|
1851
|
+
original_stream: AsyncIterator,
|
1852
|
+
client: ApiClient,
|
1853
|
+
output_entry: TraceEntry
|
1854
|
+
) -> AsyncGenerator[Any, None]:
|
1855
|
+
# [Existing logic - unchanged]
|
1856
|
+
content_parts = [] # Use a list instead of string concatenation
|
1857
|
+
final_usage_data = None
|
1858
|
+
last_content_chunk = None
|
1859
|
+
anthropic_input_tokens = 0
|
1860
|
+
anthropic_output_tokens = 0
|
1861
|
+
|
1862
|
+
target_span_id = getattr(output_entry, 'span_id', 'UNKNOWN')
|
1863
|
+
|
1864
|
+
try:
|
1865
|
+
async for chunk in original_stream:
|
1866
|
+
# Check for OpenAI's final usage chunk
|
1867
|
+
if isinstance(client, (AsyncOpenAI, OpenAI)) and hasattr(chunk, 'usage') and chunk.usage is not None:
|
1868
|
+
final_usage_data = {
|
1869
|
+
"prompt_tokens": chunk.usage.prompt_tokens,
|
1870
|
+
"completion_tokens": chunk.usage.completion_tokens,
|
1871
|
+
"total_tokens": chunk.usage.total_tokens
|
1872
|
+
}
|
1873
|
+
yield chunk
|
1874
|
+
continue
|
1875
|
+
|
1876
|
+
if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(chunk, 'type'):
|
1877
|
+
if chunk.type == "message_start":
|
1878
|
+
if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
|
1879
|
+
anthropic_input_tokens = chunk.message.usage.input_tokens
|
1880
|
+
elif chunk.type == "message_delta":
|
1881
|
+
if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
|
1882
|
+
anthropic_output_tokens += chunk.usage.output_tokens
|
1883
|
+
|
1884
|
+
content_part = _extract_content_from_chunk(client, chunk)
|
1885
|
+
if content_part:
|
1886
|
+
content_parts.append(content_part) # Append to list instead of concatenating
|
1887
|
+
last_content_chunk = chunk
|
1888
|
+
|
1889
|
+
yield chunk
|
1890
|
+
finally:
|
1891
|
+
anthropic_final_usage = None
|
1892
|
+
if isinstance(client, (AsyncAnthropic, Anthropic)) and (anthropic_input_tokens > 0 or anthropic_output_tokens > 0):
|
1893
|
+
anthropic_final_usage = {
|
1894
|
+
"input_tokens": anthropic_input_tokens,
|
1895
|
+
"output_tokens": anthropic_output_tokens,
|
1896
|
+
"total_tokens": anthropic_input_tokens + anthropic_output_tokens
|
1897
|
+
}
|
1898
|
+
|
1899
|
+
usage_info = None
|
1900
|
+
if final_usage_data:
|
1901
|
+
usage_info = final_usage_data
|
1902
|
+
elif anthropic_final_usage:
|
1903
|
+
usage_info = anthropic_final_usage
|
1904
|
+
elif last_content_chunk:
|
1905
|
+
usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
|
1906
|
+
|
1907
|
+
if output_entry and hasattr(output_entry, 'output'):
|
1908
|
+
output_entry.output = {
|
1909
|
+
"content": "".join(content_parts), # Join list at the end
|
1910
|
+
"usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
|
1911
|
+
"streamed": True
|
1912
|
+
}
|
1913
|
+
start_ts = getattr(output_entry, 'created_at', time.time())
|
1914
|
+
output_entry.duration = time.time() - start_ts
|
1915
|
+
# else: # Handle error case if necessary, but remove debug print
|
1916
|
+
|
1917
|
+
# --- Define Context Manager Wrapper Classes ---
|
1918
|
+
class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
|
1919
|
+
"""Wraps an original async stream manager to add tracing."""
|
1920
|
+
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
1921
|
+
self._original_manager = original_manager
|
1922
|
+
self._client = client
|
1923
|
+
self._span_name = span_name
|
1924
|
+
self._trace_client = trace_client
|
1925
|
+
self._stream_wrapper_func = stream_wrapper_func
|
1926
|
+
self._input_kwargs = input_kwargs
|
1927
|
+
self._parent_span_id_at_entry = None
|
1928
|
+
|
1929
|
+
async def __aenter__(self):
|
1930
|
+
self._parent_span_id_at_entry = current_span_var.get()
|
1931
|
+
if not self._trace_client:
|
1932
|
+
# If no trace, just delegate to the original manager
|
1933
|
+
return await self._original_manager.__aenter__()
|
1934
|
+
|
1935
|
+
# --- Manually create the 'enter' entry ---
|
1936
|
+
start_time = time.time()
|
1937
|
+
span_id = str(uuid.uuid4())
|
1938
|
+
current_depth = 0
|
1939
|
+
if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
|
1940
|
+
current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
1941
|
+
self._trace_client._span_depths[span_id] = current_depth
|
1942
|
+
enter_entry = TraceEntry(
|
1943
|
+
type="enter", function=self._span_name, span_id=span_id,
|
1944
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
|
1945
|
+
created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
|
1946
|
+
)
|
1947
|
+
self._trace_client.add_entry(enter_entry)
|
1948
|
+
# --- End manual 'enter' entry ---
|
1949
|
+
|
1950
|
+
# Set the current span ID in contextvars
|
1951
|
+
self._span_context_token = current_span_var.set(span_id)
|
1952
|
+
|
1953
|
+
# Manually create 'input' entry
|
1954
|
+
input_data = _format_input_data(self._client, **self._input_kwargs)
|
1955
|
+
input_entry = TraceEntry(
|
1956
|
+
type="input", function=self._span_name, span_id=span_id,
|
1957
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
|
1958
|
+
created_at=time.time(), inputs=input_data, span_type="llm"
|
1959
|
+
)
|
1960
|
+
self._trace_client.add_entry(input_entry)
|
1961
|
+
|
1962
|
+
# Call the original __aenter__
|
1963
|
+
raw_iterator = await self._original_manager.__aenter__()
|
1964
|
+
|
1965
|
+
# Manually create pending 'output' entry
|
1966
|
+
output_entry = TraceEntry(
|
1967
|
+
type="output", function=self._span_name, span_id=span_id,
|
1968
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
|
1969
|
+
created_at=time.time(), output="<pending stream>", span_type="llm"
|
1970
|
+
)
|
1971
|
+
self._trace_client.add_entry(output_entry)
|
1972
|
+
|
1973
|
+
# Wrap the raw iterator
|
1974
|
+
wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
|
1975
|
+
return wrapped_iterator
|
1976
|
+
|
1977
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
1978
|
+
# Manually create the 'exit' entry
|
1979
|
+
if hasattr(self, '_span_context_token'):
|
1980
|
+
span_id = current_span_var.get()
|
1981
|
+
start_time_for_duration = 0
|
1982
|
+
for entry in reversed(self._trace_client.entries):
|
1983
|
+
if entry.span_id == span_id and entry.type == 'enter':
|
1984
|
+
start_time_for_duration = entry.created_at
|
1985
|
+
break
|
1986
|
+
duration = time.time() - start_time_for_duration if start_time_for_duration else None
|
1987
|
+
exit_depth = self._trace_client._span_depths.get(span_id, 0)
|
1988
|
+
exit_entry = TraceEntry(
|
1989
|
+
type="exit", function=self._span_name, span_id=span_id,
|
1990
|
+
trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
|
1991
|
+
created_at=time.time(), duration=duration, span_type="llm"
|
1992
|
+
)
|
1993
|
+
self._trace_client.add_entry(exit_entry)
|
1994
|
+
if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
|
1995
|
+
current_span_var.reset(self._span_context_token)
|
1996
|
+
delattr(self, '_span_context_token')
|
1997
|
+
|
1998
|
+
# Delegate __aexit__
|
1999
|
+
if hasattr(self._original_manager, "__aexit__"):
|
2000
|
+
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2001
|
+
return None
|
2002
|
+
|
2003
|
+
class _TracedSyncStreamManagerWrapper(AbstractContextManager):
|
2004
|
+
"""Wraps an original sync stream manager to add tracing."""
|
2005
|
+
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
2006
|
+
self._original_manager = original_manager
|
2007
|
+
self._client = client
|
2008
|
+
self._span_name = span_name
|
2009
|
+
self._trace_client = trace_client
|
2010
|
+
self._stream_wrapper_func = stream_wrapper_func
|
2011
|
+
self._input_kwargs = input_kwargs
|
2012
|
+
self._parent_span_id_at_entry = None
|
2013
|
+
|
2014
|
+
def __enter__(self):
|
2015
|
+
self._parent_span_id_at_entry = current_span_var.get()
|
2016
|
+
if not self._trace_client:
|
2017
|
+
return self._original_manager.__enter__()
|
2018
|
+
|
2019
|
+
# Manually create 'enter' entry
|
2020
|
+
start_time = time.time()
|
2021
|
+
span_id = str(uuid.uuid4())
|
2022
|
+
current_depth = 0
|
2023
|
+
if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
|
2024
|
+
current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
2025
|
+
self._trace_client._span_depths[span_id] = current_depth
|
2026
|
+
enter_entry = TraceEntry(
|
2027
|
+
type="enter", function=self._span_name, span_id=span_id,
|
2028
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
|
2029
|
+
created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
|
2030
|
+
)
|
2031
|
+
self._trace_client.add_entry(enter_entry)
|
2032
|
+
self._span_context_token = current_span_var.set(span_id)
|
2033
|
+
|
2034
|
+
# Manually create 'input' entry
|
2035
|
+
input_data = _format_input_data(self._client, **self._input_kwargs)
|
2036
|
+
input_entry = TraceEntry(
|
2037
|
+
type="input", function=self._span_name, span_id=span_id,
|
2038
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
|
2039
|
+
created_at=time.time(), inputs=input_data, span_type="llm"
|
2040
|
+
)
|
2041
|
+
self._trace_client.add_entry(input_entry)
|
2042
|
+
|
2043
|
+
# Call original __enter__
|
2044
|
+
raw_iterator = self._original_manager.__enter__()
|
2045
|
+
|
2046
|
+
# Manually create 'output' entry (pending)
|
2047
|
+
output_entry = TraceEntry(
|
2048
|
+
type="output", function=self._span_name, span_id=span_id,
|
2049
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
|
2050
|
+
created_at=time.time(), output="<pending stream>", span_type="llm"
|
2051
|
+
)
|
2052
|
+
self._trace_client.add_entry(output_entry)
|
2053
|
+
|
2054
|
+
# Wrap the raw iterator
|
2055
|
+
wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
|
2056
|
+
return wrapped_iterator
|
2057
|
+
|
2058
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
2059
|
+
# Manually create 'exit' entry
|
2060
|
+
if hasattr(self, '_span_context_token'):
|
2061
|
+
span_id = current_span_var.get()
|
2062
|
+
start_time_for_duration = 0
|
2063
|
+
for entry in reversed(self._trace_client.entries):
|
2064
|
+
if entry.span_id == span_id and entry.type == 'enter':
|
2065
|
+
start_time_for_duration = entry.created_at
|
2066
|
+
break
|
2067
|
+
duration = time.time() - start_time_for_duration if start_time_for_duration else None
|
2068
|
+
exit_depth = self._trace_client._span_depths.get(span_id, 0)
|
2069
|
+
exit_entry = TraceEntry(
|
2070
|
+
type="exit", function=self._span_name, span_id=span_id,
|
2071
|
+
trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
|
2072
|
+
created_at=time.time(), duration=duration, span_type="llm"
|
2073
|
+
)
|
2074
|
+
self._trace_client.add_entry(exit_entry)
|
2075
|
+
if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
|
2076
|
+
current_span_var.reset(self._span_context_token)
|
2077
|
+
delattr(self, '_span_context_token')
|
2078
|
+
|
2079
|
+
# Delegate __exit__
|
2080
|
+
if hasattr(self._original_manager, "__exit__"):
|
2081
|
+
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
2082
|
+
return None
|