judgeval 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +139 -12
- judgeval/api/__init__.py +501 -0
- judgeval/api/api_types.py +344 -0
- judgeval/cli.py +2 -4
- judgeval/constants.py +10 -26
- judgeval/data/evaluation_run.py +49 -26
- judgeval/data/example.py +2 -2
- judgeval/data/judgment_types.py +266 -82
- judgeval/data/result.py +4 -5
- judgeval/data/scorer_data.py +4 -2
- judgeval/data/tool.py +2 -2
- judgeval/data/trace.py +7 -50
- judgeval/data/trace_run.py +7 -4
- judgeval/{dataset.py → dataset/__init__.py} +43 -28
- judgeval/env.py +67 -0
- judgeval/{run_evaluation.py → evaluation/__init__.py} +29 -95
- judgeval/exceptions.py +27 -0
- judgeval/integrations/langgraph/__init__.py +788 -0
- judgeval/judges/__init__.py +2 -2
- judgeval/judges/litellm_judge.py +75 -15
- judgeval/judges/together_judge.py +86 -18
- judgeval/judges/utils.py +7 -21
- judgeval/{common/logger.py → logger.py} +8 -6
- judgeval/scorers/__init__.py +0 -4
- judgeval/scorers/agent_scorer.py +3 -7
- judgeval/scorers/api_scorer.py +8 -13
- judgeval/scorers/base_scorer.py +52 -32
- judgeval/scorers/example_scorer.py +1 -3
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +0 -14
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +45 -20
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +2 -2
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +3 -3
- judgeval/scorers/score.py +21 -31
- judgeval/scorers/trace_api_scorer.py +5 -0
- judgeval/scorers/utils.py +1 -103
- judgeval/tracer/__init__.py +1075 -2
- judgeval/tracer/constants.py +1 -0
- judgeval/tracer/exporters/__init__.py +37 -0
- judgeval/tracer/exporters/s3.py +119 -0
- judgeval/tracer/exporters/store.py +43 -0
- judgeval/tracer/exporters/utils.py +32 -0
- judgeval/tracer/keys.py +67 -0
- judgeval/tracer/llm/__init__.py +1233 -0
- judgeval/{common/tracer → tracer/llm}/providers.py +5 -10
- judgeval/{local_eval_queue.py → tracer/local_eval_queue.py} +15 -10
- judgeval/tracer/managers.py +188 -0
- judgeval/tracer/processors/__init__.py +181 -0
- judgeval/tracer/utils.py +20 -0
- judgeval/trainer/__init__.py +5 -0
- judgeval/{common/trainer → trainer}/config.py +12 -9
- judgeval/{common/trainer → trainer}/console.py +2 -9
- judgeval/{common/trainer → trainer}/trainable_model.py +12 -7
- judgeval/{common/trainer → trainer}/trainer.py +119 -17
- judgeval/utils/async_utils.py +2 -3
- judgeval/utils/decorators.py +24 -0
- judgeval/utils/file_utils.py +37 -4
- judgeval/utils/guards.py +32 -0
- judgeval/utils/meta.py +14 -0
- judgeval/{common/api/json_encoder.py → utils/serialize.py} +7 -1
- judgeval/utils/testing.py +88 -0
- judgeval/utils/url.py +10 -0
- judgeval/{version_check.py → utils/version_check.py} +3 -3
- judgeval/version.py +5 -0
- judgeval/warnings.py +4 -0
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/METADATA +12 -14
- judgeval-0.9.0.dist-info/RECORD +80 -0
- judgeval/clients.py +0 -35
- judgeval/common/__init__.py +0 -13
- judgeval/common/api/__init__.py +0 -3
- judgeval/common/api/api.py +0 -375
- judgeval/common/api/constants.py +0 -186
- judgeval/common/exceptions.py +0 -27
- judgeval/common/storage/__init__.py +0 -6
- judgeval/common/storage/s3_storage.py +0 -97
- judgeval/common/tracer/__init__.py +0 -31
- judgeval/common/tracer/constants.py +0 -22
- judgeval/common/tracer/core.py +0 -2427
- judgeval/common/tracer/otel_exporter.py +0 -108
- judgeval/common/tracer/otel_span_processor.py +0 -188
- judgeval/common/tracer/span_processor.py +0 -37
- judgeval/common/tracer/span_transformer.py +0 -207
- judgeval/common/tracer/trace_manager.py +0 -101
- judgeval/common/trainer/__init__.py +0 -5
- judgeval/common/utils.py +0 -948
- judgeval/integrations/langgraph.py +0 -844
- judgeval/judges/mixture_of_judges.py +0 -287
- judgeval/judgment_client.py +0 -267
- judgeval/rules.py +0 -521
- judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
- judgeval/utils/alerts.py +0 -93
- judgeval/utils/requests.py +0 -50
- judgeval-0.7.1.dist-info/RECORD +0 -82
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/WHEEL +0 -0
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/entry_points.txt +0 -0
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer/core.py
DELETED
@@ -1,2427 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Tracing system for judgeval that allows for function tracing using decorators.
|
3
|
-
"""
|
4
|
-
|
5
|
-
from __future__ import annotations
|
6
|
-
|
7
|
-
import asyncio
|
8
|
-
import atexit
|
9
|
-
import functools
|
10
|
-
import inspect
|
11
|
-
import os
|
12
|
-
import threading
|
13
|
-
import time
|
14
|
-
import traceback
|
15
|
-
import uuid
|
16
|
-
import contextvars
|
17
|
-
import sys
|
18
|
-
from contextlib import (
|
19
|
-
contextmanager,
|
20
|
-
)
|
21
|
-
from datetime import datetime, timezone
|
22
|
-
from typing import (
|
23
|
-
Any,
|
24
|
-
Callable,
|
25
|
-
Dict,
|
26
|
-
Generator,
|
27
|
-
List,
|
28
|
-
Optional,
|
29
|
-
ParamSpec,
|
30
|
-
Tuple,
|
31
|
-
TypeVar,
|
32
|
-
Union,
|
33
|
-
TypeAlias,
|
34
|
-
overload,
|
35
|
-
)
|
36
|
-
import types
|
37
|
-
import random
|
38
|
-
|
39
|
-
|
40
|
-
from judgeval.common.tracer.constants import _TRACE_FILEPATH_BLOCKLIST
|
41
|
-
|
42
|
-
from judgeval.common.tracer.otel_span_processor import JudgmentSpanProcessor
|
43
|
-
from judgeval.common.tracer.span_processor import SpanProcessorBase
|
44
|
-
from judgeval.common.tracer.trace_manager import TraceManagerClient
|
45
|
-
|
46
|
-
from judgeval.data import Example, Trace, TraceSpan, TraceUsage
|
47
|
-
from judgeval.scorers import APIScorerConfig, BaseScorer
|
48
|
-
from judgeval.data.evaluation_run import EvaluationRun
|
49
|
-
from judgeval.local_eval_queue import LocalEvaluationQueue
|
50
|
-
from judgeval.common.api import JudgmentApiClient
|
51
|
-
from judgeval.common.utils import OptExcInfo, validate_api_key
|
52
|
-
from judgeval.common.logger import judgeval_logger
|
53
|
-
|
54
|
-
from litellm import cost_per_token as _original_cost_per_token # type: ignore
|
55
|
-
from judgeval.common.tracer.providers import (
|
56
|
-
HAS_OPENAI,
|
57
|
-
HAS_TOGETHER,
|
58
|
-
HAS_ANTHROPIC,
|
59
|
-
HAS_GOOGLE_GENAI,
|
60
|
-
HAS_GROQ,
|
61
|
-
ApiClient,
|
62
|
-
)
|
63
|
-
from judgeval.constants import DEFAULT_GPT_MODEL
|
64
|
-
|
65
|
-
|
66
|
-
current_trace_var = contextvars.ContextVar[Optional["TraceClient"]](
|
67
|
-
"current_trace", default=None
|
68
|
-
)
|
69
|
-
current_span_var = contextvars.ContextVar[Optional[str]]("current_span", default=None)
|
70
|
-
|
71
|
-
|
72
|
-
SpanType: TypeAlias = str
|
73
|
-
|
74
|
-
|
75
|
-
class TraceClient:
|
76
|
-
"""Client for managing a single trace context"""
|
77
|
-
|
78
|
-
def __init__(
|
79
|
-
self,
|
80
|
-
tracer: Tracer,
|
81
|
-
trace_id: Optional[str] = None,
|
82
|
-
name: str = "default",
|
83
|
-
project_name: Union[str, None] = None,
|
84
|
-
enable_monitoring: bool = True,
|
85
|
-
enable_evaluations: bool = True,
|
86
|
-
parent_trace_id: Optional[str] = None,
|
87
|
-
parent_name: Optional[str] = None,
|
88
|
-
):
|
89
|
-
self.name = name
|
90
|
-
self.trace_id = trace_id or str(uuid.uuid4())
|
91
|
-
self.project_name = project_name or "default_project"
|
92
|
-
self.tracer = tracer
|
93
|
-
self.enable_monitoring = enable_monitoring
|
94
|
-
self.enable_evaluations = enable_evaluations
|
95
|
-
self.parent_trace_id = parent_trace_id
|
96
|
-
self.parent_name = parent_name
|
97
|
-
self.customer_id: Optional[str] = None
|
98
|
-
self.tags: List[Union[str, set, tuple]] = []
|
99
|
-
self.metadata: Dict[str, Any] = {}
|
100
|
-
self.has_notification: Optional[bool] = False
|
101
|
-
self.update_id: int = 1
|
102
|
-
self.trace_spans: List[TraceSpan] = []
|
103
|
-
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
104
|
-
self.evaluation_runs: List[EvaluationRun] = []
|
105
|
-
self.start_time: Optional[float] = None
|
106
|
-
self.trace_manager_client = TraceManagerClient(
|
107
|
-
tracer.api_key, tracer.organization_id, tracer
|
108
|
-
)
|
109
|
-
self._span_depths: Dict[str, int] = {}
|
110
|
-
|
111
|
-
self.otel_span_processor = tracer.otel_span_processor
|
112
|
-
|
113
|
-
def get_current_span(self):
|
114
|
-
"""Get the current span from the context var"""
|
115
|
-
return self.tracer.get_current_span()
|
116
|
-
|
117
|
-
def set_current_span(self, span: Any):
|
118
|
-
"""Set the current span from the context var"""
|
119
|
-
return self.tracer.set_current_span(span)
|
120
|
-
|
121
|
-
def reset_current_span(self, token: Any):
|
122
|
-
"""Reset the current span from the context var"""
|
123
|
-
self.tracer.reset_current_span(token)
|
124
|
-
|
125
|
-
@contextmanager
|
126
|
-
def span(self, name: str, span_type: SpanType = "span"):
|
127
|
-
"""Context manager for creating a trace span, managing the current span via contextvars"""
|
128
|
-
is_first_span = len(self.trace_spans) == 0
|
129
|
-
if is_first_span:
|
130
|
-
try:
|
131
|
-
self.save(final_save=False)
|
132
|
-
except Exception as e:
|
133
|
-
judgeval_logger.warning(
|
134
|
-
f"Failed to save initial trace for live tracking: {e}"
|
135
|
-
)
|
136
|
-
start_time = time.time()
|
137
|
-
|
138
|
-
span_id = str(uuid.uuid4())
|
139
|
-
|
140
|
-
parent_span_id = self.get_current_span()
|
141
|
-
token = self.set_current_span(span_id)
|
142
|
-
|
143
|
-
current_depth = 0
|
144
|
-
if parent_span_id and parent_span_id in self._span_depths:
|
145
|
-
current_depth = self._span_depths[parent_span_id] + 1
|
146
|
-
|
147
|
-
self._span_depths[span_id] = current_depth
|
148
|
-
|
149
|
-
span = TraceSpan(
|
150
|
-
span_id=span_id,
|
151
|
-
trace_id=self.trace_id,
|
152
|
-
depth=current_depth,
|
153
|
-
message=name,
|
154
|
-
created_at=start_time,
|
155
|
-
span_type=span_type,
|
156
|
-
parent_span_id=parent_span_id,
|
157
|
-
function=name,
|
158
|
-
)
|
159
|
-
self.add_span(span)
|
160
|
-
|
161
|
-
self.otel_span_processor.queue_span_update(span, span_state="input")
|
162
|
-
|
163
|
-
try:
|
164
|
-
yield self
|
165
|
-
finally:
|
166
|
-
duration = time.time() - start_time
|
167
|
-
span.duration = duration
|
168
|
-
|
169
|
-
self.otel_span_processor.queue_span_update(span, span_state="completed")
|
170
|
-
|
171
|
-
if span_id in self._span_depths:
|
172
|
-
del self._span_depths[span_id]
|
173
|
-
self.reset_current_span(token)
|
174
|
-
|
175
|
-
def async_evaluate(
|
176
|
-
self,
|
177
|
-
scorer: Union[APIScorerConfig, BaseScorer],
|
178
|
-
example: Example,
|
179
|
-
model: str = DEFAULT_GPT_MODEL,
|
180
|
-
):
|
181
|
-
start_time = time.time()
|
182
|
-
span_id = self.get_current_span()
|
183
|
-
eval_run_name = (
|
184
|
-
f"{self.name.capitalize()}-{span_id}-{scorer.score_type.capitalize()}"
|
185
|
-
)
|
186
|
-
hosted_scoring = isinstance(scorer, APIScorerConfig) or (
|
187
|
-
isinstance(scorer, BaseScorer) and scorer.server_hosted
|
188
|
-
)
|
189
|
-
if hosted_scoring:
|
190
|
-
eval_run = EvaluationRun(
|
191
|
-
organization_id=self.tracer.organization_id,
|
192
|
-
project_name=self.project_name,
|
193
|
-
eval_name=eval_run_name,
|
194
|
-
examples=[example],
|
195
|
-
scorers=[scorer],
|
196
|
-
model=model,
|
197
|
-
trace_span_id=span_id,
|
198
|
-
)
|
199
|
-
|
200
|
-
self.add_eval_run(eval_run, start_time)
|
201
|
-
|
202
|
-
if span_id:
|
203
|
-
current_span = self.span_id_to_span.get(span_id)
|
204
|
-
if current_span:
|
205
|
-
self.otel_span_processor.queue_evaluation_run(
|
206
|
-
eval_run, span_id=span_id, span_data=current_span
|
207
|
-
)
|
208
|
-
else:
|
209
|
-
# Handle custom scorers using local evaluation queue
|
210
|
-
eval_run = EvaluationRun(
|
211
|
-
organization_id=self.tracer.organization_id,
|
212
|
-
project_name=self.project_name,
|
213
|
-
eval_name=eval_run_name,
|
214
|
-
examples=[example],
|
215
|
-
scorers=[scorer],
|
216
|
-
model=model,
|
217
|
-
trace_span_id=span_id,
|
218
|
-
)
|
219
|
-
|
220
|
-
self.add_eval_run(eval_run, start_time)
|
221
|
-
|
222
|
-
# Enqueue the evaluation run to the local evaluation queue
|
223
|
-
self.tracer.local_eval_queue.enqueue(eval_run)
|
224
|
-
|
225
|
-
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
226
|
-
current_span_id = eval_run.trace_span_id
|
227
|
-
|
228
|
-
if current_span_id:
|
229
|
-
span = self.span_id_to_span[current_span_id]
|
230
|
-
span.has_evaluation = True
|
231
|
-
self.evaluation_runs.append(eval_run)
|
232
|
-
|
233
|
-
def record_input(self, inputs: dict):
|
234
|
-
current_span_id = self.get_current_span()
|
235
|
-
if current_span_id:
|
236
|
-
span = self.span_id_to_span[current_span_id]
|
237
|
-
if "self" in inputs:
|
238
|
-
del inputs["self"]
|
239
|
-
span.inputs = inputs
|
240
|
-
|
241
|
-
try:
|
242
|
-
self.otel_span_processor.queue_span_update(span, span_state="input")
|
243
|
-
except Exception as e:
|
244
|
-
judgeval_logger.warning(f"Failed to queue span with input data: {e}")
|
245
|
-
|
246
|
-
def record_agent_name(self, agent_name: str):
|
247
|
-
current_span_id = self.get_current_span()
|
248
|
-
if current_span_id:
|
249
|
-
span = self.span_id_to_span[current_span_id]
|
250
|
-
span.agent_name = agent_name
|
251
|
-
|
252
|
-
self.otel_span_processor.queue_span_update(span, span_state="agent_name")
|
253
|
-
|
254
|
-
def record_class_name(self, class_name: str):
|
255
|
-
current_span_id = self.get_current_span()
|
256
|
-
if current_span_id:
|
257
|
-
span = self.span_id_to_span[current_span_id]
|
258
|
-
span.class_name = class_name
|
259
|
-
|
260
|
-
self.otel_span_processor.queue_span_update(span, span_state="class_name")
|
261
|
-
|
262
|
-
def record_state_before(self, state: dict):
|
263
|
-
"""Records the agent's state before a tool execution on the current span.
|
264
|
-
|
265
|
-
Args:
|
266
|
-
state: A dictionary representing the agent's state.
|
267
|
-
"""
|
268
|
-
current_span_id = self.get_current_span()
|
269
|
-
if current_span_id:
|
270
|
-
span = self.span_id_to_span[current_span_id]
|
271
|
-
span.state_before = state
|
272
|
-
|
273
|
-
self.otel_span_processor.queue_span_update(span, span_state="state_before")
|
274
|
-
|
275
|
-
def record_state_after(self, state: dict):
|
276
|
-
"""Records the agent's state after a tool execution on the current span.
|
277
|
-
|
278
|
-
Args:
|
279
|
-
state: A dictionary representing the agent's state.
|
280
|
-
"""
|
281
|
-
current_span_id = self.get_current_span()
|
282
|
-
if current_span_id:
|
283
|
-
span = self.span_id_to_span[current_span_id]
|
284
|
-
span.state_after = state
|
285
|
-
|
286
|
-
self.otel_span_processor.queue_span_update(span, span_state="state_after")
|
287
|
-
|
288
|
-
def record_output(self, output: Any):
|
289
|
-
current_span_id = self.get_current_span()
|
290
|
-
if current_span_id:
|
291
|
-
span = self.span_id_to_span[current_span_id]
|
292
|
-
span.output = output
|
293
|
-
|
294
|
-
self.otel_span_processor.queue_span_update(span, span_state="output")
|
295
|
-
|
296
|
-
return span
|
297
|
-
return None
|
298
|
-
|
299
|
-
def record_usage(self, usage: TraceUsage):
|
300
|
-
current_span_id = self.get_current_span()
|
301
|
-
if current_span_id:
|
302
|
-
span = self.span_id_to_span[current_span_id]
|
303
|
-
span.usage = usage
|
304
|
-
|
305
|
-
self.otel_span_processor.queue_span_update(span, span_state="usage")
|
306
|
-
|
307
|
-
return span
|
308
|
-
return None
|
309
|
-
|
310
|
-
def record_error(self, error: Dict[str, Any]):
|
311
|
-
current_span_id = self.get_current_span()
|
312
|
-
if current_span_id:
|
313
|
-
span = self.span_id_to_span[current_span_id]
|
314
|
-
span.error = error
|
315
|
-
|
316
|
-
self.otel_span_processor.queue_span_update(span, span_state="error")
|
317
|
-
|
318
|
-
return span
|
319
|
-
return None
|
320
|
-
|
321
|
-
def add_span(self, span: TraceSpan):
|
322
|
-
"""Add a trace span to this trace context"""
|
323
|
-
self.trace_spans.append(span)
|
324
|
-
self.span_id_to_span[span.span_id] = span
|
325
|
-
return self
|
326
|
-
|
327
|
-
def print(self):
|
328
|
-
"""Print the complete trace with proper visual structure"""
|
329
|
-
for span in self.trace_spans:
|
330
|
-
span.print_span()
|
331
|
-
|
332
|
-
def get_duration(self) -> float:
|
333
|
-
"""
|
334
|
-
Get the total duration of this trace
|
335
|
-
"""
|
336
|
-
if self.start_time is None:
|
337
|
-
return 0.0
|
338
|
-
return time.time() - self.start_time
|
339
|
-
|
340
|
-
def save(self, final_save: bool = False) -> Tuple[str, dict]:
|
341
|
-
"""
|
342
|
-
Save the current trace to the database with rate limiting checks.
|
343
|
-
First checks usage limits, then upserts the trace if allowed.
|
344
|
-
|
345
|
-
Args:
|
346
|
-
final_save: Whether this is the final save (updates usage counters)
|
347
|
-
|
348
|
-
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
349
|
-
"""
|
350
|
-
if final_save:
|
351
|
-
try:
|
352
|
-
self.otel_span_processor.flush_pending_spans()
|
353
|
-
except Exception as e:
|
354
|
-
judgeval_logger.warning(
|
355
|
-
f"Error flushing spans for trace {self.trace_id}: {e}"
|
356
|
-
)
|
357
|
-
|
358
|
-
total_duration = self.get_duration()
|
359
|
-
|
360
|
-
trace_data = {
|
361
|
-
"trace_id": self.trace_id,
|
362
|
-
"name": self.name,
|
363
|
-
"project_name": self.project_name,
|
364
|
-
"created_at": datetime.fromtimestamp(
|
365
|
-
self.start_time or time.time(), timezone.utc
|
366
|
-
).isoformat(),
|
367
|
-
"duration": total_duration,
|
368
|
-
"trace_spans": [span.model_dump() for span in self.trace_spans],
|
369
|
-
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
370
|
-
"offline_mode": self.tracer.offline_mode,
|
371
|
-
"parent_trace_id": self.parent_trace_id,
|
372
|
-
"parent_name": self.parent_name,
|
373
|
-
"customer_id": self.customer_id,
|
374
|
-
"tags": self.tags,
|
375
|
-
"metadata": self.metadata,
|
376
|
-
"update_id": self.update_id,
|
377
|
-
}
|
378
|
-
|
379
|
-
server_response = self.trace_manager_client.upsert_trace(
|
380
|
-
trace_data,
|
381
|
-
offline_mode=self.tracer.offline_mode,
|
382
|
-
show_link=not final_save,
|
383
|
-
final_save=final_save,
|
384
|
-
)
|
385
|
-
|
386
|
-
if self.start_time is None:
|
387
|
-
self.start_time = time.time()
|
388
|
-
|
389
|
-
self.update_id += 1
|
390
|
-
|
391
|
-
return self.trace_id, server_response
|
392
|
-
|
393
|
-
def delete(self):
|
394
|
-
return self.trace_manager_client.delete_trace(self.trace_id)
|
395
|
-
|
396
|
-
def update_metadata(self, metadata: dict):
|
397
|
-
"""
|
398
|
-
Set metadata for this trace.
|
399
|
-
|
400
|
-
Args:
|
401
|
-
metadata: Metadata as a dictionary
|
402
|
-
|
403
|
-
Supported keys:
|
404
|
-
- customer_id: ID of the customer using this trace
|
405
|
-
- tags: List of tags for this trace
|
406
|
-
- has_notification: Whether this trace has a notification
|
407
|
-
- name: Name of the trace
|
408
|
-
"""
|
409
|
-
for k, v in metadata.items():
|
410
|
-
if k == "customer_id":
|
411
|
-
if v is not None:
|
412
|
-
self.customer_id = str(v)
|
413
|
-
else:
|
414
|
-
self.customer_id = None
|
415
|
-
elif k == "tags":
|
416
|
-
if isinstance(v, list):
|
417
|
-
for item in v:
|
418
|
-
if not isinstance(item, (str, set, tuple)):
|
419
|
-
raise ValueError(
|
420
|
-
f"Tags must be a list of strings, sets, or tuples, got item of type {type(item)}"
|
421
|
-
)
|
422
|
-
self.tags = v
|
423
|
-
else:
|
424
|
-
raise ValueError(
|
425
|
-
f"Tags must be a list of strings, sets, or tuples, got {type(v)}"
|
426
|
-
)
|
427
|
-
elif k == "has_notification":
|
428
|
-
if not isinstance(v, bool):
|
429
|
-
raise ValueError(
|
430
|
-
f"has_notification must be a boolean, got {type(v)}"
|
431
|
-
)
|
432
|
-
self.has_notification = v
|
433
|
-
elif k == "name":
|
434
|
-
self.name = v
|
435
|
-
else:
|
436
|
-
self.metadata[k] = v
|
437
|
-
|
438
|
-
def set_customer_id(self, customer_id: str):
|
439
|
-
"""
|
440
|
-
Set the customer ID for this trace.
|
441
|
-
|
442
|
-
Args:
|
443
|
-
customer_id: The customer ID to set
|
444
|
-
"""
|
445
|
-
self.update_metadata({"customer_id": customer_id})
|
446
|
-
|
447
|
-
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
448
|
-
"""
|
449
|
-
Set the tags for this trace.
|
450
|
-
|
451
|
-
Args:
|
452
|
-
tags: List of tags to set
|
453
|
-
"""
|
454
|
-
self.update_metadata({"tags": tags})
|
455
|
-
|
456
|
-
def set_reward_score(self, reward_score: Union[float, Dict[str, float]]):
|
457
|
-
"""
|
458
|
-
Set the reward score for this trace to be used for RL or SFT.
|
459
|
-
|
460
|
-
Args:
|
461
|
-
reward_score: The reward score to set
|
462
|
-
"""
|
463
|
-
self.update_metadata({"reward_score": reward_score})
|
464
|
-
|
465
|
-
|
466
|
-
def _capture_exception_for_trace(
|
467
|
-
current_trace: Optional[TraceClient], exc_info: OptExcInfo
|
468
|
-
):
|
469
|
-
if not current_trace:
|
470
|
-
return
|
471
|
-
|
472
|
-
exc_type, exc_value, exc_traceback_obj = exc_info
|
473
|
-
formatted_exception = {
|
474
|
-
"type": exc_type.__name__ if exc_type else "UnknownExceptionType",
|
475
|
-
"message": str(exc_value) if exc_value else "No exception message",
|
476
|
-
"traceback": (
|
477
|
-
traceback.format_tb(exc_traceback_obj) if exc_traceback_obj else []
|
478
|
-
),
|
479
|
-
}
|
480
|
-
|
481
|
-
# This is where we specially handle exceptions that we might want to collect additional data for.
|
482
|
-
# When we do this, always try checking the module from sys.modules instead of importing. This will
|
483
|
-
# Let us support a wider range of exceptions without needing to import them for all clients.
|
484
|
-
|
485
|
-
# Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
|
486
|
-
# The alternative is to hand select libraries we want from sys.modules and check for them:
|
487
|
-
# As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
|
488
|
-
|
489
|
-
# General HTTP Like errors
|
490
|
-
try:
|
491
|
-
url = getattr(getattr(exc_value, "request", None), "url", None)
|
492
|
-
status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
|
493
|
-
if status_code:
|
494
|
-
formatted_exception["http"] = {
|
495
|
-
"url": url if url else "Unknown URL",
|
496
|
-
"status_code": status_code if status_code else None,
|
497
|
-
}
|
498
|
-
except Exception:
|
499
|
-
pass
|
500
|
-
|
501
|
-
current_trace.record_error(formatted_exception)
|
502
|
-
|
503
|
-
|
504
|
-
class _DeepTracer:
|
505
|
-
_instance: Optional["_DeepTracer"] = None
|
506
|
-
_lock: threading.Lock = threading.Lock()
|
507
|
-
_refcount: int = 0
|
508
|
-
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
|
509
|
-
"_deep_profiler_span_stack", default=[]
|
510
|
-
)
|
511
|
-
_skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar(
|
512
|
-
"_deep_profiler_skip_stack", default=[]
|
513
|
-
)
|
514
|
-
_original_sys_trace: Optional[Callable] = None
|
515
|
-
_original_threading_trace: Optional[Callable] = None
|
516
|
-
|
517
|
-
def __init__(self, tracer: "Tracer"):
|
518
|
-
self._tracer = tracer
|
519
|
-
|
520
|
-
def _get_qual_name(self, frame) -> str:
|
521
|
-
func_name = frame.f_code.co_name
|
522
|
-
module_name = frame.f_globals.get("__name__", "unknown_module")
|
523
|
-
|
524
|
-
try:
|
525
|
-
func = frame.f_globals.get(func_name)
|
526
|
-
if func is None:
|
527
|
-
return f"{module_name}.{func_name}"
|
528
|
-
if hasattr(func, "__qualname__"):
|
529
|
-
return f"{module_name}.{func.__qualname__}"
|
530
|
-
return f"{module_name}.{func_name}"
|
531
|
-
except Exception:
|
532
|
-
return f"{module_name}.{func_name}"
|
533
|
-
|
534
|
-
def __new__(cls, tracer: "Tracer"):
|
535
|
-
with cls._lock:
|
536
|
-
if cls._instance is None:
|
537
|
-
cls._instance = super().__new__(cls)
|
538
|
-
return cls._instance
|
539
|
-
|
540
|
-
def _should_trace(self, frame):
|
541
|
-
# Skip stack is maintained by the tracer as an optimization to skip earlier
|
542
|
-
# frames in the call stack that we've already determined should be skipped
|
543
|
-
skip_stack = self._skip_stack.get()
|
544
|
-
if len(skip_stack) > 0:
|
545
|
-
return False
|
546
|
-
|
547
|
-
func_name = frame.f_code.co_name
|
548
|
-
module_name = frame.f_globals.get("__name__", None)
|
549
|
-
func = frame.f_globals.get(func_name)
|
550
|
-
if func and (
|
551
|
-
hasattr(func, "_judgment_span_name") or hasattr(func, "_judgment_span_type")
|
552
|
-
):
|
553
|
-
return False
|
554
|
-
|
555
|
-
if (
|
556
|
-
not module_name
|
557
|
-
or func_name.startswith("<") # ex: <listcomp>
|
558
|
-
or func_name.startswith("__")
|
559
|
-
and func_name != "__call__" # dunders
|
560
|
-
or not self._is_user_code(frame.f_code.co_filename)
|
561
|
-
):
|
562
|
-
return False
|
563
|
-
|
564
|
-
return True
|
565
|
-
|
566
|
-
@functools.cache
|
567
|
-
def _is_user_code(self, filename: str):
|
568
|
-
return (
|
569
|
-
bool(filename)
|
570
|
-
and not filename.startswith("<")
|
571
|
-
and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
|
572
|
-
)
|
573
|
-
|
574
|
-
def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
|
575
|
-
"""Cooperative trace function for sys.settrace that chains with existing tracers."""
|
576
|
-
# First, call the original sys trace function if it exists
|
577
|
-
original_result = None
|
578
|
-
if self._original_sys_trace:
|
579
|
-
try:
|
580
|
-
original_result = self._original_sys_trace(frame, event, arg)
|
581
|
-
except Exception:
|
582
|
-
pass
|
583
|
-
|
584
|
-
our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
|
585
|
-
|
586
|
-
if original_result is None and self._original_sys_trace:
|
587
|
-
return None
|
588
|
-
|
589
|
-
return our_result or original_result
|
590
|
-
|
591
|
-
def _cooperative_threading_trace(
|
592
|
-
self, frame: types.FrameType, event: str, arg: Any
|
593
|
-
):
|
594
|
-
"""Cooperative trace function for threading.settrace that chains with existing tracers."""
|
595
|
-
original_result = None
|
596
|
-
if self._original_threading_trace:
|
597
|
-
try:
|
598
|
-
original_result = self._original_threading_trace(frame, event, arg)
|
599
|
-
except Exception:
|
600
|
-
pass
|
601
|
-
|
602
|
-
our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
|
603
|
-
|
604
|
-
if original_result is None and self._original_threading_trace:
|
605
|
-
return None
|
606
|
-
|
607
|
-
return our_result or original_result
|
608
|
-
|
609
|
-
def _trace(
|
610
|
-
self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable
|
611
|
-
):
|
612
|
-
frame.f_trace_lines = False
|
613
|
-
frame.f_trace_opcodes = False
|
614
|
-
|
615
|
-
if not self._should_trace(frame):
|
616
|
-
return
|
617
|
-
|
618
|
-
if event not in ("call", "return", "exception"):
|
619
|
-
return
|
620
|
-
|
621
|
-
current_trace = self._tracer.get_current_trace()
|
622
|
-
if not current_trace:
|
623
|
-
return
|
624
|
-
|
625
|
-
parent_span_id = self._tracer.get_current_span()
|
626
|
-
if not parent_span_id:
|
627
|
-
return
|
628
|
-
|
629
|
-
qual_name = self._get_qual_name(frame)
|
630
|
-
instance_name = None
|
631
|
-
class_name = None
|
632
|
-
if "self" in frame.f_locals:
|
633
|
-
instance = frame.f_locals["self"]
|
634
|
-
class_name = instance.__class__.__name__
|
635
|
-
class_identifiers = getattr(self._tracer, "class_identifiers", {})
|
636
|
-
instance_name = get_instance_prefixed_name(
|
637
|
-
instance, class_name, class_identifiers
|
638
|
-
)
|
639
|
-
skip_stack = self._skip_stack.get()
|
640
|
-
|
641
|
-
if event == "call":
|
642
|
-
# If we have entries in the skip stack and the current qual_name matches the top entry,
|
643
|
-
# push it again to track nesting depth and skip
|
644
|
-
# As an optimization, we only care about duplicate qual_names.
|
645
|
-
if skip_stack:
|
646
|
-
if qual_name == skip_stack[-1]:
|
647
|
-
skip_stack.append(qual_name)
|
648
|
-
self._skip_stack.set(skip_stack)
|
649
|
-
return
|
650
|
-
|
651
|
-
should_trace = self._should_trace(frame)
|
652
|
-
|
653
|
-
if not should_trace:
|
654
|
-
if not skip_stack:
|
655
|
-
self._skip_stack.set([qual_name])
|
656
|
-
return
|
657
|
-
elif event == "return":
|
658
|
-
# If we have entries in skip stack and current qual_name matches the top entry,
|
659
|
-
# pop it to track exiting from the skipped section
|
660
|
-
if skip_stack and qual_name == skip_stack[-1]:
|
661
|
-
skip_stack.pop()
|
662
|
-
self._skip_stack.set(skip_stack)
|
663
|
-
return
|
664
|
-
|
665
|
-
if skip_stack:
|
666
|
-
return
|
667
|
-
|
668
|
-
span_stack = self._span_stack.get()
|
669
|
-
if event == "call":
|
670
|
-
if not self._should_trace(frame):
|
671
|
-
return
|
672
|
-
|
673
|
-
span_id = str(uuid.uuid4())
|
674
|
-
|
675
|
-
parent_depth = current_trace._span_depths.get(parent_span_id, 0)
|
676
|
-
depth = parent_depth + 1
|
677
|
-
|
678
|
-
current_trace._span_depths[span_id] = depth
|
679
|
-
|
680
|
-
start_time = time.time()
|
681
|
-
|
682
|
-
span_stack.append(
|
683
|
-
{
|
684
|
-
"span_id": span_id,
|
685
|
-
"parent_span_id": parent_span_id,
|
686
|
-
"function": qual_name,
|
687
|
-
"start_time": start_time,
|
688
|
-
}
|
689
|
-
)
|
690
|
-
self._span_stack.set(span_stack)
|
691
|
-
|
692
|
-
token = self._tracer.set_current_span(span_id)
|
693
|
-
frame.f_locals["_judgment_span_token"] = token
|
694
|
-
|
695
|
-
span = TraceSpan(
|
696
|
-
span_id=span_id,
|
697
|
-
trace_id=current_trace.trace_id,
|
698
|
-
depth=depth,
|
699
|
-
message=qual_name,
|
700
|
-
created_at=start_time,
|
701
|
-
span_type="span",
|
702
|
-
parent_span_id=parent_span_id,
|
703
|
-
function=qual_name,
|
704
|
-
agent_name=instance_name,
|
705
|
-
class_name=class_name,
|
706
|
-
)
|
707
|
-
current_trace.add_span(span)
|
708
|
-
|
709
|
-
inputs = {}
|
710
|
-
try:
|
711
|
-
args_info = inspect.getargvalues(frame)
|
712
|
-
for arg in args_info.args:
|
713
|
-
try:
|
714
|
-
inputs[arg] = args_info.locals.get(arg)
|
715
|
-
except Exception:
|
716
|
-
inputs[arg] = "<<Unserializable>>"
|
717
|
-
current_trace.record_input(inputs)
|
718
|
-
except Exception as e:
|
719
|
-
current_trace.record_input({"error": str(e)})
|
720
|
-
|
721
|
-
elif event == "return":
|
722
|
-
if not span_stack:
|
723
|
-
return
|
724
|
-
|
725
|
-
current_id = self._tracer.get_current_span()
|
726
|
-
|
727
|
-
span_data = None
|
728
|
-
for i, entry in enumerate(reversed(span_stack)):
|
729
|
-
if entry["span_id"] == current_id:
|
730
|
-
span_data = span_stack.pop(-(i + 1))
|
731
|
-
self._span_stack.set(span_stack)
|
732
|
-
break
|
733
|
-
|
734
|
-
if not span_data:
|
735
|
-
return
|
736
|
-
|
737
|
-
start_time = span_data["start_time"]
|
738
|
-
duration = time.time() - start_time
|
739
|
-
|
740
|
-
current_trace.span_id_to_span[span_data["span_id"]].duration = duration
|
741
|
-
|
742
|
-
if arg is not None:
|
743
|
-
# exception handling will take priority.
|
744
|
-
current_trace.record_output(arg)
|
745
|
-
|
746
|
-
if span_data["span_id"] in current_trace._span_depths:
|
747
|
-
del current_trace._span_depths[span_data["span_id"]]
|
748
|
-
|
749
|
-
if span_stack:
|
750
|
-
self._tracer.set_current_span(span_stack[-1]["span_id"])
|
751
|
-
else:
|
752
|
-
self._tracer.set_current_span(span_data["parent_span_id"])
|
753
|
-
|
754
|
-
if "_judgment_span_token" in frame.f_locals:
|
755
|
-
self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
|
756
|
-
|
757
|
-
elif event == "exception":
|
758
|
-
exc_type = arg[0]
|
759
|
-
if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
|
760
|
-
return
|
761
|
-
_capture_exception_for_trace(current_trace, arg)
|
762
|
-
|
763
|
-
return continuation_func
|
764
|
-
|
765
|
-
def __enter__(self):
|
766
|
-
with self._lock:
|
767
|
-
self._refcount += 1
|
768
|
-
if self._refcount == 1:
|
769
|
-
# Store the existing trace functions before setting ours
|
770
|
-
self._original_sys_trace = sys.gettrace()
|
771
|
-
self._original_threading_trace = threading.gettrace()
|
772
|
-
|
773
|
-
self._skip_stack.set([])
|
774
|
-
self._span_stack.set([])
|
775
|
-
|
776
|
-
sys.settrace(self._cooperative_sys_trace)
|
777
|
-
threading.settrace(self._cooperative_threading_trace)
|
778
|
-
return self
|
779
|
-
|
780
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
781
|
-
with self._lock:
|
782
|
-
self._refcount -= 1
|
783
|
-
if self._refcount == 0:
|
784
|
-
# Restore the original trace functions instead of setting to None
|
785
|
-
sys.settrace(self._original_sys_trace)
|
786
|
-
threading.settrace(self._original_threading_trace)
|
787
|
-
|
788
|
-
# Clean up the references
|
789
|
-
self._original_sys_trace = None
|
790
|
-
self._original_threading_trace = None
|
791
|
-
|
792
|
-
|
793
|
-
T = TypeVar("T", bound=Callable[..., Any])
|
794
|
-
P = ParamSpec("P")
|
795
|
-
|
796
|
-
|
797
|
-
class Tracer:
|
798
|
-
# Tracer.current_trace class variable is currently used in wrap()
|
799
|
-
# TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
|
800
|
-
# Should be fine to do so as long as we keep Tracer as a singleton
|
801
|
-
current_trace: Optional[TraceClient] = None
|
802
|
-
# current_span_id: Optional[str] = None
|
803
|
-
|
804
|
-
trace_across_async_contexts: bool = (
|
805
|
-
False # BY default, we don't trace across async contexts
|
806
|
-
)
|
807
|
-
|
808
|
-
def __init__(
|
809
|
-
self,
|
810
|
-
api_key: Union[str, None] = os.getenv("JUDGMENT_API_KEY"),
|
811
|
-
organization_id: Union[str, None] = os.getenv("JUDGMENT_ORG_ID"),
|
812
|
-
project_name: Union[str, None] = None,
|
813
|
-
deep_tracing: bool = False, # Deep tracing is disabled by default
|
814
|
-
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
|
815
|
-
== "true",
|
816
|
-
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
|
817
|
-
== "true",
|
818
|
-
show_trace_urls: bool = os.getenv("JUDGMENT_SHOW_TRACE_URLS", "true").lower()
|
819
|
-
== "true",
|
820
|
-
# S3 configuration
|
821
|
-
use_s3: bool = False,
|
822
|
-
s3_bucket_name: Optional[str] = None,
|
823
|
-
s3_aws_access_key_id: Optional[str] = None,
|
824
|
-
s3_aws_secret_access_key: Optional[str] = None,
|
825
|
-
s3_region_name: Optional[str] = None,
|
826
|
-
trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
|
827
|
-
span_batch_size: int = 50,
|
828
|
-
span_flush_interval: float = 1.0,
|
829
|
-
span_max_queue_size: int = 2048,
|
830
|
-
span_export_timeout: int = 30000,
|
831
|
-
):
|
832
|
-
try:
|
833
|
-
if not api_key:
|
834
|
-
raise ValueError(
|
835
|
-
"api_key parameter must be provided. Please provide a valid API key value or set the JUDGMENT_API_KEY environment variable"
|
836
|
-
)
|
837
|
-
|
838
|
-
if not organization_id:
|
839
|
-
raise ValueError(
|
840
|
-
"organization_id parameter must be provided. Please provide a valid organization ID value or set the JUDGMENT_ORG_ID environment variable"
|
841
|
-
)
|
842
|
-
|
843
|
-
try:
|
844
|
-
result, response = validate_api_key(api_key)
|
845
|
-
except Exception as e:
|
846
|
-
judgeval_logger.error(
|
847
|
-
f"Issue with verifying API key, disabling monitoring: {e}"
|
848
|
-
)
|
849
|
-
enable_monitoring = False
|
850
|
-
result = True
|
851
|
-
|
852
|
-
if not result:
|
853
|
-
raise ValueError(f"Issue with passed in Judgment API key: {response}")
|
854
|
-
|
855
|
-
if use_s3 and not s3_bucket_name:
|
856
|
-
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
857
|
-
|
858
|
-
self.api_key: str = api_key
|
859
|
-
self.project_name: str = project_name or "default_project"
|
860
|
-
self.organization_id: str = organization_id
|
861
|
-
self.traces: List[Trace] = []
|
862
|
-
self.enable_monitoring: bool = enable_monitoring
|
863
|
-
self.enable_evaluations: bool = enable_evaluations
|
864
|
-
self.show_trace_urls: bool = show_trace_urls
|
865
|
-
self.class_identifiers: Dict[
|
866
|
-
str, str
|
867
|
-
] = {} # Dictionary to store class identifiers
|
868
|
-
self.span_id_to_previous_span_id: Dict[str, Union[str, None]] = {}
|
869
|
-
self.trace_id_to_previous_trace: Dict[str, Union[TraceClient, None]] = {}
|
870
|
-
self.current_span_id: Optional[str] = None
|
871
|
-
self.current_trace: Optional[TraceClient] = None
|
872
|
-
self.trace_across_async_contexts: bool = trace_across_async_contexts
|
873
|
-
Tracer.trace_across_async_contexts = trace_across_async_contexts
|
874
|
-
|
875
|
-
# Initialize S3 storage if enabled
|
876
|
-
self.use_s3 = use_s3
|
877
|
-
if use_s3:
|
878
|
-
from judgeval.common.storage.s3_storage import S3Storage
|
879
|
-
|
880
|
-
try:
|
881
|
-
self.s3_storage = S3Storage(
|
882
|
-
bucket_name=s3_bucket_name,
|
883
|
-
aws_access_key_id=s3_aws_access_key_id,
|
884
|
-
aws_secret_access_key=s3_aws_secret_access_key,
|
885
|
-
region_name=s3_region_name,
|
886
|
-
)
|
887
|
-
except Exception as e:
|
888
|
-
judgeval_logger.error(
|
889
|
-
f"Issue with initializing S3 storage, disabling S3: {e}"
|
890
|
-
)
|
891
|
-
self.use_s3 = False
|
892
|
-
|
893
|
-
self.offline_mode = False # This is used to differentiate traces between online and offline (IE experiments vs monitoring page)
|
894
|
-
self.deep_tracing: bool = deep_tracing
|
895
|
-
|
896
|
-
self.span_batch_size = span_batch_size
|
897
|
-
self.span_flush_interval = span_flush_interval
|
898
|
-
self.span_max_queue_size = span_max_queue_size
|
899
|
-
self.span_export_timeout = span_export_timeout
|
900
|
-
self.otel_span_processor: SpanProcessorBase
|
901
|
-
if enable_monitoring:
|
902
|
-
self.otel_span_processor = JudgmentSpanProcessor(
|
903
|
-
judgment_api_key=api_key,
|
904
|
-
organization_id=organization_id,
|
905
|
-
batch_size=span_batch_size,
|
906
|
-
flush_interval=span_flush_interval,
|
907
|
-
max_queue_size=span_max_queue_size,
|
908
|
-
export_timeout=span_export_timeout,
|
909
|
-
)
|
910
|
-
else:
|
911
|
-
self.otel_span_processor = SpanProcessorBase()
|
912
|
-
|
913
|
-
# Initialize local evaluation queue for custom scorers
|
914
|
-
self.local_eval_queue = LocalEvaluationQueue()
|
915
|
-
|
916
|
-
# Start workers with callback to log results only if monitoring is enabled
|
917
|
-
if enable_evaluations and enable_monitoring:
|
918
|
-
self.local_eval_queue.start_workers(
|
919
|
-
callback=self._log_eval_results_callback
|
920
|
-
)
|
921
|
-
|
922
|
-
atexit.register(self._cleanup_on_exit)
|
923
|
-
except Exception as e:
|
924
|
-
judgeval_logger.error(
|
925
|
-
f"Issue with initializing Tracer: {e}. Disabling monitoring and evaluations."
|
926
|
-
)
|
927
|
-
self.enable_monitoring = False
|
928
|
-
self.enable_evaluations = False
|
929
|
-
|
930
|
-
def set_current_span(
|
931
|
-
self, span_id: str
|
932
|
-
) -> Optional[contextvars.Token[Union[str, None]]]:
|
933
|
-
self.span_id_to_previous_span_id[span_id] = self.current_span_id
|
934
|
-
self.current_span_id = span_id
|
935
|
-
Tracer.current_span_id = span_id
|
936
|
-
try:
|
937
|
-
token = current_span_var.set(span_id)
|
938
|
-
except Exception:
|
939
|
-
token = None
|
940
|
-
return token
|
941
|
-
|
942
|
-
def get_current_span(self) -> Optional[str]:
|
943
|
-
try:
|
944
|
-
current_span_var_val = current_span_var.get()
|
945
|
-
except Exception:
|
946
|
-
current_span_var_val = None
|
947
|
-
return (
|
948
|
-
(self.current_span_id or current_span_var_val)
|
949
|
-
if self.trace_across_async_contexts
|
950
|
-
else current_span_var_val
|
951
|
-
)
|
952
|
-
|
953
|
-
def reset_current_span(
|
954
|
-
self,
|
955
|
-
token: Optional[contextvars.Token[Union[str, None]]] = None,
|
956
|
-
span_id: Optional[str] = None,
|
957
|
-
):
|
958
|
-
try:
|
959
|
-
if token:
|
960
|
-
current_span_var.reset(token)
|
961
|
-
except Exception:
|
962
|
-
pass
|
963
|
-
if not span_id:
|
964
|
-
span_id = self.current_span_id
|
965
|
-
if span_id:
|
966
|
-
self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
|
967
|
-
Tracer.current_span_id = self.current_span_id
|
968
|
-
|
969
|
-
def set_current_trace(
|
970
|
-
self, trace: TraceClient
|
971
|
-
) -> Optional[contextvars.Token[Union[TraceClient, None]]]:
|
972
|
-
"""
|
973
|
-
Set the current trace context in contextvars
|
974
|
-
"""
|
975
|
-
self.trace_id_to_previous_trace[trace.trace_id] = self.current_trace
|
976
|
-
self.current_trace = trace
|
977
|
-
Tracer.current_trace = trace
|
978
|
-
try:
|
979
|
-
token = current_trace_var.set(trace)
|
980
|
-
except Exception:
|
981
|
-
token = None
|
982
|
-
return token
|
983
|
-
|
984
|
-
def get_current_trace(self) -> Optional[TraceClient]:
|
985
|
-
"""
|
986
|
-
Get the current trace context.
|
987
|
-
|
988
|
-
Tries to get the trace client from the context variable first.
|
989
|
-
If not found (e.g., context lost across threads/tasks),
|
990
|
-
it falls back to the active trace client managed by the callback handler.
|
991
|
-
"""
|
992
|
-
try:
|
993
|
-
current_trace_var_val = current_trace_var.get()
|
994
|
-
except Exception:
|
995
|
-
current_trace_var_val = None
|
996
|
-
return (
|
997
|
-
(self.current_trace or current_trace_var_val)
|
998
|
-
if self.trace_across_async_contexts
|
999
|
-
else current_trace_var_val
|
1000
|
-
)
|
1001
|
-
|
1002
|
-
def reset_current_trace(
|
1003
|
-
self,
|
1004
|
-
token: Optional[contextvars.Token[Union[TraceClient, None]]] = None,
|
1005
|
-
trace_id: Optional[str] = None,
|
1006
|
-
):
|
1007
|
-
try:
|
1008
|
-
if token:
|
1009
|
-
current_trace_var.reset(token)
|
1010
|
-
except Exception:
|
1011
|
-
pass
|
1012
|
-
if not trace_id and self.current_trace:
|
1013
|
-
trace_id = self.current_trace.trace_id
|
1014
|
-
if trace_id:
|
1015
|
-
self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
|
1016
|
-
Tracer.current_trace = self.current_trace
|
1017
|
-
|
1018
|
-
@contextmanager
|
1019
|
-
def trace(
|
1020
|
-
self, name: str, project_name: Union[str, None] = None
|
1021
|
-
) -> Generator[TraceClient, None, None]:
|
1022
|
-
"""Start a new trace context using a context manager"""
|
1023
|
-
trace_id = str(uuid.uuid4())
|
1024
|
-
project = project_name if project_name is not None else self.project_name
|
1025
|
-
|
1026
|
-
# Get parent trace info from context
|
1027
|
-
parent_trace = self.get_current_trace()
|
1028
|
-
parent_trace_id = None
|
1029
|
-
parent_name = None
|
1030
|
-
|
1031
|
-
if parent_trace:
|
1032
|
-
parent_trace_id = parent_trace.trace_id
|
1033
|
-
parent_name = parent_trace.name
|
1034
|
-
|
1035
|
-
trace = TraceClient(
|
1036
|
-
self,
|
1037
|
-
trace_id,
|
1038
|
-
name,
|
1039
|
-
project_name=project,
|
1040
|
-
enable_monitoring=self.enable_monitoring,
|
1041
|
-
enable_evaluations=self.enable_evaluations,
|
1042
|
-
parent_trace_id=parent_trace_id,
|
1043
|
-
parent_name=parent_name,
|
1044
|
-
)
|
1045
|
-
|
1046
|
-
# Set the current trace in context variables
|
1047
|
-
token = self.set_current_trace(trace)
|
1048
|
-
|
1049
|
-
with trace.span(name or "unnamed_trace"):
|
1050
|
-
try:
|
1051
|
-
# Save the trace to the database to handle Evaluations' trace_id referential integrity
|
1052
|
-
yield trace
|
1053
|
-
finally:
|
1054
|
-
# Reset the context variable
|
1055
|
-
self.reset_current_trace(token)
|
1056
|
-
|
1057
|
-
def agent(
|
1058
|
-
self,
|
1059
|
-
identifier: Optional[str] = None,
|
1060
|
-
track_state: Optional[bool] = False,
|
1061
|
-
track_attributes: Optional[List[str]] = None,
|
1062
|
-
field_mappings: Optional[Dict[str, str]] = None,
|
1063
|
-
):
|
1064
|
-
"""
|
1065
|
-
Class decorator that associates a class with a custom identifier and enables state tracking.
|
1066
|
-
|
1067
|
-
This decorator creates a mapping between the class name and the provided
|
1068
|
-
identifier, which can be useful for tagging, grouping, or referencing
|
1069
|
-
classes in a standardized way. It also enables automatic state capture
|
1070
|
-
for instances of the decorated class when used with tracing.
|
1071
|
-
|
1072
|
-
Args:
|
1073
|
-
identifier: The identifier to associate with the decorated class.
|
1074
|
-
This will be used as the instance name in traces.
|
1075
|
-
track_state: Whether to automatically capture the state (attributes)
|
1076
|
-
of instances before and after function execution. Defaults to False.
|
1077
|
-
track_attributes: Optional list of specific attribute names to track.
|
1078
|
-
If None, all non-private attributes (not starting with '_')
|
1079
|
-
will be tracked when track_state=True.
|
1080
|
-
field_mappings: Optional dictionary mapping internal attribute names to
|
1081
|
-
display names in the captured state. For example:
|
1082
|
-
{"system_prompt": "instructions"} will capture the
|
1083
|
-
'instructions' attribute as 'system_prompt' in the state.
|
1084
|
-
|
1085
|
-
Example:
|
1086
|
-
@tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
|
1087
|
-
class User:
|
1088
|
-
# Class implementation
|
1089
|
-
"""
|
1090
|
-
|
1091
|
-
def decorator(cls):
|
1092
|
-
class_name = cls.__name__
|
1093
|
-
self.class_identifiers[class_name] = {
|
1094
|
-
"identifier": identifier,
|
1095
|
-
"track_state": track_state,
|
1096
|
-
"track_attributes": track_attributes,
|
1097
|
-
"field_mappings": field_mappings or {},
|
1098
|
-
"class_name": class_name,
|
1099
|
-
}
|
1100
|
-
return cls
|
1101
|
-
|
1102
|
-
return decorator
|
1103
|
-
|
1104
|
-
def identify(self, *args, **kwargs):
|
1105
|
-
judgeval_logger.warning(
|
1106
|
-
"identify() is deprecated and may not be supported in future versions of judgeval. Use the agent() decorator instead."
|
1107
|
-
)
|
1108
|
-
return self.agent(*args, **kwargs)
|
1109
|
-
|
1110
|
-
def _capture_instance_state(
|
1111
|
-
self, instance: Any, class_config: Dict[str, Any]
|
1112
|
-
) -> Dict[str, Any]:
|
1113
|
-
"""
|
1114
|
-
Capture the state of an instance based on class configuration.
|
1115
|
-
Args:
|
1116
|
-
instance: The instance to capture the state of.
|
1117
|
-
class_config: Configuration dictionary for state capture,
|
1118
|
-
expected to contain 'track_attributes' and 'field_mappings'.
|
1119
|
-
"""
|
1120
|
-
track_attributes = class_config.get("track_attributes")
|
1121
|
-
field_mappings = class_config.get("field_mappings")
|
1122
|
-
|
1123
|
-
if track_attributes:
|
1124
|
-
state = {attr: getattr(instance, attr, None) for attr in track_attributes}
|
1125
|
-
else:
|
1126
|
-
state = {
|
1127
|
-
k: v for k, v in instance.__dict__.items() if not k.startswith("_")
|
1128
|
-
}
|
1129
|
-
|
1130
|
-
if field_mappings:
|
1131
|
-
state["field_mappings"] = field_mappings
|
1132
|
-
|
1133
|
-
return state
|
1134
|
-
|
1135
|
-
def _get_instance_state_if_tracked(self, args):
|
1136
|
-
"""
|
1137
|
-
Extract instance state if the instance should be tracked.
|
1138
|
-
|
1139
|
-
Returns the captured state dict if tracking is enabled, None otherwise.
|
1140
|
-
"""
|
1141
|
-
if args and hasattr(args[0], "__class__"):
|
1142
|
-
instance = args[0]
|
1143
|
-
class_name = instance.__class__.__name__
|
1144
|
-
if (
|
1145
|
-
class_name in self.class_identifiers
|
1146
|
-
and isinstance(self.class_identifiers[class_name], dict)
|
1147
|
-
and self.class_identifiers[class_name].get("track_state", False)
|
1148
|
-
):
|
1149
|
-
return self._capture_instance_state(
|
1150
|
-
instance, self.class_identifiers[class_name]
|
1151
|
-
)
|
1152
|
-
|
1153
|
-
def _conditionally_capture_and_record_state(
|
1154
|
-
self, trace_client_instance: TraceClient, args: tuple, is_before: bool
|
1155
|
-
):
|
1156
|
-
"""Captures instance state if tracked and records it via the trace_client."""
|
1157
|
-
state = self._get_instance_state_if_tracked(args)
|
1158
|
-
if state:
|
1159
|
-
if is_before:
|
1160
|
-
trace_client_instance.record_state_before(state)
|
1161
|
-
else:
|
1162
|
-
trace_client_instance.record_state_after(state)
|
1163
|
-
|
1164
|
-
@overload
|
1165
|
-
def observe(
|
1166
|
-
self, func: T, *, name: Optional[str] = None, span_type: SpanType = "span"
|
1167
|
-
) -> T: ...
|
1168
|
-
|
1169
|
-
@overload
|
1170
|
-
def observe(
|
1171
|
-
self,
|
1172
|
-
*,
|
1173
|
-
name: Optional[str] = None,
|
1174
|
-
span_type: SpanType = "span",
|
1175
|
-
) -> Callable[[T], T]: ...
|
1176
|
-
|
1177
|
-
def observe(
|
1178
|
-
self,
|
1179
|
-
func: Optional[T] = None,
|
1180
|
-
*,
|
1181
|
-
name: Optional[str] = None,
|
1182
|
-
span_type: SpanType = "span",
|
1183
|
-
):
|
1184
|
-
"""
|
1185
|
-
Decorator to trace function execution with detailed entry/exit information.
|
1186
|
-
|
1187
|
-
Args:
|
1188
|
-
func: The function to decorate
|
1189
|
-
name: Optional custom name for the span (defaults to function name)
|
1190
|
-
span_type: Type of span (default "span").
|
1191
|
-
"""
|
1192
|
-
# If monitoring is disabled, return the function as is
|
1193
|
-
try:
|
1194
|
-
if not self.enable_monitoring:
|
1195
|
-
return func if func else lambda f: f
|
1196
|
-
|
1197
|
-
if func is None:
|
1198
|
-
return lambda func: self.observe(
|
1199
|
-
func,
|
1200
|
-
name=name,
|
1201
|
-
span_type=span_type,
|
1202
|
-
)
|
1203
|
-
|
1204
|
-
# Use provided name or fall back to function name
|
1205
|
-
original_span_name = name or func.__name__
|
1206
|
-
|
1207
|
-
# Store custom attributes on the function object
|
1208
|
-
func._judgment_span_name = original_span_name # type: ignore
|
1209
|
-
func._judgment_span_type = span_type # type: ignore
|
1210
|
-
|
1211
|
-
except Exception:
|
1212
|
-
return func
|
1213
|
-
|
1214
|
-
def _record_span_data(span, args, kwargs):
|
1215
|
-
"""Helper function to record inputs, agent info, and state on a span."""
|
1216
|
-
# Get class and agent info
|
1217
|
-
class_name = None
|
1218
|
-
agent_name = None
|
1219
|
-
if args and hasattr(args[0], "__class__"):
|
1220
|
-
class_name = args[0].__class__.__name__
|
1221
|
-
agent_name = get_instance_prefixed_name(
|
1222
|
-
args[0], class_name, self.class_identifiers
|
1223
|
-
)
|
1224
|
-
|
1225
|
-
# Record inputs, agent name, class name
|
1226
|
-
inputs = combine_args_kwargs(func, args, kwargs)
|
1227
|
-
span.record_input(inputs)
|
1228
|
-
if agent_name:
|
1229
|
-
span.record_agent_name(agent_name)
|
1230
|
-
if class_name and class_name in self.class_identifiers:
|
1231
|
-
span.record_class_name(class_name)
|
1232
|
-
|
1233
|
-
# Capture state before execution
|
1234
|
-
self._conditionally_capture_and_record_state(span, args, is_before=True)
|
1235
|
-
|
1236
|
-
return class_name, agent_name
|
1237
|
-
|
1238
|
-
def _finalize_span_data(span, result, args):
|
1239
|
-
"""Helper function to record outputs and final state on a span."""
|
1240
|
-
# Record output
|
1241
|
-
span.record_output(result)
|
1242
|
-
|
1243
|
-
# Capture state after execution
|
1244
|
-
self._conditionally_capture_and_record_state(span, args, is_before=False)
|
1245
|
-
|
1246
|
-
def _cleanup_trace(current_trace, trace_token, wrapper_type="function"):
|
1247
|
-
"""Helper function to handle trace cleanup in finally blocks."""
|
1248
|
-
try:
|
1249
|
-
trace_id, server_response = current_trace.save(final_save=True)
|
1250
|
-
|
1251
|
-
complete_trace_data = {
|
1252
|
-
"trace_id": current_trace.trace_id,
|
1253
|
-
"name": current_trace.name,
|
1254
|
-
"project_name": current_trace.project_name,
|
1255
|
-
"created_at": datetime.fromtimestamp(
|
1256
|
-
current_trace.start_time or time.time(),
|
1257
|
-
timezone.utc,
|
1258
|
-
).isoformat(),
|
1259
|
-
"duration": current_trace.get_duration(),
|
1260
|
-
"trace_spans": [
|
1261
|
-
span.model_dump() for span in current_trace.trace_spans
|
1262
|
-
],
|
1263
|
-
"evaluation_runs": [
|
1264
|
-
run.model_dump() for run in current_trace.evaluation_runs
|
1265
|
-
],
|
1266
|
-
"offline_mode": self.offline_mode,
|
1267
|
-
"parent_trace_id": current_trace.parent_trace_id,
|
1268
|
-
"parent_name": current_trace.parent_name,
|
1269
|
-
"customer_id": current_trace.customer_id,
|
1270
|
-
"tags": current_trace.tags,
|
1271
|
-
"metadata": current_trace.metadata,
|
1272
|
-
"update_id": current_trace.update_id,
|
1273
|
-
}
|
1274
|
-
self.traces.append(complete_trace_data)
|
1275
|
-
self.reset_current_trace(trace_token)
|
1276
|
-
except Exception as e:
|
1277
|
-
judgeval_logger.warning(f"Issue with {wrapper_type} cleanup: {e}")
|
1278
|
-
|
1279
|
-
def _execute_in_span(
|
1280
|
-
current_trace, span_name, span_type, execution_func, args, kwargs
|
1281
|
-
):
|
1282
|
-
"""Helper function to execute code within a span context."""
|
1283
|
-
with current_trace.span(span_name, span_type=span_type) as span:
|
1284
|
-
_record_span_data(span, args, kwargs)
|
1285
|
-
|
1286
|
-
try:
|
1287
|
-
result = execution_func()
|
1288
|
-
_finalize_span_data(span, result, args)
|
1289
|
-
return result
|
1290
|
-
except Exception as e:
|
1291
|
-
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1292
|
-
raise e
|
1293
|
-
|
1294
|
-
async def _execute_in_span_async(
|
1295
|
-
current_trace, span_name, span_type, async_execution_func, args, kwargs
|
1296
|
-
):
|
1297
|
-
"""Helper function to execute async code within a span context."""
|
1298
|
-
with current_trace.span(span_name, span_type=span_type) as span:
|
1299
|
-
_record_span_data(span, args, kwargs)
|
1300
|
-
|
1301
|
-
try:
|
1302
|
-
result = await async_execution_func()
|
1303
|
-
_finalize_span_data(span, result, args)
|
1304
|
-
return result
|
1305
|
-
except Exception as e:
|
1306
|
-
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1307
|
-
raise e
|
1308
|
-
|
1309
|
-
def _create_new_trace(self, span_name):
|
1310
|
-
"""Helper function to create a new trace and set it as current."""
|
1311
|
-
trace_id = str(uuid.uuid4())
|
1312
|
-
project = self.project_name
|
1313
|
-
|
1314
|
-
current_trace = TraceClient(
|
1315
|
-
self,
|
1316
|
-
trace_id,
|
1317
|
-
span_name,
|
1318
|
-
project_name=project,
|
1319
|
-
enable_monitoring=self.enable_monitoring,
|
1320
|
-
enable_evaluations=self.enable_evaluations,
|
1321
|
-
)
|
1322
|
-
|
1323
|
-
trace_token = self.set_current_trace(current_trace)
|
1324
|
-
return current_trace, trace_token
|
1325
|
-
|
1326
|
-
def _execute_with_auto_trace_creation(
|
1327
|
-
span_name, span_type, execution_func, args, kwargs
|
1328
|
-
):
|
1329
|
-
"""Helper function that handles automatic trace creation and span execution."""
|
1330
|
-
current_trace = self.get_current_trace()
|
1331
|
-
|
1332
|
-
if not current_trace:
|
1333
|
-
current_trace, trace_token = _create_new_trace(self, span_name)
|
1334
|
-
|
1335
|
-
try:
|
1336
|
-
result = _execute_in_span(
|
1337
|
-
current_trace,
|
1338
|
-
span_name,
|
1339
|
-
span_type,
|
1340
|
-
execution_func,
|
1341
|
-
args,
|
1342
|
-
kwargs,
|
1343
|
-
)
|
1344
|
-
return result
|
1345
|
-
finally:
|
1346
|
-
# Cleanup the trace we created
|
1347
|
-
_cleanup_trace(current_trace, trace_token, "auto_trace")
|
1348
|
-
else:
|
1349
|
-
# Use existing trace
|
1350
|
-
return _execute_in_span(
|
1351
|
-
current_trace, span_name, span_type, execution_func, args, kwargs
|
1352
|
-
)
|
1353
|
-
|
1354
|
-
async def _execute_with_auto_trace_creation_async(
|
1355
|
-
span_name, span_type, async_execution_func, args, kwargs
|
1356
|
-
):
|
1357
|
-
"""Helper function that handles automatic trace creation and async span execution."""
|
1358
|
-
current_trace = self.get_current_trace()
|
1359
|
-
|
1360
|
-
if not current_trace:
|
1361
|
-
current_trace, trace_token = _create_new_trace(self, span_name)
|
1362
|
-
|
1363
|
-
try:
|
1364
|
-
result = await _execute_in_span_async(
|
1365
|
-
current_trace,
|
1366
|
-
span_name,
|
1367
|
-
span_type,
|
1368
|
-
async_execution_func,
|
1369
|
-
args,
|
1370
|
-
kwargs,
|
1371
|
-
)
|
1372
|
-
return result
|
1373
|
-
finally:
|
1374
|
-
# Cleanup the trace we created
|
1375
|
-
_cleanup_trace(current_trace, trace_token, "async_auto_trace")
|
1376
|
-
else:
|
1377
|
-
# Use existing trace
|
1378
|
-
return await _execute_in_span_async(
|
1379
|
-
current_trace,
|
1380
|
-
span_name,
|
1381
|
-
span_type,
|
1382
|
-
async_execution_func,
|
1383
|
-
args,
|
1384
|
-
kwargs,
|
1385
|
-
)
|
1386
|
-
|
1387
|
-
# Check for generator functions first
|
1388
|
-
if inspect.isgeneratorfunction(func):
|
1389
|
-
|
1390
|
-
@functools.wraps(func)
|
1391
|
-
def generator_wrapper(*args, **kwargs):
|
1392
|
-
# Get the generator from the original function
|
1393
|
-
generator = func(*args, **kwargs)
|
1394
|
-
|
1395
|
-
# Create wrapper generator that creates spans for each yield
|
1396
|
-
def traced_generator():
|
1397
|
-
while True:
|
1398
|
-
try:
|
1399
|
-
# Handle automatic trace creation and span execution
|
1400
|
-
item = _execute_with_auto_trace_creation(
|
1401
|
-
original_span_name,
|
1402
|
-
span_type,
|
1403
|
-
lambda: next(generator),
|
1404
|
-
args,
|
1405
|
-
kwargs,
|
1406
|
-
)
|
1407
|
-
yield item
|
1408
|
-
except StopIteration:
|
1409
|
-
break
|
1410
|
-
|
1411
|
-
return traced_generator()
|
1412
|
-
|
1413
|
-
return generator_wrapper
|
1414
|
-
|
1415
|
-
# Check for async generator functions
|
1416
|
-
elif inspect.isasyncgenfunction(func):
|
1417
|
-
|
1418
|
-
@functools.wraps(func)
|
1419
|
-
def async_generator_wrapper(*args, **kwargs):
|
1420
|
-
# Get the async generator from the original function
|
1421
|
-
async_generator = func(*args, **kwargs)
|
1422
|
-
|
1423
|
-
# Create wrapper async generator that creates spans for each yield
|
1424
|
-
async def traced_async_generator():
|
1425
|
-
while True:
|
1426
|
-
try:
|
1427
|
-
# Handle automatic trace creation and span execution
|
1428
|
-
item = await _execute_with_auto_trace_creation_async(
|
1429
|
-
original_span_name,
|
1430
|
-
span_type,
|
1431
|
-
lambda: async_generator.__anext__(),
|
1432
|
-
args,
|
1433
|
-
kwargs,
|
1434
|
-
)
|
1435
|
-
if inspect.iscoroutine(item):
|
1436
|
-
item = await item
|
1437
|
-
yield item
|
1438
|
-
except StopAsyncIteration:
|
1439
|
-
break
|
1440
|
-
|
1441
|
-
return traced_async_generator()
|
1442
|
-
|
1443
|
-
return async_generator_wrapper
|
1444
|
-
|
1445
|
-
elif asyncio.iscoroutinefunction(func):
|
1446
|
-
|
1447
|
-
@functools.wraps(func)
|
1448
|
-
async def async_wrapper(*args, **kwargs):
|
1449
|
-
nonlocal original_span_name
|
1450
|
-
span_name = original_span_name
|
1451
|
-
|
1452
|
-
async def async_execution():
|
1453
|
-
if self.deep_tracing:
|
1454
|
-
with _DeepTracer(self):
|
1455
|
-
return await func(*args, **kwargs)
|
1456
|
-
else:
|
1457
|
-
return await func(*args, **kwargs)
|
1458
|
-
|
1459
|
-
result = await _execute_with_auto_trace_creation_async(
|
1460
|
-
span_name, span_type, async_execution, args, kwargs
|
1461
|
-
)
|
1462
|
-
|
1463
|
-
return result
|
1464
|
-
|
1465
|
-
return async_wrapper
|
1466
|
-
else:
|
1467
|
-
# Non-async function implementation with deep tracing
|
1468
|
-
@functools.wraps(func)
|
1469
|
-
def wrapper(*args, **kwargs):
|
1470
|
-
nonlocal original_span_name
|
1471
|
-
span_name = original_span_name
|
1472
|
-
|
1473
|
-
def sync_execution():
|
1474
|
-
if self.deep_tracing:
|
1475
|
-
with _DeepTracer(self):
|
1476
|
-
return func(*args, **kwargs)
|
1477
|
-
else:
|
1478
|
-
return func(*args, **kwargs)
|
1479
|
-
|
1480
|
-
return _execute_with_auto_trace_creation(
|
1481
|
-
span_name, span_type, sync_execution, args, kwargs
|
1482
|
-
)
|
1483
|
-
|
1484
|
-
return wrapper
|
1485
|
-
|
1486
|
-
def observe_tools(
|
1487
|
-
self,
|
1488
|
-
cls=None,
|
1489
|
-
*,
|
1490
|
-
exclude_methods: Optional[List[str]] = None,
|
1491
|
-
include_private: bool = False,
|
1492
|
-
warn_on_double_decoration: bool = True,
|
1493
|
-
):
|
1494
|
-
"""
|
1495
|
-
Automatically adds @observe(span_type="tool") to all methods in a class.
|
1496
|
-
|
1497
|
-
Args:
|
1498
|
-
cls: The class to decorate (automatically provided when used as decorator)
|
1499
|
-
exclude_methods: List of method names to skip decorating. Defaults to common magic methods
|
1500
|
-
include_private: Whether to decorate methods starting with underscore. Defaults to False
|
1501
|
-
warn_on_double_decoration: Whether to print warnings when skipping already-decorated methods. Defaults to True
|
1502
|
-
"""
|
1503
|
-
|
1504
|
-
if exclude_methods is None:
|
1505
|
-
exclude_methods = ["__init__", "__new__", "__del__", "__str__", "__repr__"]
|
1506
|
-
|
1507
|
-
def decorate_class(cls):
|
1508
|
-
if not self.enable_monitoring:
|
1509
|
-
return cls
|
1510
|
-
|
1511
|
-
decorated = []
|
1512
|
-
skipped = []
|
1513
|
-
|
1514
|
-
for name in dir(cls):
|
1515
|
-
method = getattr(cls, name)
|
1516
|
-
|
1517
|
-
if (
|
1518
|
-
not callable(method)
|
1519
|
-
or name in exclude_methods
|
1520
|
-
or (name.startswith("_") and not include_private)
|
1521
|
-
or not hasattr(cls, name)
|
1522
|
-
):
|
1523
|
-
continue
|
1524
|
-
|
1525
|
-
if hasattr(method, "_judgment_span_name"):
|
1526
|
-
skipped.append(name)
|
1527
|
-
if warn_on_double_decoration:
|
1528
|
-
judgeval_logger.info(
|
1529
|
-
f"{cls.__name__}.{name} already decorated, skipping"
|
1530
|
-
)
|
1531
|
-
continue
|
1532
|
-
|
1533
|
-
try:
|
1534
|
-
decorated_method = self.observe(method, span_type="tool")
|
1535
|
-
setattr(cls, name, decorated_method)
|
1536
|
-
decorated.append(name)
|
1537
|
-
except Exception as e:
|
1538
|
-
if warn_on_double_decoration:
|
1539
|
-
judgeval_logger.warning(
|
1540
|
-
f"Failed to decorate {cls.__name__}.{name}: {e}"
|
1541
|
-
)
|
1542
|
-
|
1543
|
-
return cls
|
1544
|
-
|
1545
|
-
return decorate_class if cls is None else decorate_class(cls)
|
1546
|
-
|
1547
|
-
def async_evaluate(
|
1548
|
-
self,
|
1549
|
-
scorer: Union[APIScorerConfig, BaseScorer],
|
1550
|
-
example: Example,
|
1551
|
-
model: str = DEFAULT_GPT_MODEL,
|
1552
|
-
sampling_rate: float = 1,
|
1553
|
-
):
|
1554
|
-
try:
|
1555
|
-
if not self.enable_monitoring or not self.enable_evaluations:
|
1556
|
-
return
|
1557
|
-
|
1558
|
-
if not isinstance(scorer, (APIScorerConfig, BaseScorer)):
|
1559
|
-
judgeval_logger.warning(
|
1560
|
-
f"Scorer must be an instance of APIScorerConfig or BaseScorer, got {type(scorer)}, skipping evaluation"
|
1561
|
-
)
|
1562
|
-
return
|
1563
|
-
|
1564
|
-
if not isinstance(example, Example):
|
1565
|
-
judgeval_logger.warning(
|
1566
|
-
f"Example must be an instance of Example, got {type(example)} skipping evaluation"
|
1567
|
-
)
|
1568
|
-
return
|
1569
|
-
|
1570
|
-
if sampling_rate < 0:
|
1571
|
-
judgeval_logger.warning(
|
1572
|
-
"Cannot set sampling_rate below 0, skipping evaluation"
|
1573
|
-
)
|
1574
|
-
return
|
1575
|
-
|
1576
|
-
if sampling_rate > 1:
|
1577
|
-
judgeval_logger.warning(
|
1578
|
-
"Cannot set sampling_rate above 1, skipping evaluation"
|
1579
|
-
)
|
1580
|
-
return
|
1581
|
-
|
1582
|
-
percentage = random.uniform(0, 1)
|
1583
|
-
if percentage > sampling_rate:
|
1584
|
-
judgeval_logger.info("Skipping async_evaluate due to sampling rate")
|
1585
|
-
return
|
1586
|
-
|
1587
|
-
current_trace = self.get_current_trace()
|
1588
|
-
if current_trace:
|
1589
|
-
current_trace.async_evaluate(
|
1590
|
-
scorer=scorer, example=example, model=model
|
1591
|
-
)
|
1592
|
-
else:
|
1593
|
-
judgeval_logger.warning(
|
1594
|
-
"No trace found (context var or fallback), skipping evaluation"
|
1595
|
-
)
|
1596
|
-
except Exception as e:
|
1597
|
-
judgeval_logger.warning(f"Issue with async_evaluate: {e}")
|
1598
|
-
|
1599
|
-
def update_metadata(self, metadata: dict):
|
1600
|
-
"""
|
1601
|
-
Update metadata for the current trace.
|
1602
|
-
|
1603
|
-
Args:
|
1604
|
-
metadata: Metadata as a dictionary
|
1605
|
-
"""
|
1606
|
-
current_trace = self.get_current_trace()
|
1607
|
-
if current_trace:
|
1608
|
-
current_trace.update_metadata(metadata)
|
1609
|
-
else:
|
1610
|
-
judgeval_logger.warning("No current trace found, cannot set metadata")
|
1611
|
-
|
1612
|
-
def set_customer_id(self, customer_id: str):
|
1613
|
-
"""
|
1614
|
-
Set the customer ID for the current trace.
|
1615
|
-
|
1616
|
-
Args:
|
1617
|
-
customer_id: The customer ID to set
|
1618
|
-
"""
|
1619
|
-
current_trace = self.get_current_trace()
|
1620
|
-
if current_trace:
|
1621
|
-
current_trace.set_customer_id(customer_id)
|
1622
|
-
else:
|
1623
|
-
judgeval_logger.warning("No current trace found, cannot set customer ID")
|
1624
|
-
|
1625
|
-
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
1626
|
-
"""
|
1627
|
-
Set the tags for the current trace.
|
1628
|
-
|
1629
|
-
Args:
|
1630
|
-
tags: List of tags to set
|
1631
|
-
"""
|
1632
|
-
current_trace = self.get_current_trace()
|
1633
|
-
if current_trace:
|
1634
|
-
current_trace.set_tags(tags)
|
1635
|
-
else:
|
1636
|
-
judgeval_logger.warning("No current trace found, cannot set tags")
|
1637
|
-
|
1638
|
-
def set_reward_score(self, reward_score: Union[float, Dict[str, float]]):
|
1639
|
-
"""
|
1640
|
-
Set the reward score for this trace to be used for RL or SFT.
|
1641
|
-
|
1642
|
-
Args:
|
1643
|
-
reward_score: The reward score to set
|
1644
|
-
"""
|
1645
|
-
current_trace = self.get_current_trace()
|
1646
|
-
if current_trace:
|
1647
|
-
current_trace.set_reward_score(reward_score)
|
1648
|
-
else:
|
1649
|
-
judgeval_logger.warning("No current trace found, cannot set reward score")
|
1650
|
-
|
1651
|
-
def get_otel_span_processor(self) -> SpanProcessorBase:
|
1652
|
-
"""Get the OpenTelemetry span processor instance."""
|
1653
|
-
return self.otel_span_processor
|
1654
|
-
|
1655
|
-
def flush_background_spans(self, timeout_millis: int = 30000):
|
1656
|
-
"""Flush all pending spans in the background service."""
|
1657
|
-
self.otel_span_processor.force_flush(timeout_millis)
|
1658
|
-
|
1659
|
-
def shutdown_background_service(self):
|
1660
|
-
"""Shutdown the background span service."""
|
1661
|
-
self.otel_span_processor.shutdown()
|
1662
|
-
self.otel_span_processor = SpanProcessorBase()
|
1663
|
-
|
1664
|
-
def wait_for_completion(self, timeout: Optional[float] = 30.0) -> bool:
|
1665
|
-
"""Wait for all evaluations and span processing to complete.
|
1666
|
-
|
1667
|
-
This method blocks until all queued evaluations are processed and
|
1668
|
-
all pending spans are flushed to the server.
|
1669
|
-
|
1670
|
-
Args:
|
1671
|
-
timeout: Maximum time to wait in seconds. Defaults to 30 seconds.
|
1672
|
-
None means wait indefinitely.
|
1673
|
-
|
1674
|
-
Returns:
|
1675
|
-
True if all processing completed within the timeout, False otherwise.
|
1676
|
-
|
1677
|
-
"""
|
1678
|
-
try:
|
1679
|
-
judgeval_logger.debug(
|
1680
|
-
"Waiting for all evaluations and spans to complete..."
|
1681
|
-
)
|
1682
|
-
|
1683
|
-
# Wait for all queued evaluation work to complete
|
1684
|
-
eval_completed = self.local_eval_queue.wait_for_completion()
|
1685
|
-
if not eval_completed:
|
1686
|
-
judgeval_logger.warning(
|
1687
|
-
f"Local evaluation queue did not complete within {timeout} seconds"
|
1688
|
-
)
|
1689
|
-
return False
|
1690
|
-
|
1691
|
-
self.flush_background_spans()
|
1692
|
-
|
1693
|
-
judgeval_logger.debug("All evaluations and spans completed successfully")
|
1694
|
-
return True
|
1695
|
-
|
1696
|
-
except Exception as e:
|
1697
|
-
judgeval_logger.warning(f"Error while waiting for completion: {e}")
|
1698
|
-
return False
|
1699
|
-
|
1700
|
-
def _log_eval_results_callback(self, evaluation_run, scoring_results):
|
1701
|
-
"""Callback to log evaluation results after local processing."""
|
1702
|
-
try:
|
1703
|
-
if scoring_results and self.enable_evaluations and self.enable_monitoring:
|
1704
|
-
# Convert scoring results to the format expected by API client
|
1705
|
-
results_dict = [
|
1706
|
-
result.model_dump(warnings=False) for result in scoring_results
|
1707
|
-
]
|
1708
|
-
api_client = JudgmentApiClient(self.api_key, self.organization_id)
|
1709
|
-
api_client.log_evaluation_results(
|
1710
|
-
results_dict, evaluation_run.model_dump(warnings=False)
|
1711
|
-
)
|
1712
|
-
except Exception as e:
|
1713
|
-
judgeval_logger.warning(f"Failed to log local evaluation results: {e}")
|
1714
|
-
|
1715
|
-
def _cleanup_on_exit(self):
|
1716
|
-
"""Cleanup handler called on application exit to ensure spans are flushed."""
|
1717
|
-
try:
|
1718
|
-
# Wait for all queued evaluation work to complete before stopping
|
1719
|
-
completed = self.local_eval_queue.wait_for_completion()
|
1720
|
-
if not completed:
|
1721
|
-
judgeval_logger.warning(
|
1722
|
-
"Local evaluation queue did not complete within 30 seconds"
|
1723
|
-
)
|
1724
|
-
|
1725
|
-
self.local_eval_queue.stop_workers()
|
1726
|
-
self.flush_background_spans()
|
1727
|
-
except Exception as e:
|
1728
|
-
judgeval_logger.warning(f"Error during tracer cleanup: {e}")
|
1729
|
-
finally:
|
1730
|
-
try:
|
1731
|
-
self.shutdown_background_service()
|
1732
|
-
except Exception as e:
|
1733
|
-
judgeval_logger.warning(
|
1734
|
-
f"Error during background service shutdown: {e}"
|
1735
|
-
)
|
1736
|
-
|
1737
|
-
def trace_to_message_history(
|
1738
|
-
self, trace: Union[Trace, TraceClient]
|
1739
|
-
) -> List[Dict[str, str]]:
|
1740
|
-
"""
|
1741
|
-
Extract message history from a trace for training purposes.
|
1742
|
-
|
1743
|
-
This method processes trace spans to reconstruct the conversation flow,
|
1744
|
-
extracting messages in chronological order from LLM, user, and tool spans.
|
1745
|
-
|
1746
|
-
Args:
|
1747
|
-
trace: Trace or TraceClient instance to extract messages from
|
1748
|
-
|
1749
|
-
Returns:
|
1750
|
-
List of message dictionaries with 'role' and 'content' keys
|
1751
|
-
|
1752
|
-
Raises:
|
1753
|
-
ValueError: If no trace is provided
|
1754
|
-
"""
|
1755
|
-
if not trace:
|
1756
|
-
raise ValueError("No trace provided")
|
1757
|
-
|
1758
|
-
# Handle both Trace and TraceClient objects
|
1759
|
-
if isinstance(trace, TraceClient):
|
1760
|
-
spans = trace.trace_spans
|
1761
|
-
else:
|
1762
|
-
spans = trace.trace_spans if hasattr(trace, "trace_spans") else []
|
1763
|
-
|
1764
|
-
messages = []
|
1765
|
-
first_found = False
|
1766
|
-
|
1767
|
-
# Process spans in chronological order
|
1768
|
-
for span in sorted(
|
1769
|
-
spans, key=lambda s: s.created_at if hasattr(s, "created_at") else 0
|
1770
|
-
):
|
1771
|
-
# Skip spans without output (except for first LLM span which may have input messages)
|
1772
|
-
if span.output is None and span.span_type != "llm":
|
1773
|
-
continue
|
1774
|
-
|
1775
|
-
if span.span_type == "llm":
|
1776
|
-
# For the first LLM span, extract input messages (system + user prompts)
|
1777
|
-
if not first_found and hasattr(span, "inputs") and span.inputs:
|
1778
|
-
input_messages = span.inputs.get("messages", [])
|
1779
|
-
if input_messages:
|
1780
|
-
first_found = True
|
1781
|
-
# Add input messages (typically system and user messages)
|
1782
|
-
for msg in input_messages:
|
1783
|
-
if (
|
1784
|
-
isinstance(msg, dict)
|
1785
|
-
and "role" in msg
|
1786
|
-
and "content" in msg
|
1787
|
-
):
|
1788
|
-
messages.append(
|
1789
|
-
{"role": msg["role"], "content": msg["content"]}
|
1790
|
-
)
|
1791
|
-
|
1792
|
-
# Add assistant response from span output
|
1793
|
-
if span.output is not None:
|
1794
|
-
messages.append({"role": "assistant", "content": str(span.output)})
|
1795
|
-
|
1796
|
-
elif span.span_type == "user":
|
1797
|
-
# Add user messages
|
1798
|
-
if span.output is not None:
|
1799
|
-
messages.append({"role": "user", "content": str(span.output)})
|
1800
|
-
|
1801
|
-
elif span.span_type == "tool":
|
1802
|
-
# Add tool responses as user messages (common pattern in training)
|
1803
|
-
if span.output is not None:
|
1804
|
-
messages.append({"role": "user", "content": str(span.output)})
|
1805
|
-
|
1806
|
-
return messages
|
1807
|
-
|
1808
|
-
def get_current_message_history(self) -> List[Dict[str, str]]:
|
1809
|
-
"""
|
1810
|
-
Get message history from the current trace.
|
1811
|
-
|
1812
|
-
Returns:
|
1813
|
-
List of message dictionaries from the current trace context
|
1814
|
-
|
1815
|
-
Raises:
|
1816
|
-
ValueError: If no current trace is found
|
1817
|
-
"""
|
1818
|
-
current_trace = self.get_current_trace()
|
1819
|
-
if not current_trace:
|
1820
|
-
raise ValueError("No current trace found")
|
1821
|
-
|
1822
|
-
return self.trace_to_message_history(current_trace)
|
1823
|
-
|
1824
|
-
|
1825
|
-
def _get_current_trace(
|
1826
|
-
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
1827
|
-
):
|
1828
|
-
if trace_across_async_contexts:
|
1829
|
-
return Tracer.current_trace
|
1830
|
-
else:
|
1831
|
-
return current_trace_var.get()
|
1832
|
-
|
1833
|
-
|
1834
|
-
def wrap(
|
1835
|
-
client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts
|
1836
|
-
) -> Any:
|
1837
|
-
"""
|
1838
|
-
Wraps an API client to add tracing capabilities.
|
1839
|
-
Supports OpenAI, Together, Anthropic, Google GenAI clients, and TrainableModel.
|
1840
|
-
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1841
|
-
"""
|
1842
|
-
(
|
1843
|
-
span_name,
|
1844
|
-
original_create,
|
1845
|
-
original_responses_create,
|
1846
|
-
original_stream,
|
1847
|
-
original_beta_parse,
|
1848
|
-
) = _get_client_config(client)
|
1849
|
-
|
1850
|
-
def process_span(span, response):
|
1851
|
-
"""Format and record the output in the span"""
|
1852
|
-
output, usage = _format_output_data(client, response)
|
1853
|
-
span.record_output(output)
|
1854
|
-
span.record_usage(usage)
|
1855
|
-
|
1856
|
-
return response
|
1857
|
-
|
1858
|
-
def wrapped(function):
|
1859
|
-
def wrapper(*args, **kwargs):
|
1860
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
1861
|
-
if not current_trace:
|
1862
|
-
return function(*args, **kwargs)
|
1863
|
-
|
1864
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
1865
|
-
span.record_input(kwargs)
|
1866
|
-
|
1867
|
-
try:
|
1868
|
-
response = function(*args, **kwargs)
|
1869
|
-
return process_span(span, response)
|
1870
|
-
except Exception as e:
|
1871
|
-
_capture_exception_for_trace(span, sys.exc_info())
|
1872
|
-
raise e
|
1873
|
-
|
1874
|
-
return wrapper
|
1875
|
-
|
1876
|
-
def wrapped_async(function):
|
1877
|
-
async def wrapper(*args, **kwargs):
|
1878
|
-
current_trace = _get_current_trace(trace_across_async_contexts)
|
1879
|
-
if not current_trace:
|
1880
|
-
return await function(*args, **kwargs)
|
1881
|
-
|
1882
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
1883
|
-
span.record_input(kwargs)
|
1884
|
-
|
1885
|
-
try:
|
1886
|
-
response = await function(*args, **kwargs)
|
1887
|
-
return process_span(span, response)
|
1888
|
-
except Exception as e:
|
1889
|
-
_capture_exception_for_trace(span, sys.exc_info())
|
1890
|
-
raise e
|
1891
|
-
|
1892
|
-
return wrapper
|
1893
|
-
|
1894
|
-
if HAS_OPENAI:
|
1895
|
-
from judgeval.common.tracer.providers import openai_OpenAI, openai_AsyncOpenAI
|
1896
|
-
|
1897
|
-
assert openai_OpenAI is not None, "OpenAI client not found"
|
1898
|
-
assert openai_AsyncOpenAI is not None, "OpenAI async client not found"
|
1899
|
-
if isinstance(client, (openai_OpenAI)):
|
1900
|
-
setattr(client.chat.completions, "create", wrapped(original_create))
|
1901
|
-
setattr(client.responses, "create", wrapped(original_responses_create))
|
1902
|
-
setattr(client.beta.chat.completions, "parse", wrapped(original_beta_parse))
|
1903
|
-
elif isinstance(client, (openai_AsyncOpenAI)):
|
1904
|
-
setattr(client.chat.completions, "create", wrapped_async(original_create))
|
1905
|
-
setattr(
|
1906
|
-
client.responses, "create", wrapped_async(original_responses_create)
|
1907
|
-
)
|
1908
|
-
setattr(
|
1909
|
-
client.beta.chat.completions,
|
1910
|
-
"parse",
|
1911
|
-
wrapped_async(original_beta_parse),
|
1912
|
-
)
|
1913
|
-
|
1914
|
-
if HAS_TOGETHER:
|
1915
|
-
from judgeval.common.tracer.providers import (
|
1916
|
-
together_Together,
|
1917
|
-
together_AsyncTogether,
|
1918
|
-
)
|
1919
|
-
|
1920
|
-
assert together_Together is not None, "Together client not found"
|
1921
|
-
assert together_AsyncTogether is not None, "Together async client not found"
|
1922
|
-
if isinstance(client, (together_Together)):
|
1923
|
-
setattr(client.chat.completions, "create", wrapped(original_create))
|
1924
|
-
elif isinstance(client, (together_AsyncTogether)):
|
1925
|
-
setattr(client.chat.completions, "create", wrapped_async(original_create))
|
1926
|
-
|
1927
|
-
if HAS_ANTHROPIC:
|
1928
|
-
from judgeval.common.tracer.providers import (
|
1929
|
-
anthropic_Anthropic,
|
1930
|
-
anthropic_AsyncAnthropic,
|
1931
|
-
)
|
1932
|
-
|
1933
|
-
assert anthropic_Anthropic is not None, "Anthropic client not found"
|
1934
|
-
assert anthropic_AsyncAnthropic is not None, "Anthropic async client not found"
|
1935
|
-
if isinstance(client, (anthropic_Anthropic)):
|
1936
|
-
setattr(client.messages, "create", wrapped(original_create))
|
1937
|
-
elif isinstance(client, (anthropic_AsyncAnthropic)):
|
1938
|
-
setattr(client.messages, "create", wrapped_async(original_create))
|
1939
|
-
|
1940
|
-
if HAS_GOOGLE_GENAI:
|
1941
|
-
from judgeval.common.tracer.providers import (
|
1942
|
-
google_genai_Client,
|
1943
|
-
google_genai_AsyncClient,
|
1944
|
-
)
|
1945
|
-
|
1946
|
-
assert google_genai_Client is not None, "Google GenAI client not found"
|
1947
|
-
assert google_genai_AsyncClient is not None, (
|
1948
|
-
"Google GenAI async client not found"
|
1949
|
-
)
|
1950
|
-
if isinstance(client, (google_genai_Client)):
|
1951
|
-
setattr(client.models, "generate_content", wrapped(original_create))
|
1952
|
-
elif isinstance(client, (google_genai_AsyncClient)):
|
1953
|
-
setattr(client.models, "generate_content", wrapped_async(original_create))
|
1954
|
-
|
1955
|
-
if HAS_GROQ:
|
1956
|
-
from judgeval.common.tracer.providers import groq_Groq, groq_AsyncGroq
|
1957
|
-
|
1958
|
-
assert groq_Groq is not None, "Groq client not found"
|
1959
|
-
assert groq_AsyncGroq is not None, "Groq async client not found"
|
1960
|
-
if isinstance(client, (groq_Groq)):
|
1961
|
-
setattr(client.chat.completions, "create", wrapped(original_create))
|
1962
|
-
elif isinstance(client, (groq_AsyncGroq)):
|
1963
|
-
setattr(client.chat.completions, "create", wrapped_async(original_create))
|
1964
|
-
|
1965
|
-
# Check for TrainableModel from judgeval.common.trainer
|
1966
|
-
try:
|
1967
|
-
from judgeval.common.trainer import TrainableModel
|
1968
|
-
|
1969
|
-
if isinstance(client, TrainableModel):
|
1970
|
-
# Define a wrapper function that can be reapplied to new model instances
|
1971
|
-
def wrap_model_instance(model_instance):
|
1972
|
-
"""Wrap a model instance with tracing functionality"""
|
1973
|
-
if hasattr(model_instance, "chat") and hasattr(
|
1974
|
-
model_instance.chat, "completions"
|
1975
|
-
):
|
1976
|
-
if hasattr(model_instance.chat.completions, "create"):
|
1977
|
-
setattr(
|
1978
|
-
model_instance.chat.completions,
|
1979
|
-
"create",
|
1980
|
-
wrapped(model_instance.chat.completions.create),
|
1981
|
-
)
|
1982
|
-
if hasattr(model_instance.chat.completions, "acreate"):
|
1983
|
-
setattr(
|
1984
|
-
model_instance.chat.completions,
|
1985
|
-
"acreate",
|
1986
|
-
wrapped_async(model_instance.chat.completions.acreate),
|
1987
|
-
)
|
1988
|
-
|
1989
|
-
# Register the wrapper function with the TrainableModel
|
1990
|
-
client._register_tracer_wrapper(wrap_model_instance)
|
1991
|
-
|
1992
|
-
# Apply wrapping to the current model
|
1993
|
-
wrap_model_instance(client._current_model)
|
1994
|
-
except ImportError:
|
1995
|
-
pass # TrainableModel not available
|
1996
|
-
|
1997
|
-
return client
|
1998
|
-
|
1999
|
-
|
2000
|
-
# Helper functions for client-specific operations
|
2001
|
-
|
2002
|
-
|
2003
|
-
def _get_client_config(
|
2004
|
-
client: ApiClient,
|
2005
|
-
) -> tuple[str, Callable, Optional[Callable], Optional[Callable], Optional[Callable]]:
|
2006
|
-
"""Returns configuration tuple for the given API client.
|
2007
|
-
|
2008
|
-
Args:
|
2009
|
-
client: An instance of OpenAI, Together, or Anthropic client
|
2010
|
-
|
2011
|
-
Returns:
|
2012
|
-
tuple: (span_name, create_method, responses_method, stream_method, beta_parse_method)
|
2013
|
-
- span_name: String identifier for tracing
|
2014
|
-
- create_method: Reference to the client's creation method
|
2015
|
-
- responses_method: Reference to the client's responses method (if applicable)
|
2016
|
-
- stream_method: Reference to the client's stream method (if applicable)
|
2017
|
-
- beta_parse_method: Reference to the client's beta parse method (if applicable)
|
2018
|
-
|
2019
|
-
Raises:
|
2020
|
-
ValueError: If client type is not supported
|
2021
|
-
"""
|
2022
|
-
|
2023
|
-
if HAS_OPENAI:
|
2024
|
-
from judgeval.common.tracer.providers import openai_OpenAI, openai_AsyncOpenAI
|
2025
|
-
|
2026
|
-
assert openai_OpenAI is not None, "OpenAI client not found"
|
2027
|
-
assert openai_AsyncOpenAI is not None, "OpenAI async client not found"
|
2028
|
-
if isinstance(client, (openai_OpenAI)):
|
2029
|
-
return (
|
2030
|
-
"OPENAI_API_CALL",
|
2031
|
-
client.chat.completions.create,
|
2032
|
-
client.responses.create,
|
2033
|
-
None,
|
2034
|
-
client.beta.chat.completions.parse,
|
2035
|
-
)
|
2036
|
-
elif isinstance(client, (openai_AsyncOpenAI)):
|
2037
|
-
return (
|
2038
|
-
"OPENAI_API_CALL",
|
2039
|
-
client.chat.completions.create,
|
2040
|
-
client.responses.create,
|
2041
|
-
None,
|
2042
|
-
client.beta.chat.completions.parse,
|
2043
|
-
)
|
2044
|
-
if HAS_TOGETHER:
|
2045
|
-
from judgeval.common.tracer.providers import (
|
2046
|
-
together_Together,
|
2047
|
-
together_AsyncTogether,
|
2048
|
-
)
|
2049
|
-
|
2050
|
-
assert together_Together is not None, "Together client not found"
|
2051
|
-
assert together_AsyncTogether is not None, "Together async client not found"
|
2052
|
-
if isinstance(client, (together_Together)):
|
2053
|
-
return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
|
2054
|
-
elif isinstance(client, (together_AsyncTogether)):
|
2055
|
-
return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
|
2056
|
-
if HAS_ANTHROPIC:
|
2057
|
-
from judgeval.common.tracer.providers import (
|
2058
|
-
anthropic_Anthropic,
|
2059
|
-
anthropic_AsyncAnthropic,
|
2060
|
-
)
|
2061
|
-
|
2062
|
-
assert anthropic_Anthropic is not None, "Anthropic client not found"
|
2063
|
-
assert anthropic_AsyncAnthropic is not None, "Anthropic async client not found"
|
2064
|
-
if isinstance(client, (anthropic_Anthropic)):
|
2065
|
-
return (
|
2066
|
-
"ANTHROPIC_API_CALL",
|
2067
|
-
client.messages.create,
|
2068
|
-
None,
|
2069
|
-
client.messages.stream,
|
2070
|
-
None,
|
2071
|
-
)
|
2072
|
-
elif isinstance(client, (anthropic_AsyncAnthropic)):
|
2073
|
-
return (
|
2074
|
-
"ANTHROPIC_API_CALL",
|
2075
|
-
client.messages.create,
|
2076
|
-
None,
|
2077
|
-
client.messages.stream,
|
2078
|
-
None,
|
2079
|
-
)
|
2080
|
-
if HAS_GOOGLE_GENAI:
|
2081
|
-
from judgeval.common.tracer.providers import (
|
2082
|
-
google_genai_Client,
|
2083
|
-
google_genai_AsyncClient,
|
2084
|
-
)
|
2085
|
-
|
2086
|
-
assert google_genai_Client is not None, "Google GenAI client not found"
|
2087
|
-
assert google_genai_AsyncClient is not None, (
|
2088
|
-
"Google GenAI async client not found"
|
2089
|
-
)
|
2090
|
-
if isinstance(client, (google_genai_Client)):
|
2091
|
-
return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
|
2092
|
-
elif isinstance(client, (google_genai_AsyncClient)):
|
2093
|
-
return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
|
2094
|
-
if HAS_GROQ:
|
2095
|
-
from judgeval.common.tracer.providers import groq_Groq, groq_AsyncGroq
|
2096
|
-
|
2097
|
-
assert groq_Groq is not None, "Groq client not found"
|
2098
|
-
assert groq_AsyncGroq is not None, "Groq async client not found"
|
2099
|
-
if isinstance(client, (groq_Groq)):
|
2100
|
-
return "GROQ_API_CALL", client.chat.completions.create, None, None, None
|
2101
|
-
elif isinstance(client, (groq_AsyncGroq)):
|
2102
|
-
return "GROQ_API_CALL", client.chat.completions.create, None, None, None
|
2103
|
-
|
2104
|
-
# Check for TrainableModel
|
2105
|
-
try:
|
2106
|
-
from judgeval.common.trainer import TrainableModel
|
2107
|
-
|
2108
|
-
if isinstance(client, TrainableModel):
|
2109
|
-
return (
|
2110
|
-
"FIREWORKS_TRAINABLE_MODEL_CALL",
|
2111
|
-
client._current_model.chat.completions.create,
|
2112
|
-
None,
|
2113
|
-
None,
|
2114
|
-
None,
|
2115
|
-
)
|
2116
|
-
except ImportError:
|
2117
|
-
pass # TrainableModel not available
|
2118
|
-
|
2119
|
-
raise ValueError(f"Unsupported client type: {type(client)}")
|
2120
|
-
|
2121
|
-
|
2122
|
-
def _format_output_data(
|
2123
|
-
client: ApiClient, response: Any
|
2124
|
-
) -> tuple[Optional[str], Optional[TraceUsage]]:
|
2125
|
-
"""Format API response data based on client type.
|
2126
|
-
|
2127
|
-
Normalizes different response formats into a consistent structure
|
2128
|
-
for tracing purposes.
|
2129
|
-
|
2130
|
-
Returns:
|
2131
|
-
dict containing:
|
2132
|
-
- content: The generated text
|
2133
|
-
- usage: Token usage statistics
|
2134
|
-
"""
|
2135
|
-
prompt_tokens = 0
|
2136
|
-
completion_tokens = 0
|
2137
|
-
cache_read_input_tokens = 0
|
2138
|
-
cache_creation_input_tokens = 0
|
2139
|
-
model_name = None
|
2140
|
-
message_content = None
|
2141
|
-
|
2142
|
-
if HAS_OPENAI:
|
2143
|
-
from judgeval.common.tracer.providers import (
|
2144
|
-
openai_OpenAI,
|
2145
|
-
openai_AsyncOpenAI,
|
2146
|
-
openai_ChatCompletion,
|
2147
|
-
openai_Response,
|
2148
|
-
openai_ParsedChatCompletion,
|
2149
|
-
)
|
2150
|
-
|
2151
|
-
assert openai_OpenAI is not None, "OpenAI client not found"
|
2152
|
-
assert openai_AsyncOpenAI is not None, "OpenAI async client not found"
|
2153
|
-
assert openai_ChatCompletion is not None, "OpenAI chat completion not found"
|
2154
|
-
assert openai_Response is not None, "OpenAI response not found"
|
2155
|
-
assert openai_ParsedChatCompletion is not None, (
|
2156
|
-
"OpenAI parsed chat completion not found"
|
2157
|
-
)
|
2158
|
-
|
2159
|
-
if isinstance(client, (openai_OpenAI, openai_AsyncOpenAI)):
|
2160
|
-
if isinstance(response, openai_ChatCompletion):
|
2161
|
-
model_name = response.model
|
2162
|
-
prompt_tokens = response.usage.prompt_tokens if response.usage else 0
|
2163
|
-
completion_tokens = (
|
2164
|
-
response.usage.completion_tokens if response.usage else 0
|
2165
|
-
)
|
2166
|
-
cache_read_input_tokens = (
|
2167
|
-
response.usage.prompt_tokens_details.cached_tokens
|
2168
|
-
if response.usage
|
2169
|
-
and response.usage.prompt_tokens_details
|
2170
|
-
and response.usage.prompt_tokens_details.cached_tokens
|
2171
|
-
else 0
|
2172
|
-
)
|
2173
|
-
|
2174
|
-
if isinstance(response, openai_ParsedChatCompletion):
|
2175
|
-
message_content = response.choices[0].message.parsed
|
2176
|
-
else:
|
2177
|
-
message_content = response.choices[0].message.content
|
2178
|
-
elif isinstance(response, openai_Response):
|
2179
|
-
model_name = response.model
|
2180
|
-
prompt_tokens = response.usage.input_tokens if response.usage else 0
|
2181
|
-
completion_tokens = (
|
2182
|
-
response.usage.output_tokens if response.usage else 0
|
2183
|
-
)
|
2184
|
-
cache_read_input_tokens = (
|
2185
|
-
response.usage.input_tokens_details.cached_tokens
|
2186
|
-
if response.usage and response.usage.input_tokens_details
|
2187
|
-
else 0
|
2188
|
-
)
|
2189
|
-
if hasattr(response.output[0], "content"):
|
2190
|
-
message_content = "".join(
|
2191
|
-
seg.text
|
2192
|
-
for seg in response.output[0].content
|
2193
|
-
if hasattr(seg, "text")
|
2194
|
-
)
|
2195
|
-
# Note: LiteLLM seems to use cache_read_input_tokens to calculate the cost for OpenAI
|
2196
|
-
return message_content, _create_usage(
|
2197
|
-
model_name,
|
2198
|
-
prompt_tokens,
|
2199
|
-
completion_tokens,
|
2200
|
-
cache_read_input_tokens,
|
2201
|
-
cache_creation_input_tokens,
|
2202
|
-
)
|
2203
|
-
|
2204
|
-
if HAS_TOGETHER:
|
2205
|
-
from judgeval.common.tracer.providers import (
|
2206
|
-
together_Together,
|
2207
|
-
together_AsyncTogether,
|
2208
|
-
)
|
2209
|
-
|
2210
|
-
assert together_Together is not None, "Together client not found"
|
2211
|
-
assert together_AsyncTogether is not None, "Together async client not found"
|
2212
|
-
|
2213
|
-
if isinstance(client, (together_Together, together_AsyncTogether)):
|
2214
|
-
model_name = "together_ai/" + response.model
|
2215
|
-
prompt_tokens = response.usage.prompt_tokens
|
2216
|
-
completion_tokens = response.usage.completion_tokens
|
2217
|
-
message_content = response.choices[0].message.content
|
2218
|
-
|
2219
|
-
# As of 2025-07-14, Together does not do any input cache token tracking
|
2220
|
-
return message_content, _create_usage(
|
2221
|
-
model_name,
|
2222
|
-
prompt_tokens,
|
2223
|
-
completion_tokens,
|
2224
|
-
cache_read_input_tokens,
|
2225
|
-
cache_creation_input_tokens,
|
2226
|
-
)
|
2227
|
-
|
2228
|
-
if HAS_GOOGLE_GENAI:
|
2229
|
-
from judgeval.common.tracer.providers import (
|
2230
|
-
google_genai_Client,
|
2231
|
-
google_genai_AsyncClient,
|
2232
|
-
)
|
2233
|
-
|
2234
|
-
assert google_genai_Client is not None, "Google GenAI client not found"
|
2235
|
-
assert google_genai_AsyncClient is not None, (
|
2236
|
-
"Google GenAI async client not found"
|
2237
|
-
)
|
2238
|
-
if isinstance(client, (google_genai_Client, google_genai_AsyncClient)):
|
2239
|
-
model_name = response.model_version
|
2240
|
-
prompt_tokens = response.usage_metadata.prompt_token_count
|
2241
|
-
completion_tokens = response.usage_metadata.candidates_token_count
|
2242
|
-
message_content = response.candidates[0].content.parts[0].text
|
2243
|
-
|
2244
|
-
if hasattr(response.usage_metadata, "cached_content_token_count"):
|
2245
|
-
cache_read_input_tokens = (
|
2246
|
-
response.usage_metadata.cached_content_token_count
|
2247
|
-
)
|
2248
|
-
return message_content, _create_usage(
|
2249
|
-
model_name,
|
2250
|
-
prompt_tokens,
|
2251
|
-
completion_tokens,
|
2252
|
-
cache_read_input_tokens,
|
2253
|
-
cache_creation_input_tokens,
|
2254
|
-
)
|
2255
|
-
|
2256
|
-
if HAS_ANTHROPIC:
|
2257
|
-
from judgeval.common.tracer.providers import (
|
2258
|
-
anthropic_Anthropic,
|
2259
|
-
anthropic_AsyncAnthropic,
|
2260
|
-
)
|
2261
|
-
|
2262
|
-
assert anthropic_Anthropic is not None, "Anthropic client not found"
|
2263
|
-
assert anthropic_AsyncAnthropic is not None, "Anthropic async client not found"
|
2264
|
-
if isinstance(client, (anthropic_Anthropic, anthropic_AsyncAnthropic)):
|
2265
|
-
model_name = response.model
|
2266
|
-
prompt_tokens = response.usage.input_tokens
|
2267
|
-
completion_tokens = response.usage.output_tokens
|
2268
|
-
cache_read_input_tokens = response.usage.cache_read_input_tokens
|
2269
|
-
cache_creation_input_tokens = response.usage.cache_creation_input_tokens
|
2270
|
-
message_content = response.content[0].text
|
2271
|
-
return message_content, _create_usage(
|
2272
|
-
model_name,
|
2273
|
-
prompt_tokens,
|
2274
|
-
completion_tokens,
|
2275
|
-
cache_read_input_tokens,
|
2276
|
-
cache_creation_input_tokens,
|
2277
|
-
)
|
2278
|
-
|
2279
|
-
if HAS_GROQ:
|
2280
|
-
from judgeval.common.tracer.providers import groq_Groq, groq_AsyncGroq
|
2281
|
-
|
2282
|
-
assert groq_Groq is not None, "Groq client not found"
|
2283
|
-
assert groq_AsyncGroq is not None, "Groq async client not found"
|
2284
|
-
if isinstance(client, (groq_Groq, groq_AsyncGroq)):
|
2285
|
-
model_name = "groq/" + response.model
|
2286
|
-
prompt_tokens = response.usage.prompt_tokens
|
2287
|
-
completion_tokens = response.usage.completion_tokens
|
2288
|
-
message_content = response.choices[0].message.content
|
2289
|
-
return message_content, _create_usage(
|
2290
|
-
model_name,
|
2291
|
-
prompt_tokens,
|
2292
|
-
completion_tokens,
|
2293
|
-
cache_read_input_tokens,
|
2294
|
-
cache_creation_input_tokens,
|
2295
|
-
)
|
2296
|
-
|
2297
|
-
# Check for TrainableModel
|
2298
|
-
try:
|
2299
|
-
from judgeval.common.trainer import TrainableModel
|
2300
|
-
|
2301
|
-
if isinstance(client, TrainableModel):
|
2302
|
-
# TrainableModel uses Fireworks LLM internally, so response format should be similar to OpenAI
|
2303
|
-
if (
|
2304
|
-
hasattr(response, "model")
|
2305
|
-
and hasattr(response, "usage")
|
2306
|
-
and hasattr(response, "choices")
|
2307
|
-
):
|
2308
|
-
model_name = response.model
|
2309
|
-
prompt_tokens = response.usage.prompt_tokens if response.usage else 0
|
2310
|
-
completion_tokens = (
|
2311
|
-
response.usage.completion_tokens if response.usage else 0
|
2312
|
-
)
|
2313
|
-
message_content = response.choices[0].message.content
|
2314
|
-
|
2315
|
-
# Use LiteLLM cost calculation with fireworks_ai prefix
|
2316
|
-
# LiteLLM supports Fireworks AI models for cost calculation when prefixed with "fireworks_ai/"
|
2317
|
-
fireworks_model_name = f"fireworks_ai/{model_name}"
|
2318
|
-
return message_content, _create_usage(
|
2319
|
-
fireworks_model_name,
|
2320
|
-
prompt_tokens,
|
2321
|
-
completion_tokens,
|
2322
|
-
cache_read_input_tokens,
|
2323
|
-
cache_creation_input_tokens,
|
2324
|
-
)
|
2325
|
-
except ImportError:
|
2326
|
-
pass # TrainableModel not available
|
2327
|
-
|
2328
|
-
judgeval_logger.warning(f"Unsupported client type: {type(client)}")
|
2329
|
-
return None, None
|
2330
|
-
|
2331
|
-
|
2332
|
-
def _create_usage(
|
2333
|
-
model_name: str,
|
2334
|
-
prompt_tokens: int,
|
2335
|
-
completion_tokens: int,
|
2336
|
-
cache_read_input_tokens: int = 0,
|
2337
|
-
cache_creation_input_tokens: int = 0,
|
2338
|
-
) -> TraceUsage:
|
2339
|
-
"""Helper function to create TraceUsage object with cost calculation."""
|
2340
|
-
prompt_cost, completion_cost = cost_per_token(
|
2341
|
-
model=model_name,
|
2342
|
-
prompt_tokens=prompt_tokens,
|
2343
|
-
completion_tokens=completion_tokens,
|
2344
|
-
cache_read_input_tokens=cache_read_input_tokens,
|
2345
|
-
cache_creation_input_tokens=cache_creation_input_tokens,
|
2346
|
-
)
|
2347
|
-
total_cost_usd = (
|
2348
|
-
(prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
2349
|
-
)
|
2350
|
-
return TraceUsage(
|
2351
|
-
prompt_tokens=prompt_tokens,
|
2352
|
-
completion_tokens=completion_tokens,
|
2353
|
-
total_tokens=prompt_tokens + completion_tokens,
|
2354
|
-
cache_read_input_tokens=cache_read_input_tokens,
|
2355
|
-
cache_creation_input_tokens=cache_creation_input_tokens,
|
2356
|
-
prompt_tokens_cost_usd=prompt_cost,
|
2357
|
-
completion_tokens_cost_usd=completion_cost,
|
2358
|
-
total_cost_usd=total_cost_usd,
|
2359
|
-
model_name=model_name,
|
2360
|
-
)
|
2361
|
-
|
2362
|
-
|
2363
|
-
def combine_args_kwargs(func, args, kwargs):
|
2364
|
-
"""
|
2365
|
-
Combine positional arguments and keyword arguments into a single dictionary.
|
2366
|
-
|
2367
|
-
Args:
|
2368
|
-
func: The function being called
|
2369
|
-
args: Tuple of positional arguments
|
2370
|
-
kwargs: Dictionary of keyword arguments
|
2371
|
-
|
2372
|
-
Returns:
|
2373
|
-
A dictionary combining both args and kwargs
|
2374
|
-
"""
|
2375
|
-
try:
|
2376
|
-
import inspect
|
2377
|
-
|
2378
|
-
sig = inspect.signature(func)
|
2379
|
-
param_names = list(sig.parameters.keys())
|
2380
|
-
|
2381
|
-
args_dict = {}
|
2382
|
-
for i, arg in enumerate(args):
|
2383
|
-
if i < len(param_names):
|
2384
|
-
args_dict[param_names[i]] = arg
|
2385
|
-
else:
|
2386
|
-
args_dict[f"arg{i}"] = arg
|
2387
|
-
|
2388
|
-
return {**args_dict, **kwargs}
|
2389
|
-
except Exception:
|
2390
|
-
# Fallback if signature inspection fails
|
2391
|
-
return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
|
2392
|
-
|
2393
|
-
|
2394
|
-
def cost_per_token(*args, **kwargs):
|
2395
|
-
try:
|
2396
|
-
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = (
|
2397
|
-
_original_cost_per_token(*args, **kwargs)
|
2398
|
-
)
|
2399
|
-
if (
|
2400
|
-
prompt_tokens_cost_usd_dollar == 0
|
2401
|
-
and completion_tokens_cost_usd_dollar == 0
|
2402
|
-
):
|
2403
|
-
judgeval_logger.warning("LiteLLM returned a total of 0 for cost per token")
|
2404
|
-
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
2405
|
-
except Exception as e:
|
2406
|
-
judgeval_logger.warning(f"Error calculating cost per token: {e}")
|
2407
|
-
return None, None
|
2408
|
-
|
2409
|
-
|
2410
|
-
# --- Helper function for instance-prefixed qual_name ---
|
2411
|
-
def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
2412
|
-
"""
|
2413
|
-
Returns the agent name (prefix) if the class and attribute are found in class_identifiers.
|
2414
|
-
Otherwise, returns None.
|
2415
|
-
"""
|
2416
|
-
if class_name in class_identifiers:
|
2417
|
-
class_config = class_identifiers[class_name]
|
2418
|
-
attr = class_config.get("identifier")
|
2419
|
-
if attr:
|
2420
|
-
if hasattr(instance, attr) and not callable(getattr(instance, attr)):
|
2421
|
-
instance_name = getattr(instance, attr)
|
2422
|
-
return instance_name
|
2423
|
-
else:
|
2424
|
-
raise Exception(
|
2425
|
-
f"Attribute {attr} does not exist for {class_name}. Check your agent() decorator."
|
2426
|
-
)
|
2427
|
-
return None
|