judgeval 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +83 -0
- judgeval/clients.py +19 -0
- judgeval/common/__init__.py +8 -0
- judgeval/common/exceptions.py +28 -0
- judgeval/common/logger.py +189 -0
- judgeval/common/tracer.py +587 -0
- judgeval/common/utils.py +763 -0
- judgeval/constants.py +55 -0
- judgeval/data/__init__.py +14 -0
- judgeval/data/api_example.py +111 -0
- judgeval/data/datasets/__init__.py +4 -0
- judgeval/data/datasets/dataset.py +407 -0
- judgeval/data/datasets/ground_truth.py +54 -0
- judgeval/data/datasets/utils.py +74 -0
- judgeval/data/example.py +76 -0
- judgeval/data/result.py +83 -0
- judgeval/data/scorer_data.py +86 -0
- judgeval/evaluation_run.py +130 -0
- judgeval/judges/__init__.py +7 -0
- judgeval/judges/base_judge.py +44 -0
- judgeval/judges/litellm_judge.py +49 -0
- judgeval/judges/mixture_of_judges.py +248 -0
- judgeval/judges/together_judge.py +55 -0
- judgeval/judges/utils.py +45 -0
- judgeval/judgment_client.py +244 -0
- judgeval/run_evaluation.py +355 -0
- judgeval/scorers/__init__.py +30 -0
- judgeval/scorers/base_scorer.py +51 -0
- judgeval/scorers/custom_scorer.py +134 -0
- judgeval/scorers/judgeval_scorers/__init__.py +21 -0
- judgeval/scorers/judgeval_scorers/answer_relevancy.py +19 -0
- judgeval/scorers/judgeval_scorers/contextual_precision.py +19 -0
- judgeval/scorers/judgeval_scorers/contextual_recall.py +19 -0
- judgeval/scorers/judgeval_scorers/contextual_relevancy.py +22 -0
- judgeval/scorers/judgeval_scorers/faithfulness.py +19 -0
- judgeval/scorers/judgeval_scorers/hallucination.py +19 -0
- judgeval/scorers/judgeval_scorers/json_correctness.py +32 -0
- judgeval/scorers/judgeval_scorers/summarization.py +20 -0
- judgeval/scorers/judgeval_scorers/tool_correctness.py +19 -0
- judgeval/scorers/prompt_scorer.py +439 -0
- judgeval/scorers/score.py +427 -0
- judgeval/scorers/utils.py +175 -0
- judgeval-0.0.1.dist-info/METADATA +40 -0
- judgeval-0.0.1.dist-info/RECORD +46 -0
- judgeval-0.0.1.dist-info/WHEEL +4 -0
- judgeval-0.0.1.dist-info/licenses/LICENSE.md +202 -0
@@ -0,0 +1,587 @@
|
|
1
|
+
"""
|
2
|
+
Tracing system for judgeval that allows for function tracing using decorators.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import time
|
6
|
+
import functools
|
7
|
+
import requests
|
8
|
+
import uuid
|
9
|
+
from contextlib import contextmanager
|
10
|
+
from typing import (
|
11
|
+
Optional,
|
12
|
+
Any,
|
13
|
+
List,
|
14
|
+
Literal,
|
15
|
+
Tuple,
|
16
|
+
Generator,
|
17
|
+
TypeAlias,
|
18
|
+
Union
|
19
|
+
)
|
20
|
+
from dataclasses import dataclass, field
|
21
|
+
from datetime import datetime
|
22
|
+
from openai import OpenAI
|
23
|
+
from together import Together
|
24
|
+
from anthropic import Anthropic
|
25
|
+
from typing import Dict
|
26
|
+
import inspect
|
27
|
+
import asyncio
|
28
|
+
import json
|
29
|
+
import warnings
|
30
|
+
from pydantic import BaseModel
|
31
|
+
from http import HTTPStatus
|
32
|
+
|
33
|
+
from judgeval.constants import JUDGMENT_TRACES_SAVE_API_URL
|
34
|
+
from judgeval.judgment_client import JudgmentClient
|
35
|
+
from judgeval.data import Example
|
36
|
+
from judgeval.scorers import JudgmentScorer, CustomScorer
|
37
|
+
from judgeval.data.result import ScoringResult
|
38
|
+
|
39
|
+
# Define type aliases for better code readability and maintainability
|
40
|
+
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
|
41
|
+
TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
|
42
|
+
SpanType = Literal['span', 'tool', 'llm', 'evaluation']
|
43
|
+
@dataclass
|
44
|
+
class TraceEntry:
|
45
|
+
"""Represents a single trace entry with its visual representation.
|
46
|
+
|
47
|
+
Visual representations:
|
48
|
+
- enter: → (function entry)
|
49
|
+
- exit: ← (function exit)
|
50
|
+
- output: Output: (function return value)
|
51
|
+
- input: Input: (function parameters)
|
52
|
+
- evaluation: Evaluation: (evaluation results)
|
53
|
+
"""
|
54
|
+
type: TraceEntryType
|
55
|
+
function: str # Name of the function being traced
|
56
|
+
depth: int # Indentation level for nested calls
|
57
|
+
message: str # Human-readable description
|
58
|
+
timestamp: float # Unix timestamp when entry was created
|
59
|
+
duration: Optional[float] = None # Time taken (for exit/evaluation entries)
|
60
|
+
output: Any = None # Function output value
|
61
|
+
# Use field() for mutable defaults to avoid shared state issues
|
62
|
+
inputs: dict = field(default_factory=dict)
|
63
|
+
span_type: SpanType = "span"
|
64
|
+
evaluation_result: Optional[List[ScoringResult]] = field(default=None)
|
65
|
+
|
66
|
+
def print_entry(self):
|
67
|
+
indent = " " * self.depth
|
68
|
+
if self.type == "enter":
|
69
|
+
print(f"{indent}→ {self.function} (trace: {self.message})")
|
70
|
+
elif self.type == "exit":
|
71
|
+
print(f"{indent}← {self.function} ({self.duration:.3f}s)")
|
72
|
+
elif self.type == "output":
|
73
|
+
print(f"{indent}Output: {self.output}")
|
74
|
+
elif self.type == "input":
|
75
|
+
print(f"{indent}Input: {self.inputs}")
|
76
|
+
elif self.type == "evaluation":
|
77
|
+
print(f"{indent}Evaluation: {self.evaluation_result} ({self.duration:.3f}s)")
|
78
|
+
|
79
|
+
def to_dict(self) -> dict:
|
80
|
+
"""Convert the trace entry to a dictionary format for storage/transmission."""
|
81
|
+
try:
|
82
|
+
output = self._serialize_output()
|
83
|
+
except (TypeError, OverflowError, ValueError):
|
84
|
+
# Handle cases where output cannot be serialized
|
85
|
+
warnings.warn(f"Output for function {self.function} is not JSON serializable. Setting to None.")
|
86
|
+
output = None
|
87
|
+
|
88
|
+
# Build a complete dictionary representation of the trace entry
|
89
|
+
return {
|
90
|
+
"type": self.type,
|
91
|
+
"function": self.function,
|
92
|
+
"depth": self.depth,
|
93
|
+
"message": self.message,
|
94
|
+
"timestamp": self.timestamp,
|
95
|
+
"duration": self.duration,
|
96
|
+
"output": output,
|
97
|
+
"inputs": self.inputs or None, # Convert empty dict to None
|
98
|
+
"evaluation_result": [result.to_dict() for result in self.evaluation_result] if self.evaluation_result else None,
|
99
|
+
"span_type": self.span_type
|
100
|
+
}
|
101
|
+
|
102
|
+
def _serialize_output(self) -> Any:
|
103
|
+
"""Helper method to serialize output data safely.
|
104
|
+
|
105
|
+
Handles special cases:
|
106
|
+
- Pydantic models are converted using model_dump()
|
107
|
+
- Other objects must be JSON serializable
|
108
|
+
"""
|
109
|
+
if isinstance(self.output, BaseModel):
|
110
|
+
return self.output.model_dump()
|
111
|
+
|
112
|
+
# Verify JSON serialization is possible
|
113
|
+
json.dumps(self.output)
|
114
|
+
return self.output
|
115
|
+
|
116
|
+
class TraceClient:
|
117
|
+
"""Client for managing a single trace context"""
|
118
|
+
def __init__(self, tracer, trace_id: str, name: str, project_name: str = "default_project"):
|
119
|
+
self.tracer = tracer
|
120
|
+
self.trace_id = trace_id
|
121
|
+
self.name = name
|
122
|
+
self.project_name = project_name
|
123
|
+
self.client: JudgmentClient = tracer.client
|
124
|
+
self.entries: List[TraceEntry] = []
|
125
|
+
self.start_time = time.time()
|
126
|
+
self.span_type = None
|
127
|
+
self._current_span: Optional[TraceEntry] = None
|
128
|
+
|
129
|
+
@contextmanager
|
130
|
+
def span(self, name: str, span_type: SpanType = "span"):
|
131
|
+
"""Context manager for creating a trace span"""
|
132
|
+
start_time = time.time()
|
133
|
+
|
134
|
+
# Record span entry
|
135
|
+
self.add_entry(TraceEntry(
|
136
|
+
type="enter",
|
137
|
+
function=name,
|
138
|
+
depth=self.tracer.depth,
|
139
|
+
message=name,
|
140
|
+
timestamp=start_time,
|
141
|
+
span_type=span_type
|
142
|
+
))
|
143
|
+
|
144
|
+
self.tracer.depth += 1
|
145
|
+
prev_span = self._current_span
|
146
|
+
self._current_span = name
|
147
|
+
|
148
|
+
try:
|
149
|
+
yield self
|
150
|
+
finally:
|
151
|
+
self.tracer.depth -= 1
|
152
|
+
duration = time.time() - start_time
|
153
|
+
|
154
|
+
# Record span exit
|
155
|
+
self.add_entry(TraceEntry(
|
156
|
+
type="exit",
|
157
|
+
function=name,
|
158
|
+
depth=self.tracer.depth,
|
159
|
+
message=f"← {name}",
|
160
|
+
timestamp=time.time(),
|
161
|
+
duration=duration,
|
162
|
+
span_type=span_type
|
163
|
+
))
|
164
|
+
self._current_span = prev_span
|
165
|
+
|
166
|
+
async def async_evaluate(
|
167
|
+
self,
|
168
|
+
scorers: List[Union[JudgmentScorer, CustomScorer]],
|
169
|
+
input: Optional[str] = None,
|
170
|
+
actual_output: Optional[str] = None,
|
171
|
+
expected_output: Optional[str] = None,
|
172
|
+
context: Optional[List[str]] = None,
|
173
|
+
retrieval_context: Optional[List[str]] = None,
|
174
|
+
tools_called: Optional[List[str]] = None,
|
175
|
+
expected_tools: Optional[List[str]] = None,
|
176
|
+
additional_metadata: Optional[Dict[str, Any]] = None,
|
177
|
+
model: Optional[str] = None,
|
178
|
+
log_results: Optional[bool] = False,
|
179
|
+
):
|
180
|
+
start_time = time.time() # Record start time
|
181
|
+
example = Example(
|
182
|
+
input=input,
|
183
|
+
actual_output=actual_output,
|
184
|
+
expected_output=expected_output,
|
185
|
+
context=context,
|
186
|
+
retrieval_context=retrieval_context,
|
187
|
+
tools_called=tools_called,
|
188
|
+
expected_tools=expected_tools,
|
189
|
+
additional_metadata=additional_metadata,
|
190
|
+
trace_id=self.trace_id
|
191
|
+
)
|
192
|
+
scoring_results = self.client.run_evaluation(
|
193
|
+
examples=[example],
|
194
|
+
scorers=scorers,
|
195
|
+
model=model,
|
196
|
+
metadata={},
|
197
|
+
log_results=log_results,
|
198
|
+
project_name="TestSpanLevel1", # TODO this should be dynamic
|
199
|
+
eval_run_name="TestSpanLevel1",
|
200
|
+
override=True,
|
201
|
+
)
|
202
|
+
|
203
|
+
self.record_evaluation(scoring_results, start_time) # Pass start_time to record_evaluation
|
204
|
+
|
205
|
+
def record_evaluation(self, results: List[ScoringResult], start_time: float):
|
206
|
+
"""Record evaluation results for the current span"""
|
207
|
+
if self._current_span:
|
208
|
+
duration = time.time() - start_time # Calculate duration from start_time
|
209
|
+
|
210
|
+
self.add_entry(TraceEntry(
|
211
|
+
type="evaluation",
|
212
|
+
function=self._current_span,
|
213
|
+
depth=self.tracer.depth,
|
214
|
+
message=f"Evaluation results for {self._current_span}",
|
215
|
+
timestamp=time.time(),
|
216
|
+
evaluation_result=results,
|
217
|
+
duration=duration,
|
218
|
+
span_type="evaluation"
|
219
|
+
))
|
220
|
+
|
221
|
+
def record_input(self, inputs: dict):
|
222
|
+
"""Record input parameters for the current span"""
|
223
|
+
if self._current_span:
|
224
|
+
self.add_entry(TraceEntry(
|
225
|
+
type="input",
|
226
|
+
function=self._current_span,
|
227
|
+
depth=self.tracer.depth,
|
228
|
+
message=f"Inputs to {self._current_span}",
|
229
|
+
timestamp=time.time(),
|
230
|
+
inputs=inputs,
|
231
|
+
span_type=self.span_type
|
232
|
+
))
|
233
|
+
|
234
|
+
async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
|
235
|
+
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
236
|
+
try:
|
237
|
+
result = await coroutine
|
238
|
+
entry.output = result
|
239
|
+
return result
|
240
|
+
except Exception as e:
|
241
|
+
entry.output = f"Error: {str(e)}"
|
242
|
+
raise
|
243
|
+
|
244
|
+
def record_output(self, output: Any):
|
245
|
+
"""Record output for the current span"""
|
246
|
+
if self._current_span:
|
247
|
+
entry = TraceEntry(
|
248
|
+
type="output",
|
249
|
+
function=self._current_span,
|
250
|
+
depth=self.tracer.depth,
|
251
|
+
message=f"Output from {self._current_span}",
|
252
|
+
timestamp=time.time(),
|
253
|
+
output="<pending>" if inspect.iscoroutine(output) else output,
|
254
|
+
span_type=self.span_type
|
255
|
+
)
|
256
|
+
self.add_entry(entry)
|
257
|
+
|
258
|
+
if inspect.iscoroutine(output):
|
259
|
+
# Create a task to update the output once the coroutine completes
|
260
|
+
asyncio.create_task(self._update_coroutine_output(entry, output))
|
261
|
+
|
262
|
+
def add_entry(self, entry: TraceEntry):
|
263
|
+
"""Add a trace entry to this trace context"""
|
264
|
+
self.entries.append(entry)
|
265
|
+
return self
|
266
|
+
|
267
|
+
def print(self):
|
268
|
+
"""Print the complete trace with proper visual structure"""
|
269
|
+
for entry in self.entries:
|
270
|
+
entry.print_entry()
|
271
|
+
|
272
|
+
def get_duration(self) -> float:
|
273
|
+
"""
|
274
|
+
Get the total duration of this trace
|
275
|
+
"""
|
276
|
+
return time.time() - self.start_time
|
277
|
+
|
278
|
+
def condense_trace(self, entries: List[dict]) -> List[dict]:
|
279
|
+
"""
|
280
|
+
Condenses trace entries into a single entry for each function call.
|
281
|
+
"""
|
282
|
+
condensed = []
|
283
|
+
active_functions = [] # Stack to track nested function calls
|
284
|
+
function_entries = {} # Store entries for each function
|
285
|
+
|
286
|
+
for entry in entries:
|
287
|
+
function = entry["function"]
|
288
|
+
|
289
|
+
if entry["type"] == "enter":
|
290
|
+
# Initialize new function entry
|
291
|
+
function_entries[function] = {
|
292
|
+
"depth": entry["depth"],
|
293
|
+
"function": function,
|
294
|
+
"timestamp": entry["timestamp"],
|
295
|
+
"inputs": None,
|
296
|
+
"output": None,
|
297
|
+
"evaluation_result": None,
|
298
|
+
"span_type": entry.get("span_type", "span")
|
299
|
+
}
|
300
|
+
active_functions.append(function)
|
301
|
+
|
302
|
+
elif entry["type"] == "exit" and function in active_functions:
|
303
|
+
# Complete function entry
|
304
|
+
current_entry = function_entries[function]
|
305
|
+
current_entry["duration"] = entry["timestamp"] - current_entry["timestamp"]
|
306
|
+
condensed.append(current_entry)
|
307
|
+
active_functions.remove(function)
|
308
|
+
del function_entries[function]
|
309
|
+
|
310
|
+
elif function in active_functions:
|
311
|
+
# Update existing function entry with additional data
|
312
|
+
current_entry = function_entries[function]
|
313
|
+
|
314
|
+
if entry["type"] == "input" and entry["inputs"]:
|
315
|
+
current_entry["inputs"] = entry["inputs"]
|
316
|
+
|
317
|
+
if entry["type"] == "output" and entry["output"]:
|
318
|
+
current_entry["output"] = entry["output"]
|
319
|
+
|
320
|
+
if entry["type"] == "evaluation" and entry["evaluation_result"]:
|
321
|
+
current_entry["evaluation_result"] = entry["evaluation_result"]
|
322
|
+
|
323
|
+
# Sort by timestamp
|
324
|
+
condensed.sort(key=lambda x: x["timestamp"])
|
325
|
+
return condensed
|
326
|
+
|
327
|
+
def save(self, empty_save: bool = False, overwrite: bool = False) -> Tuple[str, dict]:
|
328
|
+
"""
|
329
|
+
Save the current trace to the database.
|
330
|
+
Returns a tuple of (trace_id, trace_data) where trace_data is the trace data that was saved.
|
331
|
+
"""
|
332
|
+
# Calculate total elapsed time
|
333
|
+
total_duration = self.get_duration()
|
334
|
+
|
335
|
+
raw_entries = [entry.to_dict() for entry in self.entries]
|
336
|
+
condensed_entries = self.condense_trace(raw_entries)
|
337
|
+
|
338
|
+
# Create trace document
|
339
|
+
trace_data = {
|
340
|
+
"trace_id": self.trace_id,
|
341
|
+
"api_key": self.tracer.api_key,
|
342
|
+
"name": self.name,
|
343
|
+
"project_name": self.project_name,
|
344
|
+
"created_at": datetime.fromtimestamp(self.start_time).isoformat(),
|
345
|
+
"duration": total_duration,
|
346
|
+
"token_counts": {
|
347
|
+
"prompt_tokens": 0, # Dummy value
|
348
|
+
"completion_tokens": 0, # Dummy value
|
349
|
+
"total_tokens": 0, # Dummy value
|
350
|
+
}, # TODO: Add token counts
|
351
|
+
"entries": condensed_entries,
|
352
|
+
"empty_save": empty_save,
|
353
|
+
"overwrite": overwrite
|
354
|
+
}
|
355
|
+
|
356
|
+
# Save trace data by making POST request to API
|
357
|
+
response = requests.post(
|
358
|
+
JUDGMENT_TRACES_SAVE_API_URL,
|
359
|
+
json=trace_data,
|
360
|
+
headers={
|
361
|
+
"Content-Type": "application/json",
|
362
|
+
}
|
363
|
+
)
|
364
|
+
|
365
|
+
if response.status_code == HTTPStatus.BAD_REQUEST:
|
366
|
+
raise ValueError(f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}")
|
367
|
+
elif response.status_code != HTTPStatus.OK:
|
368
|
+
raise ValueError(f"Failed to save trace data: {response.text}")
|
369
|
+
|
370
|
+
return self.trace_id, trace_data
|
371
|
+
|
372
|
+
class Tracer:
|
373
|
+
_instance = None
|
374
|
+
|
375
|
+
def __new__(cls, *args, **kwargs):
|
376
|
+
if cls._instance is None:
|
377
|
+
cls._instance = super(Tracer, cls).__new__(cls)
|
378
|
+
return cls._instance
|
379
|
+
|
380
|
+
def __init__(self, api_key: str):
|
381
|
+
if not hasattr(self, 'initialized'):
|
382
|
+
|
383
|
+
if not api_key:
|
384
|
+
raise ValueError("Tracer must be configured with a Judgment API key")
|
385
|
+
|
386
|
+
self.api_key: str = api_key
|
387
|
+
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
388
|
+
self.depth: int = 0
|
389
|
+
self._current_trace: Optional[str] = None
|
390
|
+
self.initialized: bool = True
|
391
|
+
|
392
|
+
@contextmanager
|
393
|
+
def trace(self, name: str, project_name: str = "default_project", overwrite: bool = False) -> Generator[TraceClient, None, None]:
|
394
|
+
"""Start a new trace context using a context manager"""
|
395
|
+
trace_id = str(uuid.uuid4())
|
396
|
+
trace = TraceClient(self, trace_id, name, project_name=project_name)
|
397
|
+
prev_trace = self._current_trace
|
398
|
+
self._current_trace = trace
|
399
|
+
|
400
|
+
# Automatically create top-level span
|
401
|
+
with trace.span(name or "unnamed_trace") as span:
|
402
|
+
try:
|
403
|
+
# Save the trace to the database to handle Evaluations' trace_id referential integrity
|
404
|
+
trace.save(empty_save=True, overwrite=overwrite)
|
405
|
+
yield trace
|
406
|
+
finally:
|
407
|
+
self._current_trace = prev_trace
|
408
|
+
|
409
|
+
def get_current_trace(self) -> Optional[TraceClient]:
|
410
|
+
"""
|
411
|
+
Get the current trace context
|
412
|
+
"""
|
413
|
+
return self._current_trace
|
414
|
+
|
415
|
+
def observe(self, func=None, *, name=None, span_type: SpanType = "span"):
|
416
|
+
"""
|
417
|
+
Decorator to trace function execution with detailed entry/exit information.
|
418
|
+
|
419
|
+
Args:
|
420
|
+
func: The function to trace
|
421
|
+
name: Optional custom name for the function
|
422
|
+
span_type: The type of span to use for this observation (default: "span")
|
423
|
+
"""
|
424
|
+
if func is None:
|
425
|
+
return lambda f: self.observe(f, name=name, span_type=span_type)
|
426
|
+
|
427
|
+
if asyncio.iscoroutinefunction(func):
|
428
|
+
@functools.wraps(func)
|
429
|
+
async def async_wrapper(*args, **kwargs):
|
430
|
+
if self._current_trace:
|
431
|
+
span_name = name or func.__name__
|
432
|
+
|
433
|
+
with self._current_trace.span(span_name, span_type=span_type) as span:
|
434
|
+
# Set the span type
|
435
|
+
span.span_type = span_type
|
436
|
+
|
437
|
+
# Record inputs
|
438
|
+
span.record_input({
|
439
|
+
'args': list(args),
|
440
|
+
'kwargs': kwargs
|
441
|
+
})
|
442
|
+
|
443
|
+
# Execute function
|
444
|
+
result = await func(*args, **kwargs)
|
445
|
+
|
446
|
+
# Record output
|
447
|
+
span.record_output(result)
|
448
|
+
|
449
|
+
return result
|
450
|
+
|
451
|
+
return await func(*args, **kwargs)
|
452
|
+
return async_wrapper
|
453
|
+
else:
|
454
|
+
@functools.wraps(func)
|
455
|
+
def wrapper(*args, **kwargs):
|
456
|
+
if self._current_trace:
|
457
|
+
span_name = name or func.__name__
|
458
|
+
|
459
|
+
with self._current_trace.span(span_name, span_type=span_type) as span:
|
460
|
+
# Set the span type
|
461
|
+
span.span_type = span_type
|
462
|
+
|
463
|
+
# Record inputs
|
464
|
+
span.record_input({
|
465
|
+
'args': list(args),
|
466
|
+
'kwargs': kwargs
|
467
|
+
})
|
468
|
+
|
469
|
+
# Execute function
|
470
|
+
result = func(*args, **kwargs)
|
471
|
+
|
472
|
+
# Record output
|
473
|
+
span.record_output(result)
|
474
|
+
|
475
|
+
return result
|
476
|
+
|
477
|
+
return func(*args, **kwargs)
|
478
|
+
return wrapper
|
479
|
+
|
480
|
+
def wrap(client: Any) -> Any:
|
481
|
+
"""
|
482
|
+
Wraps an API client to add tracing capabilities.
|
483
|
+
Supports OpenAI, Together, and Anthropic clients.
|
484
|
+
"""
|
485
|
+
tracer = Tracer._instance # Get the global tracer instance
|
486
|
+
|
487
|
+
# Get the appropriate configuration for this client type
|
488
|
+
span_name, original_create = _get_client_config(client)
|
489
|
+
|
490
|
+
def traced_create(*args, **kwargs):
|
491
|
+
# Skip tracing if no active trace
|
492
|
+
if not (tracer and tracer._current_trace):
|
493
|
+
return original_create(*args, **kwargs)
|
494
|
+
|
495
|
+
with tracer._current_trace.span(span_name, span_type="llm") as span:
|
496
|
+
# Format and record the input parameters
|
497
|
+
input_data = _format_input_data(client, **kwargs)
|
498
|
+
span.record_input(input_data)
|
499
|
+
|
500
|
+
# Make the actual API call
|
501
|
+
response = original_create(*args, **kwargs)
|
502
|
+
|
503
|
+
# Format and record the output
|
504
|
+
output_data = _format_output_data(client, response)
|
505
|
+
span.record_output(output_data)
|
506
|
+
|
507
|
+
return response
|
508
|
+
|
509
|
+
# Replace the original method with our traced version
|
510
|
+
if isinstance(client, (OpenAI, Together)):
|
511
|
+
client.chat.completions.create = traced_create
|
512
|
+
elif isinstance(client, Anthropic):
|
513
|
+
client.messages.create = traced_create
|
514
|
+
|
515
|
+
return client
|
516
|
+
|
517
|
+
# Helper functions for client-specific operations
|
518
|
+
|
519
|
+
def _get_client_config(client: ApiClient) -> tuple[str, callable]:
|
520
|
+
"""Returns configuration tuple for the given API client.
|
521
|
+
|
522
|
+
Args:
|
523
|
+
client: An instance of OpenAI, Together, or Anthropic client
|
524
|
+
|
525
|
+
Returns:
|
526
|
+
tuple: (span_name, create_method)
|
527
|
+
- span_name: String identifier for tracing
|
528
|
+
- create_method: Reference to the client's creation method
|
529
|
+
|
530
|
+
Raises:
|
531
|
+
ValueError: If client type is not supported
|
532
|
+
"""
|
533
|
+
if isinstance(client, OpenAI):
|
534
|
+
return "OPENAI_API_CALL", client.chat.completions.create
|
535
|
+
elif isinstance(client, Together):
|
536
|
+
return "TOGETHER_API_CALL", client.chat.completions.create
|
537
|
+
elif isinstance(client, Anthropic):
|
538
|
+
return "ANTHROPIC_API_CALL", client.messages.create
|
539
|
+
raise ValueError(f"Unsupported client type: {type(client)}")
|
540
|
+
|
541
|
+
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
542
|
+
"""Format input parameters based on client type.
|
543
|
+
|
544
|
+
Extracts relevant parameters from kwargs based on the client type
|
545
|
+
to ensure consistent tracing across different APIs.
|
546
|
+
"""
|
547
|
+
if isinstance(client, (OpenAI, Together)):
|
548
|
+
return {
|
549
|
+
"model": kwargs.get("model"),
|
550
|
+
"messages": kwargs.get("messages"),
|
551
|
+
}
|
552
|
+
# Anthropic requires additional max_tokens parameter
|
553
|
+
return {
|
554
|
+
"model": kwargs.get("model"),
|
555
|
+
"messages": kwargs.get("messages"),
|
556
|
+
"max_tokens": kwargs.get("max_tokens")
|
557
|
+
}
|
558
|
+
|
559
|
+
def _format_output_data(client: ApiClient, response: Any) -> dict:
|
560
|
+
"""Format API response data based on client type.
|
561
|
+
|
562
|
+
Normalizes different response formats into a consistent structure
|
563
|
+
for tracing purposes.
|
564
|
+
|
565
|
+
Returns:
|
566
|
+
dict containing:
|
567
|
+
- content: The generated text
|
568
|
+
- usage: Token usage statistics
|
569
|
+
"""
|
570
|
+
if isinstance(client, (OpenAI, Together)):
|
571
|
+
return {
|
572
|
+
"content": response.choices[0].message.content,
|
573
|
+
"usage": {
|
574
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
575
|
+
"completion_tokens": response.usage.completion_tokens,
|
576
|
+
"total_tokens": response.usage.total_tokens
|
577
|
+
}
|
578
|
+
}
|
579
|
+
# Anthropic has a different response structure
|
580
|
+
return {
|
581
|
+
"content": response.content[0].text,
|
582
|
+
"usage": {
|
583
|
+
"input_tokens": response.usage.input_tokens,
|
584
|
+
"output_tokens": response.usage.output_tokens,
|
585
|
+
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
|
586
|
+
}
|
587
|
+
}
|