judgeval 0.0.55__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/api/__init__.py +3 -0
- judgeval/common/api/api.py +352 -0
- judgeval/common/api/constants.py +165 -0
- judgeval/common/storage/__init__.py +6 -0
- judgeval/common/tracer/__init__.py +31 -0
- judgeval/common/tracer/constants.py +22 -0
- judgeval/common/tracer/core.py +1916 -0
- judgeval/common/tracer/otel_exporter.py +108 -0
- judgeval/common/tracer/otel_span_processor.py +234 -0
- judgeval/common/tracer/span_processor.py +37 -0
- judgeval/common/tracer/span_transformer.py +211 -0
- judgeval/common/tracer/trace_manager.py +92 -0
- judgeval/common/utils.py +2 -2
- judgeval/constants.py +3 -30
- judgeval/data/datasets/eval_dataset_client.py +29 -156
- judgeval/data/judgment_types.py +4 -12
- judgeval/data/result.py +1 -1
- judgeval/data/scorer_data.py +2 -2
- judgeval/data/scripts/openapi_transform.py +1 -1
- judgeval/data/trace.py +66 -1
- judgeval/data/trace_run.py +0 -3
- judgeval/evaluation_run.py +0 -2
- judgeval/integrations/langgraph.py +43 -164
- judgeval/judgment_client.py +17 -211
- judgeval/run_evaluation.py +209 -611
- judgeval/scorers/__init__.py +2 -6
- judgeval/scorers/base_scorer.py +4 -23
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +3 -3
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +215 -0
- judgeval/scorers/score.py +2 -1
- judgeval/scorers/utils.py +1 -13
- judgeval/utils/requests.py +21 -0
- judgeval-0.1.0.dist-info/METADATA +202 -0
- {judgeval-0.0.55.dist-info → judgeval-0.1.0.dist-info}/RECORD +37 -29
- judgeval/common/tracer.py +0 -3215
- judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +0 -73
- judgeval/scorers/judgeval_scorers/classifiers/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +0 -53
- judgeval-0.0.55.dist-info/METADATA +0 -1384
- /judgeval/common/{s3_storage.py → storage/s3_storage.py} +0 -0
- {judgeval-0.0.55.dist-info → judgeval-0.1.0.dist-info}/WHEEL +0 -0
- {judgeval-0.0.55.dist-info → judgeval-0.1.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,1916 @@
|
|
1
|
+
"""
|
2
|
+
Tracing system for judgeval that allows for function tracing using decorators.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import atexit
|
9
|
+
import functools
|
10
|
+
import inspect
|
11
|
+
import os
|
12
|
+
import threading
|
13
|
+
import time
|
14
|
+
import traceback
|
15
|
+
import uuid
|
16
|
+
import contextvars
|
17
|
+
import sys
|
18
|
+
from contextlib import (
|
19
|
+
contextmanager,
|
20
|
+
)
|
21
|
+
from datetime import datetime, timezone
|
22
|
+
from typing import (
|
23
|
+
Any,
|
24
|
+
Callable,
|
25
|
+
Dict,
|
26
|
+
Generator,
|
27
|
+
List,
|
28
|
+
Optional,
|
29
|
+
Tuple,
|
30
|
+
Union,
|
31
|
+
TypeAlias,
|
32
|
+
)
|
33
|
+
import types
|
34
|
+
|
35
|
+
from judgeval.common.tracer.constants import _TRACE_FILEPATH_BLOCKLIST
|
36
|
+
|
37
|
+
from judgeval.common.tracer.otel_span_processor import JudgmentSpanProcessor
|
38
|
+
from judgeval.common.tracer.span_processor import SpanProcessorBase
|
39
|
+
from judgeval.common.tracer.trace_manager import TraceManagerClient
|
40
|
+
from litellm import cost_per_token as _original_cost_per_token
|
41
|
+
from openai import OpenAI, AsyncOpenAI
|
42
|
+
from openai.types.chat.chat_completion import ChatCompletion
|
43
|
+
from openai.types.responses.response import Response
|
44
|
+
from openai.types.chat import ParsedChatCompletion
|
45
|
+
from together import Together, AsyncTogether
|
46
|
+
from anthropic import Anthropic, AsyncAnthropic
|
47
|
+
from google import genai
|
48
|
+
|
49
|
+
from judgeval.data import Example, Trace, TraceSpan, TraceUsage
|
50
|
+
from judgeval.scorers import APIScorerConfig, BaseScorer
|
51
|
+
from judgeval.evaluation_run import EvaluationRun
|
52
|
+
from judgeval.common.utils import ExcInfo, validate_api_key
|
53
|
+
from judgeval.common.logger import judgeval_logger
|
54
|
+
|
55
|
+
|
56
|
+
current_trace_var = contextvars.ContextVar[Optional["TraceClient"]](
|
57
|
+
"current_trace", default=None
|
58
|
+
)
|
59
|
+
current_span_var = contextvars.ContextVar[Optional[str]]("current_span", default=None)
|
60
|
+
|
61
|
+
ApiClient: TypeAlias = Union[
|
62
|
+
OpenAI,
|
63
|
+
Together,
|
64
|
+
Anthropic,
|
65
|
+
AsyncOpenAI,
|
66
|
+
AsyncAnthropic,
|
67
|
+
AsyncTogether,
|
68
|
+
genai.Client,
|
69
|
+
genai.client.AsyncClient,
|
70
|
+
]
|
71
|
+
SpanType: TypeAlias = str
|
72
|
+
|
73
|
+
|
74
|
+
class TraceClient:
|
75
|
+
"""Client for managing a single trace context"""
|
76
|
+
|
77
|
+
def __init__(
|
78
|
+
self,
|
79
|
+
tracer: Tracer,
|
80
|
+
trace_id: Optional[str] = None,
|
81
|
+
name: str = "default",
|
82
|
+
project_name: str | None = None,
|
83
|
+
enable_monitoring: bool = True,
|
84
|
+
enable_evaluations: bool = True,
|
85
|
+
parent_trace_id: Optional[str] = None,
|
86
|
+
parent_name: Optional[str] = None,
|
87
|
+
):
|
88
|
+
self.name = name
|
89
|
+
self.trace_id = trace_id or str(uuid.uuid4())
|
90
|
+
self.project_name = project_name or "default_project"
|
91
|
+
self.tracer = tracer
|
92
|
+
self.enable_monitoring = enable_monitoring
|
93
|
+
self.enable_evaluations = enable_evaluations
|
94
|
+
self.parent_trace_id = parent_trace_id
|
95
|
+
self.parent_name = parent_name
|
96
|
+
self.customer_id: Optional[str] = None
|
97
|
+
self.tags: List[Union[str, set, tuple]] = []
|
98
|
+
self.metadata: Dict[str, Any] = {}
|
99
|
+
self.has_notification: Optional[bool] = False
|
100
|
+
self.update_id: int = 1
|
101
|
+
self.trace_spans: List[TraceSpan] = []
|
102
|
+
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
103
|
+
self.evaluation_runs: List[EvaluationRun] = []
|
104
|
+
self.start_time: Optional[float] = None
|
105
|
+
self.trace_manager_client = TraceManagerClient(
|
106
|
+
tracer.api_key, tracer.organization_id, tracer
|
107
|
+
)
|
108
|
+
self._span_depths: Dict[str, int] = {}
|
109
|
+
|
110
|
+
self.otel_span_processor = tracer.otel_span_processor
|
111
|
+
|
112
|
+
judgeval_logger.info(
|
113
|
+
f"🎯 TraceClient using span processor for trace {self.trace_id}"
|
114
|
+
)
|
115
|
+
|
116
|
+
def get_current_span(self):
|
117
|
+
"""Get the current span from the context var"""
|
118
|
+
return self.tracer.get_current_span()
|
119
|
+
|
120
|
+
def set_current_span(self, span: Any):
|
121
|
+
"""Set the current span from the context var"""
|
122
|
+
return self.tracer.set_current_span(span)
|
123
|
+
|
124
|
+
def reset_current_span(self, token: Any):
|
125
|
+
"""Reset the current span from the context var"""
|
126
|
+
self.tracer.reset_current_span(token)
|
127
|
+
|
128
|
+
@contextmanager
|
129
|
+
def span(self, name: str, span_type: SpanType = "span"):
|
130
|
+
"""Context manager for creating a trace span, managing the current span via contextvars"""
|
131
|
+
is_first_span = len(self.trace_spans) == 0
|
132
|
+
if is_first_span:
|
133
|
+
try:
|
134
|
+
self.save(final_save=False)
|
135
|
+
except Exception as e:
|
136
|
+
judgeval_logger.warning(
|
137
|
+
f"Failed to save initial trace for live tracking: {e}"
|
138
|
+
)
|
139
|
+
start_time = time.time()
|
140
|
+
|
141
|
+
span_id = str(uuid.uuid4())
|
142
|
+
|
143
|
+
parent_span_id = self.get_current_span()
|
144
|
+
token = self.set_current_span(span_id)
|
145
|
+
|
146
|
+
current_depth = 0
|
147
|
+
if parent_span_id and parent_span_id in self._span_depths:
|
148
|
+
current_depth = self._span_depths[parent_span_id] + 1
|
149
|
+
|
150
|
+
self._span_depths[span_id] = current_depth
|
151
|
+
|
152
|
+
span = TraceSpan(
|
153
|
+
span_id=span_id,
|
154
|
+
trace_id=self.trace_id,
|
155
|
+
depth=current_depth,
|
156
|
+
message=name,
|
157
|
+
created_at=start_time,
|
158
|
+
span_type=span_type,
|
159
|
+
parent_span_id=parent_span_id,
|
160
|
+
function=name,
|
161
|
+
)
|
162
|
+
self.add_span(span)
|
163
|
+
|
164
|
+
self.otel_span_processor.queue_span_update(span, span_state="input")
|
165
|
+
|
166
|
+
try:
|
167
|
+
yield self
|
168
|
+
finally:
|
169
|
+
duration = time.time() - start_time
|
170
|
+
span.duration = duration
|
171
|
+
|
172
|
+
self.otel_span_processor.queue_span_update(span, span_state="completed")
|
173
|
+
|
174
|
+
if span_id in self._span_depths:
|
175
|
+
del self._span_depths[span_id]
|
176
|
+
self.reset_current_span(token)
|
177
|
+
|
178
|
+
def async_evaluate(
|
179
|
+
self,
|
180
|
+
scorers: List[Union[APIScorerConfig, BaseScorer]],
|
181
|
+
example: Optional[Example] = None,
|
182
|
+
input: Optional[str] = None,
|
183
|
+
actual_output: Optional[Union[str, List[str]]] = None,
|
184
|
+
expected_output: Optional[Union[str, List[str]]] = None,
|
185
|
+
context: Optional[List[str]] = None,
|
186
|
+
retrieval_context: Optional[List[str]] = None,
|
187
|
+
tools_called: Optional[List[str]] = None,
|
188
|
+
expected_tools: Optional[List[str]] = None,
|
189
|
+
additional_metadata: Optional[Dict[str, Any]] = None,
|
190
|
+
model: Optional[str] = None,
|
191
|
+
span_id: Optional[str] = None,
|
192
|
+
):
|
193
|
+
if not self.enable_evaluations:
|
194
|
+
return
|
195
|
+
|
196
|
+
start_time = time.time()
|
197
|
+
|
198
|
+
try:
|
199
|
+
if not scorers:
|
200
|
+
judgeval_logger.warning("No valid scorers available for evaluation")
|
201
|
+
return
|
202
|
+
|
203
|
+
except Exception as e:
|
204
|
+
judgeval_logger.warning(f"Failed to load scorers: {str(e)}")
|
205
|
+
return
|
206
|
+
|
207
|
+
if example is None:
|
208
|
+
if any(
|
209
|
+
param is not None
|
210
|
+
for param in [
|
211
|
+
input,
|
212
|
+
actual_output,
|
213
|
+
expected_output,
|
214
|
+
context,
|
215
|
+
retrieval_context,
|
216
|
+
tools_called,
|
217
|
+
expected_tools,
|
218
|
+
additional_metadata,
|
219
|
+
]
|
220
|
+
):
|
221
|
+
example = Example(
|
222
|
+
input=input,
|
223
|
+
actual_output=actual_output,
|
224
|
+
expected_output=expected_output,
|
225
|
+
context=context,
|
226
|
+
retrieval_context=retrieval_context,
|
227
|
+
tools_called=tools_called,
|
228
|
+
expected_tools=expected_tools,
|
229
|
+
additional_metadata=additional_metadata,
|
230
|
+
)
|
231
|
+
else:
|
232
|
+
raise ValueError(
|
233
|
+
"Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided"
|
234
|
+
)
|
235
|
+
|
236
|
+
span_id_to_use = span_id if span_id is not None else self.get_current_span()
|
237
|
+
|
238
|
+
eval_run = EvaluationRun(
|
239
|
+
organization_id=self.tracer.organization_id,
|
240
|
+
project_name=self.project_name,
|
241
|
+
eval_name=f"{self.name.capitalize()}-"
|
242
|
+
f"{span_id_to_use}-"
|
243
|
+
f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
|
244
|
+
examples=[example],
|
245
|
+
scorers=scorers,
|
246
|
+
model=model,
|
247
|
+
judgment_api_key=self.tracer.api_key,
|
248
|
+
trace_span_id=span_id_to_use,
|
249
|
+
)
|
250
|
+
|
251
|
+
self.add_eval_run(eval_run, start_time)
|
252
|
+
|
253
|
+
if span_id_to_use:
|
254
|
+
current_span = self.span_id_to_span.get(span_id_to_use)
|
255
|
+
if current_span:
|
256
|
+
self.otel_span_processor.queue_evaluation_run(
|
257
|
+
eval_run, span_id=span_id_to_use, span_data=current_span
|
258
|
+
)
|
259
|
+
|
260
|
+
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
261
|
+
current_span_id = eval_run.trace_span_id
|
262
|
+
|
263
|
+
if current_span_id:
|
264
|
+
span = self.span_id_to_span[current_span_id]
|
265
|
+
span.has_evaluation = True
|
266
|
+
self.evaluation_runs.append(eval_run)
|
267
|
+
|
268
|
+
def record_input(self, inputs: dict):
|
269
|
+
current_span_id = self.get_current_span()
|
270
|
+
if current_span_id:
|
271
|
+
span = self.span_id_to_span[current_span_id]
|
272
|
+
if "self" in inputs:
|
273
|
+
del inputs["self"]
|
274
|
+
span.inputs = inputs
|
275
|
+
|
276
|
+
try:
|
277
|
+
self.otel_span_processor.queue_span_update(span, span_state="input")
|
278
|
+
except Exception as e:
|
279
|
+
judgeval_logger.warning(f"Failed to queue span with input data: {e}")
|
280
|
+
|
281
|
+
def record_agent_name(self, agent_name: str):
|
282
|
+
current_span_id = self.get_current_span()
|
283
|
+
if current_span_id:
|
284
|
+
span = self.span_id_to_span[current_span_id]
|
285
|
+
span.agent_name = agent_name
|
286
|
+
|
287
|
+
self.otel_span_processor.queue_span_update(span, span_state="agent_name")
|
288
|
+
|
289
|
+
def record_state_before(self, state: dict):
|
290
|
+
"""Records the agent's state before a tool execution on the current span.
|
291
|
+
|
292
|
+
Args:
|
293
|
+
state: A dictionary representing the agent's state.
|
294
|
+
"""
|
295
|
+
current_span_id = self.get_current_span()
|
296
|
+
if current_span_id:
|
297
|
+
span = self.span_id_to_span[current_span_id]
|
298
|
+
span.state_before = state
|
299
|
+
|
300
|
+
self.otel_span_processor.queue_span_update(span, span_state="state_before")
|
301
|
+
|
302
|
+
def record_state_after(self, state: dict):
|
303
|
+
"""Records the agent's state after a tool execution on the current span.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
state: A dictionary representing the agent's state.
|
307
|
+
"""
|
308
|
+
current_span_id = self.get_current_span()
|
309
|
+
if current_span_id:
|
310
|
+
span = self.span_id_to_span[current_span_id]
|
311
|
+
span.state_after = state
|
312
|
+
|
313
|
+
self.otel_span_processor.queue_span_update(span, span_state="state_after")
|
314
|
+
|
315
|
+
async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
|
316
|
+
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
317
|
+
try:
|
318
|
+
result = await coroutine
|
319
|
+
setattr(span, field, result)
|
320
|
+
|
321
|
+
if field == "output":
|
322
|
+
self.otel_span_processor.queue_span_update(span, span_state="output")
|
323
|
+
|
324
|
+
return result
|
325
|
+
except Exception as e:
|
326
|
+
setattr(span, field, f"Error: {str(e)}")
|
327
|
+
|
328
|
+
if field == "output":
|
329
|
+
self.otel_span_processor.queue_span_update(span, span_state="output")
|
330
|
+
|
331
|
+
raise
|
332
|
+
|
333
|
+
def record_output(self, output: Any):
|
334
|
+
current_span_id = self.get_current_span()
|
335
|
+
if current_span_id:
|
336
|
+
span = self.span_id_to_span[current_span_id]
|
337
|
+
span.output = "<pending>" if inspect.iscoroutine(output) else output
|
338
|
+
|
339
|
+
if inspect.iscoroutine(output):
|
340
|
+
asyncio.create_task(self._update_coroutine(span, output, "output"))
|
341
|
+
|
342
|
+
if not inspect.iscoroutine(output):
|
343
|
+
self.otel_span_processor.queue_span_update(span, span_state="output")
|
344
|
+
|
345
|
+
return span
|
346
|
+
return None
|
347
|
+
|
348
|
+
def record_usage(self, usage: TraceUsage):
|
349
|
+
current_span_id = self.get_current_span()
|
350
|
+
if current_span_id:
|
351
|
+
span = self.span_id_to_span[current_span_id]
|
352
|
+
span.usage = usage
|
353
|
+
|
354
|
+
self.otel_span_processor.queue_span_update(span, span_state="usage")
|
355
|
+
|
356
|
+
return span
|
357
|
+
return None
|
358
|
+
|
359
|
+
def record_error(self, error: Dict[str, Any]):
|
360
|
+
current_span_id = self.get_current_span()
|
361
|
+
if current_span_id:
|
362
|
+
span = self.span_id_to_span[current_span_id]
|
363
|
+
span.error = error
|
364
|
+
|
365
|
+
self.otel_span_processor.queue_span_update(span, span_state="error")
|
366
|
+
|
367
|
+
return span
|
368
|
+
return None
|
369
|
+
|
370
|
+
def add_span(self, span: TraceSpan):
|
371
|
+
"""Add a trace span to this trace context"""
|
372
|
+
self.trace_spans.append(span)
|
373
|
+
self.span_id_to_span[span.span_id] = span
|
374
|
+
return self
|
375
|
+
|
376
|
+
def print(self):
|
377
|
+
"""Print the complete trace with proper visual structure"""
|
378
|
+
for span in self.trace_spans:
|
379
|
+
span.print_span()
|
380
|
+
|
381
|
+
def get_duration(self) -> float:
|
382
|
+
"""
|
383
|
+
Get the total duration of this trace
|
384
|
+
"""
|
385
|
+
if self.start_time is None:
|
386
|
+
return 0.0
|
387
|
+
return time.time() - self.start_time
|
388
|
+
|
389
|
+
def save(self, final_save: bool = False) -> Tuple[str, dict]:
|
390
|
+
"""
|
391
|
+
Save the current trace to the database with rate limiting checks.
|
392
|
+
First checks usage limits, then upserts the trace if allowed.
|
393
|
+
|
394
|
+
Args:
|
395
|
+
final_save: Whether this is the final save (updates usage counters)
|
396
|
+
|
397
|
+
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
398
|
+
"""
|
399
|
+
if final_save:
|
400
|
+
try:
|
401
|
+
self.otel_span_processor.flush_pending_spans()
|
402
|
+
except Exception as e:
|
403
|
+
judgeval_logger.warning(
|
404
|
+
f"Error flushing spans for trace {self.trace_id}: {e}"
|
405
|
+
)
|
406
|
+
|
407
|
+
total_duration = self.get_duration()
|
408
|
+
|
409
|
+
trace_data = {
|
410
|
+
"trace_id": self.trace_id,
|
411
|
+
"name": self.name,
|
412
|
+
"project_name": self.project_name,
|
413
|
+
"created_at": datetime.fromtimestamp(
|
414
|
+
self.start_time or time.time(), timezone.utc
|
415
|
+
).isoformat(),
|
416
|
+
"duration": total_duration,
|
417
|
+
"trace_spans": [span.model_dump() for span in self.trace_spans],
|
418
|
+
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
419
|
+
"offline_mode": self.tracer.offline_mode,
|
420
|
+
"parent_trace_id": self.parent_trace_id,
|
421
|
+
"parent_name": self.parent_name,
|
422
|
+
"customer_id": self.customer_id,
|
423
|
+
"tags": self.tags,
|
424
|
+
"metadata": self.metadata,
|
425
|
+
"update_id": self.update_id,
|
426
|
+
}
|
427
|
+
|
428
|
+
server_response = self.trace_manager_client.upsert_trace(
|
429
|
+
trace_data,
|
430
|
+
offline_mode=self.tracer.offline_mode,
|
431
|
+
show_link=not final_save,
|
432
|
+
final_save=final_save,
|
433
|
+
)
|
434
|
+
|
435
|
+
if self.start_time is None:
|
436
|
+
self.start_time = time.time()
|
437
|
+
|
438
|
+
self.update_id += 1
|
439
|
+
|
440
|
+
return self.trace_id, server_response
|
441
|
+
|
442
|
+
def delete(self):
|
443
|
+
return self.trace_manager_client.delete_trace(self.trace_id)
|
444
|
+
|
445
|
+
def update_metadata(self, metadata: dict):
|
446
|
+
"""
|
447
|
+
Set metadata for this trace.
|
448
|
+
|
449
|
+
Args:
|
450
|
+
metadata: Metadata as a dictionary
|
451
|
+
|
452
|
+
Supported keys:
|
453
|
+
- customer_id: ID of the customer using this trace
|
454
|
+
- tags: List of tags for this trace
|
455
|
+
- has_notification: Whether this trace has a notification
|
456
|
+
- name: Name of the trace
|
457
|
+
"""
|
458
|
+
for k, v in metadata.items():
|
459
|
+
if k == "customer_id":
|
460
|
+
if v is not None:
|
461
|
+
self.customer_id = str(v)
|
462
|
+
else:
|
463
|
+
self.customer_id = None
|
464
|
+
elif k == "tags":
|
465
|
+
if isinstance(v, list):
|
466
|
+
for item in v:
|
467
|
+
if not isinstance(item, (str, set, tuple)):
|
468
|
+
raise ValueError(
|
469
|
+
f"Tags must be a list of strings, sets, or tuples, got item of type {type(item)}"
|
470
|
+
)
|
471
|
+
self.tags = v
|
472
|
+
else:
|
473
|
+
raise ValueError(
|
474
|
+
f"Tags must be a list of strings, sets, or tuples, got {type(v)}"
|
475
|
+
)
|
476
|
+
elif k == "has_notification":
|
477
|
+
if not isinstance(v, bool):
|
478
|
+
raise ValueError(
|
479
|
+
f"has_notification must be a boolean, got {type(v)}"
|
480
|
+
)
|
481
|
+
self.has_notification = v
|
482
|
+
elif k == "name":
|
483
|
+
self.name = v
|
484
|
+
else:
|
485
|
+
self.metadata[k] = v
|
486
|
+
|
487
|
+
def set_customer_id(self, customer_id: str):
|
488
|
+
"""
|
489
|
+
Set the customer ID for this trace.
|
490
|
+
|
491
|
+
Args:
|
492
|
+
customer_id: The customer ID to set
|
493
|
+
"""
|
494
|
+
self.update_metadata({"customer_id": customer_id})
|
495
|
+
|
496
|
+
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
497
|
+
"""
|
498
|
+
Set the tags for this trace.
|
499
|
+
|
500
|
+
Args:
|
501
|
+
tags: List of tags to set
|
502
|
+
"""
|
503
|
+
self.update_metadata({"tags": tags})
|
504
|
+
|
505
|
+
def set_reward_score(self, reward_score: Union[float, Dict[str, float]]):
|
506
|
+
"""
|
507
|
+
Set the reward score for this trace to be used for RL or SFT.
|
508
|
+
|
509
|
+
Args:
|
510
|
+
reward_score: The reward score to set
|
511
|
+
"""
|
512
|
+
self.update_metadata({"reward_score": reward_score})
|
513
|
+
|
514
|
+
|
515
|
+
def _capture_exception_for_trace(
|
516
|
+
current_trace: Optional[TraceClient], exc_info: ExcInfo
|
517
|
+
):
|
518
|
+
if not current_trace:
|
519
|
+
return
|
520
|
+
|
521
|
+
exc_type, exc_value, exc_traceback_obj = exc_info
|
522
|
+
formatted_exception = {
|
523
|
+
"type": exc_type.__name__ if exc_type else "UnknownExceptionType",
|
524
|
+
"message": str(exc_value) if exc_value else "No exception message",
|
525
|
+
"traceback": (
|
526
|
+
traceback.format_tb(exc_traceback_obj) if exc_traceback_obj else []
|
527
|
+
),
|
528
|
+
}
|
529
|
+
|
530
|
+
# This is where we specially handle exceptions that we might want to collect additional data for.
|
531
|
+
# When we do this, always try checking the module from sys.modules instead of importing. This will
|
532
|
+
# Let us support a wider range of exceptions without needing to import them for all clients.
|
533
|
+
|
534
|
+
# Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
|
535
|
+
# The alternative is to hand select libraries we want from sys.modules and check for them:
|
536
|
+
# As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
|
537
|
+
|
538
|
+
# General HTTP Like errors
|
539
|
+
try:
|
540
|
+
url = getattr(getattr(exc_value, "request", None), "url", None)
|
541
|
+
status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
|
542
|
+
if status_code:
|
543
|
+
formatted_exception["http"] = {
|
544
|
+
"url": url if url else "Unknown URL",
|
545
|
+
"status_code": status_code if status_code else None,
|
546
|
+
}
|
547
|
+
except Exception:
|
548
|
+
pass
|
549
|
+
|
550
|
+
current_trace.record_error(formatted_exception)
|
551
|
+
|
552
|
+
|
553
|
+
class _DeepTracer:
|
554
|
+
_instance: Optional["_DeepTracer"] = None
|
555
|
+
_lock: threading.Lock = threading.Lock()
|
556
|
+
_refcount: int = 0
|
557
|
+
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
|
558
|
+
"_deep_profiler_span_stack", default=[]
|
559
|
+
)
|
560
|
+
_skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar(
|
561
|
+
"_deep_profiler_skip_stack", default=[]
|
562
|
+
)
|
563
|
+
_original_sys_trace: Optional[Callable] = None
|
564
|
+
_original_threading_trace: Optional[Callable] = None
|
565
|
+
|
566
|
+
def __init__(self, tracer: "Tracer"):
|
567
|
+
self._tracer = tracer
|
568
|
+
|
569
|
+
def _get_qual_name(self, frame) -> str:
|
570
|
+
func_name = frame.f_code.co_name
|
571
|
+
module_name = frame.f_globals.get("__name__", "unknown_module")
|
572
|
+
|
573
|
+
try:
|
574
|
+
func = frame.f_globals.get(func_name)
|
575
|
+
if func is None:
|
576
|
+
return f"{module_name}.{func_name}"
|
577
|
+
if hasattr(func, "__qualname__"):
|
578
|
+
return f"{module_name}.{func.__qualname__}"
|
579
|
+
return f"{module_name}.{func_name}"
|
580
|
+
except Exception:
|
581
|
+
return f"{module_name}.{func_name}"
|
582
|
+
|
583
|
+
def __new__(cls, tracer: "Tracer"):
|
584
|
+
with cls._lock:
|
585
|
+
if cls._instance is None:
|
586
|
+
cls._instance = super().__new__(cls)
|
587
|
+
return cls._instance
|
588
|
+
|
589
|
+
def _should_trace(self, frame):
|
590
|
+
# Skip stack is maintained by the tracer as an optimization to skip earlier
|
591
|
+
# frames in the call stack that we've already determined should be skipped
|
592
|
+
skip_stack = self._skip_stack.get()
|
593
|
+
if len(skip_stack) > 0:
|
594
|
+
return False
|
595
|
+
|
596
|
+
func_name = frame.f_code.co_name
|
597
|
+
module_name = frame.f_globals.get("__name__", None)
|
598
|
+
func = frame.f_globals.get(func_name)
|
599
|
+
if func and (
|
600
|
+
hasattr(func, "_judgment_span_name") or hasattr(func, "_judgment_span_type")
|
601
|
+
):
|
602
|
+
return False
|
603
|
+
|
604
|
+
if (
|
605
|
+
not module_name
|
606
|
+
or func_name.startswith("<") # ex: <listcomp>
|
607
|
+
or func_name.startswith("__")
|
608
|
+
and func_name != "__call__" # dunders
|
609
|
+
or not self._is_user_code(frame.f_code.co_filename)
|
610
|
+
):
|
611
|
+
return False
|
612
|
+
|
613
|
+
return True
|
614
|
+
|
615
|
+
@functools.cache
|
616
|
+
def _is_user_code(self, filename: str):
|
617
|
+
return (
|
618
|
+
bool(filename)
|
619
|
+
and not filename.startswith("<")
|
620
|
+
and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
|
621
|
+
)
|
622
|
+
|
623
|
+
def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
|
624
|
+
"""Cooperative trace function for sys.settrace that chains with existing tracers."""
|
625
|
+
# First, call the original sys trace function if it exists
|
626
|
+
original_result = None
|
627
|
+
if self._original_sys_trace:
|
628
|
+
try:
|
629
|
+
original_result = self._original_sys_trace(frame, event, arg)
|
630
|
+
except Exception:
|
631
|
+
pass
|
632
|
+
|
633
|
+
our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
|
634
|
+
|
635
|
+
if original_result is None and self._original_sys_trace:
|
636
|
+
return None
|
637
|
+
|
638
|
+
return our_result or original_result
|
639
|
+
|
640
|
+
def _cooperative_threading_trace(
|
641
|
+
self, frame: types.FrameType, event: str, arg: Any
|
642
|
+
):
|
643
|
+
"""Cooperative trace function for threading.settrace that chains with existing tracers."""
|
644
|
+
original_result = None
|
645
|
+
if self._original_threading_trace:
|
646
|
+
try:
|
647
|
+
original_result = self._original_threading_trace(frame, event, arg)
|
648
|
+
except Exception:
|
649
|
+
pass
|
650
|
+
|
651
|
+
our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
|
652
|
+
|
653
|
+
if original_result is None and self._original_threading_trace:
|
654
|
+
return None
|
655
|
+
|
656
|
+
return our_result or original_result
|
657
|
+
|
658
|
+
def _trace(
|
659
|
+
self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable
|
660
|
+
):
|
661
|
+
frame.f_trace_lines = False
|
662
|
+
frame.f_trace_opcodes = False
|
663
|
+
|
664
|
+
if not self._should_trace(frame):
|
665
|
+
return
|
666
|
+
|
667
|
+
if event not in ("call", "return", "exception"):
|
668
|
+
return
|
669
|
+
|
670
|
+
current_trace = self._tracer.get_current_trace()
|
671
|
+
if not current_trace:
|
672
|
+
return
|
673
|
+
|
674
|
+
parent_span_id = self._tracer.get_current_span()
|
675
|
+
if not parent_span_id:
|
676
|
+
return
|
677
|
+
|
678
|
+
qual_name = self._get_qual_name(frame)
|
679
|
+
instance_name = None
|
680
|
+
if "self" in frame.f_locals:
|
681
|
+
instance = frame.f_locals["self"]
|
682
|
+
class_name = instance.__class__.__name__
|
683
|
+
class_identifiers = getattr(self._tracer, "class_identifiers", {})
|
684
|
+
instance_name = get_instance_prefixed_name(
|
685
|
+
instance, class_name, class_identifiers
|
686
|
+
)
|
687
|
+
skip_stack = self._skip_stack.get()
|
688
|
+
|
689
|
+
if event == "call":
|
690
|
+
# If we have entries in the skip stack and the current qual_name matches the top entry,
|
691
|
+
# push it again to track nesting depth and skip
|
692
|
+
# As an optimization, we only care about duplicate qual_names.
|
693
|
+
if skip_stack:
|
694
|
+
if qual_name == skip_stack[-1]:
|
695
|
+
skip_stack.append(qual_name)
|
696
|
+
self._skip_stack.set(skip_stack)
|
697
|
+
return
|
698
|
+
|
699
|
+
should_trace = self._should_trace(frame)
|
700
|
+
|
701
|
+
if not should_trace:
|
702
|
+
if not skip_stack:
|
703
|
+
self._skip_stack.set([qual_name])
|
704
|
+
return
|
705
|
+
elif event == "return":
|
706
|
+
# If we have entries in skip stack and current qual_name matches the top entry,
|
707
|
+
# pop it to track exiting from the skipped section
|
708
|
+
if skip_stack and qual_name == skip_stack[-1]:
|
709
|
+
skip_stack.pop()
|
710
|
+
self._skip_stack.set(skip_stack)
|
711
|
+
return
|
712
|
+
|
713
|
+
if skip_stack:
|
714
|
+
return
|
715
|
+
|
716
|
+
span_stack = self._span_stack.get()
|
717
|
+
if event == "call":
|
718
|
+
if not self._should_trace(frame):
|
719
|
+
return
|
720
|
+
|
721
|
+
span_id = str(uuid.uuid4())
|
722
|
+
|
723
|
+
parent_depth = current_trace._span_depths.get(parent_span_id, 0)
|
724
|
+
depth = parent_depth + 1
|
725
|
+
|
726
|
+
current_trace._span_depths[span_id] = depth
|
727
|
+
|
728
|
+
start_time = time.time()
|
729
|
+
|
730
|
+
span_stack.append(
|
731
|
+
{
|
732
|
+
"span_id": span_id,
|
733
|
+
"parent_span_id": parent_span_id,
|
734
|
+
"function": qual_name,
|
735
|
+
"start_time": start_time,
|
736
|
+
}
|
737
|
+
)
|
738
|
+
self._span_stack.set(span_stack)
|
739
|
+
|
740
|
+
token = self._tracer.set_current_span(span_id)
|
741
|
+
frame.f_locals["_judgment_span_token"] = token
|
742
|
+
|
743
|
+
span = TraceSpan(
|
744
|
+
span_id=span_id,
|
745
|
+
trace_id=current_trace.trace_id,
|
746
|
+
depth=depth,
|
747
|
+
message=qual_name,
|
748
|
+
created_at=start_time,
|
749
|
+
span_type="span",
|
750
|
+
parent_span_id=parent_span_id,
|
751
|
+
function=qual_name,
|
752
|
+
agent_name=instance_name,
|
753
|
+
)
|
754
|
+
current_trace.add_span(span)
|
755
|
+
|
756
|
+
inputs = {}
|
757
|
+
try:
|
758
|
+
args_info = inspect.getargvalues(frame)
|
759
|
+
for arg in args_info.args:
|
760
|
+
try:
|
761
|
+
inputs[arg] = args_info.locals.get(arg)
|
762
|
+
except Exception:
|
763
|
+
inputs[arg] = "<<Unserializable>>"
|
764
|
+
current_trace.record_input(inputs)
|
765
|
+
except Exception as e:
|
766
|
+
current_trace.record_input({"error": str(e)})
|
767
|
+
|
768
|
+
elif event == "return":
|
769
|
+
if not span_stack:
|
770
|
+
return
|
771
|
+
|
772
|
+
current_id = self._tracer.get_current_span()
|
773
|
+
|
774
|
+
span_data = None
|
775
|
+
for i, entry in enumerate(reversed(span_stack)):
|
776
|
+
if entry["span_id"] == current_id:
|
777
|
+
span_data = span_stack.pop(-(i + 1))
|
778
|
+
self._span_stack.set(span_stack)
|
779
|
+
break
|
780
|
+
|
781
|
+
if not span_data:
|
782
|
+
return
|
783
|
+
|
784
|
+
start_time = span_data["start_time"]
|
785
|
+
duration = time.time() - start_time
|
786
|
+
|
787
|
+
current_trace.span_id_to_span[span_data["span_id"]].duration = duration
|
788
|
+
|
789
|
+
if arg is not None:
|
790
|
+
# exception handling will take priority.
|
791
|
+
current_trace.record_output(arg)
|
792
|
+
|
793
|
+
if span_data["span_id"] in current_trace._span_depths:
|
794
|
+
del current_trace._span_depths[span_data["span_id"]]
|
795
|
+
|
796
|
+
if span_stack:
|
797
|
+
self._tracer.set_current_span(span_stack[-1]["span_id"])
|
798
|
+
else:
|
799
|
+
self._tracer.set_current_span(span_data["parent_span_id"])
|
800
|
+
|
801
|
+
if "_judgment_span_token" in frame.f_locals:
|
802
|
+
self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
|
803
|
+
|
804
|
+
elif event == "exception":
|
805
|
+
exc_type = arg[0]
|
806
|
+
if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
|
807
|
+
return
|
808
|
+
_capture_exception_for_trace(current_trace, arg)
|
809
|
+
|
810
|
+
return continuation_func
|
811
|
+
|
812
|
+
def __enter__(self):
|
813
|
+
with self._lock:
|
814
|
+
self._refcount += 1
|
815
|
+
if self._refcount == 1:
|
816
|
+
# Store the existing trace functions before setting ours
|
817
|
+
self._original_sys_trace = sys.gettrace()
|
818
|
+
self._original_threading_trace = threading.gettrace()
|
819
|
+
|
820
|
+
self._skip_stack.set([])
|
821
|
+
self._span_stack.set([])
|
822
|
+
|
823
|
+
sys.settrace(self._cooperative_sys_trace)
|
824
|
+
threading.settrace(self._cooperative_threading_trace)
|
825
|
+
return self
|
826
|
+
|
827
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
828
|
+
with self._lock:
|
829
|
+
self._refcount -= 1
|
830
|
+
if self._refcount == 0:
|
831
|
+
# Restore the original trace functions instead of setting to None
|
832
|
+
sys.settrace(self._original_sys_trace)
|
833
|
+
threading.settrace(self._original_threading_trace)
|
834
|
+
|
835
|
+
# Clean up the references
|
836
|
+
self._original_sys_trace = None
|
837
|
+
self._original_threading_trace = None
|
838
|
+
|
839
|
+
|
840
|
+
class Tracer:
|
841
|
+
# Tracer.current_trace class variable is currently used in wrap()
|
842
|
+
# TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
|
843
|
+
# Should be fine to do so as long as we keep Tracer as a singleton
|
844
|
+
current_trace: Optional[TraceClient] = None
|
845
|
+
# current_span_id: Optional[str] = None
|
846
|
+
|
847
|
+
trace_across_async_contexts: bool = (
|
848
|
+
False # BY default, we don't trace across async contexts
|
849
|
+
)
|
850
|
+
|
851
|
+
def __init__(
|
852
|
+
self,
|
853
|
+
api_key: str | None = os.getenv("JUDGMENT_API_KEY"),
|
854
|
+
organization_id: str | None = os.getenv("JUDGMENT_ORG_ID"),
|
855
|
+
project_name: str | None = None,
|
856
|
+
deep_tracing: bool = False, # Deep tracing is disabled by default
|
857
|
+
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
|
858
|
+
== "true",
|
859
|
+
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
|
860
|
+
== "true",
|
861
|
+
# S3 configuration
|
862
|
+
use_s3: bool = False,
|
863
|
+
s3_bucket_name: Optional[str] = None,
|
864
|
+
s3_aws_access_key_id: Optional[str] = None,
|
865
|
+
s3_aws_secret_access_key: Optional[str] = None,
|
866
|
+
s3_region_name: Optional[str] = None,
|
867
|
+
trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
|
868
|
+
span_batch_size: int = 50,
|
869
|
+
span_flush_interval: float = 1.0,
|
870
|
+
span_max_queue_size: int = 2048,
|
871
|
+
span_export_timeout: int = 30000,
|
872
|
+
):
|
873
|
+
try:
|
874
|
+
if not api_key:
|
875
|
+
raise ValueError(
|
876
|
+
"api_key parameter must be provided. Please provide a valid API key value or set the JUDGMENT_API_KEY environment variable"
|
877
|
+
)
|
878
|
+
|
879
|
+
if not organization_id:
|
880
|
+
raise ValueError(
|
881
|
+
"organization_id parameter must be provided. Please provide a valid organization ID value or set the JUDGMENT_ORG_ID environment variable"
|
882
|
+
)
|
883
|
+
|
884
|
+
try:
|
885
|
+
result, response = validate_api_key(api_key)
|
886
|
+
except Exception as e:
|
887
|
+
judgeval_logger.error(
|
888
|
+
f"Issue with verifying API key, disabling monitoring: {e}"
|
889
|
+
)
|
890
|
+
enable_monitoring = False
|
891
|
+
result = True
|
892
|
+
|
893
|
+
if not result:
|
894
|
+
raise ValueError(f"Issue with passed in Judgment API key: {response}")
|
895
|
+
|
896
|
+
if use_s3 and not s3_bucket_name:
|
897
|
+
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
898
|
+
|
899
|
+
self.api_key: str = api_key
|
900
|
+
self.project_name: str = project_name or "default_project"
|
901
|
+
self.organization_id: str = organization_id
|
902
|
+
self.traces: List[Trace] = []
|
903
|
+
self.enable_monitoring: bool = enable_monitoring
|
904
|
+
self.enable_evaluations: bool = enable_evaluations
|
905
|
+
self.class_identifiers: Dict[
|
906
|
+
str, str
|
907
|
+
] = {} # Dictionary to store class identifiers
|
908
|
+
self.span_id_to_previous_span_id: Dict[str, str | None] = {}
|
909
|
+
self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
|
910
|
+
self.current_span_id: Optional[str] = None
|
911
|
+
self.current_trace: Optional[TraceClient] = None
|
912
|
+
self.trace_across_async_contexts: bool = trace_across_async_contexts
|
913
|
+
Tracer.trace_across_async_contexts = trace_across_async_contexts
|
914
|
+
|
915
|
+
# Initialize S3 storage if enabled
|
916
|
+
self.use_s3 = use_s3
|
917
|
+
if use_s3:
|
918
|
+
from judgeval.common.storage.s3_storage import S3Storage
|
919
|
+
|
920
|
+
try:
|
921
|
+
self.s3_storage = S3Storage(
|
922
|
+
bucket_name=s3_bucket_name,
|
923
|
+
aws_access_key_id=s3_aws_access_key_id,
|
924
|
+
aws_secret_access_key=s3_aws_secret_access_key,
|
925
|
+
region_name=s3_region_name,
|
926
|
+
)
|
927
|
+
except Exception as e:
|
928
|
+
judgeval_logger.error(
|
929
|
+
f"Issue with initializing S3 storage, disabling S3: {e}"
|
930
|
+
)
|
931
|
+
self.use_s3 = False
|
932
|
+
|
933
|
+
self.offline_mode = False # This is used to differentiate traces between online and offline (IE experiments vs monitoring page)
|
934
|
+
self.deep_tracing: bool = deep_tracing
|
935
|
+
|
936
|
+
self.span_batch_size = span_batch_size
|
937
|
+
self.span_flush_interval = span_flush_interval
|
938
|
+
self.span_max_queue_size = span_max_queue_size
|
939
|
+
self.span_export_timeout = span_export_timeout
|
940
|
+
self.otel_span_processor: SpanProcessorBase
|
941
|
+
if enable_monitoring:
|
942
|
+
self.otel_span_processor = JudgmentSpanProcessor(
|
943
|
+
judgment_api_key=api_key,
|
944
|
+
organization_id=organization_id,
|
945
|
+
batch_size=span_batch_size,
|
946
|
+
flush_interval=span_flush_interval,
|
947
|
+
max_queue_size=span_max_queue_size,
|
948
|
+
export_timeout=span_export_timeout,
|
949
|
+
)
|
950
|
+
else:
|
951
|
+
self.otel_span_processor = SpanProcessorBase()
|
952
|
+
|
953
|
+
atexit.register(self._cleanup_on_exit)
|
954
|
+
except Exception as e:
|
955
|
+
judgeval_logger.error(
|
956
|
+
f"Issue with initializing Tracer: {e}. Disabling monitoring and evaluations."
|
957
|
+
)
|
958
|
+
self.enable_monitoring = False
|
959
|
+
self.enable_evaluations = False
|
960
|
+
|
961
|
+
def set_current_span(self, span_id: str) -> Optional[contextvars.Token[str | None]]:
|
962
|
+
self.span_id_to_previous_span_id[span_id] = self.current_span_id
|
963
|
+
self.current_span_id = span_id
|
964
|
+
Tracer.current_span_id = span_id
|
965
|
+
try:
|
966
|
+
token = current_span_var.set(span_id)
|
967
|
+
except Exception:
|
968
|
+
token = None
|
969
|
+
return token
|
970
|
+
|
971
|
+
def get_current_span(self) -> Optional[str]:
|
972
|
+
try:
|
973
|
+
current_span_var_val = current_span_var.get()
|
974
|
+
except Exception:
|
975
|
+
current_span_var_val = None
|
976
|
+
return (
|
977
|
+
(self.current_span_id or current_span_var_val)
|
978
|
+
if self.trace_across_async_contexts
|
979
|
+
else current_span_var_val
|
980
|
+
)
|
981
|
+
|
982
|
+
def reset_current_span(
|
983
|
+
self,
|
984
|
+
token: Optional[contextvars.Token[str | None]] = None,
|
985
|
+
span_id: Optional[str] = None,
|
986
|
+
):
|
987
|
+
try:
|
988
|
+
if token:
|
989
|
+
current_span_var.reset(token)
|
990
|
+
except Exception:
|
991
|
+
pass
|
992
|
+
if not span_id:
|
993
|
+
span_id = self.current_span_id
|
994
|
+
if span_id:
|
995
|
+
self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
|
996
|
+
Tracer.current_span_id = self.current_span_id
|
997
|
+
|
998
|
+
def set_current_trace(
|
999
|
+
self, trace: TraceClient
|
1000
|
+
) -> Optional[contextvars.Token[TraceClient | None]]:
|
1001
|
+
"""
|
1002
|
+
Set the current trace context in contextvars
|
1003
|
+
"""
|
1004
|
+
self.trace_id_to_previous_trace[trace.trace_id] = self.current_trace
|
1005
|
+
self.current_trace = trace
|
1006
|
+
Tracer.current_trace = trace
|
1007
|
+
try:
|
1008
|
+
token = current_trace_var.set(trace)
|
1009
|
+
except Exception:
|
1010
|
+
token = None
|
1011
|
+
return token
|
1012
|
+
|
1013
|
+
def get_current_trace(self) -> Optional[TraceClient]:
|
1014
|
+
"""
|
1015
|
+
Get the current trace context.
|
1016
|
+
|
1017
|
+
Tries to get the trace client from the context variable first.
|
1018
|
+
If not found (e.g., context lost across threads/tasks),
|
1019
|
+
it falls back to the active trace client managed by the callback handler.
|
1020
|
+
"""
|
1021
|
+
try:
|
1022
|
+
current_trace_var_val = current_trace_var.get()
|
1023
|
+
except Exception:
|
1024
|
+
current_trace_var_val = None
|
1025
|
+
return (
|
1026
|
+
(self.current_trace or current_trace_var_val)
|
1027
|
+
if self.trace_across_async_contexts
|
1028
|
+
else current_trace_var_val
|
1029
|
+
)
|
1030
|
+
|
1031
|
+
def reset_current_trace(
|
1032
|
+
self,
|
1033
|
+
token: Optional[contextvars.Token[TraceClient | None]] = None,
|
1034
|
+
trace_id: Optional[str] = None,
|
1035
|
+
):
|
1036
|
+
try:
|
1037
|
+
if token:
|
1038
|
+
current_trace_var.reset(token)
|
1039
|
+
except Exception:
|
1040
|
+
pass
|
1041
|
+
if not trace_id and self.current_trace:
|
1042
|
+
trace_id = self.current_trace.trace_id
|
1043
|
+
if trace_id:
|
1044
|
+
self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
|
1045
|
+
Tracer.current_trace = self.current_trace
|
1046
|
+
|
1047
|
+
@contextmanager
|
1048
|
+
def trace(
|
1049
|
+
self, name: str, project_name: str | None = None
|
1050
|
+
) -> Generator[TraceClient, None, None]:
|
1051
|
+
"""Start a new trace context using a context manager"""
|
1052
|
+
trace_id = str(uuid.uuid4())
|
1053
|
+
project = project_name if project_name is not None else self.project_name
|
1054
|
+
|
1055
|
+
# Get parent trace info from context
|
1056
|
+
parent_trace = self.get_current_trace()
|
1057
|
+
parent_trace_id = None
|
1058
|
+
parent_name = None
|
1059
|
+
|
1060
|
+
if parent_trace:
|
1061
|
+
parent_trace_id = parent_trace.trace_id
|
1062
|
+
parent_name = parent_trace.name
|
1063
|
+
|
1064
|
+
trace = TraceClient(
|
1065
|
+
self,
|
1066
|
+
trace_id,
|
1067
|
+
name,
|
1068
|
+
project_name=project,
|
1069
|
+
enable_monitoring=self.enable_monitoring,
|
1070
|
+
enable_evaluations=self.enable_evaluations,
|
1071
|
+
parent_trace_id=parent_trace_id,
|
1072
|
+
parent_name=parent_name,
|
1073
|
+
)
|
1074
|
+
|
1075
|
+
# Set the current trace in context variables
|
1076
|
+
token = self.set_current_trace(trace)
|
1077
|
+
|
1078
|
+
with trace.span(name or "unnamed_trace"):
|
1079
|
+
try:
|
1080
|
+
# Save the trace to the database to handle Evaluations' trace_id referential integrity
|
1081
|
+
yield trace
|
1082
|
+
finally:
|
1083
|
+
# Reset the context variable
|
1084
|
+
self.reset_current_trace(token)
|
1085
|
+
|
1086
|
+
def identify(
|
1087
|
+
self,
|
1088
|
+
identifier: str,
|
1089
|
+
track_state: bool = False,
|
1090
|
+
track_attributes: Optional[List[str]] = None,
|
1091
|
+
field_mappings: Optional[Dict[str, str]] = None,
|
1092
|
+
):
|
1093
|
+
"""
|
1094
|
+
Class decorator that associates a class with a custom identifier and enables state tracking.
|
1095
|
+
|
1096
|
+
This decorator creates a mapping between the class name and the provided
|
1097
|
+
identifier, which can be useful for tagging, grouping, or referencing
|
1098
|
+
classes in a standardized way. It also enables automatic state capture
|
1099
|
+
for instances of the decorated class when used with tracing.
|
1100
|
+
|
1101
|
+
Args:
|
1102
|
+
identifier: The identifier to associate with the decorated class.
|
1103
|
+
This will be used as the instance name in traces.
|
1104
|
+
track_state: Whether to automatically capture the state (attributes)
|
1105
|
+
of instances before and after function execution. Defaults to False.
|
1106
|
+
track_attributes: Optional list of specific attribute names to track.
|
1107
|
+
If None, all non-private attributes (not starting with '_')
|
1108
|
+
will be tracked when track_state=True.
|
1109
|
+
field_mappings: Optional dictionary mapping internal attribute names to
|
1110
|
+
display names in the captured state. For example:
|
1111
|
+
{"system_prompt": "instructions"} will capture the
|
1112
|
+
'instructions' attribute as 'system_prompt' in the state.
|
1113
|
+
|
1114
|
+
Example:
|
1115
|
+
@tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
|
1116
|
+
class User:
|
1117
|
+
# Class implementation
|
1118
|
+
"""
|
1119
|
+
|
1120
|
+
def decorator(cls):
|
1121
|
+
class_name = cls.__name__
|
1122
|
+
self.class_identifiers[class_name] = {
|
1123
|
+
"identifier": identifier,
|
1124
|
+
"track_state": track_state,
|
1125
|
+
"track_attributes": track_attributes,
|
1126
|
+
"field_mappings": field_mappings or {},
|
1127
|
+
}
|
1128
|
+
return cls
|
1129
|
+
|
1130
|
+
return decorator
|
1131
|
+
|
1132
|
+
def _capture_instance_state(
|
1133
|
+
self, instance: Any, class_config: Dict[str, Any]
|
1134
|
+
) -> Dict[str, Any]:
|
1135
|
+
"""
|
1136
|
+
Capture the state of an instance based on class configuration.
|
1137
|
+
Args:
|
1138
|
+
instance: The instance to capture the state of.
|
1139
|
+
class_config: Configuration dictionary for state capture,
|
1140
|
+
expected to contain 'track_attributes' and 'field_mappings'.
|
1141
|
+
"""
|
1142
|
+
track_attributes = class_config.get("track_attributes")
|
1143
|
+
field_mappings = class_config.get("field_mappings")
|
1144
|
+
|
1145
|
+
if track_attributes:
|
1146
|
+
state = {attr: getattr(instance, attr, None) for attr in track_attributes}
|
1147
|
+
else:
|
1148
|
+
state = {
|
1149
|
+
k: v for k, v in instance.__dict__.items() if not k.startswith("_")
|
1150
|
+
}
|
1151
|
+
|
1152
|
+
if field_mappings:
|
1153
|
+
state["field_mappings"] = field_mappings
|
1154
|
+
|
1155
|
+
return state
|
1156
|
+
|
1157
|
+
def _get_instance_state_if_tracked(self, args):
|
1158
|
+
"""
|
1159
|
+
Extract instance state if the instance should be tracked.
|
1160
|
+
|
1161
|
+
Returns the captured state dict if tracking is enabled, None otherwise.
|
1162
|
+
"""
|
1163
|
+
if args and hasattr(args[0], "__class__"):
|
1164
|
+
instance = args[0]
|
1165
|
+
class_name = instance.__class__.__name__
|
1166
|
+
if (
|
1167
|
+
class_name in self.class_identifiers
|
1168
|
+
and isinstance(self.class_identifiers[class_name], dict)
|
1169
|
+
and self.class_identifiers[class_name].get("track_state", False)
|
1170
|
+
):
|
1171
|
+
return self._capture_instance_state(
|
1172
|
+
instance, self.class_identifiers[class_name]
|
1173
|
+
)
|
1174
|
+
|
1175
|
+
def _conditionally_capture_and_record_state(
|
1176
|
+
self, trace_client_instance: TraceClient, args: tuple, is_before: bool
|
1177
|
+
):
|
1178
|
+
"""Captures instance state if tracked and records it via the trace_client."""
|
1179
|
+
state = self._get_instance_state_if_tracked(args)
|
1180
|
+
if state:
|
1181
|
+
if is_before:
|
1182
|
+
trace_client_instance.record_state_before(state)
|
1183
|
+
else:
|
1184
|
+
trace_client_instance.record_state_after(state)
|
1185
|
+
|
1186
|
+
def observe(
|
1187
|
+
self,
|
1188
|
+
func=None,
|
1189
|
+
*,
|
1190
|
+
name=None,
|
1191
|
+
span_type: SpanType = "span",
|
1192
|
+
):
|
1193
|
+
"""
|
1194
|
+
Decorator to trace function execution with detailed entry/exit information.
|
1195
|
+
|
1196
|
+
Args:
|
1197
|
+
func: The function to decorate
|
1198
|
+
name: Optional custom name for the span (defaults to function name)
|
1199
|
+
span_type: Type of span (default "span").
|
1200
|
+
"""
|
1201
|
+
# If monitoring is disabled, return the function as is
|
1202
|
+
try:
|
1203
|
+
if not self.enable_monitoring:
|
1204
|
+
return func if func else lambda f: f
|
1205
|
+
|
1206
|
+
if func is None:
|
1207
|
+
return lambda f: self.observe(
|
1208
|
+
f,
|
1209
|
+
name=name,
|
1210
|
+
span_type=span_type,
|
1211
|
+
)
|
1212
|
+
|
1213
|
+
# Use provided name or fall back to function name
|
1214
|
+
original_span_name = name or func.__name__
|
1215
|
+
|
1216
|
+
# Store custom attributes on the function object
|
1217
|
+
func._judgment_span_name = original_span_name
|
1218
|
+
func._judgment_span_type = span_type
|
1219
|
+
|
1220
|
+
except Exception:
|
1221
|
+
return func
|
1222
|
+
|
1223
|
+
if asyncio.iscoroutinefunction(func):
|
1224
|
+
|
1225
|
+
@functools.wraps(func)
|
1226
|
+
async def async_wrapper(*args, **kwargs):
|
1227
|
+
nonlocal original_span_name
|
1228
|
+
class_name = None
|
1229
|
+
span_name = original_span_name
|
1230
|
+
agent_name = None
|
1231
|
+
|
1232
|
+
if args and hasattr(args[0], "__class__"):
|
1233
|
+
class_name = args[0].__class__.__name__
|
1234
|
+
agent_name = get_instance_prefixed_name(
|
1235
|
+
args[0], class_name, self.class_identifiers
|
1236
|
+
)
|
1237
|
+
|
1238
|
+
current_trace = self.get_current_trace()
|
1239
|
+
|
1240
|
+
if not current_trace:
|
1241
|
+
trace_id = str(uuid.uuid4())
|
1242
|
+
project = self.project_name
|
1243
|
+
|
1244
|
+
current_trace = TraceClient(
|
1245
|
+
self,
|
1246
|
+
trace_id,
|
1247
|
+
span_name,
|
1248
|
+
project_name=project,
|
1249
|
+
enable_monitoring=self.enable_monitoring,
|
1250
|
+
enable_evaluations=self.enable_evaluations,
|
1251
|
+
)
|
1252
|
+
|
1253
|
+
trace_token = self.set_current_trace(current_trace)
|
1254
|
+
|
1255
|
+
try:
|
1256
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1257
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1258
|
+
span.record_input(inputs)
|
1259
|
+
if agent_name:
|
1260
|
+
span.record_agent_name(agent_name)
|
1261
|
+
|
1262
|
+
self._conditionally_capture_and_record_state(
|
1263
|
+
span, args, is_before=True
|
1264
|
+
)
|
1265
|
+
|
1266
|
+
try:
|
1267
|
+
if self.deep_tracing:
|
1268
|
+
with _DeepTracer(self):
|
1269
|
+
result = await func(*args, **kwargs)
|
1270
|
+
else:
|
1271
|
+
result = await func(*args, **kwargs)
|
1272
|
+
except Exception as e:
|
1273
|
+
_capture_exception_for_trace(
|
1274
|
+
current_trace, sys.exc_info()
|
1275
|
+
)
|
1276
|
+
raise e
|
1277
|
+
|
1278
|
+
self._conditionally_capture_and_record_state(
|
1279
|
+
span, args, is_before=False
|
1280
|
+
)
|
1281
|
+
|
1282
|
+
span.record_output(result)
|
1283
|
+
return result
|
1284
|
+
finally:
|
1285
|
+
try:
|
1286
|
+
complete_trace_data = {
|
1287
|
+
"trace_id": current_trace.trace_id,
|
1288
|
+
"name": current_trace.name,
|
1289
|
+
"created_at": datetime.fromtimestamp(
|
1290
|
+
current_trace.start_time or time.time(),
|
1291
|
+
timezone.utc,
|
1292
|
+
).isoformat(),
|
1293
|
+
"duration": current_trace.get_duration(),
|
1294
|
+
"trace_spans": [
|
1295
|
+
span.model_dump()
|
1296
|
+
for span in current_trace.trace_spans
|
1297
|
+
],
|
1298
|
+
"offline_mode": self.offline_mode,
|
1299
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
1300
|
+
"parent_name": current_trace.parent_name,
|
1301
|
+
}
|
1302
|
+
|
1303
|
+
trace_id, server_response = current_trace.save(
|
1304
|
+
final_save=True
|
1305
|
+
)
|
1306
|
+
|
1307
|
+
self.traces.append(complete_trace_data)
|
1308
|
+
|
1309
|
+
self.reset_current_trace(trace_token)
|
1310
|
+
except Exception as e:
|
1311
|
+
judgeval_logger.warning(f"Issue with async_wrapper: {e}")
|
1312
|
+
pass
|
1313
|
+
else:
|
1314
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1315
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1316
|
+
span.record_input(inputs)
|
1317
|
+
if agent_name:
|
1318
|
+
span.record_agent_name(agent_name)
|
1319
|
+
|
1320
|
+
# Capture state before execution
|
1321
|
+
self._conditionally_capture_and_record_state(
|
1322
|
+
span, args, is_before=True
|
1323
|
+
)
|
1324
|
+
|
1325
|
+
try:
|
1326
|
+
if self.deep_tracing:
|
1327
|
+
with _DeepTracer(self):
|
1328
|
+
result = await func(*args, **kwargs)
|
1329
|
+
else:
|
1330
|
+
result = await func(*args, **kwargs)
|
1331
|
+
except Exception as e:
|
1332
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1333
|
+
raise e
|
1334
|
+
|
1335
|
+
# Capture state after execution
|
1336
|
+
self._conditionally_capture_and_record_state(
|
1337
|
+
span, args, is_before=False
|
1338
|
+
)
|
1339
|
+
|
1340
|
+
span.record_output(result)
|
1341
|
+
return result
|
1342
|
+
|
1343
|
+
return async_wrapper
|
1344
|
+
else:
|
1345
|
+
# Non-async function implementation with deep tracing
|
1346
|
+
@functools.wraps(func)
|
1347
|
+
def wrapper(*args, **kwargs):
|
1348
|
+
nonlocal original_span_name
|
1349
|
+
class_name = None
|
1350
|
+
span_name = original_span_name
|
1351
|
+
agent_name = None
|
1352
|
+
if args and hasattr(args[0], "__class__"):
|
1353
|
+
class_name = args[0].__class__.__name__
|
1354
|
+
agent_name = get_instance_prefixed_name(
|
1355
|
+
args[0], class_name, self.class_identifiers
|
1356
|
+
)
|
1357
|
+
# Get current trace from context
|
1358
|
+
current_trace = self.get_current_trace()
|
1359
|
+
|
1360
|
+
# If there's no current trace, create a root trace
|
1361
|
+
if not current_trace:
|
1362
|
+
trace_id = str(uuid.uuid4())
|
1363
|
+
project = self.project_name
|
1364
|
+
|
1365
|
+
# Create a new trace client to serve as the root
|
1366
|
+
current_trace = TraceClient(
|
1367
|
+
self,
|
1368
|
+
trace_id,
|
1369
|
+
span_name,
|
1370
|
+
project_name=project,
|
1371
|
+
enable_monitoring=self.enable_monitoring,
|
1372
|
+
enable_evaluations=self.enable_evaluations,
|
1373
|
+
)
|
1374
|
+
|
1375
|
+
trace_token = self.set_current_trace(current_trace)
|
1376
|
+
|
1377
|
+
try:
|
1378
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1379
|
+
# Record inputs
|
1380
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1381
|
+
span.record_input(inputs)
|
1382
|
+
if agent_name:
|
1383
|
+
span.record_agent_name(agent_name)
|
1384
|
+
# Capture state before execution
|
1385
|
+
self._conditionally_capture_and_record_state(
|
1386
|
+
span, args, is_before=True
|
1387
|
+
)
|
1388
|
+
|
1389
|
+
try:
|
1390
|
+
if self.deep_tracing:
|
1391
|
+
with _DeepTracer(self):
|
1392
|
+
result = func(*args, **kwargs)
|
1393
|
+
else:
|
1394
|
+
result = func(*args, **kwargs)
|
1395
|
+
except Exception as e:
|
1396
|
+
_capture_exception_for_trace(
|
1397
|
+
current_trace, sys.exc_info()
|
1398
|
+
)
|
1399
|
+
raise e
|
1400
|
+
|
1401
|
+
# Capture state after execution
|
1402
|
+
self._conditionally_capture_and_record_state(
|
1403
|
+
span, args, is_before=False
|
1404
|
+
)
|
1405
|
+
|
1406
|
+
# Record output
|
1407
|
+
span.record_output(result)
|
1408
|
+
return result
|
1409
|
+
finally:
|
1410
|
+
try:
|
1411
|
+
trace_id, server_response = current_trace.save(
|
1412
|
+
final_save=True
|
1413
|
+
)
|
1414
|
+
|
1415
|
+
complete_trace_data = {
|
1416
|
+
"trace_id": current_trace.trace_id,
|
1417
|
+
"name": current_trace.name,
|
1418
|
+
"created_at": datetime.fromtimestamp(
|
1419
|
+
current_trace.start_time or time.time(),
|
1420
|
+
timezone.utc,
|
1421
|
+
).isoformat(),
|
1422
|
+
"duration": current_trace.get_duration(),
|
1423
|
+
"trace_spans": [
|
1424
|
+
span.model_dump()
|
1425
|
+
for span in current_trace.trace_spans
|
1426
|
+
],
|
1427
|
+
"offline_mode": self.offline_mode,
|
1428
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
1429
|
+
"parent_name": current_trace.parent_name,
|
1430
|
+
}
|
1431
|
+
self.traces.append(complete_trace_data)
|
1432
|
+
self.reset_current_trace(trace_token)
|
1433
|
+
except Exception as e:
|
1434
|
+
judgeval_logger.warning(f"Issue with save: {e}")
|
1435
|
+
pass
|
1436
|
+
else:
|
1437
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1438
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1439
|
+
span.record_input(inputs)
|
1440
|
+
if agent_name:
|
1441
|
+
span.record_agent_name(agent_name)
|
1442
|
+
|
1443
|
+
# Capture state before execution
|
1444
|
+
self._conditionally_capture_and_record_state(
|
1445
|
+
span, args, is_before=True
|
1446
|
+
)
|
1447
|
+
|
1448
|
+
try:
|
1449
|
+
if self.deep_tracing:
|
1450
|
+
with _DeepTracer(self):
|
1451
|
+
result = func(*args, **kwargs)
|
1452
|
+
else:
|
1453
|
+
result = func(*args, **kwargs)
|
1454
|
+
except Exception as e:
|
1455
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1456
|
+
raise e
|
1457
|
+
|
1458
|
+
# Capture state after execution
|
1459
|
+
self._conditionally_capture_and_record_state(
|
1460
|
+
span, args, is_before=False
|
1461
|
+
)
|
1462
|
+
|
1463
|
+
span.record_output(result)
|
1464
|
+
return result
|
1465
|
+
|
1466
|
+
return wrapper
|
1467
|
+
|
1468
|
+
def observe_tools(
|
1469
|
+
self,
|
1470
|
+
cls=None,
|
1471
|
+
*,
|
1472
|
+
exclude_methods: Optional[List[str]] = None,
|
1473
|
+
include_private: bool = False,
|
1474
|
+
warn_on_double_decoration: bool = True,
|
1475
|
+
):
|
1476
|
+
"""
|
1477
|
+
Automatically adds @observe(span_type="tool") to all methods in a class.
|
1478
|
+
|
1479
|
+
Args:
|
1480
|
+
cls: The class to decorate (automatically provided when used as decorator)
|
1481
|
+
exclude_methods: List of method names to skip decorating. Defaults to common magic methods
|
1482
|
+
include_private: Whether to decorate methods starting with underscore. Defaults to False
|
1483
|
+
warn_on_double_decoration: Whether to print warnings when skipping already-decorated methods. Defaults to True
|
1484
|
+
"""
|
1485
|
+
|
1486
|
+
if exclude_methods is None:
|
1487
|
+
exclude_methods = ["__init__", "__new__", "__del__", "__str__", "__repr__"]
|
1488
|
+
|
1489
|
+
def decorate_class(cls):
|
1490
|
+
if not self.enable_monitoring:
|
1491
|
+
return cls
|
1492
|
+
|
1493
|
+
decorated = []
|
1494
|
+
skipped = []
|
1495
|
+
|
1496
|
+
for name in dir(cls):
|
1497
|
+
method = getattr(cls, name)
|
1498
|
+
|
1499
|
+
if (
|
1500
|
+
not callable(method)
|
1501
|
+
or name in exclude_methods
|
1502
|
+
or (name.startswith("_") and not include_private)
|
1503
|
+
or not hasattr(cls, name)
|
1504
|
+
):
|
1505
|
+
continue
|
1506
|
+
|
1507
|
+
if hasattr(method, "_judgment_span_name"):
|
1508
|
+
skipped.append(name)
|
1509
|
+
if warn_on_double_decoration:
|
1510
|
+
judgeval_logger.info(
|
1511
|
+
f"{cls.__name__}.{name} already decorated, skipping"
|
1512
|
+
)
|
1513
|
+
continue
|
1514
|
+
|
1515
|
+
try:
|
1516
|
+
decorated_method = self.observe(method, span_type="tool")
|
1517
|
+
setattr(cls, name, decorated_method)
|
1518
|
+
decorated.append(name)
|
1519
|
+
except Exception as e:
|
1520
|
+
if warn_on_double_decoration:
|
1521
|
+
judgeval_logger.warning(
|
1522
|
+
f"Failed to decorate {cls.__name__}.{name}: {e}"
|
1523
|
+
)
|
1524
|
+
|
1525
|
+
return cls
|
1526
|
+
|
1527
|
+
return decorate_class if cls is None else decorate_class(cls)
|
1528
|
+
|
1529
|
+
def async_evaluate(self, *args, **kwargs):
|
1530
|
+
try:
|
1531
|
+
if not self.enable_monitoring or not self.enable_evaluations:
|
1532
|
+
return
|
1533
|
+
|
1534
|
+
current_trace = self.get_current_trace()
|
1535
|
+
|
1536
|
+
if current_trace:
|
1537
|
+
current_trace.async_evaluate(*args, **kwargs)
|
1538
|
+
else:
|
1539
|
+
judgeval_logger.warning(
|
1540
|
+
"No trace found (context var or fallback), skipping evaluation"
|
1541
|
+
)
|
1542
|
+
except Exception as e:
|
1543
|
+
judgeval_logger.warning(f"Issue with async_evaluate: {e}")
|
1544
|
+
|
1545
|
+
def update_metadata(self, metadata: dict):
|
1546
|
+
"""
|
1547
|
+
Update metadata for the current trace.
|
1548
|
+
|
1549
|
+
Args:
|
1550
|
+
metadata: Metadata as a dictionary
|
1551
|
+
"""
|
1552
|
+
current_trace = self.get_current_trace()
|
1553
|
+
if current_trace:
|
1554
|
+
current_trace.update_metadata(metadata)
|
1555
|
+
else:
|
1556
|
+
judgeval_logger.warning("No current trace found, cannot set metadata")
|
1557
|
+
|
1558
|
+
def set_customer_id(self, customer_id: str):
|
1559
|
+
"""
|
1560
|
+
Set the customer ID for the current trace.
|
1561
|
+
|
1562
|
+
Args:
|
1563
|
+
customer_id: The customer ID to set
|
1564
|
+
"""
|
1565
|
+
current_trace = self.get_current_trace()
|
1566
|
+
if current_trace:
|
1567
|
+
current_trace.set_customer_id(customer_id)
|
1568
|
+
else:
|
1569
|
+
judgeval_logger.warning("No current trace found, cannot set customer ID")
|
1570
|
+
|
1571
|
+
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
1572
|
+
"""
|
1573
|
+
Set the tags for the current trace.
|
1574
|
+
|
1575
|
+
Args:
|
1576
|
+
tags: List of tags to set
|
1577
|
+
"""
|
1578
|
+
current_trace = self.get_current_trace()
|
1579
|
+
if current_trace:
|
1580
|
+
current_trace.set_tags(tags)
|
1581
|
+
else:
|
1582
|
+
judgeval_logger.warning("No current trace found, cannot set tags")
|
1583
|
+
|
1584
|
+
def set_reward_score(self, reward_score: Union[float, Dict[str, float]]):
|
1585
|
+
"""
|
1586
|
+
Set the reward score for this trace to be used for RL or SFT.
|
1587
|
+
|
1588
|
+
Args:
|
1589
|
+
reward_score: The reward score to set
|
1590
|
+
"""
|
1591
|
+
current_trace = self.get_current_trace()
|
1592
|
+
if current_trace:
|
1593
|
+
current_trace.set_reward_score(reward_score)
|
1594
|
+
else:
|
1595
|
+
judgeval_logger.warning("No current trace found, cannot set reward score")
|
1596
|
+
|
1597
|
+
def get_otel_span_processor(self) -> SpanProcessorBase:
|
1598
|
+
"""Get the OpenTelemetry span processor instance."""
|
1599
|
+
return self.otel_span_processor
|
1600
|
+
|
1601
|
+
def flush_background_spans(self, timeout_millis: int = 30000):
|
1602
|
+
"""Flush all pending spans in the background service."""
|
1603
|
+
self.otel_span_processor.force_flush(timeout_millis)
|
1604
|
+
|
1605
|
+
def shutdown_background_service(self):
|
1606
|
+
"""Shutdown the background span service."""
|
1607
|
+
self.otel_span_processor.shutdown()
|
1608
|
+
self.otel_span_processor = SpanProcessorBase()
|
1609
|
+
|
1610
|
+
def _cleanup_on_exit(self):
|
1611
|
+
"""Cleanup handler called on application exit to ensure spans are flushed."""
|
1612
|
+
try:
|
1613
|
+
self.flush_background_spans()
|
1614
|
+
except Exception as e:
|
1615
|
+
judgeval_logger.warning(f"Error during tracer cleanup: {e}")
|
1616
|
+
finally:
|
1617
|
+
try:
|
1618
|
+
self.shutdown_background_service()
|
1619
|
+
except Exception as e:
|
1620
|
+
judgeval_logger.warning(
|
1621
|
+
f"Error during background service shutdown: {e}"
|
1622
|
+
)
|
1623
|
+
|
1624
|
+
|
1625
|
+
def _get_current_trace(
|
1626
|
+
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
1627
|
+
):
|
1628
|
+
if trace_across_async_contexts:
|
1629
|
+
return Tracer.current_trace
|
1630
|
+
else:
|
1631
|
+
return current_trace_var.get()
|
1632
|
+
|
1633
|
+
|
1634
|
+
def wrap(
|
1635
|
+
client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts
|
1636
|
+
) -> Any:
|
1637
|
+
"""
|
1638
|
+
Wraps an API client to add tracing capabilities.
|
1639
|
+
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
1640
|
+
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1641
|
+
"""
|
1642
|
+
(
|
1643
|
+
span_name,
|
1644
|
+
original_create,
|
1645
|
+
original_responses_create,
|
1646
|
+
original_stream,
|
1647
|
+
original_beta_parse,
|
1648
|
+
) = _get_client_config(client)
|
1649
|
+
|
1650
|
+
def process_span(span, response):
|
1651
|
+
"""Format and record the output in the span"""
|
1652
|
+
output, usage = _format_output_data(client, response)
|
1653
|
+
span.record_output(output)
|
1654
|
+
span.record_usage(usage)
|
1655
|
+
|
1656
|
+
return response
|
1657
|
+
|
1658
|
+
def wrapped(function):
|
1659
|
+
def wrapper(*args, **kwargs):
|
1660
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
1661
|
+
if not current_trace:
|
1662
|
+
return function(*args, **kwargs)
|
1663
|
+
|
1664
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1665
|
+
span.record_input(kwargs)
|
1666
|
+
|
1667
|
+
try:
|
1668
|
+
response = function(*args, **kwargs)
|
1669
|
+
return process_span(span, response)
|
1670
|
+
except Exception as e:
|
1671
|
+
_capture_exception_for_trace(span, sys.exc_info())
|
1672
|
+
raise e
|
1673
|
+
|
1674
|
+
return wrapper
|
1675
|
+
|
1676
|
+
def wrapped_async(function):
|
1677
|
+
async def wrapper(*args, **kwargs):
|
1678
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
1679
|
+
if not current_trace:
|
1680
|
+
return await function(*args, **kwargs)
|
1681
|
+
|
1682
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1683
|
+
span.record_input(kwargs)
|
1684
|
+
|
1685
|
+
try:
|
1686
|
+
response = await function(*args, **kwargs)
|
1687
|
+
return process_span(span, response)
|
1688
|
+
except Exception as e:
|
1689
|
+
_capture_exception_for_trace(span, sys.exc_info())
|
1690
|
+
raise e
|
1691
|
+
|
1692
|
+
return wrapper
|
1693
|
+
|
1694
|
+
if isinstance(client, (OpenAI)):
|
1695
|
+
client.chat.completions.create = wrapped(original_create)
|
1696
|
+
client.responses.create = wrapped(original_responses_create)
|
1697
|
+
client.beta.chat.completions.parse = wrapped(original_beta_parse)
|
1698
|
+
elif isinstance(client, (AsyncOpenAI)):
|
1699
|
+
client.chat.completions.create = wrapped_async(original_create)
|
1700
|
+
client.responses.create = wrapped_async(original_responses_create)
|
1701
|
+
client.beta.chat.completions.parse = wrapped_async(original_beta_parse)
|
1702
|
+
elif isinstance(client, (Together)):
|
1703
|
+
client.chat.completions.create = wrapped(original_create)
|
1704
|
+
elif isinstance(client, (AsyncTogether)):
|
1705
|
+
client.chat.completions.create = wrapped_async(original_create)
|
1706
|
+
elif isinstance(client, (Anthropic)):
|
1707
|
+
client.messages.create = wrapped(original_create)
|
1708
|
+
elif isinstance(client, (AsyncAnthropic)):
|
1709
|
+
client.messages.create = wrapped_async(original_create)
|
1710
|
+
elif isinstance(client, (genai.Client)):
|
1711
|
+
client.models.generate_content = wrapped(original_create)
|
1712
|
+
elif isinstance(client, (genai.client.AsyncClient)):
|
1713
|
+
client.models.generate_content = wrapped_async(original_create)
|
1714
|
+
|
1715
|
+
return client
|
1716
|
+
|
1717
|
+
|
1718
|
+
# Helper functions for client-specific operations
|
1719
|
+
|
1720
|
+
|
1721
|
+
def _get_client_config(
|
1722
|
+
client: ApiClient,
|
1723
|
+
) -> tuple[str, Callable, Optional[Callable], Optional[Callable], Optional[Callable]]:
|
1724
|
+
"""Returns configuration tuple for the given API client.
|
1725
|
+
|
1726
|
+
Args:
|
1727
|
+
client: An instance of OpenAI, Together, or Anthropic client
|
1728
|
+
|
1729
|
+
Returns:
|
1730
|
+
tuple: (span_name, create_method, responses_method, stream_method, beta_parse_method)
|
1731
|
+
- span_name: String identifier for tracing
|
1732
|
+
- create_method: Reference to the client's creation method
|
1733
|
+
- responses_method: Reference to the client's responses method (if applicable)
|
1734
|
+
- stream_method: Reference to the client's stream method (if applicable)
|
1735
|
+
- beta_parse_method: Reference to the client's beta parse method (if applicable)
|
1736
|
+
|
1737
|
+
Raises:
|
1738
|
+
ValueError: If client type is not supported
|
1739
|
+
"""
|
1740
|
+
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
1741
|
+
return (
|
1742
|
+
"OPENAI_API_CALL",
|
1743
|
+
client.chat.completions.create,
|
1744
|
+
client.responses.create,
|
1745
|
+
None,
|
1746
|
+
client.beta.chat.completions.parse,
|
1747
|
+
)
|
1748
|
+
elif isinstance(client, (Together, AsyncTogether)):
|
1749
|
+
return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
|
1750
|
+
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1751
|
+
return (
|
1752
|
+
"ANTHROPIC_API_CALL",
|
1753
|
+
client.messages.create,
|
1754
|
+
None,
|
1755
|
+
client.messages.stream,
|
1756
|
+
None,
|
1757
|
+
)
|
1758
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1759
|
+
return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
|
1760
|
+
raise ValueError(f"Unsupported client type: {type(client)}")
|
1761
|
+
|
1762
|
+
|
1763
|
+
def _format_output_data(
|
1764
|
+
client: ApiClient, response: Any
|
1765
|
+
) -> tuple[Optional[str], Optional[TraceUsage]]:
|
1766
|
+
"""Format API response data based on client type.
|
1767
|
+
|
1768
|
+
Normalizes different response formats into a consistent structure
|
1769
|
+
for tracing purposes.
|
1770
|
+
|
1771
|
+
Returns:
|
1772
|
+
dict containing:
|
1773
|
+
- content: The generated text
|
1774
|
+
- usage: Token usage statistics
|
1775
|
+
"""
|
1776
|
+
prompt_tokens = 0
|
1777
|
+
completion_tokens = 0
|
1778
|
+
cache_read_input_tokens = 0
|
1779
|
+
cache_creation_input_tokens = 0
|
1780
|
+
model_name = None
|
1781
|
+
message_content = None
|
1782
|
+
|
1783
|
+
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
1784
|
+
if isinstance(response, ChatCompletion):
|
1785
|
+
model_name = response.model
|
1786
|
+
prompt_tokens = response.usage.prompt_tokens
|
1787
|
+
completion_tokens = response.usage.completion_tokens
|
1788
|
+
cache_read_input_tokens = response.usage.prompt_tokens_details.cached_tokens
|
1789
|
+
|
1790
|
+
if isinstance(response, ParsedChatCompletion):
|
1791
|
+
message_content = response.choices[0].message.parsed
|
1792
|
+
else:
|
1793
|
+
message_content = response.choices[0].message.content
|
1794
|
+
elif isinstance(response, Response):
|
1795
|
+
model_name = response.model
|
1796
|
+
prompt_tokens = response.usage.input_tokens
|
1797
|
+
completion_tokens = response.usage.output_tokens
|
1798
|
+
cache_read_input_tokens = response.usage.input_tokens_details.cached_tokens
|
1799
|
+
message_content = "".join(seg.text for seg in response.output[0].content)
|
1800
|
+
|
1801
|
+
# Note: LiteLLM seems to use cache_read_input_tokens to calculate the cost for OpenAI
|
1802
|
+
elif isinstance(client, (Together, AsyncTogether)):
|
1803
|
+
model_name = "together_ai/" + response.model
|
1804
|
+
prompt_tokens = response.usage.prompt_tokens
|
1805
|
+
completion_tokens = response.usage.completion_tokens
|
1806
|
+
message_content = response.choices[0].message.content
|
1807
|
+
|
1808
|
+
# As of 2025-07-14, Together does not do any input cache token tracking
|
1809
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1810
|
+
model_name = response.model_version
|
1811
|
+
prompt_tokens = response.usage_metadata.prompt_token_count
|
1812
|
+
completion_tokens = response.usage_metadata.candidates_token_count
|
1813
|
+
message_content = response.candidates[0].content.parts[0].text
|
1814
|
+
|
1815
|
+
if hasattr(response.usage_metadata, "cached_content_token_count"):
|
1816
|
+
cache_read_input_tokens = response.usage_metadata.cached_content_token_count
|
1817
|
+
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1818
|
+
model_name = response.model
|
1819
|
+
prompt_tokens = response.usage.input_tokens
|
1820
|
+
completion_tokens = response.usage.output_tokens
|
1821
|
+
cache_read_input_tokens = response.usage.cache_read_input_tokens
|
1822
|
+
cache_creation_input_tokens = response.usage.cache_creation_input_tokens
|
1823
|
+
message_content = response.content[0].text
|
1824
|
+
else:
|
1825
|
+
judgeval_logger.warning(f"Unsupported client type: {type(client)}")
|
1826
|
+
return None, None
|
1827
|
+
|
1828
|
+
prompt_cost, completion_cost = cost_per_token(
|
1829
|
+
model=model_name,
|
1830
|
+
prompt_tokens=prompt_tokens,
|
1831
|
+
completion_tokens=completion_tokens,
|
1832
|
+
cache_read_input_tokens=cache_read_input_tokens,
|
1833
|
+
cache_creation_input_tokens=cache_creation_input_tokens,
|
1834
|
+
)
|
1835
|
+
total_cost_usd = (
|
1836
|
+
(prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
1837
|
+
)
|
1838
|
+
usage = TraceUsage(
|
1839
|
+
prompt_tokens=prompt_tokens,
|
1840
|
+
completion_tokens=completion_tokens,
|
1841
|
+
total_tokens=prompt_tokens + completion_tokens,
|
1842
|
+
cache_read_input_tokens=cache_read_input_tokens,
|
1843
|
+
cache_creation_input_tokens=cache_creation_input_tokens,
|
1844
|
+
prompt_tokens_cost_usd=prompt_cost,
|
1845
|
+
completion_tokens_cost_usd=completion_cost,
|
1846
|
+
total_cost_usd=total_cost_usd,
|
1847
|
+
model_name=model_name,
|
1848
|
+
)
|
1849
|
+
return message_content, usage
|
1850
|
+
|
1851
|
+
|
1852
|
+
def combine_args_kwargs(func, args, kwargs):
|
1853
|
+
"""
|
1854
|
+
Combine positional arguments and keyword arguments into a single dictionary.
|
1855
|
+
|
1856
|
+
Args:
|
1857
|
+
func: The function being called
|
1858
|
+
args: Tuple of positional arguments
|
1859
|
+
kwargs: Dictionary of keyword arguments
|
1860
|
+
|
1861
|
+
Returns:
|
1862
|
+
A dictionary combining both args and kwargs
|
1863
|
+
"""
|
1864
|
+
try:
|
1865
|
+
import inspect
|
1866
|
+
|
1867
|
+
sig = inspect.signature(func)
|
1868
|
+
param_names = list(sig.parameters.keys())
|
1869
|
+
|
1870
|
+
args_dict = {}
|
1871
|
+
for i, arg in enumerate(args):
|
1872
|
+
if i < len(param_names):
|
1873
|
+
args_dict[param_names[i]] = arg
|
1874
|
+
else:
|
1875
|
+
args_dict[f"arg{i}"] = arg
|
1876
|
+
|
1877
|
+
return {**args_dict, **kwargs}
|
1878
|
+
except Exception:
|
1879
|
+
# Fallback if signature inspection fails
|
1880
|
+
return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
|
1881
|
+
|
1882
|
+
|
1883
|
+
def cost_per_token(*args, **kwargs):
|
1884
|
+
try:
|
1885
|
+
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = (
|
1886
|
+
_original_cost_per_token(*args, **kwargs)
|
1887
|
+
)
|
1888
|
+
if (
|
1889
|
+
prompt_tokens_cost_usd_dollar == 0
|
1890
|
+
and completion_tokens_cost_usd_dollar == 0
|
1891
|
+
):
|
1892
|
+
judgeval_logger.warning("LiteLLM returned a total of 0 for cost per token")
|
1893
|
+
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
1894
|
+
except Exception as e:
|
1895
|
+
judgeval_logger.warning(f"Error calculating cost per token: {e}")
|
1896
|
+
return None, None
|
1897
|
+
|
1898
|
+
|
1899
|
+
# --- Helper function for instance-prefixed qual_name ---
|
1900
|
+
def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
1901
|
+
"""
|
1902
|
+
Returns the agent name (prefix) if the class and attribute are found in class_identifiers.
|
1903
|
+
Otherwise, returns None.
|
1904
|
+
"""
|
1905
|
+
if class_name in class_identifiers:
|
1906
|
+
class_config = class_identifiers[class_name]
|
1907
|
+
attr = class_config["identifier"]
|
1908
|
+
|
1909
|
+
if hasattr(instance, attr):
|
1910
|
+
instance_name = getattr(instance, attr)
|
1911
|
+
return instance_name
|
1912
|
+
else:
|
1913
|
+
raise Exception(
|
1914
|
+
f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator."
|
1915
|
+
)
|
1916
|
+
return None
|