judgeval 0.0.11__py3-none-any.whl → 0.22.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of judgeval might be problematic. Click here for more details.
- judgeval/__init__.py +177 -12
- judgeval/api/__init__.py +519 -0
- judgeval/api/api_types.py +407 -0
- judgeval/cli.py +79 -0
- judgeval/constants.py +76 -47
- judgeval/data/__init__.py +3 -3
- judgeval/data/evaluation_run.py +125 -0
- judgeval/data/example.py +15 -56
- judgeval/data/judgment_types.py +450 -0
- judgeval/data/result.py +29 -73
- judgeval/data/scorer_data.py +29 -62
- judgeval/data/scripts/fix_default_factory.py +23 -0
- judgeval/data/scripts/openapi_transform.py +123 -0
- judgeval/data/trace.py +121 -0
- judgeval/dataset/__init__.py +264 -0
- judgeval/env.py +52 -0
- judgeval/evaluation/__init__.py +344 -0
- judgeval/exceptions.py +27 -0
- judgeval/integrations/langgraph/__init__.py +13 -0
- judgeval/integrations/openlit/__init__.py +50 -0
- judgeval/judges/__init__.py +2 -3
- judgeval/judges/base_judge.py +2 -3
- judgeval/judges/litellm_judge.py +100 -20
- judgeval/judges/together_judge.py +101 -20
- judgeval/judges/utils.py +20 -24
- judgeval/logger.py +62 -0
- judgeval/prompt/__init__.py +330 -0
- judgeval/scorers/__init__.py +18 -25
- judgeval/scorers/agent_scorer.py +17 -0
- judgeval/scorers/api_scorer.py +45 -41
- judgeval/scorers/base_scorer.py +83 -38
- judgeval/scorers/example_scorer.py +17 -0
- judgeval/scorers/exceptions.py +1 -0
- judgeval/scorers/judgeval_scorers/__init__.py +0 -148
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +19 -17
- judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +13 -19
- judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +12 -19
- judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +13 -19
- judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +15 -0
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +327 -0
- judgeval/scorers/score.py +77 -306
- judgeval/scorers/utils.py +4 -199
- judgeval/tracer/__init__.py +1122 -2
- judgeval/tracer/constants.py +1 -0
- judgeval/tracer/exporters/__init__.py +40 -0
- judgeval/tracer/exporters/s3.py +119 -0
- judgeval/tracer/exporters/store.py +59 -0
- judgeval/tracer/exporters/utils.py +32 -0
- judgeval/tracer/keys.py +63 -0
- judgeval/tracer/llm/__init__.py +7 -0
- judgeval/tracer/llm/config.py +78 -0
- judgeval/tracer/llm/constants.py +9 -0
- judgeval/tracer/llm/llm_anthropic/__init__.py +3 -0
- judgeval/tracer/llm/llm_anthropic/config.py +6 -0
- judgeval/tracer/llm/llm_anthropic/messages.py +452 -0
- judgeval/tracer/llm/llm_anthropic/messages_stream.py +322 -0
- judgeval/tracer/llm/llm_anthropic/wrapper.py +59 -0
- judgeval/tracer/llm/llm_google/__init__.py +3 -0
- judgeval/tracer/llm/llm_google/config.py +6 -0
- judgeval/tracer/llm/llm_google/generate_content.py +127 -0
- judgeval/tracer/llm/llm_google/wrapper.py +30 -0
- judgeval/tracer/llm/llm_openai/__init__.py +3 -0
- judgeval/tracer/llm/llm_openai/beta_chat_completions.py +216 -0
- judgeval/tracer/llm/llm_openai/chat_completions.py +501 -0
- judgeval/tracer/llm/llm_openai/config.py +6 -0
- judgeval/tracer/llm/llm_openai/responses.py +506 -0
- judgeval/tracer/llm/llm_openai/utils.py +42 -0
- judgeval/tracer/llm/llm_openai/wrapper.py +63 -0
- judgeval/tracer/llm/llm_together/__init__.py +3 -0
- judgeval/tracer/llm/llm_together/chat_completions.py +406 -0
- judgeval/tracer/llm/llm_together/config.py +6 -0
- judgeval/tracer/llm/llm_together/wrapper.py +52 -0
- judgeval/tracer/llm/providers.py +19 -0
- judgeval/tracer/managers.py +167 -0
- judgeval/tracer/processors/__init__.py +220 -0
- judgeval/tracer/utils.py +19 -0
- judgeval/trainer/__init__.py +14 -0
- judgeval/trainer/base_trainer.py +122 -0
- judgeval/trainer/config.py +128 -0
- judgeval/trainer/console.py +144 -0
- judgeval/trainer/fireworks_trainer.py +396 -0
- judgeval/trainer/trainable_model.py +243 -0
- judgeval/trainer/trainer.py +70 -0
- judgeval/utils/async_utils.py +39 -0
- judgeval/utils/decorators/__init__.py +0 -0
- judgeval/utils/decorators/dont_throw.py +37 -0
- judgeval/utils/decorators/use_once.py +13 -0
- judgeval/utils/file_utils.py +97 -0
- judgeval/utils/guards.py +36 -0
- judgeval/utils/meta.py +27 -0
- judgeval/utils/project.py +15 -0
- judgeval/utils/serialize.py +253 -0
- judgeval/utils/testing.py +70 -0
- judgeval/utils/url.py +10 -0
- judgeval/utils/version_check.py +28 -0
- judgeval/utils/wrappers/README.md +3 -0
- judgeval/utils/wrappers/__init__.py +15 -0
- judgeval/utils/wrappers/immutable_wrap_async.py +74 -0
- judgeval/utils/wrappers/immutable_wrap_async_iterator.py +84 -0
- judgeval/utils/wrappers/immutable_wrap_sync.py +66 -0
- judgeval/utils/wrappers/immutable_wrap_sync_iterator.py +84 -0
- judgeval/utils/wrappers/mutable_wrap_async.py +67 -0
- judgeval/utils/wrappers/mutable_wrap_sync.py +67 -0
- judgeval/utils/wrappers/py.typed +0 -0
- judgeval/utils/wrappers/utils.py +35 -0
- judgeval/version.py +5 -0
- judgeval/warnings.py +4 -0
- judgeval-0.22.2.dist-info/METADATA +265 -0
- judgeval-0.22.2.dist-info/RECORD +112 -0
- judgeval-0.22.2.dist-info/entry_points.txt +2 -0
- judgeval/clients.py +0 -39
- judgeval/common/__init__.py +0 -8
- judgeval/common/exceptions.py +0 -28
- judgeval/common/logger.py +0 -189
- judgeval/common/tracer.py +0 -798
- judgeval/common/utils.py +0 -763
- judgeval/data/api_example.py +0 -111
- judgeval/data/datasets/__init__.py +0 -5
- judgeval/data/datasets/dataset.py +0 -286
- judgeval/data/datasets/eval_dataset_client.py +0 -193
- judgeval/data/datasets/ground_truth.py +0 -54
- judgeval/data/datasets/utils.py +0 -74
- judgeval/evaluation_run.py +0 -132
- judgeval/judges/mixture_of_judges.py +0 -248
- judgeval/judgment_client.py +0 -354
- judgeval/run_evaluation.py +0 -439
- judgeval/scorers/judgeval_scorer.py +0 -140
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +0 -19
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +0 -19
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +0 -22
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -19
- judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +0 -32
- judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +0 -20
- judgeval/scorers/judgeval_scorers/api_scorers/tool_correctness.py +0 -19
- judgeval/scorers/judgeval_scorers/classifiers/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +0 -54
- judgeval/scorers/judgeval_scorers/local_implementations/__init__.py +0 -24
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/__init__.py +0 -4
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/answer_correctness_scorer.py +0 -277
- judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/prompts.py +0 -169
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/__init__.py +0 -4
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/answer_relevancy_scorer.py +0 -298
- judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/prompts.py +0 -174
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/contextual_precision_scorer.py +0 -264
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/prompts.py +0 -106
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/contextual_recall_scorer.py +0 -254
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/prompts.py +0 -142
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/contextual_relevancy_scorer.py +0 -245
- judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/prompts.py +0 -121
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/faithfulness_scorer.py +0 -325
- judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/prompts.py +0 -268
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/hallucination_scorer.py +0 -263
- judgeval/scorers/judgeval_scorers/local_implementations/hallucination/prompts.py +0 -104
- judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/__init__.py +0 -5
- judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/json_correctness_scorer.py +0 -134
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/prompts.py +0 -247
- judgeval/scorers/judgeval_scorers/local_implementations/summarization/summarization_scorer.py +0 -550
- judgeval/scorers/judgeval_scorers/local_implementations/tool_correctness/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/local_implementations/tool_correctness/tool_correctness_scorer.py +0 -157
- judgeval/scorers/prompt_scorer.py +0 -439
- judgeval-0.0.11.dist-info/METADATA +0 -36
- judgeval-0.0.11.dist-info/RECORD +0 -84
- {judgeval-0.0.11.dist-info → judgeval-0.22.2.dist-info}/WHEEL +0 -0
- {judgeval-0.0.11.dist-info → judgeval-0.22.2.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
DELETED
|
@@ -1,798 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Tracing system for judgeval that allows for function tracing using decorators.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import os
|
|
6
|
-
import time
|
|
7
|
-
import functools
|
|
8
|
-
import requests
|
|
9
|
-
import uuid
|
|
10
|
-
from contextlib import contextmanager
|
|
11
|
-
from typing import (
|
|
12
|
-
Optional,
|
|
13
|
-
Any,
|
|
14
|
-
List,
|
|
15
|
-
Literal,
|
|
16
|
-
Tuple,
|
|
17
|
-
Generator,
|
|
18
|
-
TypeAlias,
|
|
19
|
-
Union
|
|
20
|
-
)
|
|
21
|
-
from dataclasses import (
|
|
22
|
-
dataclass,
|
|
23
|
-
field
|
|
24
|
-
)
|
|
25
|
-
from datetime import datetime
|
|
26
|
-
from openai import OpenAI
|
|
27
|
-
from together import Together
|
|
28
|
-
from anthropic import Anthropic
|
|
29
|
-
from typing import Dict
|
|
30
|
-
import inspect
|
|
31
|
-
import asyncio
|
|
32
|
-
import json
|
|
33
|
-
import warnings
|
|
34
|
-
from pydantic import BaseModel
|
|
35
|
-
from http import HTTPStatus
|
|
36
|
-
|
|
37
|
-
import pika
|
|
38
|
-
import os
|
|
39
|
-
|
|
40
|
-
from judgeval.constants import JUDGMENT_TRACES_SAVE_API_URL, JUDGMENT_TRACES_FETCH_API_URL, RABBITMQ_HOST, RABBITMQ_PORT, RABBITMQ_QUEUE, JUDGMENT_TRACES_DELETE_API_URL
|
|
41
|
-
from judgeval.judgment_client import JudgmentClient
|
|
42
|
-
from judgeval.data import Example
|
|
43
|
-
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer, ScorerWrapper
|
|
44
|
-
|
|
45
|
-
from rich import print as rprint
|
|
46
|
-
|
|
47
|
-
from judgeval.data.result import ScoringResult
|
|
48
|
-
from judgeval.evaluation_run import EvaluationRun
|
|
49
|
-
|
|
50
|
-
# Define type aliases for better code readability and maintainability
|
|
51
|
-
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
|
|
52
|
-
TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
|
|
53
|
-
SpanType = Literal['span', 'tool', 'llm', 'evaluation']
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@dataclass
|
|
57
|
-
class TraceEntry:
|
|
58
|
-
"""Represents a single trace entry with its visual representation.
|
|
59
|
-
|
|
60
|
-
Visual representations:
|
|
61
|
-
- enter: → (function entry)
|
|
62
|
-
- exit: ← (function exit)
|
|
63
|
-
- output: Output: (function return value)
|
|
64
|
-
- input: Input: (function parameters)
|
|
65
|
-
- evaluation: Evaluation: (evaluation results)
|
|
66
|
-
"""
|
|
67
|
-
type: TraceEntryType
|
|
68
|
-
function: str # Name of the function being traced
|
|
69
|
-
depth: int # Indentation level for nested calls
|
|
70
|
-
message: str # Human-readable description
|
|
71
|
-
timestamp: float # Unix timestamp when entry was created
|
|
72
|
-
duration: Optional[float] = None # Time taken (for exit/evaluation entries)
|
|
73
|
-
output: Any = None # Function output value
|
|
74
|
-
# Use field() for mutable defaults to avoid shared state issues
|
|
75
|
-
inputs: dict = field(default_factory=dict)
|
|
76
|
-
span_type: SpanType = "span"
|
|
77
|
-
evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
|
|
78
|
-
|
|
79
|
-
def print_entry(self):
|
|
80
|
-
indent = " " * self.depth
|
|
81
|
-
if self.type == "enter":
|
|
82
|
-
print(f"{indent}→ {self.function} (trace: {self.message})")
|
|
83
|
-
elif self.type == "exit":
|
|
84
|
-
print(f"{indent}← {self.function} ({self.duration:.3f}s)")
|
|
85
|
-
elif self.type == "output":
|
|
86
|
-
print(f"{indent}Output: {self.output}")
|
|
87
|
-
elif self.type == "input":
|
|
88
|
-
print(f"{indent}Input: {self.inputs}")
|
|
89
|
-
elif self.type == "evaluation":
|
|
90
|
-
for evaluation_run in self.evaluation_runs:
|
|
91
|
-
print(f"{indent}Evaluation: {evaluation_run.model_dump()}")
|
|
92
|
-
|
|
93
|
-
def _serialize_inputs(self) -> dict:
|
|
94
|
-
"""Helper method to serialize input data safely.
|
|
95
|
-
|
|
96
|
-
Returns a dict with serializable versions of inputs, converting non-serializable
|
|
97
|
-
objects to None with a warning.
|
|
98
|
-
"""
|
|
99
|
-
serialized_inputs = {}
|
|
100
|
-
for key, value in self.inputs.items():
|
|
101
|
-
if isinstance(value, BaseModel):
|
|
102
|
-
serialized_inputs[key] = value.model_dump()
|
|
103
|
-
elif isinstance(value, (list, tuple)):
|
|
104
|
-
# Handle lists/tuples of arguments
|
|
105
|
-
serialized_inputs[key] = [
|
|
106
|
-
item.model_dump() if isinstance(item, BaseModel)
|
|
107
|
-
else None if not self._is_json_serializable(item)
|
|
108
|
-
else item
|
|
109
|
-
for item in value
|
|
110
|
-
]
|
|
111
|
-
else:
|
|
112
|
-
if self._is_json_serializable(value):
|
|
113
|
-
serialized_inputs[key] = value
|
|
114
|
-
else:
|
|
115
|
-
warnings.warn(f"Input '{key}' for function {self.function} is not JSON serializable. Setting to None.")
|
|
116
|
-
serialized_inputs[key] = None
|
|
117
|
-
return serialized_inputs
|
|
118
|
-
|
|
119
|
-
def _is_json_serializable(self, obj: Any) -> bool:
|
|
120
|
-
"""Helper method to check if an object is JSON serializable."""
|
|
121
|
-
try:
|
|
122
|
-
json.dumps(obj)
|
|
123
|
-
return True
|
|
124
|
-
except (TypeError, OverflowError, ValueError):
|
|
125
|
-
return False
|
|
126
|
-
|
|
127
|
-
def to_dict(self) -> dict:
|
|
128
|
-
"""Convert the trace entry to a dictionary format for storage/transmission."""
|
|
129
|
-
return {
|
|
130
|
-
"type": self.type,
|
|
131
|
-
"function": self.function,
|
|
132
|
-
"depth": self.depth,
|
|
133
|
-
"message": self.message,
|
|
134
|
-
"timestamp": self.timestamp,
|
|
135
|
-
"duration": self.duration,
|
|
136
|
-
"output": self._serialize_output(),
|
|
137
|
-
"inputs": self._serialize_inputs(),
|
|
138
|
-
"evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
|
|
139
|
-
"span_type": self.span_type
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
def _serialize_output(self) -> Any:
|
|
143
|
-
"""Helper method to serialize output data safely.
|
|
144
|
-
|
|
145
|
-
Handles special cases:
|
|
146
|
-
- Pydantic models are converted using model_dump()
|
|
147
|
-
- We try to serialize into JSON, then string, then the base representation (__repr__)
|
|
148
|
-
- Non-serializable objects return None with a warning
|
|
149
|
-
"""
|
|
150
|
-
|
|
151
|
-
def safe_stringify(output, function_name):
|
|
152
|
-
"""
|
|
153
|
-
Safely converts an object to a string or repr, handling serialization issues gracefully.
|
|
154
|
-
"""
|
|
155
|
-
try:
|
|
156
|
-
return str(output)
|
|
157
|
-
except (TypeError, OverflowError, ValueError):
|
|
158
|
-
pass
|
|
159
|
-
|
|
160
|
-
try:
|
|
161
|
-
return repr(output)
|
|
162
|
-
except (TypeError, OverflowError, ValueError):
|
|
163
|
-
pass
|
|
164
|
-
|
|
165
|
-
warnings.warn(
|
|
166
|
-
f"Output for function {function_name} is not JSON serializable and could not be converted to string. Setting to None."
|
|
167
|
-
)
|
|
168
|
-
return None
|
|
169
|
-
|
|
170
|
-
if isinstance(self.output, BaseModel):
|
|
171
|
-
return self.output.model_dump()
|
|
172
|
-
|
|
173
|
-
try:
|
|
174
|
-
# Try to serialize the output to verify it's JSON compatible
|
|
175
|
-
json.dumps(self.output)
|
|
176
|
-
return self.output
|
|
177
|
-
except (TypeError, OverflowError, ValueError):
|
|
178
|
-
return safe_stringify(self.output, self.function)
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
class TraceManagerClient:
|
|
182
|
-
"""
|
|
183
|
-
Client for handling trace endpoints with the Judgment API
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
Operations include:
|
|
187
|
-
- Fetching a trace by id
|
|
188
|
-
- Saving a trace
|
|
189
|
-
- Deleting a trace
|
|
190
|
-
"""
|
|
191
|
-
def __init__(self, judgment_api_key: str):
|
|
192
|
-
self.judgment_api_key = judgment_api_key
|
|
193
|
-
|
|
194
|
-
def fetch_trace(self, trace_id: str):
|
|
195
|
-
"""
|
|
196
|
-
Fetch a trace by its id
|
|
197
|
-
"""
|
|
198
|
-
response = requests.post(
|
|
199
|
-
JUDGMENT_TRACES_FETCH_API_URL,
|
|
200
|
-
json={
|
|
201
|
-
"trace_id": trace_id,
|
|
202
|
-
"judgment_api_key": self.judgment_api_key,
|
|
203
|
-
},
|
|
204
|
-
headers={
|
|
205
|
-
"Content-Type": "application/json",
|
|
206
|
-
}
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
if response.status_code != HTTPStatus.OK:
|
|
210
|
-
raise ValueError(f"Failed to fetch traces: {response.text}")
|
|
211
|
-
|
|
212
|
-
return response.json()
|
|
213
|
-
|
|
214
|
-
def save_trace(self, trace_data: dict, empty_save: bool):
|
|
215
|
-
"""
|
|
216
|
-
Saves a trace to the database
|
|
217
|
-
|
|
218
|
-
Args:
|
|
219
|
-
trace_data: The trace data to save
|
|
220
|
-
empty_save: Whether to save an empty trace
|
|
221
|
-
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
|
222
|
-
"""
|
|
223
|
-
response = requests.post(
|
|
224
|
-
JUDGMENT_TRACES_SAVE_API_URL,
|
|
225
|
-
json=trace_data,
|
|
226
|
-
headers={
|
|
227
|
-
"Content-Type": "application/json",
|
|
228
|
-
}
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
if response.status_code == HTTPStatus.BAD_REQUEST:
|
|
232
|
-
raise ValueError(f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}")
|
|
233
|
-
elif response.status_code != HTTPStatus.OK:
|
|
234
|
-
raise ValueError(f"Failed to save trace data: {response.text}")
|
|
235
|
-
|
|
236
|
-
if not empty_save and "ui_results_url" in response.json():
|
|
237
|
-
rprint(f"\n🔍 You can view your trace data here: [rgb(106,0,255)]{response.json()['ui_results_url']}[/]\n")
|
|
238
|
-
|
|
239
|
-
def delete_trace(self, trace_id: str):
|
|
240
|
-
"""
|
|
241
|
-
Delete a trace from the database.
|
|
242
|
-
"""
|
|
243
|
-
response = requests.delete(
|
|
244
|
-
JUDGMENT_TRACES_DELETE_API_URL,
|
|
245
|
-
json={
|
|
246
|
-
"judgment_api_key": self.judgment_api_key,
|
|
247
|
-
"trace_ids": [trace_id],
|
|
248
|
-
},
|
|
249
|
-
headers={
|
|
250
|
-
"Content-Type": "application/json",
|
|
251
|
-
}
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
if response.status_code != HTTPStatus.OK:
|
|
255
|
-
raise ValueError(f"Failed to delete trace: {response.text}")
|
|
256
|
-
|
|
257
|
-
return response.json()
|
|
258
|
-
|
|
259
|
-
def delete_traces(self, trace_ids: List[str]):
|
|
260
|
-
"""
|
|
261
|
-
Delete a batch of traces from the database.
|
|
262
|
-
"""
|
|
263
|
-
response = requests.delete(
|
|
264
|
-
JUDGMENT_TRACES_DELETE_API_URL,
|
|
265
|
-
json={
|
|
266
|
-
"judgment_api_key": self.judgment_api_key,
|
|
267
|
-
"trace_ids": trace_ids,
|
|
268
|
-
},
|
|
269
|
-
headers={
|
|
270
|
-
"Content-Type": "application/json",
|
|
271
|
-
}
|
|
272
|
-
)
|
|
273
|
-
|
|
274
|
-
if response.status_code != HTTPStatus.OK:
|
|
275
|
-
raise ValueError(f"Failed to delete trace: {response.text}")
|
|
276
|
-
|
|
277
|
-
return response.json()
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
class TraceClient:
|
|
281
|
-
"""Client for managing a single trace context"""
|
|
282
|
-
def __init__(self, tracer, trace_id: str, name: str, project_name: str = "default_project", overwrite: bool = False):
|
|
283
|
-
self.tracer = tracer
|
|
284
|
-
self.trace_id = trace_id
|
|
285
|
-
self.name = name
|
|
286
|
-
self.project_name = project_name
|
|
287
|
-
self.client: JudgmentClient = tracer.client
|
|
288
|
-
self.entries: List[TraceEntry] = []
|
|
289
|
-
self.start_time = time.time()
|
|
290
|
-
self.span_type = None
|
|
291
|
-
self._current_span: Optional[TraceEntry] = None
|
|
292
|
-
self.overwrite = overwrite
|
|
293
|
-
self.trace_manager_client = TraceManagerClient(tracer.api_key) # Manages DB operations for trace data
|
|
294
|
-
|
|
295
|
-
@contextmanager
|
|
296
|
-
def span(self, name: str, span_type: SpanType = "span"):
|
|
297
|
-
"""Context manager for creating a trace span"""
|
|
298
|
-
start_time = time.time()
|
|
299
|
-
|
|
300
|
-
# Record span entry
|
|
301
|
-
self.add_entry(TraceEntry(
|
|
302
|
-
type="enter",
|
|
303
|
-
function=name,
|
|
304
|
-
depth=self.tracer.depth,
|
|
305
|
-
message=name,
|
|
306
|
-
timestamp=start_time,
|
|
307
|
-
span_type=span_type
|
|
308
|
-
))
|
|
309
|
-
|
|
310
|
-
# Increment nested depth and set current span
|
|
311
|
-
self.tracer.depth += 1
|
|
312
|
-
prev_span = self._current_span
|
|
313
|
-
self._current_span = name
|
|
314
|
-
|
|
315
|
-
try:
|
|
316
|
-
yield self
|
|
317
|
-
finally:
|
|
318
|
-
self.tracer.depth -= 1
|
|
319
|
-
duration = time.time() - start_time
|
|
320
|
-
|
|
321
|
-
# Record span exit
|
|
322
|
-
self.add_entry(TraceEntry(
|
|
323
|
-
type="exit",
|
|
324
|
-
function=name,
|
|
325
|
-
depth=self.tracer.depth,
|
|
326
|
-
message=f"← {name}",
|
|
327
|
-
timestamp=time.time(),
|
|
328
|
-
duration=duration,
|
|
329
|
-
span_type=span_type
|
|
330
|
-
))
|
|
331
|
-
self._current_span = prev_span
|
|
332
|
-
|
|
333
|
-
def async_evaluate(
|
|
334
|
-
self,
|
|
335
|
-
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
|
336
|
-
input: Optional[str] = None,
|
|
337
|
-
actual_output: Optional[str] = None,
|
|
338
|
-
expected_output: Optional[str] = None,
|
|
339
|
-
context: Optional[List[str]] = None,
|
|
340
|
-
retrieval_context: Optional[List[str]] = None,
|
|
341
|
-
tools_called: Optional[List[str]] = None,
|
|
342
|
-
expected_tools: Optional[List[str]] = None,
|
|
343
|
-
additional_metadata: Optional[Dict[str, Any]] = None,
|
|
344
|
-
model: Optional[str] = None,
|
|
345
|
-
log_results: Optional[bool] = True,
|
|
346
|
-
):
|
|
347
|
-
start_time = time.time() # Record start time
|
|
348
|
-
example = Example(
|
|
349
|
-
input=input,
|
|
350
|
-
actual_output=actual_output,
|
|
351
|
-
expected_output=expected_output,
|
|
352
|
-
context=context,
|
|
353
|
-
retrieval_context=retrieval_context,
|
|
354
|
-
tools_called=tools_called,
|
|
355
|
-
expected_tools=expected_tools,
|
|
356
|
-
additional_metadata=additional_metadata,
|
|
357
|
-
trace_id=self.trace_id
|
|
358
|
-
)
|
|
359
|
-
|
|
360
|
-
try:
|
|
361
|
-
# Load appropriate implementations for all scorers
|
|
362
|
-
loaded_scorers: List[Union[JudgevalScorer, APIJudgmentScorer]] = [
|
|
363
|
-
scorer.load_implementation(use_judgment=True) if isinstance(scorer, ScorerWrapper) else scorer
|
|
364
|
-
for scorer in scorers
|
|
365
|
-
]
|
|
366
|
-
except Exception as e:
|
|
367
|
-
raise ValueError(f"Failed to load scorers: {str(e)}")
|
|
368
|
-
|
|
369
|
-
eval_run = EvaluationRun(
|
|
370
|
-
log_results=log_results,
|
|
371
|
-
project_name=self.project_name,
|
|
372
|
-
eval_name=f"{self.name.capitalize()}-"
|
|
373
|
-
f"{self._current_span}-"
|
|
374
|
-
f"[{','.join(scorer.load_implementation().score_type.capitalize() for scorer in scorers)}]",
|
|
375
|
-
examples=[example],
|
|
376
|
-
scorers=loaded_scorers,
|
|
377
|
-
model=model,
|
|
378
|
-
metadata={},
|
|
379
|
-
judgment_api_key=self.tracer.api_key,
|
|
380
|
-
override=self.overwrite
|
|
381
|
-
)
|
|
382
|
-
|
|
383
|
-
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
|
384
|
-
|
|
385
|
-
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
|
386
|
-
"""
|
|
387
|
-
Add evaluation run data to the trace
|
|
388
|
-
|
|
389
|
-
Args:
|
|
390
|
-
eval_run (EvaluationRun): The evaluation run to add to the trace
|
|
391
|
-
start_time (float): The start time of the evaluation run
|
|
392
|
-
"""
|
|
393
|
-
if self._current_span:
|
|
394
|
-
duration = time.time() - start_time # Calculate duration from start_time
|
|
395
|
-
|
|
396
|
-
self.add_entry(TraceEntry(
|
|
397
|
-
type="evaluation",
|
|
398
|
-
function=self._current_span,
|
|
399
|
-
depth=self.tracer.depth,
|
|
400
|
-
message=f"Evaluation results for {self._current_span}",
|
|
401
|
-
timestamp=time.time(),
|
|
402
|
-
evaluation_runs=[eval_run],
|
|
403
|
-
duration=duration,
|
|
404
|
-
span_type="evaluation"
|
|
405
|
-
))
|
|
406
|
-
|
|
407
|
-
def record_input(self, inputs: dict):
|
|
408
|
-
"""Record input parameters for the current span"""
|
|
409
|
-
if self._current_span:
|
|
410
|
-
self.add_entry(TraceEntry(
|
|
411
|
-
type="input",
|
|
412
|
-
function=self._current_span,
|
|
413
|
-
depth=self.tracer.depth,
|
|
414
|
-
message=f"Inputs to {self._current_span}",
|
|
415
|
-
timestamp=time.time(),
|
|
416
|
-
inputs=inputs,
|
|
417
|
-
span_type=self.span_type
|
|
418
|
-
))
|
|
419
|
-
|
|
420
|
-
async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
|
|
421
|
-
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
|
422
|
-
try:
|
|
423
|
-
result = await coroutine
|
|
424
|
-
entry.output = result
|
|
425
|
-
return result
|
|
426
|
-
except Exception as e:
|
|
427
|
-
entry.output = f"Error: {str(e)}"
|
|
428
|
-
raise
|
|
429
|
-
|
|
430
|
-
def record_output(self, output: Any):
|
|
431
|
-
"""Record output for the current span"""
|
|
432
|
-
if self._current_span:
|
|
433
|
-
entry = TraceEntry(
|
|
434
|
-
type="output",
|
|
435
|
-
function=self._current_span,
|
|
436
|
-
depth=self.tracer.depth,
|
|
437
|
-
message=f"Output from {self._current_span}",
|
|
438
|
-
timestamp=time.time(),
|
|
439
|
-
output="<pending>" if inspect.iscoroutine(output) else output,
|
|
440
|
-
span_type=self.span_type
|
|
441
|
-
)
|
|
442
|
-
self.add_entry(entry)
|
|
443
|
-
|
|
444
|
-
if inspect.iscoroutine(output):
|
|
445
|
-
# Create a task to update the output once the coroutine completes
|
|
446
|
-
asyncio.create_task(self._update_coroutine_output(entry, output))
|
|
447
|
-
|
|
448
|
-
def add_entry(self, entry: TraceEntry):
|
|
449
|
-
"""Add a trace entry to this trace context"""
|
|
450
|
-
self.entries.append(entry)
|
|
451
|
-
return self
|
|
452
|
-
|
|
453
|
-
def print(self):
|
|
454
|
-
"""Print the complete trace with proper visual structure"""
|
|
455
|
-
for entry in self.entries:
|
|
456
|
-
entry.print_entry()
|
|
457
|
-
|
|
458
|
-
def get_duration(self) -> float:
|
|
459
|
-
"""
|
|
460
|
-
Get the total duration of this trace
|
|
461
|
-
"""
|
|
462
|
-
return time.time() - self.start_time
|
|
463
|
-
|
|
464
|
-
def condense_trace(self, entries: List[dict]) -> List[dict]:
|
|
465
|
-
"""
|
|
466
|
-
Condenses trace entries into a single entry for each function call.
|
|
467
|
-
"""
|
|
468
|
-
condensed = []
|
|
469
|
-
active_functions = [] # Stack to track nested function calls
|
|
470
|
-
function_entries = {} # Store entries for each function
|
|
471
|
-
|
|
472
|
-
for entry in entries:
|
|
473
|
-
function = entry["function"]
|
|
474
|
-
|
|
475
|
-
if entry["type"] == "enter":
|
|
476
|
-
# Initialize new function entry
|
|
477
|
-
function_entries[function] = {
|
|
478
|
-
"depth": entry["depth"],
|
|
479
|
-
"function": function,
|
|
480
|
-
"timestamp": entry["timestamp"],
|
|
481
|
-
"inputs": None,
|
|
482
|
-
"output": None,
|
|
483
|
-
"evaluation_runs": [],
|
|
484
|
-
"span_type": entry.get("span_type", "span")
|
|
485
|
-
}
|
|
486
|
-
active_functions.append(function)
|
|
487
|
-
|
|
488
|
-
elif entry["type"] == "exit" and function in active_functions:
|
|
489
|
-
# Complete function entry
|
|
490
|
-
current_entry = function_entries[function]
|
|
491
|
-
current_entry["duration"] = entry["timestamp"] - current_entry["timestamp"]
|
|
492
|
-
condensed.append(current_entry)
|
|
493
|
-
active_functions.remove(function)
|
|
494
|
-
del function_entries[function]
|
|
495
|
-
|
|
496
|
-
elif function in active_functions:
|
|
497
|
-
# Update existing function entry with additional data
|
|
498
|
-
current_entry = function_entries[function]
|
|
499
|
-
|
|
500
|
-
if entry["type"] == "input" and entry["inputs"]:
|
|
501
|
-
current_entry["inputs"] = entry["inputs"]
|
|
502
|
-
|
|
503
|
-
if entry["type"] == "output" and entry["output"]:
|
|
504
|
-
current_entry["output"] = entry["output"]
|
|
505
|
-
|
|
506
|
-
if entry["type"] == "evaluation" and entry["evaluation_runs"]:
|
|
507
|
-
current_entry["evaluation_runs"] = entry["evaluation_runs"]
|
|
508
|
-
|
|
509
|
-
# Sort by timestamp
|
|
510
|
-
condensed.sort(key=lambda x: x["timestamp"])
|
|
511
|
-
return condensed
|
|
512
|
-
|
|
513
|
-
def save(self, empty_save: bool = False, overwrite: bool = False) -> Tuple[str, dict]:
|
|
514
|
-
"""
|
|
515
|
-
Save the current trace to the database.
|
|
516
|
-
Returns a tuple of (trace_id, trace_data) where trace_data is the trace data that was saved.
|
|
517
|
-
"""
|
|
518
|
-
# Calculate total elapsed time
|
|
519
|
-
total_duration = self.get_duration()
|
|
520
|
-
|
|
521
|
-
raw_entries = [entry.to_dict() for entry in self.entries]
|
|
522
|
-
condensed_entries = self.condense_trace(raw_entries)
|
|
523
|
-
|
|
524
|
-
# Calculate total token counts from LLM API calls
|
|
525
|
-
total_prompt_tokens = 0
|
|
526
|
-
total_completion_tokens = 0
|
|
527
|
-
total_tokens = 0
|
|
528
|
-
|
|
529
|
-
for entry in condensed_entries:
|
|
530
|
-
if entry.get("span_type") == "llm" and isinstance(entry.get("output"), dict):
|
|
531
|
-
usage = entry["output"].get("usage", {})
|
|
532
|
-
# Handle OpenAI/Together format
|
|
533
|
-
if "prompt_tokens" in usage:
|
|
534
|
-
total_prompt_tokens += usage.get("prompt_tokens", 0)
|
|
535
|
-
total_completion_tokens += usage.get("completion_tokens", 0)
|
|
536
|
-
# Handle Anthropic format
|
|
537
|
-
elif "input_tokens" in usage:
|
|
538
|
-
total_prompt_tokens += usage.get("input_tokens", 0)
|
|
539
|
-
total_completion_tokens += usage.get("output_tokens", 0)
|
|
540
|
-
total_tokens += usage.get("total_tokens", 0)
|
|
541
|
-
|
|
542
|
-
# Create trace document
|
|
543
|
-
trace_data = {
|
|
544
|
-
"trace_id": self.trace_id,
|
|
545
|
-
"api_key": self.tracer.api_key,
|
|
546
|
-
"name": self.name,
|
|
547
|
-
"project_name": self.project_name,
|
|
548
|
-
"created_at": datetime.fromtimestamp(self.start_time).isoformat(),
|
|
549
|
-
"duration": total_duration,
|
|
550
|
-
"token_counts": {
|
|
551
|
-
"prompt_tokens": total_prompt_tokens,
|
|
552
|
-
"completion_tokens": total_completion_tokens,
|
|
553
|
-
"total_tokens": total_tokens,
|
|
554
|
-
},
|
|
555
|
-
"entries": condensed_entries,
|
|
556
|
-
"empty_save": empty_save,
|
|
557
|
-
"overwrite": overwrite
|
|
558
|
-
}
|
|
559
|
-
|
|
560
|
-
if not empty_save:
|
|
561
|
-
connection = pika.BlockingConnection(
|
|
562
|
-
pika.ConnectionParameters(host=RABBITMQ_HOST, port=RABBITMQ_PORT))
|
|
563
|
-
channel = connection.channel()
|
|
564
|
-
|
|
565
|
-
channel.queue_declare(queue=RABBITMQ_QUEUE, durable=True)
|
|
566
|
-
|
|
567
|
-
channel.basic_publish(
|
|
568
|
-
exchange='',
|
|
569
|
-
routing_key=RABBITMQ_QUEUE,
|
|
570
|
-
body=json.dumps(trace_data),
|
|
571
|
-
properties=pika.BasicProperties(
|
|
572
|
-
delivery_mode=pika.DeliveryMode.Transient # Changed from Persistent to Transient
|
|
573
|
-
))
|
|
574
|
-
connection.close()
|
|
575
|
-
|
|
576
|
-
self.trace_manager_client.save_trace(trace_data, empty_save)
|
|
577
|
-
|
|
578
|
-
return self.trace_id, trace_data
|
|
579
|
-
|
|
580
|
-
def delete(self):
|
|
581
|
-
return self.trace_manager_client.delete_trace(self.trace_id)
|
|
582
|
-
|
|
583
|
-
class Tracer:
|
|
584
|
-
_instance = None
|
|
585
|
-
|
|
586
|
-
def __new__(cls, *args, **kwargs):
|
|
587
|
-
if cls._instance is None:
|
|
588
|
-
cls._instance = super(Tracer, cls).__new__(cls)
|
|
589
|
-
return cls._instance
|
|
590
|
-
|
|
591
|
-
def __init__(self, api_key: str = os.getenv("JUDGMENT_API_KEY")):
|
|
592
|
-
if not hasattr(self, 'initialized'):
|
|
593
|
-
|
|
594
|
-
if not api_key:
|
|
595
|
-
raise ValueError("Tracer must be configured with a Judgment API key")
|
|
596
|
-
|
|
597
|
-
self.api_key: str = api_key
|
|
598
|
-
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
|
599
|
-
self.depth: int = 0
|
|
600
|
-
self._current_trace: Optional[str] = None
|
|
601
|
-
self.initialized: bool = True
|
|
602
|
-
|
|
603
|
-
@contextmanager
|
|
604
|
-
def trace(self, name: str, project_name: str = "default_project", overwrite: bool = False) -> Generator[TraceClient, None, None]:
|
|
605
|
-
"""Start a new trace context using a context manager"""
|
|
606
|
-
trace_id = str(uuid.uuid4())
|
|
607
|
-
trace = TraceClient(self, trace_id, name, project_name=project_name, overwrite=overwrite)
|
|
608
|
-
prev_trace = self._current_trace
|
|
609
|
-
self._current_trace = trace
|
|
610
|
-
|
|
611
|
-
# Automatically create top-level span
|
|
612
|
-
with trace.span(name or "unnamed_trace") as span:
|
|
613
|
-
try:
|
|
614
|
-
# Save the trace to the database to handle Evaluations' trace_id referential integrity
|
|
615
|
-
trace.save(empty_save=True, overwrite=overwrite)
|
|
616
|
-
yield trace
|
|
617
|
-
finally:
|
|
618
|
-
self._current_trace = prev_trace
|
|
619
|
-
|
|
620
|
-
def get_current_trace(self) -> Optional[TraceClient]:
|
|
621
|
-
"""
|
|
622
|
-
Get the current trace context
|
|
623
|
-
"""
|
|
624
|
-
return self._current_trace
|
|
625
|
-
|
|
626
|
-
def observe(self, func=None, *, name=None, span_type: SpanType = "span"):
|
|
627
|
-
"""
|
|
628
|
-
Decorator to trace function execution with detailed entry/exit information.
|
|
629
|
-
|
|
630
|
-
Args:
|
|
631
|
-
func: The function to trace
|
|
632
|
-
name: Optional custom name for the function
|
|
633
|
-
span_type: The type of span to use for this observation (default: "span")
|
|
634
|
-
"""
|
|
635
|
-
if func is None:
|
|
636
|
-
return lambda f: self.observe(f, name=name, span_type=span_type)
|
|
637
|
-
|
|
638
|
-
if asyncio.iscoroutinefunction(func):
|
|
639
|
-
@functools.wraps(func)
|
|
640
|
-
async def async_wrapper(*args, **kwargs):
|
|
641
|
-
if self._current_trace:
|
|
642
|
-
span_name = name or func.__name__
|
|
643
|
-
|
|
644
|
-
with self._current_trace.span(span_name, span_type=span_type) as span:
|
|
645
|
-
# Set the span type
|
|
646
|
-
span.span_type = span_type
|
|
647
|
-
|
|
648
|
-
# Record inputs
|
|
649
|
-
span.record_input({
|
|
650
|
-
'args': list(args),
|
|
651
|
-
'kwargs': kwargs
|
|
652
|
-
})
|
|
653
|
-
|
|
654
|
-
# Execute function
|
|
655
|
-
result = await func(*args, **kwargs)
|
|
656
|
-
|
|
657
|
-
# Record output
|
|
658
|
-
span.record_output(result)
|
|
659
|
-
|
|
660
|
-
return result
|
|
661
|
-
|
|
662
|
-
return await func(*args, **kwargs)
|
|
663
|
-
return async_wrapper
|
|
664
|
-
else:
|
|
665
|
-
@functools.wraps(func)
|
|
666
|
-
def wrapper(*args, **kwargs):
|
|
667
|
-
if self._current_trace:
|
|
668
|
-
span_name = name or func.__name__
|
|
669
|
-
|
|
670
|
-
with self._current_trace.span(span_name, span_type=span_type) as span:
|
|
671
|
-
# Set the span type
|
|
672
|
-
span.span_type = span_type
|
|
673
|
-
|
|
674
|
-
# Record inputs
|
|
675
|
-
span.record_input({
|
|
676
|
-
'args': list(args),
|
|
677
|
-
'kwargs': kwargs
|
|
678
|
-
})
|
|
679
|
-
|
|
680
|
-
# Execute function
|
|
681
|
-
result = func(*args, **kwargs)
|
|
682
|
-
|
|
683
|
-
# Record output
|
|
684
|
-
span.record_output(result)
|
|
685
|
-
|
|
686
|
-
return result
|
|
687
|
-
|
|
688
|
-
return func(*args, **kwargs)
|
|
689
|
-
return wrapper
|
|
690
|
-
|
|
691
|
-
def wrap(client: Any) -> Any:
|
|
692
|
-
"""
|
|
693
|
-
Wraps an API client to add tracing capabilities.
|
|
694
|
-
Supports OpenAI, Together, and Anthropic clients.
|
|
695
|
-
"""
|
|
696
|
-
tracer = Tracer._instance # Get the global tracer instance
|
|
697
|
-
|
|
698
|
-
# Get the appropriate configuration for this client type
|
|
699
|
-
span_name, original_create = _get_client_config(client)
|
|
700
|
-
|
|
701
|
-
def traced_create(*args, **kwargs):
|
|
702
|
-
# Skip tracing if no active trace
|
|
703
|
-
if not (tracer and tracer._current_trace):
|
|
704
|
-
return original_create(*args, **kwargs)
|
|
705
|
-
|
|
706
|
-
with tracer._current_trace.span(span_name, span_type="llm") as span:
|
|
707
|
-
# Format and record the input parameters
|
|
708
|
-
input_data = _format_input_data(client, **kwargs)
|
|
709
|
-
span.record_input(input_data)
|
|
710
|
-
|
|
711
|
-
# Make the actual API call
|
|
712
|
-
response = original_create(*args, **kwargs)
|
|
713
|
-
|
|
714
|
-
# Format and record the output
|
|
715
|
-
output_data = _format_output_data(client, response)
|
|
716
|
-
span.record_output(output_data)
|
|
717
|
-
|
|
718
|
-
return response
|
|
719
|
-
|
|
720
|
-
# Replace the original method with our traced version
|
|
721
|
-
if isinstance(client, (OpenAI, Together)):
|
|
722
|
-
client.chat.completions.create = traced_create
|
|
723
|
-
elif isinstance(client, Anthropic):
|
|
724
|
-
client.messages.create = traced_create
|
|
725
|
-
|
|
726
|
-
return client
|
|
727
|
-
|
|
728
|
-
# Helper functions for client-specific operations
|
|
729
|
-
|
|
730
|
-
def _get_client_config(client: ApiClient) -> tuple[str, callable]:
|
|
731
|
-
"""Returns configuration tuple for the given API client.
|
|
732
|
-
|
|
733
|
-
Args:
|
|
734
|
-
client: An instance of OpenAI, Together, or Anthropic client
|
|
735
|
-
|
|
736
|
-
Returns:
|
|
737
|
-
tuple: (span_name, create_method)
|
|
738
|
-
- span_name: String identifier for tracing
|
|
739
|
-
- create_method: Reference to the client's creation method
|
|
740
|
-
|
|
741
|
-
Raises:
|
|
742
|
-
ValueError: If client type is not supported
|
|
743
|
-
"""
|
|
744
|
-
if isinstance(client, OpenAI):
|
|
745
|
-
return "OPENAI_API_CALL", client.chat.completions.create
|
|
746
|
-
elif isinstance(client, Together):
|
|
747
|
-
return "TOGETHER_API_CALL", client.chat.completions.create
|
|
748
|
-
elif isinstance(client, Anthropic):
|
|
749
|
-
return "ANTHROPIC_API_CALL", client.messages.create
|
|
750
|
-
raise ValueError(f"Unsupported client type: {type(client)}")
|
|
751
|
-
|
|
752
|
-
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
|
753
|
-
"""Format input parameters based on client type.
|
|
754
|
-
|
|
755
|
-
Extracts relevant parameters from kwargs based on the client type
|
|
756
|
-
to ensure consistent tracing across different APIs.
|
|
757
|
-
"""
|
|
758
|
-
if isinstance(client, (OpenAI, Together)):
|
|
759
|
-
return {
|
|
760
|
-
"model": kwargs.get("model"),
|
|
761
|
-
"messages": kwargs.get("messages"),
|
|
762
|
-
}
|
|
763
|
-
# Anthropic requires additional max_tokens parameter
|
|
764
|
-
return {
|
|
765
|
-
"model": kwargs.get("model"),
|
|
766
|
-
"messages": kwargs.get("messages"),
|
|
767
|
-
"max_tokens": kwargs.get("max_tokens")
|
|
768
|
-
}
|
|
769
|
-
|
|
770
|
-
def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
771
|
-
"""Format API response data based on client type.
|
|
772
|
-
|
|
773
|
-
Normalizes different response formats into a consistent structure
|
|
774
|
-
for tracing purposes.
|
|
775
|
-
|
|
776
|
-
Returns:
|
|
777
|
-
dict containing:
|
|
778
|
-
- content: The generated text
|
|
779
|
-
- usage: Token usage statistics
|
|
780
|
-
"""
|
|
781
|
-
if isinstance(client, (OpenAI, Together)):
|
|
782
|
-
return {
|
|
783
|
-
"content": response.choices[0].message.content,
|
|
784
|
-
"usage": {
|
|
785
|
-
"prompt_tokens": response.usage.prompt_tokens,
|
|
786
|
-
"completion_tokens": response.usage.completion_tokens,
|
|
787
|
-
"total_tokens": response.usage.total_tokens
|
|
788
|
-
}
|
|
789
|
-
}
|
|
790
|
-
# Anthropic has a different response structure
|
|
791
|
-
return {
|
|
792
|
-
"content": response.content[0].text,
|
|
793
|
-
"usage": {
|
|
794
|
-
"input_tokens": response.usage.input_tokens,
|
|
795
|
-
"output_tokens": response.usage.output_tokens,
|
|
796
|
-
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
|
|
797
|
-
}
|
|
798
|
-
}
|