judgeval 0.0.32__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/s3_storage.py +93 -0
- judgeval/common/tracer.py +612 -123
- judgeval/data/sequence.py +4 -10
- judgeval/judgment_client.py +25 -86
- judgeval/rules.py +4 -7
- judgeval/run_evaluation.py +1 -1
- judgeval/scorers/__init__.py +4 -4
- judgeval/scorers/judgeval_scorers/__init__.py +0 -176
- {judgeval-0.0.32.dist-info → judgeval-0.0.33.dist-info}/METADATA +15 -2
- judgeval-0.0.33.dist-info/RECORD +63 -0
- judgeval/scorers/base_scorer.py +0 -58
- judgeval/scorers/judgeval_scorers/local_implementations/__init__.py +0 -27
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/__init__.py +0 -4
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/answer_correctness_scorer.py +0 -276
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/prompts.py +0 -169
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/__init__.py +0 -4
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/answer_relevancy_scorer.py +0 -298
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/prompts.py +0 -174
- judgeval/scorers/judgeval_scorers/local_implementations/comparison/__init__.py +0 -0
- judgeval/scorers/judgeval_scorers/local_implementations/comparison/comparison_scorer.py +0 -161
- judgeval/scorers/judgeval_scorers/local_implementations/comparison/prompts.py +0 -222
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/contextual_precision_scorer.py +0 -264
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/prompts.py +0 -106
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/contextual_recall_scorer.py +0 -254
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/prompts.py +0 -142
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/contextual_relevancy_scorer.py +0 -245
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/prompts.py +0 -121
- judgeval/scorers/judgeval_scorers/local_implementations/execution_order/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/execution_order/execution_order.py +0 -156
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/faithfulness_scorer.py +0 -318
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/prompts.py +0 -268
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/hallucination_scorer.py +0 -264
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/prompts.py +0 -104
- judgeval/scorers/judgeval_scorers/local_implementations/instruction_adherence/instruction_adherence.py +0 -232
- judgeval/scorers/judgeval_scorers/local_implementations/instruction_adherence/prompt.py +0 -102
- judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/__init__.py +0 -5
- judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/json_correctness_scorer.py +0 -134
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/prompts.py +0 -247
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/summarization_scorer.py +0 -551
- judgeval-0.0.32.dist-info/RECORD +0 -97
- {judgeval-0.0.32.dist-info → judgeval-0.0.33.dist-info}/WHEEL +0 -0
- {judgeval-0.0.32.dist-info → judgeval-0.0.33.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -12,12 +12,27 @@ import uuid
|
|
12
12
|
import warnings
|
13
13
|
import contextvars
|
14
14
|
import sys
|
15
|
-
from contextlib import contextmanager
|
15
|
+
from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
|
16
16
|
from dataclasses import dataclass, field
|
17
17
|
from datetime import datetime
|
18
18
|
from http import HTTPStatus
|
19
|
-
from typing import
|
19
|
+
from typing import (
|
20
|
+
Any,
|
21
|
+
Callable,
|
22
|
+
Dict,
|
23
|
+
Generator,
|
24
|
+
List,
|
25
|
+
Literal,
|
26
|
+
Optional,
|
27
|
+
Tuple,
|
28
|
+
Type,
|
29
|
+
TypeVar,
|
30
|
+
Union,
|
31
|
+
AsyncGenerator,
|
32
|
+
TypeAlias,
|
33
|
+
)
|
20
34
|
from rich import print as rprint
|
35
|
+
import types # <--- Add this import
|
21
36
|
|
22
37
|
# Third-party imports
|
23
38
|
import pika
|
@@ -42,13 +57,14 @@ from judgeval.constants import (
|
|
42
57
|
)
|
43
58
|
from judgeval.judgment_client import JudgmentClient
|
44
59
|
from judgeval.data import Example
|
45
|
-
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
60
|
+
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
46
61
|
from judgeval.rules import Rule
|
47
62
|
from judgeval.evaluation_run import EvaluationRun
|
48
63
|
from judgeval.data.result import ScoringResult
|
49
64
|
|
50
65
|
# Standard library imports needed for the new class
|
51
66
|
import concurrent.futures
|
67
|
+
from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
|
52
68
|
|
53
69
|
# Define context variables for tracking the current trace and the current span within a trace
|
54
70
|
current_trace_var = contextvars.ContextVar('current_trace', default=None)
|
@@ -173,7 +189,7 @@ class TraceEntry:
|
|
173
189
|
"inputs": self._serialize_inputs(),
|
174
190
|
"evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
|
175
191
|
"span_type": self.span_type,
|
176
|
-
"parent_span_id": self.parent_span_id
|
192
|
+
"parent_span_id": self.parent_span_id,
|
177
193
|
}
|
178
194
|
|
179
195
|
def _serialize_output(self) -> Any:
|
@@ -188,6 +204,15 @@ class TraceEntry:
|
|
188
204
|
if isinstance(self.output, BaseModel):
|
189
205
|
return self.output.model_dump()
|
190
206
|
|
207
|
+
# NEW check: If output is the dict structure from our stream wrapper
|
208
|
+
if isinstance(self.output, dict) and 'streamed' in self.output:
|
209
|
+
# Assume it's already JSON-serializable (content is string, usage is dict or None)
|
210
|
+
return self.output
|
211
|
+
# NEW check: If output is the placeholder string before stream completes
|
212
|
+
elif self.output == "<pending stream>":
|
213
|
+
# Represent this state clearly in the serialized data
|
214
|
+
return {"status": "pending stream"}
|
215
|
+
|
191
216
|
try:
|
192
217
|
# Try to serialize the output to verify it's JSON compatible
|
193
218
|
json.dumps(self.output)
|
@@ -206,9 +231,10 @@ class TraceManagerClient:
|
|
206
231
|
- Saving a trace
|
207
232
|
- Deleting a trace
|
208
233
|
"""
|
209
|
-
def __init__(self, judgment_api_key: str, organization_id: str):
|
234
|
+
def __init__(self, judgment_api_key: str, organization_id: str, tracer: Optional["Tracer"] = None):
|
210
235
|
self.judgment_api_key = judgment_api_key
|
211
236
|
self.organization_id = organization_id
|
237
|
+
self.tracer = tracer
|
212
238
|
|
213
239
|
def fetch_trace(self, trace_id: str):
|
214
240
|
"""
|
@@ -236,12 +262,13 @@ class TraceManagerClient:
|
|
236
262
|
|
237
263
|
def save_trace(self, trace_data: dict):
|
238
264
|
"""
|
239
|
-
Saves a trace to the
|
265
|
+
Saves a trace to the Judgment Supabase and optionally to S3 if configured.
|
240
266
|
|
241
267
|
Args:
|
242
268
|
trace_data: The trace data to save
|
243
269
|
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
244
270
|
"""
|
271
|
+
# Save to Judgment API
|
245
272
|
response = requests.post(
|
246
273
|
JUDGMENT_TRACES_SAVE_API_URL,
|
247
274
|
json=trace_data,
|
@@ -258,6 +285,18 @@ class TraceManagerClient:
|
|
258
285
|
elif response.status_code != HTTPStatus.OK:
|
259
286
|
raise ValueError(f"Failed to save trace data: {response.text}")
|
260
287
|
|
288
|
+
# If S3 storage is enabled, save to S3 as well
|
289
|
+
if self.tracer and self.tracer.use_s3:
|
290
|
+
try:
|
291
|
+
s3_key = self.tracer.s3_storage.save_trace(
|
292
|
+
trace_data=trace_data,
|
293
|
+
trace_id=trace_data["trace_id"],
|
294
|
+
project_name=trace_data["project_name"]
|
295
|
+
)
|
296
|
+
print(f"Trace also saved to S3 at key: {s3_key}")
|
297
|
+
except Exception as e:
|
298
|
+
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
299
|
+
|
261
300
|
if "ui_results_url" in response.json():
|
262
301
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
|
263
302
|
rprint(pretty_str)
|
@@ -355,7 +394,7 @@ class TraceClient:
|
|
355
394
|
self.client: JudgmentClient = tracer.client
|
356
395
|
self.entries: List[TraceEntry] = []
|
357
396
|
self.start_time = time.time()
|
358
|
-
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id)
|
397
|
+
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
|
359
398
|
self.visited_nodes = []
|
360
399
|
self.executed_tools = []
|
361
400
|
self.executed_node_tools = []
|
@@ -393,13 +432,13 @@ class TraceClient:
|
|
393
432
|
entry = TraceEntry(
|
394
433
|
type="enter",
|
395
434
|
function=name,
|
396
|
-
span_id=span_id,
|
397
|
-
trace_id=self.trace_id,
|
435
|
+
span_id=span_id,
|
436
|
+
trace_id=self.trace_id,
|
398
437
|
depth=current_depth,
|
399
438
|
message=name,
|
400
439
|
created_at=start_time,
|
401
440
|
span_type=span_type,
|
402
|
-
parent_span_id=parent_span_id
|
441
|
+
parent_span_id=parent_span_id,
|
403
442
|
)
|
404
443
|
self.add_entry(entry)
|
405
444
|
|
@@ -417,7 +456,7 @@ class TraceClient:
|
|
417
456
|
message=f"← {name}",
|
418
457
|
created_at=time.time(),
|
419
458
|
duration=duration,
|
420
|
-
span_type=span_type
|
459
|
+
span_type=span_type,
|
421
460
|
))
|
422
461
|
# Clean up depth tracking for this span_id
|
423
462
|
if span_id in self._span_depths:
|
@@ -454,47 +493,14 @@ class TraceClient:
|
|
454
493
|
additional_metadata=additional_metadata,
|
455
494
|
trace_id=self.trace_id
|
456
495
|
)
|
457
|
-
loaded_rules = None
|
458
|
-
if self.rules:
|
459
|
-
loaded_rules = []
|
460
|
-
for rule in self.rules:
|
461
|
-
processed_conditions = []
|
462
|
-
for condition in rule.conditions:
|
463
|
-
# Convert metric if it's a ScorerWrapper
|
464
|
-
try:
|
465
|
-
if isinstance(condition.metric, ScorerWrapper):
|
466
|
-
condition_copy = condition.model_copy()
|
467
|
-
condition_copy.metric = condition.metric.load_implementation(use_judgment=True)
|
468
|
-
processed_conditions.append(condition_copy)
|
469
|
-
else:
|
470
|
-
processed_conditions.append(condition)
|
471
|
-
except Exception as e:
|
472
|
-
warnings.warn(f"Failed to convert ScorerWrapper in rule '{rule.name}', condition metric '{condition.metric_name}': {str(e)}")
|
473
|
-
processed_conditions.append(condition) # Keep original condition as fallback
|
474
|
-
|
475
|
-
# Create new rule with processed conditions
|
476
|
-
new_rule = rule.model_copy()
|
477
|
-
new_rule.conditions = processed_conditions
|
478
|
-
loaded_rules.append(new_rule)
|
479
496
|
try:
|
480
497
|
# Load appropriate implementations for all scorers
|
481
|
-
|
482
|
-
for scorer in scorers:
|
483
|
-
try:
|
484
|
-
if isinstance(scorer, ScorerWrapper):
|
485
|
-
loaded_scorers.append(scorer.load_implementation(use_judgment=True))
|
486
|
-
else:
|
487
|
-
loaded_scorers.append(scorer)
|
488
|
-
except Exception as e:
|
489
|
-
warnings.warn(f"Failed to load implementation for scorer {scorer}: {str(e)}")
|
490
|
-
# Skip this scorer
|
491
|
-
|
492
|
-
if not loaded_scorers:
|
498
|
+
if not scorers:
|
493
499
|
warnings.warn("No valid scorers available for evaluation")
|
494
500
|
return
|
495
501
|
|
496
502
|
# Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
|
497
|
-
if
|
503
|
+
if self.rules and any(isinstance(scorer, JudgevalScorer) for scorer in scorers):
|
498
504
|
raise ValueError("Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types.")
|
499
505
|
|
500
506
|
except Exception as e:
|
@@ -508,15 +514,15 @@ class TraceClient:
|
|
508
514
|
project_name=self.project_name,
|
509
515
|
eval_name=f"{self.name.capitalize()}-"
|
510
516
|
f"{current_span_var.get()}-"
|
511
|
-
f"[{','.join(scorer.score_type.capitalize() for scorer in
|
517
|
+
f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
|
512
518
|
examples=[example],
|
513
|
-
scorers=
|
519
|
+
scorers=scorers,
|
514
520
|
model=model,
|
515
521
|
metadata={},
|
516
522
|
judgment_api_key=self.tracer.api_key,
|
517
523
|
override=self.overwrite,
|
518
524
|
trace_span_id=current_span_var.get(),
|
519
|
-
rules=
|
525
|
+
rules=self.rules # Use the combined rules
|
520
526
|
)
|
521
527
|
|
522
528
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
@@ -574,7 +580,7 @@ class TraceClient:
|
|
574
580
|
message=f"Inputs to {function_name}",
|
575
581
|
created_at=time.time(),
|
576
582
|
inputs=inputs,
|
577
|
-
span_type=entry_span_type
|
583
|
+
span_type=entry_span_type,
|
578
584
|
))
|
579
585
|
|
580
586
|
async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
|
@@ -607,12 +613,15 @@ class TraceClient:
|
|
607
613
|
message=f"Output from {function_name}",
|
608
614
|
created_at=time.time(),
|
609
615
|
output="<pending>" if inspect.iscoroutine(output) else output,
|
610
|
-
span_type=entry_span_type
|
616
|
+
span_type=entry_span_type,
|
611
617
|
)
|
612
618
|
self.add_entry(entry)
|
613
619
|
|
614
620
|
if inspect.iscoroutine(output):
|
615
621
|
asyncio.create_task(self._update_coroutine_output(entry, output))
|
622
|
+
|
623
|
+
# Return the created entry
|
624
|
+
return entry
|
616
625
|
|
617
626
|
def add_entry(self, entry: TraceEntry):
|
618
627
|
"""Add a trace entry to this trace context"""
|
@@ -824,8 +833,10 @@ class TraceClient:
|
|
824
833
|
total_completion_tokens_cost = 0.0
|
825
834
|
total_cost = 0.0
|
826
835
|
|
836
|
+
# Only count tokens for actual LLM API call spans
|
837
|
+
llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
|
827
838
|
for entry in condensed_entries:
|
828
|
-
if entry.get("span_type") == "llm" and isinstance(entry.get("output"), dict):
|
839
|
+
if entry.get("span_type") == "llm" and entry.get("function") in llm_span_names and isinstance(entry.get("output"), dict):
|
829
840
|
output = entry["output"]
|
830
841
|
usage = output.get("usage", {})
|
831
842
|
model_name = entry.get("inputs", {}).get("model", "")
|
@@ -921,6 +932,12 @@ class Tracer:
|
|
921
932
|
organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
|
922
933
|
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
|
923
934
|
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true",
|
935
|
+
# S3 configuration
|
936
|
+
use_s3: bool = False,
|
937
|
+
s3_bucket_name: Optional[str] = None,
|
938
|
+
s3_aws_access_key_id: Optional[str] = None,
|
939
|
+
s3_aws_secret_access_key: Optional[str] = None,
|
940
|
+
s3_region_name: Optional[str] = None,
|
924
941
|
deep_tracing: bool = True # NEW: Enable deep tracing by default
|
925
942
|
):
|
926
943
|
if not hasattr(self, 'initialized'):
|
@@ -929,6 +946,13 @@ class Tracer:
|
|
929
946
|
|
930
947
|
if not organization_id:
|
931
948
|
raise ValueError("Tracer must be configured with an Organization ID")
|
949
|
+
if use_s3 and not s3_bucket_name:
|
950
|
+
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
951
|
+
if use_s3 and not (s3_aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")):
|
952
|
+
raise ValueError("AWS Access Key ID must be provided when use_s3 is True")
|
953
|
+
if use_s3 and not (s3_aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")):
|
954
|
+
raise ValueError("AWS Secret Access Key must be provided when use_s3 is True")
|
955
|
+
|
932
956
|
self.api_key: str = api_key
|
933
957
|
self.project_name: str = project_name
|
934
958
|
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
@@ -938,7 +962,19 @@ class Tracer:
|
|
938
962
|
self.initialized: bool = True
|
939
963
|
self.enable_monitoring: bool = enable_monitoring
|
940
964
|
self.enable_evaluations: bool = enable_evaluations
|
965
|
+
|
966
|
+
# Initialize S3 storage if enabled
|
967
|
+
self.use_s3 = use_s3
|
968
|
+
if use_s3:
|
969
|
+
from judgeval.common.s3_storage import S3Storage
|
970
|
+
self.s3_storage = S3Storage(
|
971
|
+
bucket_name=s3_bucket_name,
|
972
|
+
aws_access_key_id=s3_aws_access_key_id,
|
973
|
+
aws_secret_access_key=s3_aws_secret_access_key,
|
974
|
+
region_name=s3_region_name
|
975
|
+
)
|
941
976
|
self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
|
977
|
+
|
942
978
|
elif hasattr(self, 'project_name') and self.project_name != project_name:
|
943
979
|
warnings.warn(
|
944
980
|
f"Attempting to initialize Tracer with project_name='{project_name}' but it was already initialized with "
|
@@ -1320,100 +1356,192 @@ class Tracer:
|
|
1320
1356
|
def wrap(client: Any) -> Any:
|
1321
1357
|
"""
|
1322
1358
|
Wraps an API client to add tracing capabilities.
|
1323
|
-
Supports OpenAI, Together, and
|
1359
|
+
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
1360
|
+
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1324
1361
|
"""
|
1325
|
-
|
1326
|
-
span_name, original_create = _get_client_config(client)
|
1362
|
+
span_name, original_create, original_stream = _get_client_config(client)
|
1327
1363
|
|
1328
|
-
#
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1364
|
+
# --- Define Traced Async Functions ---
|
1365
|
+
async def traced_create_async(*args, **kwargs):
|
1366
|
+
# [Existing logic - unchanged]
|
1367
|
+
current_trace = current_trace_var.get()
|
1368
|
+
if not current_trace:
|
1369
|
+
if asyncio.iscoroutinefunction(original_create):
|
1370
|
+
return await original_create(*args, **kwargs)
|
1371
|
+
else:
|
1372
|
+
return original_create(*args, **kwargs)
|
1373
|
+
|
1374
|
+
is_streaming = kwargs.get("stream", False)
|
1375
|
+
|
1376
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1377
|
+
input_data = _format_input_data(client, **kwargs)
|
1378
|
+
span.record_input(input_data)
|
1379
|
+
|
1380
|
+
# Warn about token counting limitations with streaming
|
1381
|
+
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
1382
|
+
if not kwargs.get("stream_options", {}).get("include_usage"):
|
1383
|
+
warnings.warn(
|
1384
|
+
"OpenAI streaming calls don't include token counts by default. "
|
1385
|
+
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
1386
|
+
"in your API call arguments.",
|
1387
|
+
UserWarning
|
1388
|
+
)
|
1337
1389
|
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1390
|
+
try:
|
1391
|
+
if is_streaming:
|
1392
|
+
stream_iterator = await original_create(*args, **kwargs)
|
1393
|
+
output_entry = span.record_output("<pending stream>")
|
1394
|
+
return _async_stream_wrapper(stream_iterator, client, output_entry)
|
1395
|
+
else:
|
1396
|
+
awaited_response = await original_create(*args, **kwargs)
|
1397
|
+
output_data = _format_output_data(client, awaited_response)
|
1398
|
+
span.record_output(output_data)
|
1399
|
+
return awaited_response
|
1400
|
+
except Exception as e:
|
1401
|
+
print(f"Error during wrapped async API call ({span_name}): {e}")
|
1402
|
+
span.record_output({"error": str(e)})
|
1403
|
+
raise
|
1404
|
+
|
1405
|
+
|
1406
|
+
# Function replacing .stream() - NOW returns the wrapper class instance
|
1407
|
+
def traced_stream_async(*args, **kwargs):
|
1408
|
+
current_trace = current_trace_var.get()
|
1409
|
+
if not current_trace or not original_stream:
|
1410
|
+
return original_stream(*args, **kwargs)
|
1411
|
+
original_manager = original_stream(*args, **kwargs)
|
1412
|
+
wrapper_manager = _TracedAsyncStreamManagerWrapper(
|
1413
|
+
original_manager=original_manager,
|
1414
|
+
client=client,
|
1415
|
+
span_name=span_name,
|
1416
|
+
trace_client=current_trace,
|
1417
|
+
stream_wrapper_func=_async_stream_wrapper,
|
1418
|
+
input_kwargs=kwargs
|
1419
|
+
)
|
1420
|
+
return wrapper_manager
|
1421
|
+
|
1422
|
+
# --- Define Traced Sync Functions ---
|
1423
|
+
def traced_create_sync(*args, **kwargs):
|
1424
|
+
# [Existing logic - unchanged]
|
1425
|
+
current_trace = current_trace_var.get()
|
1426
|
+
if not current_trace:
|
1427
|
+
return original_create(*args, **kwargs)
|
1428
|
+
|
1429
|
+
is_streaming = kwargs.get("stream", False)
|
1430
|
+
|
1431
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1432
|
+
input_data = _format_input_data(client, **kwargs)
|
1433
|
+
span.record_input(input_data)
|
1434
|
+
|
1435
|
+
# Warn about token counting limitations with streaming
|
1436
|
+
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
1437
|
+
if not kwargs.get("stream_options", {}).get("include_usage"):
|
1438
|
+
warnings.warn(
|
1439
|
+
"OpenAI streaming calls don't include token counts by default. "
|
1440
|
+
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
1441
|
+
"in your API call arguments.",
|
1442
|
+
UserWarning
|
1443
|
+
)
|
1444
|
+
|
1445
|
+
try:
|
1446
|
+
response_or_iterator = original_create(*args, **kwargs)
|
1447
|
+
except Exception as e:
|
1448
|
+
print(f"Error during wrapped sync API call ({span_name}): {e}")
|
1449
|
+
span.record_output({"error": str(e)})
|
1450
|
+
raise
|
1451
|
+
|
1452
|
+
if is_streaming:
|
1453
|
+
output_entry = span.record_output("<pending stream>")
|
1454
|
+
return _sync_stream_wrapper(response_or_iterator, client, output_entry)
|
1455
|
+
else:
|
1456
|
+
output_data = _format_output_data(client, response_or_iterator)
|
1457
|
+
span.record_output(output_data)
|
1458
|
+
return response_or_iterator
|
1459
|
+
|
1460
|
+
|
1461
|
+
# Function replacing sync .stream()
|
1462
|
+
def traced_stream_sync(*args, **kwargs):
|
1463
|
+
current_trace = current_trace_var.get()
|
1464
|
+
if not current_trace or not original_stream:
|
1465
|
+
return original_stream(*args, **kwargs)
|
1466
|
+
original_manager = original_stream(*args, **kwargs)
|
1467
|
+
wrapper_manager = _TracedSyncStreamManagerWrapper(
|
1468
|
+
original_manager=original_manager,
|
1469
|
+
client=client,
|
1470
|
+
span_name=span_name,
|
1471
|
+
trace_client=current_trace,
|
1472
|
+
stream_wrapper_func=_sync_stream_wrapper,
|
1473
|
+
input_kwargs=kwargs
|
1474
|
+
)
|
1475
|
+
return wrapper_manager
|
1476
|
+
|
1477
|
+
|
1478
|
+
# --- Assign Traced Methods to Client Instance ---
|
1479
|
+
# [Assignment logic remains the same]
|
1480
|
+
if isinstance(client, (AsyncOpenAI, AsyncTogether)):
|
1481
|
+
client.chat.completions.create = traced_create_async
|
1482
|
+
# Wrap the Responses API endpoint for OpenAI clients
|
1483
|
+
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
1484
|
+
# Capture the original responses.create
|
1485
|
+
original_responses_create = client.responses.create
|
1486
|
+
def traced_responses(*args, **kwargs):
|
1487
|
+
# Get the current trace from contextvars
|
1488
|
+
current_trace = current_trace_var.get()
|
1489
|
+
# If no active trace, call the original
|
1490
|
+
if not current_trace:
|
1491
|
+
return original_responses_create(*args, **kwargs)
|
1492
|
+
# Trace this responses.create call
|
1493
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1494
|
+
# Record raw input kwargs
|
1495
|
+
span.record_input(kwargs)
|
1496
|
+
# Make the actual API call
|
1497
|
+
response = original_responses_create(*args, **kwargs)
|
1498
|
+
# Record the output object
|
1499
|
+
span.record_output(response)
|
1500
|
+
return response
|
1501
|
+
# Assign the traced wrapper
|
1502
|
+
client.responses.create = traced_responses
|
1503
|
+
elif isinstance(client, AsyncAnthropic):
|
1504
|
+
client.messages.create = traced_create_async
|
1505
|
+
if original_stream:
|
1506
|
+
client.messages.stream = traced_stream_async
|
1507
|
+
elif isinstance(client, genai.client.AsyncClient):
|
1508
|
+
client.generate_content = traced_create_async
|
1509
|
+
elif isinstance(client, (OpenAI, Together)):
|
1510
|
+
client.chat.completions.create = traced_create_sync
|
1511
|
+
elif isinstance(client, Anthropic):
|
1512
|
+
client.messages.create = traced_create_sync
|
1513
|
+
if original_stream:
|
1514
|
+
client.messages.stream = traced_stream_sync
|
1515
|
+
elif isinstance(client, genai.Client):
|
1516
|
+
client.generate_content = traced_create_sync
|
1363
1517
|
|
1364
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
1365
|
-
# Format and record the input parameters
|
1366
|
-
input_data = _format_input_data(client, **kwargs)
|
1367
|
-
span.record_input(input_data)
|
1368
|
-
|
1369
|
-
# Make the actual API call
|
1370
|
-
try:
|
1371
|
-
response = original_create(*args, **kwargs)
|
1372
|
-
except Exception as e:
|
1373
|
-
print(f"Error during API call: {e}")
|
1374
|
-
raise
|
1375
|
-
|
1376
|
-
# Format and record the output
|
1377
|
-
output_data = _format_output_data(client, response)
|
1378
|
-
span.record_output(output_data)
|
1379
|
-
|
1380
|
-
return response
|
1381
|
-
|
1382
|
-
|
1383
|
-
# Replace the original method with our traced version
|
1384
|
-
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1385
|
-
client.chat.completions.create = traced_create
|
1386
|
-
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1387
|
-
client.messages.create = traced_create
|
1388
|
-
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1389
|
-
client.models.generate_content = traced_create
|
1390
|
-
|
1391
1518
|
return client
|
1392
1519
|
|
1393
1520
|
# Helper functions for client-specific operations
|
1394
1521
|
|
1395
|
-
def _get_client_config(client: ApiClient) -> tuple[str, callable]:
|
1522
|
+
def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[callable]]:
|
1396
1523
|
"""Returns configuration tuple for the given API client.
|
1397
1524
|
|
1398
1525
|
Args:
|
1399
1526
|
client: An instance of OpenAI, Together, or Anthropic client
|
1400
1527
|
|
1401
1528
|
Returns:
|
1402
|
-
tuple: (span_name, create_method)
|
1529
|
+
tuple: (span_name, create_method, stream_method)
|
1403
1530
|
- span_name: String identifier for tracing
|
1404
1531
|
- create_method: Reference to the client's creation method
|
1532
|
+
- stream_method: Reference to the client's stream method (if applicable)
|
1405
1533
|
|
1406
1534
|
Raises:
|
1407
1535
|
ValueError: If client type is not supported
|
1408
1536
|
"""
|
1409
1537
|
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
1410
|
-
return "OPENAI_API_CALL", client.chat.completions.create
|
1538
|
+
return "OPENAI_API_CALL", client.chat.completions.create, None
|
1411
1539
|
elif isinstance(client, (Together, AsyncTogether)):
|
1412
|
-
return "TOGETHER_API_CALL", client.chat.completions.create
|
1540
|
+
return "TOGETHER_API_CALL", client.chat.completions.create, None
|
1413
1541
|
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1414
|
-
return "ANTHROPIC_API_CALL", client.messages.create
|
1542
|
+
return "ANTHROPIC_API_CALL", client.messages.create, client.messages.stream
|
1415
1543
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1416
|
-
return "GOOGLE_API_CALL", client.models.generate_content
|
1544
|
+
return "GOOGLE_API_CALL", client.models.generate_content, None
|
1417
1545
|
raise ValueError(f"Unsupported client type: {type(client)}")
|
1418
1546
|
|
1419
1547
|
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
@@ -1478,6 +1606,26 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1478
1606
|
}
|
1479
1607
|
}
|
1480
1608
|
|
1609
|
+
# Define a blocklist of functions that should not be traced
|
1610
|
+
# These are typically utility functions, print statements, logging, etc.
|
1611
|
+
_TRACE_BLOCKLIST = {
|
1612
|
+
# Built-in functions
|
1613
|
+
'print', 'str', 'int', 'float', 'bool', 'list', 'dict', 'set', 'tuple',
|
1614
|
+
'len', 'range', 'enumerate', 'zip', 'map', 'filter', 'sorted', 'reversed',
|
1615
|
+
'min', 'max', 'sum', 'any', 'all', 'abs', 'round', 'format',
|
1616
|
+
# Logging functions
|
1617
|
+
'debug', 'info', 'warning', 'error', 'critical', 'exception', 'log',
|
1618
|
+
# Common utility functions
|
1619
|
+
'sleep', 'time', 'datetime', 'json', 'dumps', 'loads',
|
1620
|
+
# String operations
|
1621
|
+
'join', 'split', 'strip', 'lstrip', 'rstrip', 'replace', 'lower', 'upper',
|
1622
|
+
# Dict operations
|
1623
|
+
'get', 'items', 'keys', 'values', 'update',
|
1624
|
+
# List operations
|
1625
|
+
'append', 'extend', 'insert', 'remove', 'pop', 'clear', 'index', 'count', 'sort',
|
1626
|
+
}
|
1627
|
+
|
1628
|
+
|
1481
1629
|
# Add a new function for deep tracing at the module level
|
1482
1630
|
def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
|
1483
1631
|
"""
|
@@ -1496,6 +1644,15 @@ def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
|
|
1496
1644
|
if not callable(func) or isinstance(func, type) or func.__module__ == 'builtins':
|
1497
1645
|
return func
|
1498
1646
|
|
1647
|
+
# Skip functions in the blocklist
|
1648
|
+
if func.__name__ in _TRACE_BLOCKLIST:
|
1649
|
+
return func
|
1650
|
+
|
1651
|
+
# Skip functions from certain modules (logging, sys, etc.)
|
1652
|
+
if func.__module__ and any(func.__module__.startswith(m) for m in ['logging', 'sys', 'os', 'json', 'time', 'datetime']):
|
1653
|
+
return func
|
1654
|
+
|
1655
|
+
|
1499
1656
|
# Get function name for the span - check for custom name set by @observe
|
1500
1657
|
func_name = getattr(func, '_judgment_span_name', func.__name__)
|
1501
1658
|
|
@@ -1590,4 +1747,336 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
|
1590
1747
|
return super().submit(ctx.run, func_with_bound_args)
|
1591
1748
|
|
1592
1749
|
# Note: The `map` method would also need to be overridden for full context
|
1593
|
-
# propagation if users rely on it, but `submit` is the most common use case.
|
1750
|
+
# propagation if users rely on it, but `submit` is the most common use case.
|
1751
|
+
|
1752
|
+
# Helper functions for stream processing
|
1753
|
+
# ---------------------------------------
|
1754
|
+
|
1755
|
+
def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
|
1756
|
+
"""Extracts the text content from a stream chunk based on the client type."""
|
1757
|
+
try:
|
1758
|
+
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1759
|
+
return chunk.choices[0].delta.content
|
1760
|
+
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1761
|
+
# Anthropic streams various event types, we only care for content blocks
|
1762
|
+
if chunk.type == "content_block_delta":
|
1763
|
+
return chunk.delta.text
|
1764
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1765
|
+
# Google streams Candidate objects
|
1766
|
+
if chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts:
|
1767
|
+
return chunk.candidates[0].content.parts[0].text
|
1768
|
+
except (AttributeError, IndexError, KeyError):
|
1769
|
+
# Handle cases where chunk structure is unexpected or doesn't contain content
|
1770
|
+
pass # Return None
|
1771
|
+
return None
|
1772
|
+
|
1773
|
+
def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[Dict[str, int]]:
|
1774
|
+
"""Extracts usage data if present in the *final* chunk (client-specific)."""
|
1775
|
+
try:
|
1776
|
+
# OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
|
1777
|
+
# This typically requires specific API versions or settings. Often usage is *not* streamed.
|
1778
|
+
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1779
|
+
# Check if usage is directly on the chunk (some models might do this)
|
1780
|
+
if hasattr(chunk, 'usage') and chunk.usage:
|
1781
|
+
return {
|
1782
|
+
"prompt_tokens": chunk.usage.prompt_tokens,
|
1783
|
+
"completion_tokens": chunk.usage.completion_tokens,
|
1784
|
+
"total_tokens": chunk.usage.total_tokens
|
1785
|
+
}
|
1786
|
+
# Check if usage is nested within choices (less common for final chunk, but check)
|
1787
|
+
elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
|
1788
|
+
usage = chunk.choices[0].usage
|
1789
|
+
return {
|
1790
|
+
"prompt_tokens": usage.prompt_tokens,
|
1791
|
+
"completion_tokens": usage.completion_tokens,
|
1792
|
+
"total_tokens": usage.total_tokens
|
1793
|
+
}
|
1794
|
+
# Anthropic includes usage in the 'message_stop' event type
|
1795
|
+
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1796
|
+
if chunk.type == "message_stop":
|
1797
|
+
# Anthropic final usage is often attached to the *message* object, not the chunk directly
|
1798
|
+
# The API might provide a way to get the final message object, but typically not in the stream itself.
|
1799
|
+
# Let's assume for now usage might appear in the final *chunk* metadata if supported.
|
1800
|
+
# This is a placeholder - Anthropic usage typically needs a separate call or context.
|
1801
|
+
pass
|
1802
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1803
|
+
# Google provides usage metadata on the full response object, not typically streamed per chunk.
|
1804
|
+
# It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
|
1805
|
+
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
1806
|
+
return {
|
1807
|
+
"prompt_tokens": chunk.usage_metadata.prompt_token_count,
|
1808
|
+
"completion_tokens": chunk.usage_metadata.candidates_token_count,
|
1809
|
+
"total_tokens": chunk.usage_metadata.total_token_count
|
1810
|
+
}
|
1811
|
+
|
1812
|
+
except (AttributeError, IndexError, KeyError, TypeError):
|
1813
|
+
# Handle cases where usage data is missing or malformed
|
1814
|
+
pass # Return None
|
1815
|
+
return None
|
1816
|
+
|
1817
|
+
|
1818
|
+
# --- Sync Stream Wrapper ---
|
1819
|
+
def _sync_stream_wrapper(
|
1820
|
+
original_stream: Iterator,
|
1821
|
+
client: ApiClient,
|
1822
|
+
output_entry: TraceEntry
|
1823
|
+
) -> Generator[Any, None, None]:
|
1824
|
+
"""Wraps a synchronous stream iterator to capture content and update the trace."""
|
1825
|
+
content_parts = [] # Use a list instead of string concatenation
|
1826
|
+
final_usage = None
|
1827
|
+
last_chunk = None
|
1828
|
+
try:
|
1829
|
+
for chunk in original_stream:
|
1830
|
+
content_part = _extract_content_from_chunk(client, chunk)
|
1831
|
+
if content_part:
|
1832
|
+
content_parts.append(content_part) # Append to list instead of concatenating
|
1833
|
+
last_chunk = chunk # Keep track of the last chunk for potential usage data
|
1834
|
+
yield chunk # Pass the chunk to the caller
|
1835
|
+
finally:
|
1836
|
+
# Attempt to extract usage from the last chunk received
|
1837
|
+
if last_chunk:
|
1838
|
+
final_usage = _extract_usage_from_final_chunk(client, last_chunk)
|
1839
|
+
|
1840
|
+
# Update the trace entry with the accumulated content and usage
|
1841
|
+
output_entry.output = {
|
1842
|
+
"content": "".join(content_parts), # Join list at the end
|
1843
|
+
"usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
|
1844
|
+
"streamed": True
|
1845
|
+
}
|
1846
|
+
# Note: We might need to adjust _serialize_output if this dict causes issues,
|
1847
|
+
# but Pydantic's model_dump should handle dicts.
|
1848
|
+
|
1849
|
+
# --- Async Stream Wrapper ---
|
1850
|
+
async def _async_stream_wrapper(
|
1851
|
+
original_stream: AsyncIterator,
|
1852
|
+
client: ApiClient,
|
1853
|
+
output_entry: TraceEntry
|
1854
|
+
) -> AsyncGenerator[Any, None]:
|
1855
|
+
# [Existing logic - unchanged]
|
1856
|
+
content_parts = [] # Use a list instead of string concatenation
|
1857
|
+
final_usage_data = None
|
1858
|
+
last_content_chunk = None
|
1859
|
+
anthropic_input_tokens = 0
|
1860
|
+
anthropic_output_tokens = 0
|
1861
|
+
|
1862
|
+
target_span_id = getattr(output_entry, 'span_id', 'UNKNOWN')
|
1863
|
+
|
1864
|
+
try:
|
1865
|
+
async for chunk in original_stream:
|
1866
|
+
# Check for OpenAI's final usage chunk
|
1867
|
+
if isinstance(client, (AsyncOpenAI, OpenAI)) and hasattr(chunk, 'usage') and chunk.usage is not None:
|
1868
|
+
final_usage_data = {
|
1869
|
+
"prompt_tokens": chunk.usage.prompt_tokens,
|
1870
|
+
"completion_tokens": chunk.usage.completion_tokens,
|
1871
|
+
"total_tokens": chunk.usage.total_tokens
|
1872
|
+
}
|
1873
|
+
yield chunk
|
1874
|
+
continue
|
1875
|
+
|
1876
|
+
if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(chunk, 'type'):
|
1877
|
+
if chunk.type == "message_start":
|
1878
|
+
if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
|
1879
|
+
anthropic_input_tokens = chunk.message.usage.input_tokens
|
1880
|
+
elif chunk.type == "message_delta":
|
1881
|
+
if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
|
1882
|
+
anthropic_output_tokens += chunk.usage.output_tokens
|
1883
|
+
|
1884
|
+
content_part = _extract_content_from_chunk(client, chunk)
|
1885
|
+
if content_part:
|
1886
|
+
content_parts.append(content_part) # Append to list instead of concatenating
|
1887
|
+
last_content_chunk = chunk
|
1888
|
+
|
1889
|
+
yield chunk
|
1890
|
+
finally:
|
1891
|
+
anthropic_final_usage = None
|
1892
|
+
if isinstance(client, (AsyncAnthropic, Anthropic)) and (anthropic_input_tokens > 0 or anthropic_output_tokens > 0):
|
1893
|
+
anthropic_final_usage = {
|
1894
|
+
"input_tokens": anthropic_input_tokens,
|
1895
|
+
"output_tokens": anthropic_output_tokens,
|
1896
|
+
"total_tokens": anthropic_input_tokens + anthropic_output_tokens
|
1897
|
+
}
|
1898
|
+
|
1899
|
+
usage_info = None
|
1900
|
+
if final_usage_data:
|
1901
|
+
usage_info = final_usage_data
|
1902
|
+
elif anthropic_final_usage:
|
1903
|
+
usage_info = anthropic_final_usage
|
1904
|
+
elif last_content_chunk:
|
1905
|
+
usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
|
1906
|
+
|
1907
|
+
if output_entry and hasattr(output_entry, 'output'):
|
1908
|
+
output_entry.output = {
|
1909
|
+
"content": "".join(content_parts), # Join list at the end
|
1910
|
+
"usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
|
1911
|
+
"streamed": True
|
1912
|
+
}
|
1913
|
+
start_ts = getattr(output_entry, 'created_at', time.time())
|
1914
|
+
output_entry.duration = time.time() - start_ts
|
1915
|
+
# else: # Handle error case if necessary, but remove debug print
|
1916
|
+
|
1917
|
+
# --- Define Context Manager Wrapper Classes ---
|
1918
|
+
class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
|
1919
|
+
"""Wraps an original async stream manager to add tracing."""
|
1920
|
+
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
1921
|
+
self._original_manager = original_manager
|
1922
|
+
self._client = client
|
1923
|
+
self._span_name = span_name
|
1924
|
+
self._trace_client = trace_client
|
1925
|
+
self._stream_wrapper_func = stream_wrapper_func
|
1926
|
+
self._input_kwargs = input_kwargs
|
1927
|
+
self._parent_span_id_at_entry = None
|
1928
|
+
|
1929
|
+
async def __aenter__(self):
|
1930
|
+
self._parent_span_id_at_entry = current_span_var.get()
|
1931
|
+
if not self._trace_client:
|
1932
|
+
# If no trace, just delegate to the original manager
|
1933
|
+
return await self._original_manager.__aenter__()
|
1934
|
+
|
1935
|
+
# --- Manually create the 'enter' entry ---
|
1936
|
+
start_time = time.time()
|
1937
|
+
span_id = str(uuid.uuid4())
|
1938
|
+
current_depth = 0
|
1939
|
+
if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
|
1940
|
+
current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
1941
|
+
self._trace_client._span_depths[span_id] = current_depth
|
1942
|
+
enter_entry = TraceEntry(
|
1943
|
+
type="enter", function=self._span_name, span_id=span_id,
|
1944
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
|
1945
|
+
created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
|
1946
|
+
)
|
1947
|
+
self._trace_client.add_entry(enter_entry)
|
1948
|
+
# --- End manual 'enter' entry ---
|
1949
|
+
|
1950
|
+
# Set the current span ID in contextvars
|
1951
|
+
self._span_context_token = current_span_var.set(span_id)
|
1952
|
+
|
1953
|
+
# Manually create 'input' entry
|
1954
|
+
input_data = _format_input_data(self._client, **self._input_kwargs)
|
1955
|
+
input_entry = TraceEntry(
|
1956
|
+
type="input", function=self._span_name, span_id=span_id,
|
1957
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
|
1958
|
+
created_at=time.time(), inputs=input_data, span_type="llm"
|
1959
|
+
)
|
1960
|
+
self._trace_client.add_entry(input_entry)
|
1961
|
+
|
1962
|
+
# Call the original __aenter__
|
1963
|
+
raw_iterator = await self._original_manager.__aenter__()
|
1964
|
+
|
1965
|
+
# Manually create pending 'output' entry
|
1966
|
+
output_entry = TraceEntry(
|
1967
|
+
type="output", function=self._span_name, span_id=span_id,
|
1968
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
|
1969
|
+
created_at=time.time(), output="<pending stream>", span_type="llm"
|
1970
|
+
)
|
1971
|
+
self._trace_client.add_entry(output_entry)
|
1972
|
+
|
1973
|
+
# Wrap the raw iterator
|
1974
|
+
wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
|
1975
|
+
return wrapped_iterator
|
1976
|
+
|
1977
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
1978
|
+
# Manually create the 'exit' entry
|
1979
|
+
if hasattr(self, '_span_context_token'):
|
1980
|
+
span_id = current_span_var.get()
|
1981
|
+
start_time_for_duration = 0
|
1982
|
+
for entry in reversed(self._trace_client.entries):
|
1983
|
+
if entry.span_id == span_id and entry.type == 'enter':
|
1984
|
+
start_time_for_duration = entry.created_at
|
1985
|
+
break
|
1986
|
+
duration = time.time() - start_time_for_duration if start_time_for_duration else None
|
1987
|
+
exit_depth = self._trace_client._span_depths.get(span_id, 0)
|
1988
|
+
exit_entry = TraceEntry(
|
1989
|
+
type="exit", function=self._span_name, span_id=span_id,
|
1990
|
+
trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
|
1991
|
+
created_at=time.time(), duration=duration, span_type="llm"
|
1992
|
+
)
|
1993
|
+
self._trace_client.add_entry(exit_entry)
|
1994
|
+
if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
|
1995
|
+
current_span_var.reset(self._span_context_token)
|
1996
|
+
delattr(self, '_span_context_token')
|
1997
|
+
|
1998
|
+
# Delegate __aexit__
|
1999
|
+
if hasattr(self._original_manager, "__aexit__"):
|
2000
|
+
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2001
|
+
return None
|
2002
|
+
|
2003
|
+
class _TracedSyncStreamManagerWrapper(AbstractContextManager):
|
2004
|
+
"""Wraps an original sync stream manager to add tracing."""
|
2005
|
+
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
2006
|
+
self._original_manager = original_manager
|
2007
|
+
self._client = client
|
2008
|
+
self._span_name = span_name
|
2009
|
+
self._trace_client = trace_client
|
2010
|
+
self._stream_wrapper_func = stream_wrapper_func
|
2011
|
+
self._input_kwargs = input_kwargs
|
2012
|
+
self._parent_span_id_at_entry = None
|
2013
|
+
|
2014
|
+
def __enter__(self):
|
2015
|
+
self._parent_span_id_at_entry = current_span_var.get()
|
2016
|
+
if not self._trace_client:
|
2017
|
+
return self._original_manager.__enter__()
|
2018
|
+
|
2019
|
+
# Manually create 'enter' entry
|
2020
|
+
start_time = time.time()
|
2021
|
+
span_id = str(uuid.uuid4())
|
2022
|
+
current_depth = 0
|
2023
|
+
if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
|
2024
|
+
current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
2025
|
+
self._trace_client._span_depths[span_id] = current_depth
|
2026
|
+
enter_entry = TraceEntry(
|
2027
|
+
type="enter", function=self._span_name, span_id=span_id,
|
2028
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
|
2029
|
+
created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
|
2030
|
+
)
|
2031
|
+
self._trace_client.add_entry(enter_entry)
|
2032
|
+
self._span_context_token = current_span_var.set(span_id)
|
2033
|
+
|
2034
|
+
# Manually create 'input' entry
|
2035
|
+
input_data = _format_input_data(self._client, **self._input_kwargs)
|
2036
|
+
input_entry = TraceEntry(
|
2037
|
+
type="input", function=self._span_name, span_id=span_id,
|
2038
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
|
2039
|
+
created_at=time.time(), inputs=input_data, span_type="llm"
|
2040
|
+
)
|
2041
|
+
self._trace_client.add_entry(input_entry)
|
2042
|
+
|
2043
|
+
# Call original __enter__
|
2044
|
+
raw_iterator = self._original_manager.__enter__()
|
2045
|
+
|
2046
|
+
# Manually create 'output' entry (pending)
|
2047
|
+
output_entry = TraceEntry(
|
2048
|
+
type="output", function=self._span_name, span_id=span_id,
|
2049
|
+
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
|
2050
|
+
created_at=time.time(), output="<pending stream>", span_type="llm"
|
2051
|
+
)
|
2052
|
+
self._trace_client.add_entry(output_entry)
|
2053
|
+
|
2054
|
+
# Wrap the raw iterator
|
2055
|
+
wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
|
2056
|
+
return wrapped_iterator
|
2057
|
+
|
2058
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
2059
|
+
# Manually create 'exit' entry
|
2060
|
+
if hasattr(self, '_span_context_token'):
|
2061
|
+
span_id = current_span_var.get()
|
2062
|
+
start_time_for_duration = 0
|
2063
|
+
for entry in reversed(self._trace_client.entries):
|
2064
|
+
if entry.span_id == span_id and entry.type == 'enter':
|
2065
|
+
start_time_for_duration = entry.created_at
|
2066
|
+
break
|
2067
|
+
duration = time.time() - start_time_for_duration if start_time_for_duration else None
|
2068
|
+
exit_depth = self._trace_client._span_depths.get(span_id, 0)
|
2069
|
+
exit_entry = TraceEntry(
|
2070
|
+
type="exit", function=self._span_name, span_id=span_id,
|
2071
|
+
trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
|
2072
|
+
created_at=time.time(), duration=duration, span_type="llm"
|
2073
|
+
)
|
2074
|
+
self._trace_client.add_entry(exit_entry)
|
2075
|
+
if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
|
2076
|
+
current_span_var.reset(self._span_context_token)
|
2077
|
+
delattr(self, '_span_context_token')
|
2078
|
+
|
2079
|
+
# Delegate __exit__
|
2080
|
+
if hasattr(self._original_manager, "__exit__"):
|
2081
|
+
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
2082
|
+
return None
|