judgeval 0.0.55__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/api/__init__.py +3 -0
- judgeval/common/api/api.py +352 -0
- judgeval/common/api/constants.py +165 -0
- judgeval/common/storage/__init__.py +6 -0
- judgeval/common/tracer/__init__.py +31 -0
- judgeval/common/tracer/constants.py +22 -0
- judgeval/common/tracer/core.py +1916 -0
- judgeval/common/tracer/otel_exporter.py +108 -0
- judgeval/common/tracer/otel_span_processor.py +234 -0
- judgeval/common/tracer/span_processor.py +37 -0
- judgeval/common/tracer/span_transformer.py +211 -0
- judgeval/common/tracer/trace_manager.py +92 -0
- judgeval/common/utils.py +2 -2
- judgeval/constants.py +3 -30
- judgeval/data/datasets/eval_dataset_client.py +29 -156
- judgeval/data/judgment_types.py +4 -12
- judgeval/data/result.py +1 -1
- judgeval/data/scorer_data.py +2 -2
- judgeval/data/scripts/openapi_transform.py +1 -1
- judgeval/data/trace.py +66 -1
- judgeval/data/trace_run.py +0 -3
- judgeval/evaluation_run.py +0 -2
- judgeval/integrations/langgraph.py +43 -164
- judgeval/judgment_client.py +17 -211
- judgeval/run_evaluation.py +209 -611
- judgeval/scorers/__init__.py +2 -6
- judgeval/scorers/base_scorer.py +4 -23
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +3 -3
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +215 -0
- judgeval/scorers/score.py +2 -1
- judgeval/scorers/utils.py +1 -13
- judgeval/utils/requests.py +21 -0
- judgeval-0.1.0.dist-info/METADATA +202 -0
- {judgeval-0.0.55.dist-info → judgeval-0.1.0.dist-info}/RECORD +37 -29
- judgeval/common/tracer.py +0 -3215
- judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +0 -73
- judgeval/scorers/judgeval_scorers/classifiers/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +0 -53
- judgeval-0.0.55.dist-info/METADATA +0 -1384
- /judgeval/common/{s3_storage.py → storage/s3_storage.py} +0 -0
- {judgeval-0.0.55.dist-info → judgeval-0.1.0.dist-info}/WHEEL +0 -0
- {judgeval-0.0.55.dist-info → judgeval-0.1.0.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
DELETED
@@ -1,3215 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Tracing system for judgeval that allows for function tracing using decorators.
|
3
|
-
"""
|
4
|
-
|
5
|
-
# Standard library imports
|
6
|
-
import asyncio
|
7
|
-
import functools
|
8
|
-
import inspect
|
9
|
-
import os
|
10
|
-
import site
|
11
|
-
import sysconfig
|
12
|
-
import threading
|
13
|
-
import time
|
14
|
-
import traceback
|
15
|
-
import uuid
|
16
|
-
import contextvars
|
17
|
-
import sys
|
18
|
-
import json
|
19
|
-
from contextlib import (
|
20
|
-
contextmanager,
|
21
|
-
AbstractAsyncContextManager,
|
22
|
-
AbstractContextManager,
|
23
|
-
) # Import context manager bases
|
24
|
-
from dataclasses import dataclass
|
25
|
-
from datetime import datetime, timezone
|
26
|
-
from http import HTTPStatus
|
27
|
-
from typing import (
|
28
|
-
Any,
|
29
|
-
Callable,
|
30
|
-
Dict,
|
31
|
-
Generator,
|
32
|
-
List,
|
33
|
-
Literal,
|
34
|
-
Optional,
|
35
|
-
Tuple,
|
36
|
-
Union,
|
37
|
-
AsyncGenerator,
|
38
|
-
TypeAlias,
|
39
|
-
)
|
40
|
-
from rich import print as rprint
|
41
|
-
import types
|
42
|
-
|
43
|
-
# Third-party imports
|
44
|
-
from requests import RequestException
|
45
|
-
from judgeval.utils.requests import requests
|
46
|
-
from litellm import cost_per_token as _original_cost_per_token
|
47
|
-
from openai import OpenAI, AsyncOpenAI
|
48
|
-
from together import Together, AsyncTogether
|
49
|
-
from anthropic import Anthropic, AsyncAnthropic
|
50
|
-
from google import genai
|
51
|
-
|
52
|
-
# Local application/library-specific imports
|
53
|
-
from judgeval.constants import (
|
54
|
-
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
55
|
-
JUDGMENT_TRACES_UPSERT_API_URL,
|
56
|
-
JUDGMENT_TRACES_FETCH_API_URL,
|
57
|
-
JUDGMENT_TRACES_DELETE_API_URL,
|
58
|
-
JUDGMENT_PROJECT_DELETE_API_URL,
|
59
|
-
JUDGMENT_TRACES_SPANS_BATCH_API_URL,
|
60
|
-
JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
|
61
|
-
)
|
62
|
-
from judgeval.data import Example, Trace, TraceSpan, TraceUsage
|
63
|
-
from judgeval.scorers import APIScorerConfig, BaseScorer
|
64
|
-
from judgeval.evaluation_run import EvaluationRun
|
65
|
-
from judgeval.common.utils import ExcInfo, validate_api_key
|
66
|
-
from judgeval.common.logger import judgeval_logger
|
67
|
-
|
68
|
-
# Standard library imports needed for the new class
|
69
|
-
import concurrent.futures
|
70
|
-
from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
|
71
|
-
import queue
|
72
|
-
import atexit
|
73
|
-
|
74
|
-
# Define context variables for tracking the current trace and the current span within a trace
|
75
|
-
current_trace_var = contextvars.ContextVar[Optional["TraceClient"]](
|
76
|
-
"current_trace", default=None
|
77
|
-
)
|
78
|
-
current_span_var = contextvars.ContextVar[Optional[str]](
|
79
|
-
"current_span", default=None
|
80
|
-
) # ContextVar for the active span id
|
81
|
-
|
82
|
-
# Define type aliases for better code readability and maintainability
|
83
|
-
ApiClient: TypeAlias = Union[
|
84
|
-
OpenAI,
|
85
|
-
Together,
|
86
|
-
Anthropic,
|
87
|
-
AsyncOpenAI,
|
88
|
-
AsyncAnthropic,
|
89
|
-
AsyncTogether,
|
90
|
-
genai.Client,
|
91
|
-
genai.client.AsyncClient,
|
92
|
-
] # Supported API clients
|
93
|
-
SpanType = Literal["span", "tool", "llm", "evaluation", "chain"]
|
94
|
-
|
95
|
-
|
96
|
-
# Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
|
97
|
-
@dataclass
|
98
|
-
class TraceAnnotation:
|
99
|
-
"""Represents a single annotation for a trace span."""
|
100
|
-
|
101
|
-
span_id: str
|
102
|
-
text: str
|
103
|
-
label: str
|
104
|
-
score: int
|
105
|
-
|
106
|
-
def to_dict(self) -> dict:
|
107
|
-
"""Convert the annotation to a dictionary format for storage/transmission."""
|
108
|
-
return {
|
109
|
-
"span_id": self.span_id,
|
110
|
-
"annotation": {"text": self.text, "label": self.label, "score": self.score},
|
111
|
-
}
|
112
|
-
|
113
|
-
|
114
|
-
class TraceManagerClient:
|
115
|
-
"""
|
116
|
-
Client for handling trace endpoints with the Judgment API
|
117
|
-
|
118
|
-
|
119
|
-
Operations include:
|
120
|
-
- Fetching a trace by id
|
121
|
-
- Saving a trace
|
122
|
-
- Deleting a trace
|
123
|
-
"""
|
124
|
-
|
125
|
-
def __init__(
|
126
|
-
self,
|
127
|
-
judgment_api_key: str,
|
128
|
-
organization_id: str,
|
129
|
-
tracer: Optional["Tracer"] = None,
|
130
|
-
):
|
131
|
-
self.judgment_api_key = judgment_api_key
|
132
|
-
self.organization_id = organization_id
|
133
|
-
self.tracer = tracer
|
134
|
-
|
135
|
-
def fetch_trace(self, trace_id: str):
|
136
|
-
"""
|
137
|
-
Fetch a trace by its id
|
138
|
-
"""
|
139
|
-
response = requests.post(
|
140
|
-
JUDGMENT_TRACES_FETCH_API_URL,
|
141
|
-
json={
|
142
|
-
"trace_id": trace_id,
|
143
|
-
},
|
144
|
-
headers={
|
145
|
-
"Content-Type": "application/json",
|
146
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
147
|
-
"X-Organization-Id": self.organization_id,
|
148
|
-
},
|
149
|
-
verify=True,
|
150
|
-
)
|
151
|
-
|
152
|
-
if response.status_code != HTTPStatus.OK:
|
153
|
-
raise ValueError(f"Failed to fetch traces: {response.text}")
|
154
|
-
|
155
|
-
return response.json()
|
156
|
-
|
157
|
-
def upsert_trace(
|
158
|
-
self,
|
159
|
-
trace_data: dict,
|
160
|
-
offline_mode: bool = False,
|
161
|
-
show_link: bool = True,
|
162
|
-
final_save: bool = True,
|
163
|
-
):
|
164
|
-
"""
|
165
|
-
Upserts a trace to the Judgment API (always overwrites if exists).
|
166
|
-
|
167
|
-
Args:
|
168
|
-
trace_data: The trace data to upsert
|
169
|
-
offline_mode: Whether running in offline mode
|
170
|
-
show_link: Whether to show the UI link (for live tracing)
|
171
|
-
final_save: Whether this is the final save (controls S3 saving)
|
172
|
-
|
173
|
-
Returns:
|
174
|
-
dict: Server response containing UI URL and other metadata
|
175
|
-
"""
|
176
|
-
|
177
|
-
def fallback_encoder(obj):
|
178
|
-
"""
|
179
|
-
Custom JSON encoder fallback.
|
180
|
-
Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
|
181
|
-
"""
|
182
|
-
try:
|
183
|
-
return repr(obj)
|
184
|
-
except Exception:
|
185
|
-
try:
|
186
|
-
return str(obj)
|
187
|
-
except Exception as e:
|
188
|
-
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
189
|
-
|
190
|
-
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
191
|
-
|
192
|
-
response = requests.post(
|
193
|
-
JUDGMENT_TRACES_UPSERT_API_URL,
|
194
|
-
data=serialized_trace_data,
|
195
|
-
headers={
|
196
|
-
"Content-Type": "application/json",
|
197
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
198
|
-
"X-Organization-Id": self.organization_id,
|
199
|
-
},
|
200
|
-
verify=True,
|
201
|
-
)
|
202
|
-
|
203
|
-
if response.status_code != HTTPStatus.OK:
|
204
|
-
raise ValueError(f"Failed to upsert trace data: {response.text}")
|
205
|
-
|
206
|
-
# Parse server response
|
207
|
-
server_response = response.json()
|
208
|
-
|
209
|
-
# If S3 storage is enabled, save to S3 only on final save
|
210
|
-
if self.tracer and self.tracer.use_s3 and final_save:
|
211
|
-
try:
|
212
|
-
s3_key = self.tracer.s3_storage.save_trace(
|
213
|
-
trace_data=trace_data,
|
214
|
-
trace_id=trace_data["trace_id"],
|
215
|
-
project_name=trace_data["project_name"],
|
216
|
-
)
|
217
|
-
judgeval_logger.info(f"Trace also saved to S3 at key: {s3_key}")
|
218
|
-
except Exception as e:
|
219
|
-
judgeval_logger.warning(f"Failed to save trace to S3: {str(e)}")
|
220
|
-
|
221
|
-
if not offline_mode and show_link and "ui_results_url" in server_response:
|
222
|
-
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
223
|
-
rprint(pretty_str)
|
224
|
-
|
225
|
-
return server_response
|
226
|
-
|
227
|
-
## TODO: Should have a log endpoint, endpoint should also support batched payloads
|
228
|
-
def save_annotation(self, annotation: TraceAnnotation):
|
229
|
-
json_data = {
|
230
|
-
"span_id": annotation.span_id,
|
231
|
-
"annotation": {
|
232
|
-
"text": annotation.text,
|
233
|
-
"label": annotation.label,
|
234
|
-
"score": annotation.score,
|
235
|
-
},
|
236
|
-
}
|
237
|
-
|
238
|
-
response = requests.post(
|
239
|
-
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
240
|
-
json=json_data,
|
241
|
-
headers={
|
242
|
-
"Content-Type": "application/json",
|
243
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
244
|
-
"X-Organization-Id": self.organization_id,
|
245
|
-
},
|
246
|
-
verify=True,
|
247
|
-
)
|
248
|
-
|
249
|
-
if response.status_code != HTTPStatus.OK:
|
250
|
-
raise ValueError(f"Failed to save annotation: {response.text}")
|
251
|
-
|
252
|
-
return response.json()
|
253
|
-
|
254
|
-
def delete_trace(self, trace_id: str):
|
255
|
-
"""
|
256
|
-
Delete a trace from the database.
|
257
|
-
"""
|
258
|
-
response = requests.delete(
|
259
|
-
JUDGMENT_TRACES_DELETE_API_URL,
|
260
|
-
json={
|
261
|
-
"trace_ids": [trace_id],
|
262
|
-
},
|
263
|
-
headers={
|
264
|
-
"Content-Type": "application/json",
|
265
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
266
|
-
"X-Organization-Id": self.organization_id,
|
267
|
-
},
|
268
|
-
)
|
269
|
-
|
270
|
-
if response.status_code != HTTPStatus.OK:
|
271
|
-
raise ValueError(f"Failed to delete trace: {response.text}")
|
272
|
-
|
273
|
-
return response.json()
|
274
|
-
|
275
|
-
def delete_traces(self, trace_ids: List[str]):
|
276
|
-
"""
|
277
|
-
Delete a batch of traces from the database.
|
278
|
-
"""
|
279
|
-
response = requests.delete(
|
280
|
-
JUDGMENT_TRACES_DELETE_API_URL,
|
281
|
-
json={
|
282
|
-
"trace_ids": trace_ids,
|
283
|
-
},
|
284
|
-
headers={
|
285
|
-
"Content-Type": "application/json",
|
286
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
287
|
-
"X-Organization-Id": self.organization_id,
|
288
|
-
},
|
289
|
-
)
|
290
|
-
|
291
|
-
if response.status_code != HTTPStatus.OK:
|
292
|
-
raise ValueError(f"Failed to delete trace: {response.text}")
|
293
|
-
|
294
|
-
return response.json()
|
295
|
-
|
296
|
-
def delete_project(self, project_name: str):
|
297
|
-
"""
|
298
|
-
Deletes a project from the server. Which also deletes all evaluations and traces associated with the project.
|
299
|
-
"""
|
300
|
-
response = requests.delete(
|
301
|
-
JUDGMENT_PROJECT_DELETE_API_URL,
|
302
|
-
json={
|
303
|
-
"project_name": project_name,
|
304
|
-
},
|
305
|
-
headers={
|
306
|
-
"Content-Type": "application/json",
|
307
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
308
|
-
"X-Organization-Id": self.organization_id,
|
309
|
-
},
|
310
|
-
)
|
311
|
-
|
312
|
-
if response.status_code != HTTPStatus.OK:
|
313
|
-
raise ValueError(f"Failed to delete traces: {response.text}")
|
314
|
-
|
315
|
-
return response.json()
|
316
|
-
|
317
|
-
|
318
|
-
class TraceClient:
|
319
|
-
"""Client for managing a single trace context"""
|
320
|
-
|
321
|
-
def __init__(
|
322
|
-
self,
|
323
|
-
tracer: "Tracer",
|
324
|
-
trace_id: Optional[str] = None,
|
325
|
-
name: str = "default",
|
326
|
-
project_name: str | None = None,
|
327
|
-
enable_monitoring: bool = True,
|
328
|
-
enable_evaluations: bool = True,
|
329
|
-
parent_trace_id: Optional[str] = None,
|
330
|
-
parent_name: Optional[str] = None,
|
331
|
-
):
|
332
|
-
self.name = name
|
333
|
-
self.trace_id = trace_id or str(uuid.uuid4())
|
334
|
-
self.project_name = project_name or "default_project"
|
335
|
-
self.tracer = tracer
|
336
|
-
self.enable_monitoring = enable_monitoring
|
337
|
-
self.enable_evaluations = enable_evaluations
|
338
|
-
self.parent_trace_id = parent_trace_id
|
339
|
-
self.parent_name = parent_name
|
340
|
-
self.customer_id: Optional[str] = None # Added customer_id attribute
|
341
|
-
self.tags: List[Union[str, set, tuple]] = [] # Added tags attribute
|
342
|
-
self.metadata: Dict[str, Any] = {}
|
343
|
-
self.has_notification: Optional[bool] = False # Initialize has_notification
|
344
|
-
self.update_id: int = 1 # Initialize update_id to 1, increments with each save
|
345
|
-
self.trace_spans: List[TraceSpan] = []
|
346
|
-
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
347
|
-
self.evaluation_runs: List[EvaluationRun] = []
|
348
|
-
self.annotations: List[TraceAnnotation] = []
|
349
|
-
self.start_time: Optional[float] = (
|
350
|
-
None # Will be set after first successful save
|
351
|
-
)
|
352
|
-
self.trace_manager_client = TraceManagerClient(
|
353
|
-
tracer.api_key, tracer.organization_id, tracer
|
354
|
-
)
|
355
|
-
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
356
|
-
|
357
|
-
# Get background span service from tracer
|
358
|
-
self.background_span_service = (
|
359
|
-
tracer.get_background_span_service() if tracer else None
|
360
|
-
)
|
361
|
-
|
362
|
-
def get_current_span(self):
|
363
|
-
"""Get the current span from the context var"""
|
364
|
-
return self.tracer.get_current_span()
|
365
|
-
|
366
|
-
def set_current_span(self, span: Any):
|
367
|
-
"""Set the current span from the context var"""
|
368
|
-
return self.tracer.set_current_span(span)
|
369
|
-
|
370
|
-
def reset_current_span(self, token: Any):
|
371
|
-
"""Reset the current span from the context var"""
|
372
|
-
self.tracer.reset_current_span(token)
|
373
|
-
|
374
|
-
@contextmanager
|
375
|
-
def span(self, name: str, span_type: SpanType = "span"):
|
376
|
-
"""Context manager for creating a trace span, managing the current span via contextvars"""
|
377
|
-
is_first_span = len(self.trace_spans) == 0
|
378
|
-
if is_first_span:
|
379
|
-
try:
|
380
|
-
trace_id, server_response = self.save(final_save=False)
|
381
|
-
# Set start_time after first successful save
|
382
|
-
# Link will be shown by upsert_trace method
|
383
|
-
except Exception as e:
|
384
|
-
judgeval_logger.warning(
|
385
|
-
f"Failed to save initial trace for live tracking: {e}"
|
386
|
-
)
|
387
|
-
start_time = time.time()
|
388
|
-
|
389
|
-
# Generate a unique ID for *this specific span invocation*
|
390
|
-
span_id = str(uuid.uuid4())
|
391
|
-
|
392
|
-
parent_span_id = (
|
393
|
-
self.get_current_span()
|
394
|
-
) # Get ID of the parent span from context var
|
395
|
-
token = self.set_current_span(
|
396
|
-
span_id
|
397
|
-
) # Set *this* span's ID as the current one
|
398
|
-
|
399
|
-
current_depth = 0
|
400
|
-
if parent_span_id and parent_span_id in self._span_depths:
|
401
|
-
current_depth = self._span_depths[parent_span_id] + 1
|
402
|
-
|
403
|
-
self._span_depths[span_id] = current_depth # Store depth by span_id
|
404
|
-
|
405
|
-
span = TraceSpan(
|
406
|
-
span_id=span_id,
|
407
|
-
trace_id=self.trace_id,
|
408
|
-
depth=current_depth,
|
409
|
-
message=name,
|
410
|
-
created_at=start_time,
|
411
|
-
span_type=span_type,
|
412
|
-
parent_span_id=parent_span_id,
|
413
|
-
function=name,
|
414
|
-
)
|
415
|
-
self.add_span(span)
|
416
|
-
|
417
|
-
# Queue span with initial state (input phase)
|
418
|
-
if self.background_span_service:
|
419
|
-
self.background_span_service.queue_span(span, span_state="input")
|
420
|
-
|
421
|
-
try:
|
422
|
-
yield self
|
423
|
-
finally:
|
424
|
-
duration = time.time() - start_time
|
425
|
-
span.duration = duration
|
426
|
-
|
427
|
-
# Queue span with completed state (output phase)
|
428
|
-
if self.background_span_service:
|
429
|
-
self.background_span_service.queue_span(span, span_state="completed")
|
430
|
-
|
431
|
-
# Clean up depth tracking for this span_id
|
432
|
-
if span_id in self._span_depths:
|
433
|
-
del self._span_depths[span_id]
|
434
|
-
# Reset context var
|
435
|
-
self.reset_current_span(token)
|
436
|
-
|
437
|
-
def async_evaluate(
|
438
|
-
self,
|
439
|
-
scorers: List[Union[APIScorerConfig, BaseScorer]],
|
440
|
-
example: Optional[Example] = None,
|
441
|
-
input: Optional[str] = None,
|
442
|
-
actual_output: Optional[Union[str, List[str]]] = None,
|
443
|
-
expected_output: Optional[Union[str, List[str]]] = None,
|
444
|
-
context: Optional[List[str]] = None,
|
445
|
-
retrieval_context: Optional[List[str]] = None,
|
446
|
-
tools_called: Optional[List[str]] = None,
|
447
|
-
expected_tools: Optional[List[str]] = None,
|
448
|
-
additional_metadata: Optional[Dict[str, Any]] = None,
|
449
|
-
model: Optional[str] = None,
|
450
|
-
span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
|
451
|
-
):
|
452
|
-
if not self.enable_evaluations:
|
453
|
-
return
|
454
|
-
|
455
|
-
start_time = time.time() # Record start time
|
456
|
-
|
457
|
-
try:
|
458
|
-
# Load appropriate implementations for all scorers
|
459
|
-
if not scorers:
|
460
|
-
judgeval_logger.warning("No valid scorers available for evaluation")
|
461
|
-
return
|
462
|
-
|
463
|
-
except Exception as e:
|
464
|
-
judgeval_logger.warning(f"Failed to load scorers: {str(e)}")
|
465
|
-
return
|
466
|
-
|
467
|
-
# If example is not provided, create one from the individual parameters
|
468
|
-
if example is None:
|
469
|
-
# Check if any of the individual parameters are provided
|
470
|
-
if any(
|
471
|
-
param is not None
|
472
|
-
for param in [
|
473
|
-
input,
|
474
|
-
actual_output,
|
475
|
-
expected_output,
|
476
|
-
context,
|
477
|
-
retrieval_context,
|
478
|
-
tools_called,
|
479
|
-
expected_tools,
|
480
|
-
additional_metadata,
|
481
|
-
]
|
482
|
-
):
|
483
|
-
example = Example(
|
484
|
-
input=input,
|
485
|
-
actual_output=actual_output,
|
486
|
-
expected_output=expected_output,
|
487
|
-
context=context,
|
488
|
-
retrieval_context=retrieval_context,
|
489
|
-
tools_called=tools_called,
|
490
|
-
expected_tools=expected_tools,
|
491
|
-
additional_metadata=additional_metadata,
|
492
|
-
)
|
493
|
-
else:
|
494
|
-
raise ValueError(
|
495
|
-
"Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided"
|
496
|
-
)
|
497
|
-
|
498
|
-
# Check examples before creating evaluation run
|
499
|
-
|
500
|
-
# check_examples([example], scorers)
|
501
|
-
|
502
|
-
span_id_to_use = span_id if span_id is not None else self.get_current_span()
|
503
|
-
|
504
|
-
eval_run = EvaluationRun(
|
505
|
-
organization_id=self.tracer.organization_id,
|
506
|
-
project_name=self.project_name,
|
507
|
-
eval_name=f"{self.name.capitalize()}-"
|
508
|
-
f"{span_id_to_use}-" # Keep original eval name format using context var if available
|
509
|
-
f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
|
510
|
-
examples=[example],
|
511
|
-
scorers=scorers,
|
512
|
-
model=model,
|
513
|
-
judgment_api_key=self.tracer.api_key,
|
514
|
-
trace_span_id=span_id_to_use,
|
515
|
-
)
|
516
|
-
|
517
|
-
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
518
|
-
|
519
|
-
# Queue evaluation run through background service
|
520
|
-
if self.background_span_service and span_id_to_use:
|
521
|
-
# Get the current span data to avoid race conditions
|
522
|
-
current_span = self.span_id_to_span.get(span_id_to_use)
|
523
|
-
if current_span:
|
524
|
-
self.background_span_service.queue_evaluation_run(
|
525
|
-
eval_run, span_id=span_id_to_use, span_data=current_span
|
526
|
-
)
|
527
|
-
|
528
|
-
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
529
|
-
# --- Modification: Use span_id from eval_run ---
|
530
|
-
current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
|
531
|
-
# print(f"[TraceClient.add_eval_run] Using span_id from eval_run: {current_span_id}")
|
532
|
-
# --- End Modification ---
|
533
|
-
|
534
|
-
if current_span_id:
|
535
|
-
span = self.span_id_to_span[current_span_id]
|
536
|
-
span.has_evaluation = True # Set the has_evaluation flag
|
537
|
-
self.evaluation_runs.append(eval_run)
|
538
|
-
|
539
|
-
def add_annotation(self, annotation: TraceAnnotation):
|
540
|
-
"""Add an annotation to this trace context"""
|
541
|
-
self.annotations.append(annotation)
|
542
|
-
return self
|
543
|
-
|
544
|
-
def record_input(self, inputs: dict):
|
545
|
-
current_span_id = self.get_current_span()
|
546
|
-
if current_span_id:
|
547
|
-
span = self.span_id_to_span[current_span_id]
|
548
|
-
# Ignore self parameter
|
549
|
-
if "self" in inputs:
|
550
|
-
del inputs["self"]
|
551
|
-
span.inputs = inputs
|
552
|
-
|
553
|
-
# Queue span with input data
|
554
|
-
try:
|
555
|
-
if self.background_span_service:
|
556
|
-
self.background_span_service.queue_span(span, span_state="input")
|
557
|
-
except Exception as e:
|
558
|
-
judgeval_logger.warning(f"Failed to queue span with input data: {e}")
|
559
|
-
|
560
|
-
def record_agent_name(self, agent_name: str):
|
561
|
-
current_span_id = self.get_current_span()
|
562
|
-
if current_span_id:
|
563
|
-
span = self.span_id_to_span[current_span_id]
|
564
|
-
span.agent_name = agent_name
|
565
|
-
|
566
|
-
# Queue span with agent_name data
|
567
|
-
if self.background_span_service:
|
568
|
-
self.background_span_service.queue_span(span, span_state="agent_name")
|
569
|
-
|
570
|
-
def record_state_before(self, state: dict):
|
571
|
-
"""Records the agent's state before a tool execution on the current span.
|
572
|
-
|
573
|
-
Args:
|
574
|
-
state: A dictionary representing the agent's state.
|
575
|
-
"""
|
576
|
-
current_span_id = self.get_current_span()
|
577
|
-
if current_span_id:
|
578
|
-
span = self.span_id_to_span[current_span_id]
|
579
|
-
span.state_before = state
|
580
|
-
|
581
|
-
# Queue span with state_before data
|
582
|
-
if self.background_span_service:
|
583
|
-
self.background_span_service.queue_span(span, span_state="state_before")
|
584
|
-
|
585
|
-
def record_state_after(self, state: dict):
|
586
|
-
"""Records the agent's state after a tool execution on the current span.
|
587
|
-
|
588
|
-
Args:
|
589
|
-
state: A dictionary representing the agent's state.
|
590
|
-
"""
|
591
|
-
current_span_id = self.get_current_span()
|
592
|
-
if current_span_id:
|
593
|
-
span = self.span_id_to_span[current_span_id]
|
594
|
-
span.state_after = state
|
595
|
-
|
596
|
-
# Queue span with state_after data
|
597
|
-
if self.background_span_service:
|
598
|
-
self.background_span_service.queue_span(span, span_state="state_after")
|
599
|
-
|
600
|
-
async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
|
601
|
-
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
602
|
-
try:
|
603
|
-
result = await coroutine
|
604
|
-
setattr(span, field, result)
|
605
|
-
|
606
|
-
# Queue span with output data now that coroutine is complete
|
607
|
-
if self.background_span_service and field == "output":
|
608
|
-
self.background_span_service.queue_span(span, span_state="output")
|
609
|
-
|
610
|
-
return result
|
611
|
-
except Exception as e:
|
612
|
-
setattr(span, field, f"Error: {str(e)}")
|
613
|
-
|
614
|
-
# Queue span even if there was an error
|
615
|
-
if self.background_span_service and field == "output":
|
616
|
-
self.background_span_service.queue_span(span, span_state="output")
|
617
|
-
|
618
|
-
raise
|
619
|
-
|
620
|
-
def record_output(self, output: Any):
|
621
|
-
current_span_id = self.get_current_span()
|
622
|
-
if current_span_id:
|
623
|
-
span = self.span_id_to_span[current_span_id]
|
624
|
-
span.output = "<pending>" if inspect.iscoroutine(output) else output
|
625
|
-
|
626
|
-
if inspect.iscoroutine(output):
|
627
|
-
asyncio.create_task(self._update_coroutine(span, output, "output"))
|
628
|
-
|
629
|
-
# # Queue span with output data (unless it's pending)
|
630
|
-
if self.background_span_service and not inspect.iscoroutine(output):
|
631
|
-
self.background_span_service.queue_span(span, span_state="output")
|
632
|
-
|
633
|
-
return span # Return the created entry
|
634
|
-
# Removed else block - original didn't have one
|
635
|
-
return None # Return None if no span_id found
|
636
|
-
|
637
|
-
def record_usage(self, usage: TraceUsage):
|
638
|
-
current_span_id = self.get_current_span()
|
639
|
-
if current_span_id:
|
640
|
-
span = self.span_id_to_span[current_span_id]
|
641
|
-
span.usage = usage
|
642
|
-
|
643
|
-
# Queue span with usage data
|
644
|
-
if self.background_span_service:
|
645
|
-
self.background_span_service.queue_span(span, span_state="usage")
|
646
|
-
|
647
|
-
return span # Return the created entry
|
648
|
-
# Removed else block - original didn't have one
|
649
|
-
return None # Return None if no span_id found
|
650
|
-
|
651
|
-
def record_error(self, error: Dict[str, Any]):
|
652
|
-
current_span_id = self.get_current_span()
|
653
|
-
if current_span_id:
|
654
|
-
span = self.span_id_to_span[current_span_id]
|
655
|
-
span.error = error
|
656
|
-
|
657
|
-
# Queue span with error data
|
658
|
-
if self.background_span_service:
|
659
|
-
self.background_span_service.queue_span(span, span_state="error")
|
660
|
-
|
661
|
-
return span
|
662
|
-
return None
|
663
|
-
|
664
|
-
def add_span(self, span: TraceSpan):
|
665
|
-
"""Add a trace span to this trace context"""
|
666
|
-
self.trace_spans.append(span)
|
667
|
-
self.span_id_to_span[span.span_id] = span
|
668
|
-
return self
|
669
|
-
|
670
|
-
def print(self):
|
671
|
-
"""Print the complete trace with proper visual structure"""
|
672
|
-
for span in self.trace_spans:
|
673
|
-
span.print_span()
|
674
|
-
|
675
|
-
def get_duration(self) -> float:
|
676
|
-
"""
|
677
|
-
Get the total duration of this trace
|
678
|
-
"""
|
679
|
-
if self.start_time is None:
|
680
|
-
return 0.0 # No duration if trace hasn't been saved yet
|
681
|
-
return time.time() - self.start_time
|
682
|
-
|
683
|
-
def save(self, final_save: bool = False) -> Tuple[str, dict]:
|
684
|
-
"""
|
685
|
-
Save the current trace to the database with rate limiting checks.
|
686
|
-
First checks usage limits, then upserts the trace if allowed.
|
687
|
-
|
688
|
-
Args:
|
689
|
-
final_save: Whether this is the final save (updates usage counters)
|
690
|
-
|
691
|
-
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
692
|
-
"""
|
693
|
-
|
694
|
-
# Calculate total elapsed time
|
695
|
-
total_duration = self.get_duration()
|
696
|
-
|
697
|
-
# Create trace document
|
698
|
-
trace_data = {
|
699
|
-
"trace_id": self.trace_id,
|
700
|
-
"name": self.name,
|
701
|
-
"project_name": self.project_name,
|
702
|
-
"created_at": datetime.fromtimestamp(
|
703
|
-
self.start_time or time.time(), timezone.utc
|
704
|
-
).isoformat(),
|
705
|
-
"duration": total_duration,
|
706
|
-
"trace_spans": [span.model_dump() for span in self.trace_spans],
|
707
|
-
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
708
|
-
"offline_mode": self.tracer.offline_mode,
|
709
|
-
"parent_trace_id": self.parent_trace_id,
|
710
|
-
"parent_name": self.parent_name,
|
711
|
-
"customer_id": self.customer_id,
|
712
|
-
"tags": self.tags,
|
713
|
-
"metadata": self.metadata,
|
714
|
-
"update_id": self.update_id,
|
715
|
-
}
|
716
|
-
|
717
|
-
# If usage check passes, upsert the trace
|
718
|
-
server_response = self.trace_manager_client.upsert_trace(
|
719
|
-
trace_data,
|
720
|
-
offline_mode=self.tracer.offline_mode,
|
721
|
-
show_link=not final_save, # Show link only on initial save, not final save
|
722
|
-
final_save=final_save, # Pass final_save to control S3 saving
|
723
|
-
)
|
724
|
-
|
725
|
-
if self.start_time is None:
|
726
|
-
self.start_time = time.time()
|
727
|
-
|
728
|
-
self.update_id += 1
|
729
|
-
|
730
|
-
# Upload annotations
|
731
|
-
# TODO: batch to the log endpoint
|
732
|
-
for annotation in self.annotations:
|
733
|
-
self.trace_manager_client.save_annotation(annotation)
|
734
|
-
return self.trace_id, server_response
|
735
|
-
|
736
|
-
def delete(self):
|
737
|
-
return self.trace_manager_client.delete_trace(self.trace_id)
|
738
|
-
|
739
|
-
def update_metadata(self, metadata: dict):
|
740
|
-
"""
|
741
|
-
Set metadata for this trace.
|
742
|
-
|
743
|
-
Args:
|
744
|
-
metadata: Metadata as a dictionary
|
745
|
-
|
746
|
-
Supported keys:
|
747
|
-
- customer_id: ID of the customer using this trace
|
748
|
-
- tags: List of tags for this trace
|
749
|
-
- has_notification: Whether this trace has a notification
|
750
|
-
- name: Name of the trace
|
751
|
-
"""
|
752
|
-
for k, v in metadata.items():
|
753
|
-
if k == "customer_id":
|
754
|
-
if v is not None:
|
755
|
-
self.customer_id = str(v)
|
756
|
-
else:
|
757
|
-
self.customer_id = None
|
758
|
-
elif k == "tags":
|
759
|
-
if isinstance(v, list):
|
760
|
-
# Validate that all items in the list are of the expected types
|
761
|
-
for item in v:
|
762
|
-
if not isinstance(item, (str, set, tuple)):
|
763
|
-
raise ValueError(
|
764
|
-
f"Tags must be a list of strings, sets, or tuples, got item of type {type(item)}"
|
765
|
-
)
|
766
|
-
self.tags = v
|
767
|
-
else:
|
768
|
-
raise ValueError(
|
769
|
-
f"Tags must be a list of strings, sets, or tuples, got {type(v)}"
|
770
|
-
)
|
771
|
-
elif k == "has_notification":
|
772
|
-
if not isinstance(v, bool):
|
773
|
-
raise ValueError(
|
774
|
-
f"has_notification must be a boolean, got {type(v)}"
|
775
|
-
)
|
776
|
-
self.has_notification = v
|
777
|
-
elif k == "name":
|
778
|
-
self.name = v
|
779
|
-
else:
|
780
|
-
self.metadata[k] = v
|
781
|
-
|
782
|
-
def set_customer_id(self, customer_id: str):
|
783
|
-
"""
|
784
|
-
Set the customer ID for this trace.
|
785
|
-
|
786
|
-
Args:
|
787
|
-
customer_id: The customer ID to set
|
788
|
-
"""
|
789
|
-
self.update_metadata({"customer_id": customer_id})
|
790
|
-
|
791
|
-
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
792
|
-
"""
|
793
|
-
Set the tags for this trace.
|
794
|
-
|
795
|
-
Args:
|
796
|
-
tags: List of tags to set
|
797
|
-
"""
|
798
|
-
self.update_metadata({"tags": tags})
|
799
|
-
|
800
|
-
|
801
|
-
def _capture_exception_for_trace(
|
802
|
-
current_trace: Optional["TraceClient"], exc_info: ExcInfo
|
803
|
-
):
|
804
|
-
if not current_trace:
|
805
|
-
return
|
806
|
-
|
807
|
-
exc_type, exc_value, exc_traceback_obj = exc_info
|
808
|
-
formatted_exception = {
|
809
|
-
"type": exc_type.__name__ if exc_type else "UnknownExceptionType",
|
810
|
-
"message": str(exc_value) if exc_value else "No exception message",
|
811
|
-
"traceback": traceback.format_tb(exc_traceback_obj)
|
812
|
-
if exc_traceback_obj
|
813
|
-
else [],
|
814
|
-
}
|
815
|
-
|
816
|
-
# This is where we specially handle exceptions that we might want to collect additional data for.
|
817
|
-
# When we do this, always try checking the module from sys.modules instead of importing. This will
|
818
|
-
# Let us support a wider range of exceptions without needing to import them for all clients.
|
819
|
-
|
820
|
-
# Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
|
821
|
-
# The alternative is to hand select libraries we want from sys.modules and check for them:
|
822
|
-
# As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
|
823
|
-
|
824
|
-
# General HTTP Like errors
|
825
|
-
try:
|
826
|
-
url = getattr(getattr(exc_value, "request", None), "url", None)
|
827
|
-
status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
|
828
|
-
if status_code:
|
829
|
-
formatted_exception["http"] = {
|
830
|
-
"url": url if url else "Unknown URL",
|
831
|
-
"status_code": status_code if status_code else None,
|
832
|
-
}
|
833
|
-
except Exception:
|
834
|
-
pass
|
835
|
-
|
836
|
-
current_trace.record_error(formatted_exception)
|
837
|
-
|
838
|
-
# Queue the span with error state through background service
|
839
|
-
if current_trace.background_span_service:
|
840
|
-
current_span_id = current_trace.get_current_span()
|
841
|
-
if current_span_id and current_span_id in current_trace.span_id_to_span:
|
842
|
-
error_span = current_trace.span_id_to_span[current_span_id]
|
843
|
-
current_trace.background_span_service.queue_span(
|
844
|
-
error_span, span_state="error"
|
845
|
-
)
|
846
|
-
|
847
|
-
|
848
|
-
class BackgroundSpanService:
|
849
|
-
"""
|
850
|
-
Background service for queueing and batching trace spans for efficient saving.
|
851
|
-
|
852
|
-
This service:
|
853
|
-
- Queues spans as they complete
|
854
|
-
- Batches them for efficient network usage
|
855
|
-
- Sends spans periodically or when batches reach a certain size
|
856
|
-
- Handles automatic flushing when the main event terminates
|
857
|
-
"""
|
858
|
-
|
859
|
-
def __init__(
|
860
|
-
self,
|
861
|
-
judgment_api_key: str,
|
862
|
-
organization_id: str,
|
863
|
-
batch_size: int = 10,
|
864
|
-
flush_interval: float = 5.0,
|
865
|
-
num_workers: int = 1,
|
866
|
-
):
|
867
|
-
"""
|
868
|
-
Initialize the background span service.
|
869
|
-
|
870
|
-
Args:
|
871
|
-
judgment_api_key: API key for Judgment service
|
872
|
-
organization_id: Organization ID
|
873
|
-
batch_size: Number of spans to batch before sending (default: 10)
|
874
|
-
flush_interval: Time in seconds between automatic flushes (default: 5.0)
|
875
|
-
num_workers: Number of worker threads to process the queue (default: 1)
|
876
|
-
"""
|
877
|
-
self.judgment_api_key = judgment_api_key
|
878
|
-
self.organization_id = organization_id
|
879
|
-
self.batch_size = batch_size
|
880
|
-
self.flush_interval = flush_interval
|
881
|
-
self.num_workers = max(1, num_workers) # Ensure at least 1 worker
|
882
|
-
|
883
|
-
# Queue for pending spans
|
884
|
-
self._span_queue: queue.Queue[Dict[str, Any]] = queue.Queue()
|
885
|
-
|
886
|
-
# Background threads for processing spans
|
887
|
-
self._worker_threads: List[threading.Thread] = []
|
888
|
-
self._shutdown_event = threading.Event()
|
889
|
-
|
890
|
-
# Track spans that have been sent
|
891
|
-
# self._sent_spans = set()
|
892
|
-
|
893
|
-
# Register cleanup on exit
|
894
|
-
atexit.register(self.shutdown)
|
895
|
-
|
896
|
-
# Start the background workers
|
897
|
-
self._start_workers()
|
898
|
-
|
899
|
-
def _start_workers(self):
|
900
|
-
"""Start the background worker threads."""
|
901
|
-
for i in range(self.num_workers):
|
902
|
-
if len(self._worker_threads) < self.num_workers:
|
903
|
-
worker_thread = threading.Thread(
|
904
|
-
target=self._worker_loop, daemon=True, name=f"SpanWorker-{i + 1}"
|
905
|
-
)
|
906
|
-
worker_thread.start()
|
907
|
-
self._worker_threads.append(worker_thread)
|
908
|
-
|
909
|
-
def _worker_loop(self):
|
910
|
-
"""Main worker loop that processes spans in batches."""
|
911
|
-
batch = []
|
912
|
-
last_flush_time = time.time()
|
913
|
-
pending_task_count = (
|
914
|
-
0 # Track how many tasks we've taken from queue but not marked done
|
915
|
-
)
|
916
|
-
|
917
|
-
while not self._shutdown_event.is_set() or self._span_queue.qsize() > 0:
|
918
|
-
try:
|
919
|
-
# First, do a blocking get to wait for at least one item
|
920
|
-
if not batch: # Only block if we don't have items already
|
921
|
-
try:
|
922
|
-
span_data = self._span_queue.get(timeout=1.0)
|
923
|
-
batch.append(span_data)
|
924
|
-
pending_task_count += 1
|
925
|
-
except queue.Empty:
|
926
|
-
# No new spans, continue to check for flush conditions
|
927
|
-
pass
|
928
|
-
|
929
|
-
# Then, do non-blocking gets to drain any additional available items
|
930
|
-
# up to our batch size limit
|
931
|
-
while len(batch) < self.batch_size:
|
932
|
-
try:
|
933
|
-
span_data = self._span_queue.get_nowait() # Non-blocking
|
934
|
-
batch.append(span_data)
|
935
|
-
pending_task_count += 1
|
936
|
-
except queue.Empty:
|
937
|
-
# No more items immediately available
|
938
|
-
break
|
939
|
-
|
940
|
-
current_time = time.time()
|
941
|
-
should_flush = len(batch) >= self.batch_size or (
|
942
|
-
batch and (current_time - last_flush_time) >= self.flush_interval
|
943
|
-
)
|
944
|
-
|
945
|
-
if should_flush and batch:
|
946
|
-
self._send_batch(batch)
|
947
|
-
|
948
|
-
# Only mark tasks as done after successful sending
|
949
|
-
for _ in range(pending_task_count):
|
950
|
-
self._span_queue.task_done()
|
951
|
-
pending_task_count = 0 # Reset counter
|
952
|
-
|
953
|
-
batch.clear()
|
954
|
-
last_flush_time = current_time
|
955
|
-
|
956
|
-
except Exception as e:
|
957
|
-
judgeval_logger.warning(f"Error in span service worker loop: {e}")
|
958
|
-
# On error, still need to mark tasks as done to prevent hanging
|
959
|
-
for _ in range(pending_task_count):
|
960
|
-
self._span_queue.task_done()
|
961
|
-
pending_task_count = 0
|
962
|
-
batch.clear()
|
963
|
-
|
964
|
-
# Final flush on shutdown
|
965
|
-
if batch:
|
966
|
-
self._send_batch(batch)
|
967
|
-
# Mark remaining tasks as done
|
968
|
-
for _ in range(pending_task_count):
|
969
|
-
self._span_queue.task_done()
|
970
|
-
|
971
|
-
def _send_batch(self, batch: List[Dict[str, Any]]):
|
972
|
-
"""
|
973
|
-
Send a batch of spans to the server.
|
974
|
-
|
975
|
-
Args:
|
976
|
-
batch: List of span dictionaries to send
|
977
|
-
"""
|
978
|
-
if not batch:
|
979
|
-
return
|
980
|
-
|
981
|
-
try:
|
982
|
-
# Group items by type for different endpoints
|
983
|
-
spans_to_send = []
|
984
|
-
evaluation_runs_to_send = []
|
985
|
-
|
986
|
-
for item in batch:
|
987
|
-
if item["type"] == "span":
|
988
|
-
spans_to_send.append(item["data"])
|
989
|
-
elif item["type"] == "evaluation_run":
|
990
|
-
evaluation_runs_to_send.append(item["data"])
|
991
|
-
|
992
|
-
# Send spans if any
|
993
|
-
if spans_to_send:
|
994
|
-
self._send_spans_batch(spans_to_send)
|
995
|
-
|
996
|
-
# Send evaluation runs if any
|
997
|
-
if evaluation_runs_to_send:
|
998
|
-
self._send_evaluation_runs_batch(evaluation_runs_to_send)
|
999
|
-
|
1000
|
-
except Exception as e:
|
1001
|
-
judgeval_logger.warning(f"Failed to send batch: {e}")
|
1002
|
-
|
1003
|
-
def _send_spans_batch(self, spans: List[Dict[str, Any]]):
|
1004
|
-
"""Send a batch of spans to the spans endpoint."""
|
1005
|
-
payload = {"spans": spans, "organization_id": self.organization_id}
|
1006
|
-
|
1007
|
-
# Serialize with fallback encoder
|
1008
|
-
def fallback_encoder(obj):
|
1009
|
-
try:
|
1010
|
-
return repr(obj)
|
1011
|
-
except Exception:
|
1012
|
-
try:
|
1013
|
-
return str(obj)
|
1014
|
-
except Exception as e:
|
1015
|
-
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
1016
|
-
|
1017
|
-
try:
|
1018
|
-
serialized_data = json.dumps(payload, default=fallback_encoder)
|
1019
|
-
|
1020
|
-
# Send the actual HTTP request to the batch endpoint
|
1021
|
-
response = requests.post(
|
1022
|
-
JUDGMENT_TRACES_SPANS_BATCH_API_URL,
|
1023
|
-
data=serialized_data,
|
1024
|
-
headers={
|
1025
|
-
"Content-Type": "application/json",
|
1026
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
1027
|
-
"X-Organization-Id": self.organization_id,
|
1028
|
-
},
|
1029
|
-
verify=True,
|
1030
|
-
timeout=30, # Add timeout to prevent hanging
|
1031
|
-
)
|
1032
|
-
|
1033
|
-
if response.status_code != HTTPStatus.OK:
|
1034
|
-
judgeval_logger.warning(
|
1035
|
-
f"Failed to send spans batch: HTTP {response.status_code} - {response.text}"
|
1036
|
-
)
|
1037
|
-
|
1038
|
-
except RequestException as e:
|
1039
|
-
judgeval_logger.warning(f"Network error sending spans batch: {e}")
|
1040
|
-
except Exception as e:
|
1041
|
-
judgeval_logger.warning(f"Failed to serialize or send spans batch: {e}")
|
1042
|
-
|
1043
|
-
def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
|
1044
|
-
"""Send a batch of evaluation runs with their associated span data to the endpoint."""
|
1045
|
-
# Structure payload to include both evaluation run data and span data
|
1046
|
-
evaluation_entries = []
|
1047
|
-
for eval_data in evaluation_runs:
|
1048
|
-
# eval_data already contains the evaluation run data (no need to access ['data'])
|
1049
|
-
entry = {
|
1050
|
-
"evaluation_run": {
|
1051
|
-
# Extract evaluation run fields (excluding span-specific fields)
|
1052
|
-
key: value
|
1053
|
-
for key, value in eval_data.items()
|
1054
|
-
if key not in ["associated_span_id", "span_data", "queued_at"]
|
1055
|
-
},
|
1056
|
-
"associated_span": {
|
1057
|
-
"span_id": eval_data.get("associated_span_id"),
|
1058
|
-
"span_data": eval_data.get("span_data"),
|
1059
|
-
},
|
1060
|
-
"queued_at": eval_data.get("queued_at"),
|
1061
|
-
}
|
1062
|
-
evaluation_entries.append(entry)
|
1063
|
-
|
1064
|
-
payload = {
|
1065
|
-
"organization_id": self.organization_id,
|
1066
|
-
"evaluation_entries": evaluation_entries, # Each entry contains both eval run + span data
|
1067
|
-
}
|
1068
|
-
|
1069
|
-
# Serialize with fallback encoder
|
1070
|
-
def fallback_encoder(obj):
|
1071
|
-
try:
|
1072
|
-
return repr(obj)
|
1073
|
-
except Exception:
|
1074
|
-
try:
|
1075
|
-
return str(obj)
|
1076
|
-
except Exception as e:
|
1077
|
-
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
1078
|
-
|
1079
|
-
try:
|
1080
|
-
serialized_data = json.dumps(payload, default=fallback_encoder)
|
1081
|
-
|
1082
|
-
# Send the actual HTTP request to the batch endpoint
|
1083
|
-
response = requests.post(
|
1084
|
-
JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
|
1085
|
-
data=serialized_data,
|
1086
|
-
headers={
|
1087
|
-
"Content-Type": "application/json",
|
1088
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
1089
|
-
"X-Organization-Id": self.organization_id,
|
1090
|
-
},
|
1091
|
-
verify=True,
|
1092
|
-
timeout=30, # Add timeout to prevent hanging
|
1093
|
-
)
|
1094
|
-
|
1095
|
-
if response.status_code != HTTPStatus.OK:
|
1096
|
-
judgeval_logger.warning(
|
1097
|
-
f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}"
|
1098
|
-
)
|
1099
|
-
|
1100
|
-
except RequestException as e:
|
1101
|
-
judgeval_logger.warning(f"Network error sending evaluation runs batch: {e}")
|
1102
|
-
except Exception as e:
|
1103
|
-
judgeval_logger.warning(f"Failed to send evaluation runs batch: {e}")
|
1104
|
-
|
1105
|
-
def queue_span(self, span: TraceSpan, span_state: str = "input"):
|
1106
|
-
"""
|
1107
|
-
Queue a span for background sending.
|
1108
|
-
|
1109
|
-
Args:
|
1110
|
-
span: The TraceSpan object to queue
|
1111
|
-
span_state: State of the span ("input", "output", "completed")
|
1112
|
-
"""
|
1113
|
-
if not self._shutdown_event.is_set():
|
1114
|
-
# Increment update_id when queueing the span
|
1115
|
-
span.increment_update_id()
|
1116
|
-
span_data = {
|
1117
|
-
"type": "span",
|
1118
|
-
"data": {
|
1119
|
-
**span.model_dump(),
|
1120
|
-
"span_state": span_state,
|
1121
|
-
"queued_at": time.time(),
|
1122
|
-
},
|
1123
|
-
}
|
1124
|
-
self._span_queue.put(span_data)
|
1125
|
-
|
1126
|
-
def queue_evaluation_run(
|
1127
|
-
self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan
|
1128
|
-
):
|
1129
|
-
"""
|
1130
|
-
Queue an evaluation run for background sending.
|
1131
|
-
|
1132
|
-
Args:
|
1133
|
-
evaluation_run: The EvaluationRun object to queue
|
1134
|
-
span_id: The span ID associated with this evaluation run
|
1135
|
-
span_data: The span data at the time of evaluation (to avoid race conditions)
|
1136
|
-
"""
|
1137
|
-
if not self._shutdown_event.is_set():
|
1138
|
-
eval_data = {
|
1139
|
-
"type": "evaluation_run",
|
1140
|
-
"data": {
|
1141
|
-
**evaluation_run.model_dump(),
|
1142
|
-
"associated_span_id": span_id,
|
1143
|
-
"span_data": span_data.model_dump(), # Include span data to avoid race conditions
|
1144
|
-
"queued_at": time.time(),
|
1145
|
-
},
|
1146
|
-
}
|
1147
|
-
self._span_queue.put(eval_data)
|
1148
|
-
|
1149
|
-
def flush(self):
|
1150
|
-
"""Force immediate sending of all queued spans."""
|
1151
|
-
try:
|
1152
|
-
# Wait for the queue to be processed
|
1153
|
-
self._span_queue.join()
|
1154
|
-
except Exception as e:
|
1155
|
-
judgeval_logger.warning(f"Error during flush: {e}")
|
1156
|
-
|
1157
|
-
def shutdown(self):
|
1158
|
-
"""Shutdown the background service and flush remaining spans."""
|
1159
|
-
if self._shutdown_event.is_set():
|
1160
|
-
return
|
1161
|
-
|
1162
|
-
try:
|
1163
|
-
# Signal shutdown to stop new items from being queued
|
1164
|
-
self._shutdown_event.set()
|
1165
|
-
|
1166
|
-
# Try to flush any remaining spans
|
1167
|
-
try:
|
1168
|
-
self.flush()
|
1169
|
-
except Exception as e:
|
1170
|
-
judgeval_logger.warning(f"Error during final flush: {e}")
|
1171
|
-
except Exception as e:
|
1172
|
-
judgeval_logger.warning(f"Error during BackgroundSpanService shutdown: {e}")
|
1173
|
-
finally:
|
1174
|
-
# Clear the worker threads list (daemon threads will be killed automatically)
|
1175
|
-
self._worker_threads.clear()
|
1176
|
-
|
1177
|
-
def get_queue_size(self) -> int:
|
1178
|
-
"""Get the current size of the span queue."""
|
1179
|
-
return self._span_queue.qsize()
|
1180
|
-
|
1181
|
-
|
1182
|
-
class _DeepTracer:
|
1183
|
-
_instance: Optional["_DeepTracer"] = None
|
1184
|
-
_lock: threading.Lock = threading.Lock()
|
1185
|
-
_refcount: int = 0
|
1186
|
-
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
|
1187
|
-
"_deep_profiler_span_stack", default=[]
|
1188
|
-
)
|
1189
|
-
_skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar(
|
1190
|
-
"_deep_profiler_skip_stack", default=[]
|
1191
|
-
)
|
1192
|
-
_original_sys_trace: Optional[Callable] = None
|
1193
|
-
_original_threading_trace: Optional[Callable] = None
|
1194
|
-
|
1195
|
-
def __init__(self, tracer: "Tracer"):
|
1196
|
-
self._tracer = tracer
|
1197
|
-
|
1198
|
-
def _get_qual_name(self, frame) -> str:
|
1199
|
-
func_name = frame.f_code.co_name
|
1200
|
-
module_name = frame.f_globals.get("__name__", "unknown_module")
|
1201
|
-
|
1202
|
-
try:
|
1203
|
-
func = frame.f_globals.get(func_name)
|
1204
|
-
if func is None:
|
1205
|
-
return f"{module_name}.{func_name}"
|
1206
|
-
if hasattr(func, "__qualname__"):
|
1207
|
-
return f"{module_name}.{func.__qualname__}"
|
1208
|
-
return f"{module_name}.{func_name}"
|
1209
|
-
except Exception:
|
1210
|
-
return f"{module_name}.{func_name}"
|
1211
|
-
|
1212
|
-
def __new__(cls, tracer: "Tracer"):
|
1213
|
-
with cls._lock:
|
1214
|
-
if cls._instance is None:
|
1215
|
-
cls._instance = super().__new__(cls)
|
1216
|
-
return cls._instance
|
1217
|
-
|
1218
|
-
def _should_trace(self, frame):
|
1219
|
-
# Skip stack is maintained by the tracer as an optimization to skip earlier
|
1220
|
-
# frames in the call stack that we've already determined should be skipped
|
1221
|
-
skip_stack = self._skip_stack.get()
|
1222
|
-
if len(skip_stack) > 0:
|
1223
|
-
return False
|
1224
|
-
|
1225
|
-
func_name = frame.f_code.co_name
|
1226
|
-
module_name = frame.f_globals.get("__name__", None)
|
1227
|
-
func = frame.f_globals.get(func_name)
|
1228
|
-
if func and (
|
1229
|
-
hasattr(func, "_judgment_span_name") or hasattr(func, "_judgment_span_type")
|
1230
|
-
):
|
1231
|
-
return False
|
1232
|
-
|
1233
|
-
if (
|
1234
|
-
not module_name
|
1235
|
-
or func_name.startswith("<") # ex: <listcomp>
|
1236
|
-
or func_name.startswith("__")
|
1237
|
-
and func_name != "__call__" # dunders
|
1238
|
-
or not self._is_user_code(frame.f_code.co_filename)
|
1239
|
-
):
|
1240
|
-
return False
|
1241
|
-
|
1242
|
-
return True
|
1243
|
-
|
1244
|
-
@functools.cache
|
1245
|
-
def _is_user_code(self, filename: str):
|
1246
|
-
return (
|
1247
|
-
bool(filename)
|
1248
|
-
and not filename.startswith("<")
|
1249
|
-
and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
|
1250
|
-
)
|
1251
|
-
|
1252
|
-
def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
|
1253
|
-
"""Cooperative trace function for sys.settrace that chains with existing tracers."""
|
1254
|
-
# First, call the original sys trace function if it exists
|
1255
|
-
original_result = None
|
1256
|
-
if self._original_sys_trace:
|
1257
|
-
try:
|
1258
|
-
original_result = self._original_sys_trace(frame, event, arg)
|
1259
|
-
except Exception:
|
1260
|
-
# If the original tracer fails, continue with our tracing
|
1261
|
-
pass
|
1262
|
-
|
1263
|
-
# Then do our own tracing
|
1264
|
-
our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
|
1265
|
-
|
1266
|
-
# Return our tracer to continue tracing, but respect the original's decision
|
1267
|
-
# If the original tracer returned None (stop tracing), we should respect that
|
1268
|
-
if original_result is None and self._original_sys_trace:
|
1269
|
-
return None
|
1270
|
-
|
1271
|
-
return our_result or original_result
|
1272
|
-
|
1273
|
-
def _cooperative_threading_trace(
|
1274
|
-
self, frame: types.FrameType, event: str, arg: Any
|
1275
|
-
):
|
1276
|
-
"""Cooperative trace function for threading.settrace that chains with existing tracers."""
|
1277
|
-
# First, call the original threading trace function if it exists
|
1278
|
-
original_result = None
|
1279
|
-
if self._original_threading_trace:
|
1280
|
-
try:
|
1281
|
-
original_result = self._original_threading_trace(frame, event, arg)
|
1282
|
-
except Exception:
|
1283
|
-
# If the original tracer fails, continue with our tracing
|
1284
|
-
pass
|
1285
|
-
|
1286
|
-
# Then do our own tracing
|
1287
|
-
our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
|
1288
|
-
|
1289
|
-
# Return our tracer to continue tracing, but respect the original's decision
|
1290
|
-
# If the original tracer returned None (stop tracing), we should respect that
|
1291
|
-
if original_result is None and self._original_threading_trace:
|
1292
|
-
return None
|
1293
|
-
|
1294
|
-
return our_result or original_result
|
1295
|
-
|
1296
|
-
def _trace(
|
1297
|
-
self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable
|
1298
|
-
):
|
1299
|
-
frame.f_trace_lines = False
|
1300
|
-
frame.f_trace_opcodes = False
|
1301
|
-
|
1302
|
-
if not self._should_trace(frame):
|
1303
|
-
return
|
1304
|
-
|
1305
|
-
if event not in ("call", "return", "exception"):
|
1306
|
-
return
|
1307
|
-
|
1308
|
-
current_trace = self._tracer.get_current_trace()
|
1309
|
-
if not current_trace:
|
1310
|
-
return
|
1311
|
-
|
1312
|
-
parent_span_id = self._tracer.get_current_span()
|
1313
|
-
if not parent_span_id:
|
1314
|
-
return
|
1315
|
-
|
1316
|
-
qual_name = self._get_qual_name(frame)
|
1317
|
-
instance_name = None
|
1318
|
-
if "self" in frame.f_locals:
|
1319
|
-
instance = frame.f_locals["self"]
|
1320
|
-
class_name = instance.__class__.__name__
|
1321
|
-
class_identifiers = getattr(self._tracer, "class_identifiers", {})
|
1322
|
-
instance_name = get_instance_prefixed_name(
|
1323
|
-
instance, class_name, class_identifiers
|
1324
|
-
)
|
1325
|
-
skip_stack = self._skip_stack.get()
|
1326
|
-
|
1327
|
-
if event == "call":
|
1328
|
-
# If we have entries in the skip stack and the current qual_name matches the top entry,
|
1329
|
-
# push it again to track nesting depth and skip
|
1330
|
-
# As an optimization, we only care about duplicate qual_names.
|
1331
|
-
if skip_stack:
|
1332
|
-
if qual_name == skip_stack[-1]:
|
1333
|
-
skip_stack.append(qual_name)
|
1334
|
-
self._skip_stack.set(skip_stack)
|
1335
|
-
return
|
1336
|
-
|
1337
|
-
should_trace = self._should_trace(frame)
|
1338
|
-
|
1339
|
-
if not should_trace:
|
1340
|
-
if not skip_stack:
|
1341
|
-
self._skip_stack.set([qual_name])
|
1342
|
-
return
|
1343
|
-
elif event == "return":
|
1344
|
-
# If we have entries in skip stack and current qual_name matches the top entry,
|
1345
|
-
# pop it to track exiting from the skipped section
|
1346
|
-
if skip_stack and qual_name == skip_stack[-1]:
|
1347
|
-
skip_stack.pop()
|
1348
|
-
self._skip_stack.set(skip_stack)
|
1349
|
-
return
|
1350
|
-
|
1351
|
-
if skip_stack:
|
1352
|
-
return
|
1353
|
-
|
1354
|
-
span_stack = self._span_stack.get()
|
1355
|
-
if event == "call":
|
1356
|
-
if not self._should_trace(frame):
|
1357
|
-
return
|
1358
|
-
|
1359
|
-
span_id = str(uuid.uuid4())
|
1360
|
-
|
1361
|
-
parent_depth = current_trace._span_depths.get(parent_span_id, 0)
|
1362
|
-
depth = parent_depth + 1
|
1363
|
-
|
1364
|
-
current_trace._span_depths[span_id] = depth
|
1365
|
-
|
1366
|
-
start_time = time.time()
|
1367
|
-
|
1368
|
-
span_stack.append(
|
1369
|
-
{
|
1370
|
-
"span_id": span_id,
|
1371
|
-
"parent_span_id": parent_span_id,
|
1372
|
-
"function": qual_name,
|
1373
|
-
"start_time": start_time,
|
1374
|
-
}
|
1375
|
-
)
|
1376
|
-
self._span_stack.set(span_stack)
|
1377
|
-
|
1378
|
-
token = self._tracer.set_current_span(span_id)
|
1379
|
-
frame.f_locals["_judgment_span_token"] = token
|
1380
|
-
|
1381
|
-
span = TraceSpan(
|
1382
|
-
span_id=span_id,
|
1383
|
-
trace_id=current_trace.trace_id,
|
1384
|
-
depth=depth,
|
1385
|
-
message=qual_name,
|
1386
|
-
created_at=start_time,
|
1387
|
-
span_type="span",
|
1388
|
-
parent_span_id=parent_span_id,
|
1389
|
-
function=qual_name,
|
1390
|
-
agent_name=instance_name,
|
1391
|
-
)
|
1392
|
-
current_trace.add_span(span)
|
1393
|
-
|
1394
|
-
inputs = {}
|
1395
|
-
try:
|
1396
|
-
args_info = inspect.getargvalues(frame)
|
1397
|
-
for arg in args_info.args:
|
1398
|
-
try:
|
1399
|
-
inputs[arg] = args_info.locals.get(arg)
|
1400
|
-
except Exception:
|
1401
|
-
inputs[arg] = "<<Unserializable>>"
|
1402
|
-
current_trace.record_input(inputs)
|
1403
|
-
except Exception as e:
|
1404
|
-
current_trace.record_input({"error": str(e)})
|
1405
|
-
|
1406
|
-
elif event == "return":
|
1407
|
-
if not span_stack:
|
1408
|
-
return
|
1409
|
-
|
1410
|
-
current_id = self._tracer.get_current_span()
|
1411
|
-
|
1412
|
-
span_data = None
|
1413
|
-
for i, entry in enumerate(reversed(span_stack)):
|
1414
|
-
if entry["span_id"] == current_id:
|
1415
|
-
span_data = span_stack.pop(-(i + 1))
|
1416
|
-
self._span_stack.set(span_stack)
|
1417
|
-
break
|
1418
|
-
|
1419
|
-
if not span_data:
|
1420
|
-
return
|
1421
|
-
|
1422
|
-
start_time = span_data["start_time"]
|
1423
|
-
duration = time.time() - start_time
|
1424
|
-
|
1425
|
-
current_trace.span_id_to_span[span_data["span_id"]].duration = duration
|
1426
|
-
|
1427
|
-
if arg is not None:
|
1428
|
-
# exception handling will take priority.
|
1429
|
-
current_trace.record_output(arg)
|
1430
|
-
|
1431
|
-
if span_data["span_id"] in current_trace._span_depths:
|
1432
|
-
del current_trace._span_depths[span_data["span_id"]]
|
1433
|
-
|
1434
|
-
if span_stack:
|
1435
|
-
self._tracer.set_current_span(span_stack[-1]["span_id"])
|
1436
|
-
else:
|
1437
|
-
self._tracer.set_current_span(span_data["parent_span_id"])
|
1438
|
-
|
1439
|
-
if "_judgment_span_token" in frame.f_locals:
|
1440
|
-
self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
|
1441
|
-
|
1442
|
-
elif event == "exception":
|
1443
|
-
exc_type = arg[0]
|
1444
|
-
if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
|
1445
|
-
return
|
1446
|
-
_capture_exception_for_trace(current_trace, arg)
|
1447
|
-
|
1448
|
-
return continuation_func
|
1449
|
-
|
1450
|
-
def __enter__(self):
|
1451
|
-
with self._lock:
|
1452
|
-
self._refcount += 1
|
1453
|
-
if self._refcount == 1:
|
1454
|
-
# Store the existing trace functions before setting ours
|
1455
|
-
self._original_sys_trace = sys.gettrace()
|
1456
|
-
self._original_threading_trace = threading.gettrace()
|
1457
|
-
|
1458
|
-
self._skip_stack.set([])
|
1459
|
-
self._span_stack.set([])
|
1460
|
-
|
1461
|
-
sys.settrace(self._cooperative_sys_trace)
|
1462
|
-
threading.settrace(self._cooperative_threading_trace)
|
1463
|
-
return self
|
1464
|
-
|
1465
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
1466
|
-
with self._lock:
|
1467
|
-
self._refcount -= 1
|
1468
|
-
if self._refcount == 0:
|
1469
|
-
# Restore the original trace functions instead of setting to None
|
1470
|
-
sys.settrace(self._original_sys_trace)
|
1471
|
-
threading.settrace(self._original_threading_trace)
|
1472
|
-
|
1473
|
-
# Clean up the references
|
1474
|
-
self._original_sys_trace = None
|
1475
|
-
self._original_threading_trace = None
|
1476
|
-
|
1477
|
-
|
1478
|
-
# Below commented out function isn't used anymore?
|
1479
|
-
|
1480
|
-
# def log(self, message: str, level: str = "info"):
|
1481
|
-
# """ Log a message with the span context """
|
1482
|
-
# current_trace = self._tracer.get_current_trace()
|
1483
|
-
# if current_trace:
|
1484
|
-
# current_trace.log(message, level)
|
1485
|
-
# else:
|
1486
|
-
# print(f"[{level}] {message}")
|
1487
|
-
# current_trace.record_output({"log": message})
|
1488
|
-
|
1489
|
-
|
1490
|
-
class Tracer:
|
1491
|
-
# Tracer.current_trace class variable is currently used in wrap()
|
1492
|
-
# TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
|
1493
|
-
# Should be fine to do so as long as we keep Tracer as a singleton
|
1494
|
-
current_trace: Optional[TraceClient] = None
|
1495
|
-
# current_span_id: Optional[str] = None
|
1496
|
-
|
1497
|
-
trace_across_async_contexts: bool = (
|
1498
|
-
False # BY default, we don't trace across async contexts
|
1499
|
-
)
|
1500
|
-
|
1501
|
-
def __init__(
|
1502
|
-
self,
|
1503
|
-
api_key: str | None = os.getenv("JUDGMENT_API_KEY"),
|
1504
|
-
organization_id: str | None = os.getenv("JUDGMENT_ORG_ID"),
|
1505
|
-
project_name: str | None = None,
|
1506
|
-
deep_tracing: bool = False, # Deep tracing is disabled by default
|
1507
|
-
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
|
1508
|
-
== "true",
|
1509
|
-
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
|
1510
|
-
== "true",
|
1511
|
-
# S3 configuration
|
1512
|
-
use_s3: bool = False,
|
1513
|
-
s3_bucket_name: Optional[str] = None,
|
1514
|
-
s3_aws_access_key_id: Optional[str] = None,
|
1515
|
-
s3_aws_secret_access_key: Optional[str] = None,
|
1516
|
-
s3_region_name: Optional[str] = None,
|
1517
|
-
trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
|
1518
|
-
# Background span service configuration
|
1519
|
-
span_batch_size: int = 50, # Number of spans to batch before sending
|
1520
|
-
span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
|
1521
|
-
span_num_workers: int = 10, # Number of worker threads for span processing
|
1522
|
-
):
|
1523
|
-
try:
|
1524
|
-
if not api_key:
|
1525
|
-
raise ValueError(
|
1526
|
-
"api_key parameter must be provided. Please provide a valid API key value or set the JUDGMENT_API_KEY environment variable"
|
1527
|
-
)
|
1528
|
-
|
1529
|
-
if not organization_id:
|
1530
|
-
raise ValueError(
|
1531
|
-
"organization_id parameter must be provided. Please provide a valid organization ID value or set the JUDGMENT_ORG_ID environment variable"
|
1532
|
-
)
|
1533
|
-
|
1534
|
-
try:
|
1535
|
-
result, response = validate_api_key(api_key)
|
1536
|
-
except Exception as e:
|
1537
|
-
judgeval_logger.error(
|
1538
|
-
f"Issue with verifying API key, disabling monitoring: {e}"
|
1539
|
-
)
|
1540
|
-
enable_monitoring = False
|
1541
|
-
result = True
|
1542
|
-
|
1543
|
-
if not result:
|
1544
|
-
raise ValueError(f"Issue with passed in Judgment API key: {response}")
|
1545
|
-
|
1546
|
-
if use_s3 and not s3_bucket_name:
|
1547
|
-
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
1548
|
-
|
1549
|
-
self.api_key: str = api_key
|
1550
|
-
self.project_name: str = project_name or "default_project"
|
1551
|
-
self.organization_id: str = organization_id
|
1552
|
-
self.traces: List[Trace] = []
|
1553
|
-
self.enable_monitoring: bool = enable_monitoring
|
1554
|
-
self.enable_evaluations: bool = enable_evaluations
|
1555
|
-
self.class_identifiers: Dict[
|
1556
|
-
str, str
|
1557
|
-
] = {} # Dictionary to store class identifiers
|
1558
|
-
self.span_id_to_previous_span_id: Dict[str, str | None] = {}
|
1559
|
-
self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
|
1560
|
-
self.current_span_id: Optional[str] = None
|
1561
|
-
self.current_trace: Optional[TraceClient] = None
|
1562
|
-
self.trace_across_async_contexts: bool = trace_across_async_contexts
|
1563
|
-
Tracer.trace_across_async_contexts = trace_across_async_contexts
|
1564
|
-
|
1565
|
-
# Initialize S3 storage if enabled
|
1566
|
-
self.use_s3 = use_s3
|
1567
|
-
if use_s3:
|
1568
|
-
from judgeval.common.s3_storage import S3Storage
|
1569
|
-
|
1570
|
-
try:
|
1571
|
-
self.s3_storage = S3Storage(
|
1572
|
-
bucket_name=s3_bucket_name,
|
1573
|
-
aws_access_key_id=s3_aws_access_key_id,
|
1574
|
-
aws_secret_access_key=s3_aws_secret_access_key,
|
1575
|
-
region_name=s3_region_name,
|
1576
|
-
)
|
1577
|
-
except Exception as e:
|
1578
|
-
judgeval_logger.error(
|
1579
|
-
f"Issue with initializing S3 storage, disabling S3: {e}"
|
1580
|
-
)
|
1581
|
-
self.use_s3 = False
|
1582
|
-
|
1583
|
-
self.offline_mode = False # This is used to differentiate traces between online and offline (IE experiments vs monitoring page)
|
1584
|
-
self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
|
1585
|
-
|
1586
|
-
# Initialize background span service
|
1587
|
-
self.background_span_service: Optional[BackgroundSpanService] = None
|
1588
|
-
self.background_span_service = BackgroundSpanService(
|
1589
|
-
judgment_api_key=api_key,
|
1590
|
-
organization_id=organization_id,
|
1591
|
-
batch_size=span_batch_size,
|
1592
|
-
flush_interval=span_flush_interval,
|
1593
|
-
num_workers=span_num_workers,
|
1594
|
-
)
|
1595
|
-
except Exception as e:
|
1596
|
-
judgeval_logger.error(
|
1597
|
-
f"Issue with initializing Tracer: {e}. Disabling monitoring and evaluations."
|
1598
|
-
)
|
1599
|
-
self.enable_monitoring = False
|
1600
|
-
self.enable_evaluations = False
|
1601
|
-
|
1602
|
-
def set_current_span(self, span_id: str) -> Optional[contextvars.Token[str | None]]:
|
1603
|
-
self.span_id_to_previous_span_id[span_id] = self.current_span_id
|
1604
|
-
self.current_span_id = span_id
|
1605
|
-
Tracer.current_span_id = span_id
|
1606
|
-
try:
|
1607
|
-
token = current_span_var.set(span_id)
|
1608
|
-
except Exception:
|
1609
|
-
token = None
|
1610
|
-
return token
|
1611
|
-
|
1612
|
-
def get_current_span(self) -> Optional[str]:
|
1613
|
-
try:
|
1614
|
-
current_span_var_val = current_span_var.get()
|
1615
|
-
except Exception:
|
1616
|
-
current_span_var_val = None
|
1617
|
-
return (
|
1618
|
-
(self.current_span_id or current_span_var_val)
|
1619
|
-
if self.trace_across_async_contexts
|
1620
|
-
else current_span_var_val
|
1621
|
-
)
|
1622
|
-
|
1623
|
-
def reset_current_span(
|
1624
|
-
self,
|
1625
|
-
token: Optional[contextvars.Token[str | None]] = None,
|
1626
|
-
span_id: Optional[str] = None,
|
1627
|
-
):
|
1628
|
-
try:
|
1629
|
-
if token:
|
1630
|
-
current_span_var.reset(token)
|
1631
|
-
except Exception:
|
1632
|
-
pass
|
1633
|
-
if not span_id:
|
1634
|
-
span_id = self.current_span_id
|
1635
|
-
if span_id:
|
1636
|
-
self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
|
1637
|
-
Tracer.current_span_id = self.current_span_id
|
1638
|
-
|
1639
|
-
def set_current_trace(
|
1640
|
-
self, trace: TraceClient
|
1641
|
-
) -> Optional[contextvars.Token[TraceClient | None]]:
|
1642
|
-
"""
|
1643
|
-
Set the current trace context in contextvars
|
1644
|
-
"""
|
1645
|
-
self.trace_id_to_previous_trace[trace.trace_id] = self.current_trace
|
1646
|
-
self.current_trace = trace
|
1647
|
-
Tracer.current_trace = trace
|
1648
|
-
try:
|
1649
|
-
token = current_trace_var.set(trace)
|
1650
|
-
except Exception:
|
1651
|
-
token = None
|
1652
|
-
return token
|
1653
|
-
|
1654
|
-
def get_current_trace(self) -> Optional[TraceClient]:
|
1655
|
-
"""
|
1656
|
-
Get the current trace context.
|
1657
|
-
|
1658
|
-
Tries to get the trace client from the context variable first.
|
1659
|
-
If not found (e.g., context lost across threads/tasks),
|
1660
|
-
it falls back to the active trace client managed by the callback handler.
|
1661
|
-
"""
|
1662
|
-
try:
|
1663
|
-
current_trace_var_val = current_trace_var.get()
|
1664
|
-
except Exception:
|
1665
|
-
current_trace_var_val = None
|
1666
|
-
return (
|
1667
|
-
(self.current_trace or current_trace_var_val)
|
1668
|
-
if self.trace_across_async_contexts
|
1669
|
-
else current_trace_var_val
|
1670
|
-
)
|
1671
|
-
|
1672
|
-
def reset_current_trace(
|
1673
|
-
self,
|
1674
|
-
token: Optional[contextvars.Token[TraceClient | None]] = None,
|
1675
|
-
trace_id: Optional[str] = None,
|
1676
|
-
):
|
1677
|
-
try:
|
1678
|
-
if token:
|
1679
|
-
current_trace_var.reset(token)
|
1680
|
-
except Exception:
|
1681
|
-
pass
|
1682
|
-
if not trace_id and self.current_trace:
|
1683
|
-
trace_id = self.current_trace.trace_id
|
1684
|
-
if trace_id:
|
1685
|
-
self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
|
1686
|
-
Tracer.current_trace = self.current_trace
|
1687
|
-
|
1688
|
-
@contextmanager
|
1689
|
-
def trace(
|
1690
|
-
self, name: str, project_name: str | None = None
|
1691
|
-
) -> Generator[TraceClient, None, None]:
|
1692
|
-
"""Start a new trace context using a context manager"""
|
1693
|
-
trace_id = str(uuid.uuid4())
|
1694
|
-
project = project_name if project_name is not None else self.project_name
|
1695
|
-
|
1696
|
-
# Get parent trace info from context
|
1697
|
-
parent_trace = self.get_current_trace()
|
1698
|
-
parent_trace_id = None
|
1699
|
-
parent_name = None
|
1700
|
-
|
1701
|
-
if parent_trace:
|
1702
|
-
parent_trace_id = parent_trace.trace_id
|
1703
|
-
parent_name = parent_trace.name
|
1704
|
-
|
1705
|
-
trace = TraceClient(
|
1706
|
-
self,
|
1707
|
-
trace_id,
|
1708
|
-
name,
|
1709
|
-
project_name=project,
|
1710
|
-
enable_monitoring=self.enable_monitoring,
|
1711
|
-
enable_evaluations=self.enable_evaluations,
|
1712
|
-
parent_trace_id=parent_trace_id,
|
1713
|
-
parent_name=parent_name,
|
1714
|
-
)
|
1715
|
-
|
1716
|
-
# Set the current trace in context variables
|
1717
|
-
token = self.set_current_trace(trace)
|
1718
|
-
|
1719
|
-
# Automatically create top-level span
|
1720
|
-
with trace.span(name or "unnamed_trace"):
|
1721
|
-
try:
|
1722
|
-
# Save the trace to the database to handle Evaluations' trace_id referential integrity
|
1723
|
-
yield trace
|
1724
|
-
finally:
|
1725
|
-
# Reset the context variable
|
1726
|
-
self.reset_current_trace(token)
|
1727
|
-
|
1728
|
-
def log(self, msg: str, label: str = "log", score: int = 1):
|
1729
|
-
"""Log a message with the current span context"""
|
1730
|
-
current_span_id = self.get_current_span()
|
1731
|
-
current_trace = self.get_current_trace()
|
1732
|
-
if current_span_id and current_trace:
|
1733
|
-
annotation = TraceAnnotation(
|
1734
|
-
span_id=current_span_id, text=msg, label=label, score=score
|
1735
|
-
)
|
1736
|
-
current_trace.add_annotation(annotation)
|
1737
|
-
|
1738
|
-
rprint(f"[bold]{label}:[/bold] {msg}")
|
1739
|
-
|
1740
|
-
def identify(
|
1741
|
-
self,
|
1742
|
-
identifier: str,
|
1743
|
-
track_state: bool = False,
|
1744
|
-
track_attributes: Optional[List[str]] = None,
|
1745
|
-
field_mappings: Optional[Dict[str, str]] = None,
|
1746
|
-
):
|
1747
|
-
"""
|
1748
|
-
Class decorator that associates a class with a custom identifier and enables state tracking.
|
1749
|
-
|
1750
|
-
This decorator creates a mapping between the class name and the provided
|
1751
|
-
identifier, which can be useful for tagging, grouping, or referencing
|
1752
|
-
classes in a standardized way. It also enables automatic state capture
|
1753
|
-
for instances of the decorated class when used with tracing.
|
1754
|
-
|
1755
|
-
Args:
|
1756
|
-
identifier: The identifier to associate with the decorated class.
|
1757
|
-
This will be used as the instance name in traces.
|
1758
|
-
track_state: Whether to automatically capture the state (attributes)
|
1759
|
-
of instances before and after function execution. Defaults to False.
|
1760
|
-
track_attributes: Optional list of specific attribute names to track.
|
1761
|
-
If None, all non-private attributes (not starting with '_')
|
1762
|
-
will be tracked when track_state=True.
|
1763
|
-
field_mappings: Optional dictionary mapping internal attribute names to
|
1764
|
-
display names in the captured state. For example:
|
1765
|
-
{"system_prompt": "instructions"} will capture the
|
1766
|
-
'instructions' attribute as 'system_prompt' in the state.
|
1767
|
-
|
1768
|
-
Example:
|
1769
|
-
@tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
|
1770
|
-
class User:
|
1771
|
-
# Class implementation
|
1772
|
-
"""
|
1773
|
-
|
1774
|
-
def decorator(cls):
|
1775
|
-
class_name = cls.__name__
|
1776
|
-
self.class_identifiers[class_name] = {
|
1777
|
-
"identifier": identifier,
|
1778
|
-
"track_state": track_state,
|
1779
|
-
"track_attributes": track_attributes,
|
1780
|
-
"field_mappings": field_mappings or {},
|
1781
|
-
}
|
1782
|
-
return cls
|
1783
|
-
|
1784
|
-
return decorator
|
1785
|
-
|
1786
|
-
def _capture_instance_state(
|
1787
|
-
self, instance: Any, class_config: Dict[str, Any]
|
1788
|
-
) -> Dict[str, Any]:
|
1789
|
-
"""
|
1790
|
-
Capture the state of an instance based on class configuration.
|
1791
|
-
Args:
|
1792
|
-
instance: The instance to capture the state of.
|
1793
|
-
class_config: Configuration dictionary for state capture,
|
1794
|
-
expected to contain 'track_attributes' and 'field_mappings'.
|
1795
|
-
"""
|
1796
|
-
track_attributes = class_config.get("track_attributes")
|
1797
|
-
field_mappings = class_config.get("field_mappings")
|
1798
|
-
|
1799
|
-
if track_attributes:
|
1800
|
-
state = {attr: getattr(instance, attr, None) for attr in track_attributes}
|
1801
|
-
else:
|
1802
|
-
state = {
|
1803
|
-
k: v for k, v in instance.__dict__.items() if not k.startswith("_")
|
1804
|
-
}
|
1805
|
-
|
1806
|
-
if field_mappings:
|
1807
|
-
state["field_mappings"] = field_mappings
|
1808
|
-
|
1809
|
-
return state
|
1810
|
-
|
1811
|
-
def _get_instance_state_if_tracked(self, args):
|
1812
|
-
"""
|
1813
|
-
Extract instance state if the instance should be tracked.
|
1814
|
-
|
1815
|
-
Returns the captured state dict if tracking is enabled, None otherwise.
|
1816
|
-
"""
|
1817
|
-
if args and hasattr(args[0], "__class__"):
|
1818
|
-
instance = args[0]
|
1819
|
-
class_name = instance.__class__.__name__
|
1820
|
-
if (
|
1821
|
-
class_name in self.class_identifiers
|
1822
|
-
and isinstance(self.class_identifiers[class_name], dict)
|
1823
|
-
and self.class_identifiers[class_name].get("track_state", False)
|
1824
|
-
):
|
1825
|
-
return self._capture_instance_state(
|
1826
|
-
instance, self.class_identifiers[class_name]
|
1827
|
-
)
|
1828
|
-
|
1829
|
-
def _conditionally_capture_and_record_state(
|
1830
|
-
self, trace_client_instance: TraceClient, args: tuple, is_before: bool
|
1831
|
-
):
|
1832
|
-
"""Captures instance state if tracked and records it via the trace_client."""
|
1833
|
-
state = self._get_instance_state_if_tracked(args)
|
1834
|
-
if state:
|
1835
|
-
if is_before:
|
1836
|
-
trace_client_instance.record_state_before(state)
|
1837
|
-
else:
|
1838
|
-
trace_client_instance.record_state_after(state)
|
1839
|
-
|
1840
|
-
def observe(
|
1841
|
-
self,
|
1842
|
-
func=None,
|
1843
|
-
*,
|
1844
|
-
name=None,
|
1845
|
-
span_type: SpanType = "span",
|
1846
|
-
):
|
1847
|
-
"""
|
1848
|
-
Decorator to trace function execution with detailed entry/exit information.
|
1849
|
-
|
1850
|
-
Args:
|
1851
|
-
func: The function to decorate
|
1852
|
-
name: Optional custom name for the span (defaults to function name)
|
1853
|
-
span_type: Type of span (default "span").
|
1854
|
-
"""
|
1855
|
-
# If monitoring is disabled, return the function as is
|
1856
|
-
try:
|
1857
|
-
if not self.enable_monitoring:
|
1858
|
-
return func if func else lambda f: f
|
1859
|
-
|
1860
|
-
if func is None:
|
1861
|
-
return lambda f: self.observe(
|
1862
|
-
f,
|
1863
|
-
name=name,
|
1864
|
-
span_type=span_type,
|
1865
|
-
)
|
1866
|
-
|
1867
|
-
# Use provided name or fall back to function name
|
1868
|
-
original_span_name = name or func.__name__
|
1869
|
-
|
1870
|
-
# Store custom attributes on the function object
|
1871
|
-
func._judgment_span_name = original_span_name
|
1872
|
-
func._judgment_span_type = span_type
|
1873
|
-
|
1874
|
-
except Exception:
|
1875
|
-
return func
|
1876
|
-
|
1877
|
-
if asyncio.iscoroutinefunction(func):
|
1878
|
-
|
1879
|
-
@functools.wraps(func)
|
1880
|
-
async def async_wrapper(*args, **kwargs):
|
1881
|
-
nonlocal original_span_name
|
1882
|
-
class_name = None
|
1883
|
-
span_name = original_span_name
|
1884
|
-
agent_name = None
|
1885
|
-
|
1886
|
-
if args and hasattr(args[0], "__class__"):
|
1887
|
-
class_name = args[0].__class__.__name__
|
1888
|
-
agent_name = get_instance_prefixed_name(
|
1889
|
-
args[0], class_name, self.class_identifiers
|
1890
|
-
)
|
1891
|
-
|
1892
|
-
# Get current trace from context
|
1893
|
-
current_trace = self.get_current_trace()
|
1894
|
-
|
1895
|
-
# If there's no current trace, create a root trace
|
1896
|
-
if not current_trace:
|
1897
|
-
trace_id = str(uuid.uuid4())
|
1898
|
-
project = self.project_name
|
1899
|
-
|
1900
|
-
# Create a new trace client to serve as the root
|
1901
|
-
current_trace = TraceClient(
|
1902
|
-
self,
|
1903
|
-
trace_id,
|
1904
|
-
span_name, # MODIFIED: Use span_name directly
|
1905
|
-
project_name=project,
|
1906
|
-
enable_monitoring=self.enable_monitoring,
|
1907
|
-
enable_evaluations=self.enable_evaluations,
|
1908
|
-
)
|
1909
|
-
|
1910
|
-
trace_token = self.set_current_trace(current_trace)
|
1911
|
-
|
1912
|
-
try:
|
1913
|
-
# Use span for the function execution within the root trace
|
1914
|
-
# This sets the current_span_var
|
1915
|
-
with current_trace.span(
|
1916
|
-
span_name, span_type=span_type
|
1917
|
-
) as span: # MODIFIED: Use span_name directly
|
1918
|
-
# Record inputs
|
1919
|
-
inputs = combine_args_kwargs(func, args, kwargs)
|
1920
|
-
span.record_input(inputs)
|
1921
|
-
if agent_name:
|
1922
|
-
span.record_agent_name(agent_name)
|
1923
|
-
|
1924
|
-
# Capture state before execution
|
1925
|
-
self._conditionally_capture_and_record_state(
|
1926
|
-
span, args, is_before=True
|
1927
|
-
)
|
1928
|
-
|
1929
|
-
try:
|
1930
|
-
if self.deep_tracing:
|
1931
|
-
with _DeepTracer(self):
|
1932
|
-
result = await func(*args, **kwargs)
|
1933
|
-
else:
|
1934
|
-
result = await func(*args, **kwargs)
|
1935
|
-
except Exception as e:
|
1936
|
-
_capture_exception_for_trace(
|
1937
|
-
current_trace, sys.exc_info()
|
1938
|
-
)
|
1939
|
-
raise e
|
1940
|
-
|
1941
|
-
# Capture state after execution
|
1942
|
-
self._conditionally_capture_and_record_state(
|
1943
|
-
span, args, is_before=False
|
1944
|
-
)
|
1945
|
-
|
1946
|
-
# Record output
|
1947
|
-
span.record_output(result)
|
1948
|
-
return result
|
1949
|
-
finally:
|
1950
|
-
# Flush background spans before saving the trace
|
1951
|
-
try:
|
1952
|
-
complete_trace_data = {
|
1953
|
-
"trace_id": current_trace.trace_id,
|
1954
|
-
"name": current_trace.name,
|
1955
|
-
"created_at": datetime.fromtimestamp(
|
1956
|
-
current_trace.start_time or time.time(),
|
1957
|
-
timezone.utc,
|
1958
|
-
).isoformat(),
|
1959
|
-
"duration": current_trace.get_duration(),
|
1960
|
-
"trace_spans": [
|
1961
|
-
span.model_dump()
|
1962
|
-
for span in current_trace.trace_spans
|
1963
|
-
],
|
1964
|
-
"offline_mode": self.offline_mode,
|
1965
|
-
"parent_trace_id": current_trace.parent_trace_id,
|
1966
|
-
"parent_name": current_trace.parent_name,
|
1967
|
-
}
|
1968
|
-
# Save the completed trace
|
1969
|
-
trace_id, server_response = current_trace.save(
|
1970
|
-
final_save=True
|
1971
|
-
)
|
1972
|
-
|
1973
|
-
# Store the complete trace data instead of just server response
|
1974
|
-
|
1975
|
-
self.traces.append(complete_trace_data)
|
1976
|
-
|
1977
|
-
# if self.background_span_service:
|
1978
|
-
# self.background_span_service.flush()
|
1979
|
-
|
1980
|
-
# Reset trace context (span context resets automatically)
|
1981
|
-
self.reset_current_trace(trace_token)
|
1982
|
-
except Exception as e:
|
1983
|
-
judgeval_logger.warning(f"Issue with async_wrapper: {e}")
|
1984
|
-
return
|
1985
|
-
else:
|
1986
|
-
with current_trace.span(span_name, span_type=span_type) as span:
|
1987
|
-
inputs = combine_args_kwargs(func, args, kwargs)
|
1988
|
-
span.record_input(inputs)
|
1989
|
-
if agent_name:
|
1990
|
-
span.record_agent_name(agent_name)
|
1991
|
-
|
1992
|
-
# Capture state before execution
|
1993
|
-
self._conditionally_capture_and_record_state(
|
1994
|
-
span, args, is_before=True
|
1995
|
-
)
|
1996
|
-
|
1997
|
-
try:
|
1998
|
-
if self.deep_tracing:
|
1999
|
-
with _DeepTracer(self):
|
2000
|
-
result = await func(*args, **kwargs)
|
2001
|
-
else:
|
2002
|
-
result = await func(*args, **kwargs)
|
2003
|
-
except Exception as e:
|
2004
|
-
_capture_exception_for_trace(current_trace, sys.exc_info())
|
2005
|
-
raise e
|
2006
|
-
|
2007
|
-
# Capture state after execution
|
2008
|
-
self._conditionally_capture_and_record_state(
|
2009
|
-
span, args, is_before=False
|
2010
|
-
)
|
2011
|
-
|
2012
|
-
span.record_output(result)
|
2013
|
-
return result
|
2014
|
-
|
2015
|
-
return async_wrapper
|
2016
|
-
else:
|
2017
|
-
# Non-async function implementation with deep tracing
|
2018
|
-
@functools.wraps(func)
|
2019
|
-
def wrapper(*args, **kwargs):
|
2020
|
-
nonlocal original_span_name
|
2021
|
-
class_name = None
|
2022
|
-
span_name = original_span_name
|
2023
|
-
agent_name = None
|
2024
|
-
if args and hasattr(args[0], "__class__"):
|
2025
|
-
class_name = args[0].__class__.__name__
|
2026
|
-
agent_name = get_instance_prefixed_name(
|
2027
|
-
args[0], class_name, self.class_identifiers
|
2028
|
-
)
|
2029
|
-
# Get current trace from context
|
2030
|
-
current_trace = self.get_current_trace()
|
2031
|
-
|
2032
|
-
# If there's no current trace, create a root trace
|
2033
|
-
if not current_trace:
|
2034
|
-
trace_id = str(uuid.uuid4())
|
2035
|
-
project = self.project_name
|
2036
|
-
|
2037
|
-
# Create a new trace client to serve as the root
|
2038
|
-
current_trace = TraceClient(
|
2039
|
-
self,
|
2040
|
-
trace_id,
|
2041
|
-
span_name, # MODIFIED: Use span_name directly
|
2042
|
-
project_name=project,
|
2043
|
-
enable_monitoring=self.enable_monitoring,
|
2044
|
-
enable_evaluations=self.enable_evaluations,
|
2045
|
-
)
|
2046
|
-
|
2047
|
-
trace_token = self.set_current_trace(current_trace)
|
2048
|
-
|
2049
|
-
try:
|
2050
|
-
# Use span for the function execution within the root trace
|
2051
|
-
# This sets the current_span_var
|
2052
|
-
with current_trace.span(
|
2053
|
-
span_name, span_type=span_type
|
2054
|
-
) as span: # MODIFIED: Use span_name directly
|
2055
|
-
# Record inputs
|
2056
|
-
inputs = combine_args_kwargs(func, args, kwargs)
|
2057
|
-
span.record_input(inputs)
|
2058
|
-
if agent_name:
|
2059
|
-
span.record_agent_name(agent_name)
|
2060
|
-
# Capture state before execution
|
2061
|
-
self._conditionally_capture_and_record_state(
|
2062
|
-
span, args, is_before=True
|
2063
|
-
)
|
2064
|
-
|
2065
|
-
try:
|
2066
|
-
if self.deep_tracing:
|
2067
|
-
with _DeepTracer(self):
|
2068
|
-
result = func(*args, **kwargs)
|
2069
|
-
else:
|
2070
|
-
result = func(*args, **kwargs)
|
2071
|
-
except Exception as e:
|
2072
|
-
_capture_exception_for_trace(
|
2073
|
-
current_trace, sys.exc_info()
|
2074
|
-
)
|
2075
|
-
raise e
|
2076
|
-
|
2077
|
-
# Capture state after execution
|
2078
|
-
self._conditionally_capture_and_record_state(
|
2079
|
-
span, args, is_before=False
|
2080
|
-
)
|
2081
|
-
|
2082
|
-
# Record output
|
2083
|
-
span.record_output(result)
|
2084
|
-
return result
|
2085
|
-
finally:
|
2086
|
-
# Flush background spans before saving the trace
|
2087
|
-
try:
|
2088
|
-
# Save the completed trace
|
2089
|
-
trace_id, server_response = current_trace.save(
|
2090
|
-
final_save=True
|
2091
|
-
)
|
2092
|
-
|
2093
|
-
# Store the complete trace data instead of just server response
|
2094
|
-
complete_trace_data = {
|
2095
|
-
"trace_id": current_trace.trace_id,
|
2096
|
-
"name": current_trace.name,
|
2097
|
-
"created_at": datetime.fromtimestamp(
|
2098
|
-
current_trace.start_time or time.time(),
|
2099
|
-
timezone.utc,
|
2100
|
-
).isoformat(),
|
2101
|
-
"duration": current_trace.get_duration(),
|
2102
|
-
"trace_spans": [
|
2103
|
-
span.model_dump()
|
2104
|
-
for span in current_trace.trace_spans
|
2105
|
-
],
|
2106
|
-
"offline_mode": self.offline_mode,
|
2107
|
-
"parent_trace_id": current_trace.parent_trace_id,
|
2108
|
-
"parent_name": current_trace.parent_name,
|
2109
|
-
}
|
2110
|
-
self.traces.append(complete_trace_data)
|
2111
|
-
# Reset trace context (span context resets automatically)
|
2112
|
-
self.reset_current_trace(trace_token)
|
2113
|
-
except Exception as e:
|
2114
|
-
judgeval_logger.warning(f"Issue with save: {e}")
|
2115
|
-
return
|
2116
|
-
else:
|
2117
|
-
with current_trace.span(span_name, span_type=span_type) as span:
|
2118
|
-
inputs = combine_args_kwargs(func, args, kwargs)
|
2119
|
-
span.record_input(inputs)
|
2120
|
-
if agent_name:
|
2121
|
-
span.record_agent_name(agent_name)
|
2122
|
-
|
2123
|
-
# Capture state before execution
|
2124
|
-
self._conditionally_capture_and_record_state(
|
2125
|
-
span, args, is_before=True
|
2126
|
-
)
|
2127
|
-
|
2128
|
-
try:
|
2129
|
-
if self.deep_tracing:
|
2130
|
-
with _DeepTracer(self):
|
2131
|
-
result = func(*args, **kwargs)
|
2132
|
-
else:
|
2133
|
-
result = func(*args, **kwargs)
|
2134
|
-
except Exception as e:
|
2135
|
-
_capture_exception_for_trace(current_trace, sys.exc_info())
|
2136
|
-
raise e
|
2137
|
-
|
2138
|
-
# Capture state after execution
|
2139
|
-
self._conditionally_capture_and_record_state(
|
2140
|
-
span, args, is_before=False
|
2141
|
-
)
|
2142
|
-
|
2143
|
-
span.record_output(result)
|
2144
|
-
return result
|
2145
|
-
|
2146
|
-
return wrapper
|
2147
|
-
|
2148
|
-
def observe_tools(
|
2149
|
-
self,
|
2150
|
-
cls=None,
|
2151
|
-
*,
|
2152
|
-
exclude_methods: Optional[List[str]] = None,
|
2153
|
-
include_private: bool = False,
|
2154
|
-
warn_on_double_decoration: bool = True,
|
2155
|
-
):
|
2156
|
-
"""
|
2157
|
-
Automatically adds @observe(span_type="tool") to all methods in a class.
|
2158
|
-
|
2159
|
-
Args:
|
2160
|
-
cls: The class to decorate (automatically provided when used as decorator)
|
2161
|
-
exclude_methods: List of method names to skip decorating. Defaults to common magic methods
|
2162
|
-
include_private: Whether to decorate methods starting with underscore. Defaults to False
|
2163
|
-
warn_on_double_decoration: Whether to print warnings when skipping already-decorated methods. Defaults to True
|
2164
|
-
"""
|
2165
|
-
|
2166
|
-
if exclude_methods is None:
|
2167
|
-
exclude_methods = ["__init__", "__new__", "__del__", "__str__", "__repr__"]
|
2168
|
-
|
2169
|
-
def decorate_class(cls):
|
2170
|
-
if not self.enable_monitoring:
|
2171
|
-
return cls
|
2172
|
-
|
2173
|
-
decorated = []
|
2174
|
-
skipped = []
|
2175
|
-
|
2176
|
-
for name in dir(cls):
|
2177
|
-
method = getattr(cls, name)
|
2178
|
-
|
2179
|
-
if (
|
2180
|
-
not callable(method)
|
2181
|
-
or name in exclude_methods
|
2182
|
-
or (name.startswith("_") and not include_private)
|
2183
|
-
or not hasattr(cls, name)
|
2184
|
-
):
|
2185
|
-
continue
|
2186
|
-
|
2187
|
-
if hasattr(method, "_judgment_span_name"):
|
2188
|
-
skipped.append(name)
|
2189
|
-
if warn_on_double_decoration:
|
2190
|
-
judgeval_logger.info(
|
2191
|
-
f"{cls.__name__}.{name} already decorated, skipping"
|
2192
|
-
)
|
2193
|
-
continue
|
2194
|
-
|
2195
|
-
try:
|
2196
|
-
decorated_method = self.observe(method, span_type="tool")
|
2197
|
-
setattr(cls, name, decorated_method)
|
2198
|
-
decorated.append(name)
|
2199
|
-
except Exception as e:
|
2200
|
-
if warn_on_double_decoration:
|
2201
|
-
judgeval_logger.warning(
|
2202
|
-
f"Failed to decorate {cls.__name__}.{name}: {e}"
|
2203
|
-
)
|
2204
|
-
|
2205
|
-
return cls
|
2206
|
-
|
2207
|
-
return decorate_class if cls is None else decorate_class(cls)
|
2208
|
-
|
2209
|
-
def async_evaluate(self, *args, **kwargs):
|
2210
|
-
try:
|
2211
|
-
if not self.enable_monitoring or not self.enable_evaluations:
|
2212
|
-
return
|
2213
|
-
|
2214
|
-
# --- Get trace_id passed explicitly (if any) ---
|
2215
|
-
passed_trace_id = kwargs.pop(
|
2216
|
-
"trace_id", None
|
2217
|
-
) # Get and remove trace_id from kwargs
|
2218
|
-
|
2219
|
-
current_trace = self.get_current_trace()
|
2220
|
-
|
2221
|
-
if current_trace:
|
2222
|
-
# Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
|
2223
|
-
# (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
|
2224
|
-
if passed_trace_id:
|
2225
|
-
kwargs["trace_id"] = (
|
2226
|
-
passed_trace_id # Re-add if needed by TraceClient.async_evaluate
|
2227
|
-
)
|
2228
|
-
current_trace.async_evaluate(*args, **kwargs)
|
2229
|
-
else:
|
2230
|
-
judgeval_logger.warning(
|
2231
|
-
"No trace found (context var or fallback), skipping evaluation"
|
2232
|
-
) # Modified warning
|
2233
|
-
except Exception as e:
|
2234
|
-
judgeval_logger.warning(f"Issue with async_evaluate: {e}")
|
2235
|
-
|
2236
|
-
def update_metadata(self, metadata: dict):
|
2237
|
-
"""
|
2238
|
-
Update metadata for the current trace.
|
2239
|
-
|
2240
|
-
Args:
|
2241
|
-
metadata: Metadata as a dictionary
|
2242
|
-
"""
|
2243
|
-
current_trace = self.get_current_trace()
|
2244
|
-
if current_trace:
|
2245
|
-
current_trace.update_metadata(metadata)
|
2246
|
-
else:
|
2247
|
-
judgeval_logger.warning("No current trace found, cannot set metadata")
|
2248
|
-
|
2249
|
-
def set_customer_id(self, customer_id: str):
|
2250
|
-
"""
|
2251
|
-
Set the customer ID for the current trace.
|
2252
|
-
|
2253
|
-
Args:
|
2254
|
-
customer_id: The customer ID to set
|
2255
|
-
"""
|
2256
|
-
current_trace = self.get_current_trace()
|
2257
|
-
if current_trace:
|
2258
|
-
current_trace.set_customer_id(customer_id)
|
2259
|
-
else:
|
2260
|
-
judgeval_logger.warning("No current trace found, cannot set customer ID")
|
2261
|
-
|
2262
|
-
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
2263
|
-
"""
|
2264
|
-
Set the tags for the current trace.
|
2265
|
-
|
2266
|
-
Args:
|
2267
|
-
tags: List of tags to set
|
2268
|
-
"""
|
2269
|
-
current_trace = self.get_current_trace()
|
2270
|
-
if current_trace:
|
2271
|
-
current_trace.set_tags(tags)
|
2272
|
-
else:
|
2273
|
-
judgeval_logger.warning("No current trace found, cannot set tags")
|
2274
|
-
|
2275
|
-
def get_background_span_service(self) -> Optional[BackgroundSpanService]:
|
2276
|
-
"""Get the background span service instance."""
|
2277
|
-
return self.background_span_service
|
2278
|
-
|
2279
|
-
def flush_background_spans(self):
|
2280
|
-
"""Flush all pending spans in the background service."""
|
2281
|
-
if self.background_span_service:
|
2282
|
-
self.background_span_service.flush()
|
2283
|
-
|
2284
|
-
def shutdown_background_service(self):
|
2285
|
-
"""Shutdown the background span service."""
|
2286
|
-
if self.background_span_service:
|
2287
|
-
self.background_span_service.shutdown()
|
2288
|
-
self.background_span_service = None
|
2289
|
-
|
2290
|
-
|
2291
|
-
def _get_current_trace(
|
2292
|
-
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
2293
|
-
):
|
2294
|
-
if trace_across_async_contexts:
|
2295
|
-
return Tracer.current_trace
|
2296
|
-
else:
|
2297
|
-
return current_trace_var.get()
|
2298
|
-
|
2299
|
-
|
2300
|
-
def wrap(
|
2301
|
-
client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts
|
2302
|
-
) -> Any:
|
2303
|
-
"""
|
2304
|
-
Wraps an API client to add tracing capabilities.
|
2305
|
-
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
2306
|
-
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
2307
|
-
"""
|
2308
|
-
(
|
2309
|
-
span_name,
|
2310
|
-
original_create,
|
2311
|
-
original_responses_create,
|
2312
|
-
original_stream,
|
2313
|
-
original_beta_parse,
|
2314
|
-
) = _get_client_config(client)
|
2315
|
-
|
2316
|
-
def _record_input_and_check_streaming(span, kwargs, is_responses=False):
|
2317
|
-
"""Record input and check for streaming"""
|
2318
|
-
is_streaming = kwargs.get("stream", False)
|
2319
|
-
|
2320
|
-
# Record input based on whether this is a responses endpoint
|
2321
|
-
if is_responses:
|
2322
|
-
span.record_input(kwargs)
|
2323
|
-
else:
|
2324
|
-
input_data = _format_input_data(client, **kwargs)
|
2325
|
-
span.record_input(input_data)
|
2326
|
-
|
2327
|
-
# Warn about token counting limitations with streaming
|
2328
|
-
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
2329
|
-
if not kwargs.get("stream_options", {}).get("include_usage"):
|
2330
|
-
judgeval_logger.warning(
|
2331
|
-
"OpenAI streaming calls don't include token counts by default. "
|
2332
|
-
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
2333
|
-
"in your API call arguments.",
|
2334
|
-
UserWarning,
|
2335
|
-
)
|
2336
|
-
|
2337
|
-
return is_streaming
|
2338
|
-
|
2339
|
-
def _format_and_record_output(span, response, is_streaming, is_async, is_responses):
|
2340
|
-
"""Format and record the output in the span"""
|
2341
|
-
if is_streaming:
|
2342
|
-
output_entry = span.record_output("<pending stream>")
|
2343
|
-
wrapper_func = _async_stream_wrapper if is_async else _sync_stream_wrapper
|
2344
|
-
return wrapper_func(
|
2345
|
-
response, client, output_entry, trace_across_async_contexts
|
2346
|
-
)
|
2347
|
-
else:
|
2348
|
-
format_func = (
|
2349
|
-
_format_response_output_data if is_responses else _format_output_data
|
2350
|
-
)
|
2351
|
-
output, usage = format_func(client, response)
|
2352
|
-
span.record_output(output)
|
2353
|
-
span.record_usage(usage)
|
2354
|
-
|
2355
|
-
# Queue the completed LLM span now that it has all data (input, output, usage)
|
2356
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
2357
|
-
if current_trace and current_trace.background_span_service:
|
2358
|
-
# Get the current span from the trace client
|
2359
|
-
current_span_id = current_trace.get_current_span()
|
2360
|
-
if current_span_id and current_span_id in current_trace.span_id_to_span:
|
2361
|
-
completed_span = current_trace.span_id_to_span[current_span_id]
|
2362
|
-
current_trace.background_span_service.queue_span(
|
2363
|
-
completed_span, span_state="completed"
|
2364
|
-
)
|
2365
|
-
|
2366
|
-
return response
|
2367
|
-
|
2368
|
-
# --- Traced Async Functions ---
|
2369
|
-
async def traced_create_async(*args, **kwargs):
|
2370
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
2371
|
-
if not current_trace:
|
2372
|
-
return await original_create(*args, **kwargs)
|
2373
|
-
|
2374
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
2375
|
-
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2376
|
-
|
2377
|
-
try:
|
2378
|
-
response_or_iterator = await original_create(*args, **kwargs)
|
2379
|
-
return _format_and_record_output(
|
2380
|
-
span, response_or_iterator, is_streaming, True, False
|
2381
|
-
)
|
2382
|
-
except Exception as e:
|
2383
|
-
_capture_exception_for_trace(span, sys.exc_info())
|
2384
|
-
raise e
|
2385
|
-
|
2386
|
-
async def traced_beta_parse_async(*args, **kwargs):
|
2387
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
2388
|
-
if not current_trace:
|
2389
|
-
return await original_beta_parse(*args, **kwargs)
|
2390
|
-
|
2391
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
2392
|
-
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2393
|
-
|
2394
|
-
try:
|
2395
|
-
response_or_iterator = await original_beta_parse(*args, **kwargs)
|
2396
|
-
return _format_and_record_output(
|
2397
|
-
span, response_or_iterator, is_streaming, True, False
|
2398
|
-
)
|
2399
|
-
except Exception as e:
|
2400
|
-
_capture_exception_for_trace(span, sys.exc_info())
|
2401
|
-
raise e
|
2402
|
-
|
2403
|
-
# Async responses for OpenAI clients
|
2404
|
-
async def traced_response_create_async(*args, **kwargs):
|
2405
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
2406
|
-
if not current_trace:
|
2407
|
-
return await original_responses_create(*args, **kwargs)
|
2408
|
-
|
2409
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
2410
|
-
is_streaming = _record_input_and_check_streaming(
|
2411
|
-
span, kwargs, is_responses=True
|
2412
|
-
)
|
2413
|
-
|
2414
|
-
try:
|
2415
|
-
response_or_iterator = await original_responses_create(*args, **kwargs)
|
2416
|
-
return _format_and_record_output(
|
2417
|
-
span, response_or_iterator, is_streaming, True, True
|
2418
|
-
)
|
2419
|
-
except Exception as e:
|
2420
|
-
_capture_exception_for_trace(span, sys.exc_info())
|
2421
|
-
raise e
|
2422
|
-
|
2423
|
-
# Function replacing .stream() for async clients
|
2424
|
-
def traced_stream_async(*args, **kwargs):
|
2425
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
2426
|
-
if not current_trace or not original_stream:
|
2427
|
-
return original_stream(*args, **kwargs)
|
2428
|
-
|
2429
|
-
original_manager = original_stream(*args, **kwargs)
|
2430
|
-
return _TracedAsyncStreamManagerWrapper(
|
2431
|
-
original_manager=original_manager,
|
2432
|
-
client=client,
|
2433
|
-
span_name=span_name,
|
2434
|
-
trace_client=current_trace,
|
2435
|
-
stream_wrapper_func=_async_stream_wrapper,
|
2436
|
-
input_kwargs=kwargs,
|
2437
|
-
trace_across_async_contexts=trace_across_async_contexts,
|
2438
|
-
)
|
2439
|
-
|
2440
|
-
# --- Traced Sync Functions ---
|
2441
|
-
def traced_create_sync(*args, **kwargs):
|
2442
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
2443
|
-
if not current_trace:
|
2444
|
-
return original_create(*args, **kwargs)
|
2445
|
-
|
2446
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
2447
|
-
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2448
|
-
|
2449
|
-
try:
|
2450
|
-
response_or_iterator = original_create(*args, **kwargs)
|
2451
|
-
return _format_and_record_output(
|
2452
|
-
span, response_or_iterator, is_streaming, False, False
|
2453
|
-
)
|
2454
|
-
except Exception as e:
|
2455
|
-
_capture_exception_for_trace(span, sys.exc_info())
|
2456
|
-
raise e
|
2457
|
-
|
2458
|
-
def traced_beta_parse_sync(*args, **kwargs):
|
2459
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
2460
|
-
if not current_trace:
|
2461
|
-
return original_beta_parse(*args, **kwargs)
|
2462
|
-
|
2463
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
2464
|
-
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2465
|
-
|
2466
|
-
try:
|
2467
|
-
response_or_iterator = original_beta_parse(*args, **kwargs)
|
2468
|
-
return _format_and_record_output(
|
2469
|
-
span, response_or_iterator, is_streaming, False, False
|
2470
|
-
)
|
2471
|
-
except Exception as e:
|
2472
|
-
_capture_exception_for_trace(span, sys.exc_info())
|
2473
|
-
raise e
|
2474
|
-
|
2475
|
-
def traced_response_create_sync(*args, **kwargs):
|
2476
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
2477
|
-
if not current_trace:
|
2478
|
-
return original_responses_create(*args, **kwargs)
|
2479
|
-
|
2480
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
2481
|
-
is_streaming = _record_input_and_check_streaming(
|
2482
|
-
span, kwargs, is_responses=True
|
2483
|
-
)
|
2484
|
-
|
2485
|
-
try:
|
2486
|
-
response_or_iterator = original_responses_create(*args, **kwargs)
|
2487
|
-
return _format_and_record_output(
|
2488
|
-
span, response_or_iterator, is_streaming, False, True
|
2489
|
-
)
|
2490
|
-
except Exception as e:
|
2491
|
-
_capture_exception_for_trace(span, sys.exc_info())
|
2492
|
-
raise e
|
2493
|
-
|
2494
|
-
# Function replacing sync .stream()
|
2495
|
-
def traced_stream_sync(*args, **kwargs):
|
2496
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
2497
|
-
if not current_trace or not original_stream:
|
2498
|
-
return original_stream(*args, **kwargs)
|
2499
|
-
|
2500
|
-
original_manager = original_stream(*args, **kwargs)
|
2501
|
-
return _TracedSyncStreamManagerWrapper(
|
2502
|
-
original_manager=original_manager,
|
2503
|
-
client=client,
|
2504
|
-
span_name=span_name,
|
2505
|
-
trace_client=current_trace,
|
2506
|
-
stream_wrapper_func=_sync_stream_wrapper,
|
2507
|
-
input_kwargs=kwargs,
|
2508
|
-
trace_across_async_contexts=trace_across_async_contexts,
|
2509
|
-
)
|
2510
|
-
|
2511
|
-
# --- Assign Traced Methods to Client Instance ---
|
2512
|
-
if isinstance(client, (AsyncOpenAI, AsyncTogether)):
|
2513
|
-
client.chat.completions.create = traced_create_async
|
2514
|
-
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
2515
|
-
client.responses.create = traced_response_create_async
|
2516
|
-
if (
|
2517
|
-
hasattr(client, "beta")
|
2518
|
-
and hasattr(client.beta, "chat")
|
2519
|
-
and hasattr(client.beta.chat, "completions")
|
2520
|
-
and hasattr(client.beta.chat.completions, "parse")
|
2521
|
-
):
|
2522
|
-
client.beta.chat.completions.parse = traced_beta_parse_async
|
2523
|
-
elif isinstance(client, AsyncAnthropic):
|
2524
|
-
client.messages.create = traced_create_async
|
2525
|
-
if original_stream:
|
2526
|
-
client.messages.stream = traced_stream_async
|
2527
|
-
elif isinstance(client, genai.client.AsyncClient):
|
2528
|
-
client.models.generate_content = traced_create_async
|
2529
|
-
elif isinstance(client, (OpenAI, Together)):
|
2530
|
-
client.chat.completions.create = traced_create_sync
|
2531
|
-
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
2532
|
-
client.responses.create = traced_response_create_sync
|
2533
|
-
if (
|
2534
|
-
hasattr(client, "beta")
|
2535
|
-
and hasattr(client.beta, "chat")
|
2536
|
-
and hasattr(client.beta.chat, "completions")
|
2537
|
-
and hasattr(client.beta.chat.completions, "parse")
|
2538
|
-
):
|
2539
|
-
client.beta.chat.completions.parse = traced_beta_parse_sync
|
2540
|
-
elif isinstance(client, Anthropic):
|
2541
|
-
client.messages.create = traced_create_sync
|
2542
|
-
if original_stream:
|
2543
|
-
client.messages.stream = traced_stream_sync
|
2544
|
-
elif isinstance(client, genai.Client):
|
2545
|
-
client.models.generate_content = traced_create_sync
|
2546
|
-
|
2547
|
-
return client
|
2548
|
-
|
2549
|
-
|
2550
|
-
# Helper functions for client-specific operations
|
2551
|
-
|
2552
|
-
|
2553
|
-
def _get_client_config(
|
2554
|
-
client: ApiClient,
|
2555
|
-
) -> tuple[str, Callable, Optional[Callable], Optional[Callable], Optional[Callable]]:
|
2556
|
-
"""Returns configuration tuple for the given API client.
|
2557
|
-
|
2558
|
-
Args:
|
2559
|
-
client: An instance of OpenAI, Together, or Anthropic client
|
2560
|
-
|
2561
|
-
Returns:
|
2562
|
-
tuple: (span_name, create_method, responses_method, stream_method, beta_parse_method)
|
2563
|
-
- span_name: String identifier for tracing
|
2564
|
-
- create_method: Reference to the client's creation method
|
2565
|
-
- responses_method: Reference to the client's responses method (if applicable)
|
2566
|
-
- stream_method: Reference to the client's stream method (if applicable)
|
2567
|
-
- beta_parse_method: Reference to the client's beta parse method (if applicable)
|
2568
|
-
|
2569
|
-
Raises:
|
2570
|
-
ValueError: If client type is not supported
|
2571
|
-
"""
|
2572
|
-
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
2573
|
-
return (
|
2574
|
-
"OPENAI_API_CALL",
|
2575
|
-
client.chat.completions.create,
|
2576
|
-
client.responses.create,
|
2577
|
-
None,
|
2578
|
-
client.beta.chat.completions.parse,
|
2579
|
-
)
|
2580
|
-
elif isinstance(client, (Together, AsyncTogether)):
|
2581
|
-
return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
|
2582
|
-
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
2583
|
-
return (
|
2584
|
-
"ANTHROPIC_API_CALL",
|
2585
|
-
client.messages.create,
|
2586
|
-
None,
|
2587
|
-
client.messages.stream,
|
2588
|
-
None,
|
2589
|
-
)
|
2590
|
-
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2591
|
-
return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
|
2592
|
-
raise ValueError(f"Unsupported client type: {type(client)}")
|
2593
|
-
|
2594
|
-
|
2595
|
-
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
2596
|
-
"""Format input parameters based on client type.
|
2597
|
-
|
2598
|
-
Extracts relevant parameters from kwargs based on the client type
|
2599
|
-
to ensure consistent tracing across different APIs.
|
2600
|
-
"""
|
2601
|
-
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
2602
|
-
input_data = {
|
2603
|
-
"model": kwargs.get("model"),
|
2604
|
-
"messages": kwargs.get("messages"),
|
2605
|
-
}
|
2606
|
-
if kwargs.get("response_format"):
|
2607
|
-
input_data["response_format"] = kwargs.get("response_format")
|
2608
|
-
return input_data
|
2609
|
-
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2610
|
-
return {"model": kwargs.get("model"), "contents": kwargs.get("contents")}
|
2611
|
-
# Anthropic requires additional max_tokens parameter
|
2612
|
-
return {
|
2613
|
-
"model": kwargs.get("model"),
|
2614
|
-
"messages": kwargs.get("messages"),
|
2615
|
-
"max_tokens": kwargs.get("max_tokens"),
|
2616
|
-
}
|
2617
|
-
|
2618
|
-
|
2619
|
-
def _format_response_output_data(client: ApiClient, response: Any) -> tuple:
|
2620
|
-
"""Format API response data based on client type.
|
2621
|
-
|
2622
|
-
Normalizes different response formats into a consistent structure
|
2623
|
-
for tracing purposes.
|
2624
|
-
"""
|
2625
|
-
message_content = None
|
2626
|
-
prompt_tokens = 0
|
2627
|
-
completion_tokens = 0
|
2628
|
-
model_name = None
|
2629
|
-
cache_read_input_tokens = 0
|
2630
|
-
cache_creation_input_tokens = 0
|
2631
|
-
|
2632
|
-
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
2633
|
-
model_name = response.model
|
2634
|
-
prompt_tokens = response.usage.input_tokens
|
2635
|
-
completion_tokens = response.usage.output_tokens
|
2636
|
-
cache_read_input_tokens = response.usage.input_tokens_details.cached_tokens
|
2637
|
-
message_content = "".join(seg.text for seg in response.output[0].content)
|
2638
|
-
else:
|
2639
|
-
judgeval_logger.warning(f"Unsupported client type: {type(client)}")
|
2640
|
-
return None, None
|
2641
|
-
|
2642
|
-
prompt_cost, completion_cost = cost_per_token(
|
2643
|
-
model=model_name,
|
2644
|
-
prompt_tokens=prompt_tokens,
|
2645
|
-
completion_tokens=completion_tokens,
|
2646
|
-
cache_read_input_tokens=cache_read_input_tokens,
|
2647
|
-
cache_creation_input_tokens=cache_creation_input_tokens,
|
2648
|
-
)
|
2649
|
-
total_cost_usd = (
|
2650
|
-
(prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
2651
|
-
)
|
2652
|
-
usage = TraceUsage(
|
2653
|
-
prompt_tokens=prompt_tokens,
|
2654
|
-
completion_tokens=completion_tokens,
|
2655
|
-
total_tokens=prompt_tokens + completion_tokens,
|
2656
|
-
cache_read_input_tokens=cache_read_input_tokens,
|
2657
|
-
cache_creation_input_tokens=cache_creation_input_tokens,
|
2658
|
-
prompt_tokens_cost_usd=prompt_cost,
|
2659
|
-
completion_tokens_cost_usd=completion_cost,
|
2660
|
-
total_cost_usd=total_cost_usd,
|
2661
|
-
model_name=model_name,
|
2662
|
-
)
|
2663
|
-
return message_content, usage
|
2664
|
-
|
2665
|
-
|
2666
|
-
def _format_output_data(
|
2667
|
-
client: ApiClient, response: Any
|
2668
|
-
) -> tuple[Optional[str], Optional[TraceUsage]]:
|
2669
|
-
"""Format API response data based on client type.
|
2670
|
-
|
2671
|
-
Normalizes different response formats into a consistent structure
|
2672
|
-
for tracing purposes.
|
2673
|
-
|
2674
|
-
Returns:
|
2675
|
-
dict containing:
|
2676
|
-
- content: The generated text
|
2677
|
-
- usage: Token usage statistics
|
2678
|
-
"""
|
2679
|
-
prompt_tokens = 0
|
2680
|
-
completion_tokens = 0
|
2681
|
-
cache_read_input_tokens = 0
|
2682
|
-
cache_creation_input_tokens = 0
|
2683
|
-
model_name = None
|
2684
|
-
message_content = None
|
2685
|
-
|
2686
|
-
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
2687
|
-
model_name = response.model
|
2688
|
-
prompt_tokens = response.usage.prompt_tokens
|
2689
|
-
completion_tokens = response.usage.completion_tokens
|
2690
|
-
cache_read_input_tokens = response.usage.prompt_tokens_details.cached_tokens
|
2691
|
-
if (
|
2692
|
-
hasattr(response.choices[0].message, "parsed")
|
2693
|
-
and response.choices[0].message.parsed
|
2694
|
-
):
|
2695
|
-
message_content = response.choices[0].message.parsed
|
2696
|
-
else:
|
2697
|
-
message_content = response.choices[0].message.content
|
2698
|
-
|
2699
|
-
elif isinstance(client, (Together, AsyncTogether)):
|
2700
|
-
model_name = "together_ai/" + response.model
|
2701
|
-
prompt_tokens = response.usage.prompt_tokens
|
2702
|
-
completion_tokens = response.usage.completion_tokens
|
2703
|
-
if (
|
2704
|
-
hasattr(response.choices[0].message, "parsed")
|
2705
|
-
and response.choices[0].message.parsed
|
2706
|
-
):
|
2707
|
-
message_content = response.choices[0].message.parsed
|
2708
|
-
else:
|
2709
|
-
message_content = response.choices[0].message.content
|
2710
|
-
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2711
|
-
model_name = response.model_version
|
2712
|
-
prompt_tokens = response.usage_metadata.prompt_token_count
|
2713
|
-
completion_tokens = response.usage_metadata.candidates_token_count
|
2714
|
-
message_content = response.candidates[0].content.parts[0].text
|
2715
|
-
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
2716
|
-
model_name = response.model
|
2717
|
-
prompt_tokens = response.usage.input_tokens
|
2718
|
-
completion_tokens = response.usage.output_tokens
|
2719
|
-
cache_read_input_tokens = response.usage.cache_read_input_tokens
|
2720
|
-
cache_creation_input_tokens = response.usage.cache_creation_input_tokens
|
2721
|
-
message_content = response.content[0].text
|
2722
|
-
else:
|
2723
|
-
judgeval_logger.warning(f"Unsupported client type: {type(client)}")
|
2724
|
-
return None, None
|
2725
|
-
|
2726
|
-
prompt_cost, completion_cost = cost_per_token(
|
2727
|
-
model=model_name,
|
2728
|
-
prompt_tokens=prompt_tokens,
|
2729
|
-
completion_tokens=completion_tokens,
|
2730
|
-
cache_read_input_tokens=cache_read_input_tokens,
|
2731
|
-
cache_creation_input_tokens=cache_creation_input_tokens,
|
2732
|
-
)
|
2733
|
-
total_cost_usd = (
|
2734
|
-
(prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
2735
|
-
)
|
2736
|
-
usage = TraceUsage(
|
2737
|
-
prompt_tokens=prompt_tokens,
|
2738
|
-
completion_tokens=completion_tokens,
|
2739
|
-
total_tokens=prompt_tokens + completion_tokens,
|
2740
|
-
cache_read_input_tokens=cache_read_input_tokens,
|
2741
|
-
cache_creation_input_tokens=cache_creation_input_tokens,
|
2742
|
-
prompt_tokens_cost_usd=prompt_cost,
|
2743
|
-
completion_tokens_cost_usd=completion_cost,
|
2744
|
-
total_cost_usd=total_cost_usd,
|
2745
|
-
model_name=model_name,
|
2746
|
-
)
|
2747
|
-
return message_content, usage
|
2748
|
-
|
2749
|
-
|
2750
|
-
def combine_args_kwargs(func, args, kwargs):
|
2751
|
-
"""
|
2752
|
-
Combine positional arguments and keyword arguments into a single dictionary.
|
2753
|
-
|
2754
|
-
Args:
|
2755
|
-
func: The function being called
|
2756
|
-
args: Tuple of positional arguments
|
2757
|
-
kwargs: Dictionary of keyword arguments
|
2758
|
-
|
2759
|
-
Returns:
|
2760
|
-
A dictionary combining both args and kwargs
|
2761
|
-
"""
|
2762
|
-
try:
|
2763
|
-
import inspect
|
2764
|
-
|
2765
|
-
sig = inspect.signature(func)
|
2766
|
-
param_names = list(sig.parameters.keys())
|
2767
|
-
|
2768
|
-
args_dict = {}
|
2769
|
-
for i, arg in enumerate(args):
|
2770
|
-
if i < len(param_names):
|
2771
|
-
args_dict[param_names[i]] = arg
|
2772
|
-
else:
|
2773
|
-
args_dict[f"arg{i}"] = arg
|
2774
|
-
|
2775
|
-
return {**args_dict, **kwargs}
|
2776
|
-
except Exception:
|
2777
|
-
# Fallback if signature inspection fails
|
2778
|
-
return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
|
2779
|
-
|
2780
|
-
|
2781
|
-
# NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
|
2782
|
-
# @link https://docs.python.org/3.13/library/sysconfig.html
|
2783
|
-
_TRACE_FILEPATH_BLOCKLIST = tuple(
|
2784
|
-
os.path.realpath(p) + os.sep
|
2785
|
-
for p in {
|
2786
|
-
sysconfig.get_paths()["stdlib"],
|
2787
|
-
sysconfig.get_paths().get("platstdlib", ""),
|
2788
|
-
*site.getsitepackages(),
|
2789
|
-
site.getusersitepackages(),
|
2790
|
-
*(
|
2791
|
-
[os.path.join(os.path.dirname(__file__), "../../judgeval/")]
|
2792
|
-
if os.environ.get("JUDGMENT_DEV")
|
2793
|
-
else []
|
2794
|
-
),
|
2795
|
-
}
|
2796
|
-
if p
|
2797
|
-
)
|
2798
|
-
|
2799
|
-
|
2800
|
-
# Add the new TraceThreadPoolExecutor class
|
2801
|
-
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
2802
|
-
"""
|
2803
|
-
A ThreadPoolExecutor subclass that automatically propagates contextvars
|
2804
|
-
from the submitting thread to the worker thread using copy_context().run().
|
2805
|
-
|
2806
|
-
This ensures that context variables like `current_trace_var` and
|
2807
|
-
`current_span_var` are available within functions executed by the pool,
|
2808
|
-
allowing the Tracer to maintain correct parent-child relationships across
|
2809
|
-
thread boundaries.
|
2810
|
-
"""
|
2811
|
-
|
2812
|
-
def submit(self, fn, /, *args, **kwargs):
|
2813
|
-
"""
|
2814
|
-
Submit a callable to be executed with the captured context.
|
2815
|
-
"""
|
2816
|
-
# Capture context from the submitting thread
|
2817
|
-
ctx = contextvars.copy_context()
|
2818
|
-
|
2819
|
-
# We use functools.partial to bind the arguments to the function *now*,
|
2820
|
-
# as ctx.run doesn't directly accept *args, **kwargs in the same way
|
2821
|
-
# submit does. It expects ctx.run(callable, arg1, arg2...).
|
2822
|
-
func_with_bound_args = functools.partial(fn, *args, **kwargs)
|
2823
|
-
|
2824
|
-
# Submit the ctx.run callable to the original executor.
|
2825
|
-
# ctx.run will execute the (now argument-bound) function within the
|
2826
|
-
# captured context in the worker thread.
|
2827
|
-
return super().submit(ctx.run, func_with_bound_args)
|
2828
|
-
|
2829
|
-
# Note: The `map` method would also need to be overridden for full context
|
2830
|
-
# propagation if users rely on it, but `submit` is the most common use case.
|
2831
|
-
|
2832
|
-
|
2833
|
-
# Helper functions for stream processing
|
2834
|
-
# ---------------------------------------
|
2835
|
-
|
2836
|
-
|
2837
|
-
def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
|
2838
|
-
"""Extracts the text content from a stream chunk based on the client type."""
|
2839
|
-
try:
|
2840
|
-
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
2841
|
-
return chunk.choices[0].delta.content
|
2842
|
-
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
2843
|
-
# Anthropic streams various event types, we only care for content blocks
|
2844
|
-
if chunk.type == "content_block_delta":
|
2845
|
-
return chunk.delta.text
|
2846
|
-
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2847
|
-
# Google streams Candidate objects
|
2848
|
-
if (
|
2849
|
-
chunk.candidates
|
2850
|
-
and chunk.candidates[0].content
|
2851
|
-
and chunk.candidates[0].content.parts
|
2852
|
-
):
|
2853
|
-
return chunk.candidates[0].content.parts[0].text
|
2854
|
-
except (AttributeError, IndexError, KeyError):
|
2855
|
-
# Handle cases where chunk structure is unexpected or doesn't contain content
|
2856
|
-
pass # Return None
|
2857
|
-
return None
|
2858
|
-
|
2859
|
-
|
2860
|
-
def _extract_usage_from_final_chunk(
|
2861
|
-
client: ApiClient, chunk: Any
|
2862
|
-
) -> Optional[Dict[str, int]]:
|
2863
|
-
"""Extracts usage data if present in the *final* chunk (client-specific)."""
|
2864
|
-
try:
|
2865
|
-
# OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
|
2866
|
-
# This typically requires specific API versions or settings. Often usage is *not* streamed.
|
2867
|
-
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
2868
|
-
# Check if usage is directly on the chunk (some models might do this)
|
2869
|
-
if hasattr(chunk, "usage") and chunk.usage:
|
2870
|
-
prompt_tokens = chunk.usage.prompt_tokens
|
2871
|
-
completion_tokens = chunk.usage.completion_tokens
|
2872
|
-
# Check if usage is nested within choices (less common for final chunk, but check)
|
2873
|
-
elif (
|
2874
|
-
chunk.choices
|
2875
|
-
and hasattr(chunk.choices[0], "usage")
|
2876
|
-
and chunk.choices[0].usage
|
2877
|
-
):
|
2878
|
-
prompt_tokens = chunk.choices[0].usage.prompt_tokens
|
2879
|
-
completion_tokens = chunk.choices[0].usage.completion_tokens
|
2880
|
-
|
2881
|
-
if isinstance(client, (Together, AsyncTogether)):
|
2882
|
-
model_name = "together_ai/" + chunk.model
|
2883
|
-
|
2884
|
-
prompt_cost, completion_cost = cost_per_token(
|
2885
|
-
model=model_name,
|
2886
|
-
prompt_tokens=prompt_tokens,
|
2887
|
-
completion_tokens=completion_tokens,
|
2888
|
-
)
|
2889
|
-
total_cost_usd = (
|
2890
|
-
(prompt_cost + completion_cost)
|
2891
|
-
if prompt_cost and completion_cost
|
2892
|
-
else None
|
2893
|
-
)
|
2894
|
-
return TraceUsage(
|
2895
|
-
prompt_tokens=chunk.usage.prompt_tokens,
|
2896
|
-
completion_tokens=chunk.usage.completion_tokens,
|
2897
|
-
total_tokens=chunk.usage.total_tokens,
|
2898
|
-
prompt_tokens_cost_usd=prompt_cost,
|
2899
|
-
completion_tokens_cost_usd=completion_cost,
|
2900
|
-
total_cost_usd=total_cost_usd,
|
2901
|
-
model_name=chunk.model,
|
2902
|
-
)
|
2903
|
-
# Anthropic includes usage in the 'message_stop' event type
|
2904
|
-
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
2905
|
-
if chunk.type == "message_stop":
|
2906
|
-
# Anthropic final usage is often attached to the *message* object, not the chunk directly
|
2907
|
-
# The API might provide a way to get the final message object, but typically not in the stream itself.
|
2908
|
-
# Let's assume for now usage might appear in the final *chunk* metadata if supported.
|
2909
|
-
# This is a placeholder - Anthropic usage typically needs a separate call or context.
|
2910
|
-
pass
|
2911
|
-
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2912
|
-
# Google provides usage metadata on the full response object, not typically streamed per chunk.
|
2913
|
-
# It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
|
2914
|
-
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
|
2915
|
-
return {
|
2916
|
-
"prompt_tokens": chunk.usage_metadata.prompt_token_count,
|
2917
|
-
"completion_tokens": chunk.usage_metadata.candidates_token_count,
|
2918
|
-
"total_tokens": chunk.usage_metadata.total_token_count,
|
2919
|
-
}
|
2920
|
-
|
2921
|
-
except (AttributeError, IndexError, KeyError, TypeError):
|
2922
|
-
# Handle cases where usage data is missing or malformed
|
2923
|
-
pass # Return None
|
2924
|
-
return None
|
2925
|
-
|
2926
|
-
|
2927
|
-
# --- Sync Stream Wrapper ---
|
2928
|
-
def _sync_stream_wrapper(
|
2929
|
-
original_stream: Iterator,
|
2930
|
-
client: ApiClient,
|
2931
|
-
span: TraceSpan,
|
2932
|
-
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
2933
|
-
) -> Generator[Any, None, None]:
|
2934
|
-
"""Wraps a synchronous stream iterator to capture content and update the trace."""
|
2935
|
-
content_parts = [] # Use a list instead of string concatenation
|
2936
|
-
final_usage = None
|
2937
|
-
last_chunk = None
|
2938
|
-
try:
|
2939
|
-
for chunk in original_stream:
|
2940
|
-
content_part = _extract_content_from_chunk(client, chunk)
|
2941
|
-
if content_part:
|
2942
|
-
content_parts.append(
|
2943
|
-
content_part
|
2944
|
-
) # Append to list instead of concatenating
|
2945
|
-
last_chunk = chunk # Keep track of the last chunk for potential usage data
|
2946
|
-
yield chunk # Pass the chunk to the caller
|
2947
|
-
finally:
|
2948
|
-
# Attempt to extract usage from the last chunk received
|
2949
|
-
if last_chunk:
|
2950
|
-
final_usage = _extract_usage_from_final_chunk(client, last_chunk)
|
2951
|
-
|
2952
|
-
# Update the trace entry with the accumulated content and usage
|
2953
|
-
span.output = "".join(content_parts)
|
2954
|
-
span.usage = final_usage
|
2955
|
-
|
2956
|
-
# Queue the completed LLM span now that streaming is done and all data is available
|
2957
|
-
|
2958
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
2959
|
-
if current_trace and current_trace.background_span_service:
|
2960
|
-
current_trace.background_span_service.queue_span(
|
2961
|
-
span, span_state="completed"
|
2962
|
-
)
|
2963
|
-
|
2964
|
-
# Note: We might need to adjust _serialize_output if this dict causes issues,
|
2965
|
-
# but Pydantic's model_dump should handle dicts.
|
2966
|
-
|
2967
|
-
|
2968
|
-
# --- Async Stream Wrapper ---
|
2969
|
-
async def _async_stream_wrapper(
|
2970
|
-
original_stream: AsyncIterator,
|
2971
|
-
client: ApiClient,
|
2972
|
-
span: TraceSpan,
|
2973
|
-
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
2974
|
-
) -> AsyncGenerator[Any, None]:
|
2975
|
-
# [Existing logic - unchanged]
|
2976
|
-
content_parts = [] # Use a list instead of string concatenation
|
2977
|
-
final_usage_data = None
|
2978
|
-
last_content_chunk = None
|
2979
|
-
anthropic_input_tokens = 0
|
2980
|
-
anthropic_output_tokens = 0
|
2981
|
-
|
2982
|
-
try:
|
2983
|
-
model_name = ""
|
2984
|
-
async for chunk in original_stream:
|
2985
|
-
# Check for OpenAI's final usage chunk
|
2986
|
-
if (
|
2987
|
-
isinstance(client, (AsyncOpenAI, OpenAI))
|
2988
|
-
and hasattr(chunk, "usage")
|
2989
|
-
and chunk.usage is not None
|
2990
|
-
):
|
2991
|
-
final_usage_data = {
|
2992
|
-
"prompt_tokens": chunk.usage.prompt_tokens,
|
2993
|
-
"completion_tokens": chunk.usage.completion_tokens,
|
2994
|
-
"total_tokens": chunk.usage.total_tokens,
|
2995
|
-
}
|
2996
|
-
model_name = chunk.model
|
2997
|
-
yield chunk
|
2998
|
-
continue
|
2999
|
-
|
3000
|
-
if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(
|
3001
|
-
chunk, "type"
|
3002
|
-
):
|
3003
|
-
if chunk.type == "message_start":
|
3004
|
-
if (
|
3005
|
-
hasattr(chunk, "message")
|
3006
|
-
and hasattr(chunk.message, "usage")
|
3007
|
-
and hasattr(chunk.message.usage, "input_tokens")
|
3008
|
-
):
|
3009
|
-
anthropic_input_tokens = chunk.message.usage.input_tokens
|
3010
|
-
model_name = chunk.message.model
|
3011
|
-
elif chunk.type == "message_delta":
|
3012
|
-
if hasattr(chunk, "usage") and hasattr(
|
3013
|
-
chunk.usage, "output_tokens"
|
3014
|
-
):
|
3015
|
-
anthropic_output_tokens = chunk.usage.output_tokens
|
3016
|
-
|
3017
|
-
content_part = _extract_content_from_chunk(client, chunk)
|
3018
|
-
if content_part:
|
3019
|
-
content_parts.append(
|
3020
|
-
content_part
|
3021
|
-
) # Append to list instead of concatenating
|
3022
|
-
last_content_chunk = chunk
|
3023
|
-
|
3024
|
-
yield chunk
|
3025
|
-
finally:
|
3026
|
-
anthropic_final_usage = None
|
3027
|
-
if isinstance(client, (AsyncAnthropic, Anthropic)) and (
|
3028
|
-
anthropic_input_tokens > 0 or anthropic_output_tokens > 0
|
3029
|
-
):
|
3030
|
-
anthropic_final_usage = {
|
3031
|
-
"prompt_tokens": anthropic_input_tokens,
|
3032
|
-
"completion_tokens": anthropic_output_tokens,
|
3033
|
-
"total_tokens": anthropic_input_tokens + anthropic_output_tokens,
|
3034
|
-
}
|
3035
|
-
|
3036
|
-
usage_info = None
|
3037
|
-
if final_usage_data:
|
3038
|
-
usage_info = final_usage_data
|
3039
|
-
elif anthropic_final_usage:
|
3040
|
-
usage_info = anthropic_final_usage
|
3041
|
-
elif last_content_chunk:
|
3042
|
-
usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
|
3043
|
-
|
3044
|
-
if usage_info and not isinstance(usage_info, TraceUsage):
|
3045
|
-
prompt_cost, completion_cost = cost_per_token(
|
3046
|
-
model=model_name,
|
3047
|
-
prompt_tokens=usage_info["prompt_tokens"],
|
3048
|
-
completion_tokens=usage_info["completion_tokens"],
|
3049
|
-
)
|
3050
|
-
usage_info = TraceUsage(
|
3051
|
-
prompt_tokens=usage_info["prompt_tokens"],
|
3052
|
-
completion_tokens=usage_info["completion_tokens"],
|
3053
|
-
total_tokens=usage_info["total_tokens"],
|
3054
|
-
prompt_tokens_cost_usd=prompt_cost,
|
3055
|
-
completion_tokens_cost_usd=completion_cost,
|
3056
|
-
total_cost_usd=prompt_cost + completion_cost,
|
3057
|
-
model_name=model_name,
|
3058
|
-
)
|
3059
|
-
if span and hasattr(span, "output"):
|
3060
|
-
span.output = "".join(content_parts)
|
3061
|
-
span.usage = usage_info
|
3062
|
-
start_ts = getattr(span, "created_at", time.time())
|
3063
|
-
span.duration = time.time() - start_ts
|
3064
|
-
|
3065
|
-
# Queue the completed LLM span now that async streaming is done and all data is available
|
3066
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
3067
|
-
if current_trace and current_trace.background_span_service:
|
3068
|
-
current_trace.background_span_service.queue_span(
|
3069
|
-
span, span_state="completed"
|
3070
|
-
)
|
3071
|
-
# else: # Handle error case if necessary, but remove debug print
|
3072
|
-
|
3073
|
-
|
3074
|
-
def cost_per_token(*args, **kwargs):
|
3075
|
-
try:
|
3076
|
-
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = (
|
3077
|
-
_original_cost_per_token(*args, **kwargs)
|
3078
|
-
)
|
3079
|
-
if (
|
3080
|
-
prompt_tokens_cost_usd_dollar == 0
|
3081
|
-
and completion_tokens_cost_usd_dollar == 0
|
3082
|
-
):
|
3083
|
-
judgeval_logger.warning("LiteLLM returned a total of 0 for cost per token")
|
3084
|
-
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
3085
|
-
except Exception as e:
|
3086
|
-
judgeval_logger.warning(f"Error calculating cost per token: {e}")
|
3087
|
-
return None, None
|
3088
|
-
|
3089
|
-
|
3090
|
-
class _BaseStreamManagerWrapper:
|
3091
|
-
def __init__(
|
3092
|
-
self,
|
3093
|
-
original_manager,
|
3094
|
-
client,
|
3095
|
-
span_name,
|
3096
|
-
trace_client,
|
3097
|
-
stream_wrapper_func,
|
3098
|
-
input_kwargs,
|
3099
|
-
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
3100
|
-
):
|
3101
|
-
self._original_manager = original_manager
|
3102
|
-
self._client = client
|
3103
|
-
self._span_name = span_name
|
3104
|
-
self._trace_client = trace_client
|
3105
|
-
self._stream_wrapper_func = stream_wrapper_func
|
3106
|
-
self._input_kwargs = input_kwargs
|
3107
|
-
self._parent_span_id_at_entry = None
|
3108
|
-
self._trace_across_async_contexts = trace_across_async_contexts
|
3109
|
-
|
3110
|
-
def _create_span(self):
|
3111
|
-
start_time = time.time()
|
3112
|
-
span_id = str(uuid.uuid4())
|
3113
|
-
current_depth = 0
|
3114
|
-
if (
|
3115
|
-
self._parent_span_id_at_entry
|
3116
|
-
and self._parent_span_id_at_entry in self._trace_client._span_depths
|
3117
|
-
):
|
3118
|
-
current_depth = (
|
3119
|
-
self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
3120
|
-
)
|
3121
|
-
self._trace_client._span_depths[span_id] = current_depth
|
3122
|
-
span = TraceSpan(
|
3123
|
-
function=self._span_name,
|
3124
|
-
span_id=span_id,
|
3125
|
-
trace_id=self._trace_client.trace_id,
|
3126
|
-
depth=current_depth,
|
3127
|
-
message=self._span_name,
|
3128
|
-
created_at=start_time,
|
3129
|
-
span_type="llm",
|
3130
|
-
parent_span_id=self._parent_span_id_at_entry,
|
3131
|
-
)
|
3132
|
-
self._trace_client.add_span(span)
|
3133
|
-
return span_id, span
|
3134
|
-
|
3135
|
-
def _finalize_span(self, span_id):
|
3136
|
-
span = self._trace_client.span_id_to_span.get(span_id)
|
3137
|
-
if span:
|
3138
|
-
span.duration = time.time() - span.created_at
|
3139
|
-
if span_id in self._trace_client._span_depths:
|
3140
|
-
del self._trace_client._span_depths[span_id]
|
3141
|
-
|
3142
|
-
|
3143
|
-
class _TracedAsyncStreamManagerWrapper(
|
3144
|
-
_BaseStreamManagerWrapper, AbstractAsyncContextManager
|
3145
|
-
):
|
3146
|
-
async def __aenter__(self):
|
3147
|
-
self._parent_span_id_at_entry = self._trace_client.get_current_span()
|
3148
|
-
if not self._trace_client:
|
3149
|
-
return await self._original_manager.__aenter__()
|
3150
|
-
|
3151
|
-
span_id, span = self._create_span()
|
3152
|
-
self._span_context_token = self._trace_client.set_current_span(span_id)
|
3153
|
-
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
3154
|
-
|
3155
|
-
# Call the original __aenter__ and expect it to be an async generator
|
3156
|
-
raw_iterator = await self._original_manager.__aenter__()
|
3157
|
-
span.output = "<pending stream>"
|
3158
|
-
return self._stream_wrapper_func(
|
3159
|
-
raw_iterator, self._client, span, self._trace_across_async_contexts
|
3160
|
-
)
|
3161
|
-
|
3162
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
3163
|
-
if hasattr(self, "_span_context_token"):
|
3164
|
-
span_id = self._trace_client.get_current_span()
|
3165
|
-
self._finalize_span(span_id)
|
3166
|
-
self._trace_client.reset_current_span(self._span_context_token)
|
3167
|
-
delattr(self, "_span_context_token")
|
3168
|
-
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
3169
|
-
|
3170
|
-
|
3171
|
-
class _TracedSyncStreamManagerWrapper(
|
3172
|
-
_BaseStreamManagerWrapper, AbstractContextManager
|
3173
|
-
):
|
3174
|
-
def __enter__(self):
|
3175
|
-
self._parent_span_id_at_entry = self._trace_client.get_current_span()
|
3176
|
-
if not self._trace_client:
|
3177
|
-
return self._original_manager.__enter__()
|
3178
|
-
|
3179
|
-
span_id, span = self._create_span()
|
3180
|
-
self._span_context_token = self._trace_client.set_current_span(span_id)
|
3181
|
-
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
3182
|
-
|
3183
|
-
raw_iterator = self._original_manager.__enter__()
|
3184
|
-
span.output = "<pending stream>"
|
3185
|
-
return self._stream_wrapper_func(
|
3186
|
-
raw_iterator, self._client, span, self._trace_across_async_contexts
|
3187
|
-
)
|
3188
|
-
|
3189
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
3190
|
-
if hasattr(self, "_span_context_token"):
|
3191
|
-
span_id = self._trace_client.get_current_span()
|
3192
|
-
self._finalize_span(span_id)
|
3193
|
-
self._trace_client.reset_current_span(self._span_context_token)
|
3194
|
-
delattr(self, "_span_context_token")
|
3195
|
-
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
3196
|
-
|
3197
|
-
|
3198
|
-
# --- Helper function for instance-prefixed qual_name ---
|
3199
|
-
def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
3200
|
-
"""
|
3201
|
-
Returns the agent name (prefix) if the class and attribute are found in class_identifiers.
|
3202
|
-
Otherwise, returns None.
|
3203
|
-
"""
|
3204
|
-
if class_name in class_identifiers:
|
3205
|
-
class_config = class_identifiers[class_name]
|
3206
|
-
attr = class_config["identifier"]
|
3207
|
-
|
3208
|
-
if hasattr(instance, attr):
|
3209
|
-
instance_name = getattr(instance, attr)
|
3210
|
-
return instance_name
|
3211
|
-
else:
|
3212
|
-
raise Exception(
|
3213
|
-
f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator."
|
3214
|
-
)
|
3215
|
-
return None
|