judgeval 0.0.43__py3-none-any.whl → 0.0.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +5 -4
- judgeval/clients.py +6 -6
- judgeval/common/__init__.py +7 -2
- judgeval/common/exceptions.py +2 -3
- judgeval/common/logger.py +74 -49
- judgeval/common/s3_storage.py +30 -23
- judgeval/common/tracer.py +1302 -984
- judgeval/common/utils.py +416 -244
- judgeval/constants.py +73 -61
- judgeval/data/__init__.py +1 -1
- judgeval/data/custom_example.py +3 -2
- judgeval/data/datasets/dataset.py +80 -54
- judgeval/data/datasets/eval_dataset_client.py +131 -181
- judgeval/data/example.py +67 -43
- judgeval/data/result.py +11 -9
- judgeval/data/scorer_data.py +4 -2
- judgeval/data/tool.py +25 -16
- judgeval/data/trace.py +57 -29
- judgeval/data/trace_run.py +5 -11
- judgeval/evaluation_run.py +22 -82
- judgeval/integrations/langgraph.py +546 -184
- judgeval/judges/base_judge.py +1 -2
- judgeval/judges/litellm_judge.py +33 -11
- judgeval/judges/mixture_of_judges.py +128 -78
- judgeval/judges/together_judge.py +22 -9
- judgeval/judges/utils.py +14 -5
- judgeval/judgment_client.py +259 -271
- judgeval/rules.py +169 -142
- judgeval/run_evaluation.py +462 -305
- judgeval/scorers/api_scorer.py +20 -11
- judgeval/scorers/exceptions.py +1 -0
- judgeval/scorers/judgeval_scorer.py +77 -58
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +46 -15
- judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +12 -11
- judgeval/scorers/judgeval_scorers/api_scorers/comparison.py +7 -5
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +5 -2
- judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +2 -1
- judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +17 -8
- judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/groundedness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +8 -9
- judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +4 -4
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +5 -5
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +5 -2
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +9 -10
- judgeval/scorers/prompt_scorer.py +48 -37
- judgeval/scorers/score.py +86 -53
- judgeval/scorers/utils.py +11 -7
- judgeval/tracer/__init__.py +1 -1
- judgeval/utils/alerts.py +23 -12
- judgeval/utils/{data_utils.py → file_utils.py} +5 -9
- judgeval/utils/requests.py +29 -0
- judgeval/version_check.py +5 -2
- {judgeval-0.0.43.dist-info → judgeval-0.0.45.dist-info}/METADATA +79 -135
- judgeval-0.0.45.dist-info/RECORD +69 -0
- judgeval-0.0.43.dist-info/RECORD +0 -68
- {judgeval-0.0.43.dist-info → judgeval-0.0.45.dist-info}/WHEEL +0 -0
- {judgeval-0.0.43.dist-info → judgeval-0.0.45.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""
|
2
2
|
Tracing system for judgeval that allows for function tracing using decorators.
|
3
3
|
"""
|
4
|
+
|
4
5
|
# Standard library imports
|
5
6
|
import asyncio
|
6
7
|
import functools
|
@@ -16,9 +17,13 @@ import warnings
|
|
16
17
|
import contextvars
|
17
18
|
import sys
|
18
19
|
import json
|
19
|
-
from contextlib import
|
20
|
-
|
21
|
-
|
20
|
+
from contextlib import (
|
21
|
+
contextmanager,
|
22
|
+
AbstractAsyncContextManager,
|
23
|
+
AbstractContextManager,
|
24
|
+
) # Import context manager bases
|
25
|
+
from dataclasses import dataclass
|
26
|
+
from datetime import datetime
|
22
27
|
from http import HTTPStatus
|
23
28
|
from typing import (
|
24
29
|
Any,
|
@@ -37,9 +42,9 @@ from rich import print as rprint
|
|
37
42
|
import types
|
38
43
|
|
39
44
|
# Third-party imports
|
40
|
-
import
|
45
|
+
from requests import RequestException
|
46
|
+
from judgeval.utils.requests import requests
|
41
47
|
from litellm import cost_per_token as _original_cost_per_token
|
42
|
-
from rich import print as rprint
|
43
48
|
from openai import OpenAI, AsyncOpenAI
|
44
49
|
from together import Together, AsyncTogether
|
45
50
|
from anthropic import Anthropic, AsyncAnthropic
|
@@ -50,12 +55,7 @@ from judgeval.constants import (
|
|
50
55
|
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
51
56
|
JUDGMENT_TRACES_SAVE_API_URL,
|
52
57
|
JUDGMENT_TRACES_UPSERT_API_URL,
|
53
|
-
JUDGMENT_TRACES_USAGE_CHECK_API_URL,
|
54
|
-
JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
|
55
58
|
JUDGMENT_TRACES_FETCH_API_URL,
|
56
|
-
RABBITMQ_HOST,
|
57
|
-
RABBITMQ_PORT,
|
58
|
-
RABBITMQ_QUEUE,
|
59
59
|
JUDGMENT_TRACES_DELETE_API_URL,
|
60
60
|
JUDGMENT_PROJECT_DELETE_API_URL,
|
61
61
|
JUDGMENT_TRACES_SPANS_BATCH_API_URL,
|
@@ -70,32 +70,37 @@ from judgeval.common.exceptions import JudgmentAPIError
|
|
70
70
|
|
71
71
|
# Standard library imports needed for the new class
|
72
72
|
import concurrent.futures
|
73
|
-
from collections.abc import Iterator, AsyncIterator
|
73
|
+
from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
|
74
74
|
import queue
|
75
75
|
import atexit
|
76
76
|
|
77
77
|
# Define context variables for tracking the current trace and the current span within a trace
|
78
|
-
current_trace_var = contextvars.ContextVar[Optional[
|
79
|
-
|
78
|
+
current_trace_var = contextvars.ContextVar[Optional["TraceClient"]](
|
79
|
+
"current_trace", default=None
|
80
|
+
)
|
81
|
+
current_span_var = contextvars.ContextVar[Optional[str]](
|
82
|
+
"current_span", default=None
|
83
|
+
) # ContextVar for the active span id
|
80
84
|
|
81
85
|
# Define type aliases for better code readability and maintainability
|
82
|
-
ApiClient: TypeAlias = Union[
|
83
|
-
|
86
|
+
ApiClient: TypeAlias = Union[
|
87
|
+
OpenAI,
|
88
|
+
Together,
|
89
|
+
Anthropic,
|
90
|
+
AsyncOpenAI,
|
91
|
+
AsyncAnthropic,
|
92
|
+
AsyncTogether,
|
93
|
+
genai.Client,
|
94
|
+
genai.client.AsyncClient,
|
95
|
+
] # Supported API clients
|
96
|
+
SpanType = Literal["span", "tool", "llm", "evaluation", "chain"]
|
84
97
|
|
85
|
-
# --- Evaluation Config Dataclass (Moved from langgraph.py) ---
|
86
|
-
@dataclass
|
87
|
-
class EvaluationConfig:
|
88
|
-
"""Configuration for triggering an evaluation from the handler."""
|
89
|
-
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]]
|
90
|
-
example: Example
|
91
|
-
model: Optional[str] = None
|
92
|
-
log_results: Optional[bool] = True
|
93
|
-
# --- End Evaluation Config Dataclass ---
|
94
98
|
|
95
99
|
# Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
|
96
100
|
@dataclass
|
97
101
|
class TraceAnnotation:
|
98
102
|
"""Represents a single annotation for a trace span."""
|
103
|
+
|
99
104
|
span_id: str
|
100
105
|
text: str
|
101
106
|
label: str
|
@@ -105,24 +110,27 @@ class TraceAnnotation:
|
|
105
110
|
"""Convert the annotation to a dictionary format for storage/transmission."""
|
106
111
|
return {
|
107
112
|
"span_id": self.span_id,
|
108
|
-
"annotation": {
|
109
|
-
"text": self.text,
|
110
|
-
"label": self.label,
|
111
|
-
"score": self.score
|
112
|
-
}
|
113
|
+
"annotation": {"text": self.text, "label": self.label, "score": self.score},
|
113
114
|
}
|
114
|
-
|
115
|
+
|
116
|
+
|
115
117
|
class TraceManagerClient:
|
116
118
|
"""
|
117
119
|
Client for handling trace endpoints with the Judgment API
|
118
|
-
|
120
|
+
|
119
121
|
|
120
122
|
Operations include:
|
121
123
|
- Fetching a trace by id
|
122
124
|
- Saving a trace
|
123
125
|
- Deleting a trace
|
124
126
|
"""
|
125
|
-
|
127
|
+
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
judgment_api_key: str,
|
131
|
+
organization_id: str,
|
132
|
+
tracer: Optional["Tracer"] = None,
|
133
|
+
):
|
126
134
|
self.judgment_api_key = judgment_api_key
|
127
135
|
self.organization_id = organization_id
|
128
136
|
self.tracer = tracer
|
@@ -139,17 +147,19 @@ class TraceManagerClient:
|
|
139
147
|
headers={
|
140
148
|
"Content-Type": "application/json",
|
141
149
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
142
|
-
"X-Organization-Id": self.organization_id
|
150
|
+
"X-Organization-Id": self.organization_id,
|
143
151
|
},
|
144
|
-
verify=True
|
152
|
+
verify=True,
|
145
153
|
)
|
146
154
|
|
147
155
|
if response.status_code != HTTPStatus.OK:
|
148
156
|
raise ValueError(f"Failed to fetch traces: {response.text}")
|
149
|
-
|
157
|
+
|
150
158
|
return response.json()
|
151
159
|
|
152
|
-
def save_trace(
|
160
|
+
def save_trace(
|
161
|
+
self, trace_data: dict, offline_mode: bool = False, final_save: bool = True
|
162
|
+
):
|
153
163
|
"""
|
154
164
|
Saves a trace to the Judgment Supabase and optionally to S3 if configured.
|
155
165
|
|
@@ -158,12 +168,12 @@ class TraceManagerClient:
|
|
158
168
|
offline_mode: Whether running in offline mode
|
159
169
|
final_save: Whether this is the final save (controls S3 saving)
|
160
170
|
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
161
|
-
|
171
|
+
|
162
172
|
Returns:
|
163
173
|
dict: Server response containing UI URL and other metadata
|
164
174
|
"""
|
165
175
|
# Save to Judgment API
|
166
|
-
|
176
|
+
|
167
177
|
def fallback_encoder(obj):
|
168
178
|
"""
|
169
179
|
Custom JSON encoder fallback.
|
@@ -180,7 +190,7 @@ class TraceManagerClient:
|
|
180
190
|
except Exception as e:
|
181
191
|
# If both fail, you might return a placeholder or re-raise
|
182
192
|
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
183
|
-
|
193
|
+
|
184
194
|
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
185
195
|
response = requests.post(
|
186
196
|
JUDGMENT_TRACES_SAVE_API_URL,
|
@@ -188,71 +198,46 @@ class TraceManagerClient:
|
|
188
198
|
headers={
|
189
199
|
"Content-Type": "application/json",
|
190
200
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
191
|
-
"X-Organization-Id": self.organization_id
|
201
|
+
"X-Organization-Id": self.organization_id,
|
192
202
|
},
|
193
|
-
verify=True
|
203
|
+
verify=True,
|
194
204
|
)
|
195
|
-
|
205
|
+
|
196
206
|
if response.status_code == HTTPStatus.BAD_REQUEST:
|
197
|
-
raise ValueError(
|
207
|
+
raise ValueError(
|
208
|
+
f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}"
|
209
|
+
)
|
198
210
|
elif response.status_code != HTTPStatus.OK:
|
199
211
|
raise ValueError(f"Failed to save trace data: {response.text}")
|
200
|
-
|
212
|
+
|
201
213
|
# Parse server response
|
202
214
|
server_response = response.json()
|
203
|
-
|
215
|
+
|
204
216
|
# If S3 storage is enabled, save to S3 only on final save
|
205
217
|
if self.tracer and self.tracer.use_s3 and final_save:
|
206
218
|
try:
|
207
219
|
s3_key = self.tracer.s3_storage.save_trace(
|
208
220
|
trace_data=trace_data,
|
209
221
|
trace_id=trace_data["trace_id"],
|
210
|
-
project_name=trace_data["project_name"]
|
222
|
+
project_name=trace_data["project_name"],
|
211
223
|
)
|
212
224
|
print(f"Trace also saved to S3 at key: {s3_key}")
|
213
225
|
except Exception as e:
|
214
226
|
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
215
|
-
|
227
|
+
|
216
228
|
if not offline_mode and "ui_results_url" in server_response:
|
217
229
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
218
230
|
rprint(pretty_str)
|
219
|
-
|
220
|
-
return server_response
|
221
231
|
|
222
|
-
|
223
|
-
"""
|
224
|
-
Check if the organization can use the requested number of traces without exceeding limits.
|
225
|
-
|
226
|
-
Args:
|
227
|
-
count: Number of traces to check for (default: 1)
|
228
|
-
|
229
|
-
Returns:
|
230
|
-
dict: Server response with rate limit status and usage info
|
231
|
-
|
232
|
-
Raises:
|
233
|
-
ValueError: If rate limits would be exceeded or other errors occur
|
234
|
-
"""
|
235
|
-
response = requests.post(
|
236
|
-
JUDGMENT_TRACES_USAGE_CHECK_API_URL,
|
237
|
-
json={"count": count},
|
238
|
-
headers={
|
239
|
-
"Content-Type": "application/json",
|
240
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
241
|
-
"X-Organization-Id": self.organization_id
|
242
|
-
},
|
243
|
-
verify=True
|
244
|
-
)
|
245
|
-
|
246
|
-
if response.status_code == HTTPStatus.FORBIDDEN:
|
247
|
-
# Rate limits exceeded
|
248
|
-
error_data = response.json()
|
249
|
-
raise ValueError(f"Rate limit exceeded: {error_data.get('detail', 'Monthly trace limit reached')}")
|
250
|
-
elif response.status_code != HTTPStatus.OK:
|
251
|
-
raise ValueError(f"Failed to check usage limits: {response.text}")
|
252
|
-
|
253
|
-
return response.json()
|
232
|
+
return server_response
|
254
233
|
|
255
|
-
def upsert_trace(
|
234
|
+
def upsert_trace(
|
235
|
+
self,
|
236
|
+
trace_data: dict,
|
237
|
+
offline_mode: bool = False,
|
238
|
+
show_link: bool = True,
|
239
|
+
final_save: bool = True,
|
240
|
+
):
|
256
241
|
"""
|
257
242
|
Upserts a trace to the Judgment API (always overwrites if exists).
|
258
243
|
|
@@ -261,10 +246,11 @@ class TraceManagerClient:
|
|
261
246
|
offline_mode: Whether running in offline mode
|
262
247
|
show_link: Whether to show the UI link (for live tracing)
|
263
248
|
final_save: Whether this is the final save (controls S3 saving)
|
264
|
-
|
249
|
+
|
265
250
|
Returns:
|
266
251
|
dict: Server response containing UI URL and other metadata
|
267
252
|
"""
|
253
|
+
|
268
254
|
def fallback_encoder(obj):
|
269
255
|
"""
|
270
256
|
Custom JSON encoder fallback.
|
@@ -277,7 +263,7 @@ class TraceManagerClient:
|
|
277
263
|
return str(obj)
|
278
264
|
except Exception as e:
|
279
265
|
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
280
|
-
|
266
|
+
|
281
267
|
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
282
268
|
|
283
269
|
response = requests.post(
|
@@ -286,63 +272,34 @@ class TraceManagerClient:
|
|
286
272
|
headers={
|
287
273
|
"Content-Type": "application/json",
|
288
274
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
289
|
-
"X-Organization-Id": self.organization_id
|
275
|
+
"X-Organization-Id": self.organization_id,
|
290
276
|
},
|
291
|
-
verify=True
|
277
|
+
verify=True,
|
292
278
|
)
|
293
|
-
|
279
|
+
|
294
280
|
if response.status_code != HTTPStatus.OK:
|
295
281
|
raise ValueError(f"Failed to upsert trace data: {response.text}")
|
296
|
-
|
282
|
+
|
297
283
|
# Parse server response
|
298
284
|
server_response = response.json()
|
299
|
-
|
285
|
+
|
300
286
|
# If S3 storage is enabled, save to S3 only on final save
|
301
287
|
if self.tracer and self.tracer.use_s3 and final_save:
|
302
288
|
try:
|
303
289
|
s3_key = self.tracer.s3_storage.save_trace(
|
304
290
|
trace_data=trace_data,
|
305
291
|
trace_id=trace_data["trace_id"],
|
306
|
-
project_name=trace_data["project_name"]
|
292
|
+
project_name=trace_data["project_name"],
|
307
293
|
)
|
308
294
|
print(f"Trace also saved to S3 at key: {s3_key}")
|
309
295
|
except Exception as e:
|
310
296
|
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
311
|
-
|
297
|
+
|
312
298
|
if not offline_mode and show_link and "ui_results_url" in server_response:
|
313
299
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
314
300
|
rprint(pretty_str)
|
315
|
-
|
316
|
-
return server_response
|
317
301
|
|
318
|
-
|
319
|
-
"""
|
320
|
-
Update trace usage counters after successfully saving traces.
|
321
|
-
|
322
|
-
Args:
|
323
|
-
count: Number of traces to count (default: 1)
|
324
|
-
|
325
|
-
Returns:
|
326
|
-
dict: Server response with updated usage information
|
327
|
-
|
328
|
-
Raises:
|
329
|
-
ValueError: If the update fails
|
330
|
-
"""
|
331
|
-
response = requests.post(
|
332
|
-
JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
|
333
|
-
json={"count": count},
|
334
|
-
headers={
|
335
|
-
"Content-Type": "application/json",
|
336
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
337
|
-
"X-Organization-Id": self.organization_id
|
338
|
-
},
|
339
|
-
verify=True
|
340
|
-
)
|
341
|
-
|
342
|
-
if response.status_code != HTTPStatus.OK:
|
343
|
-
raise ValueError(f"Failed to update usage counters: {response.text}")
|
344
|
-
|
345
|
-
return response.json()
|
302
|
+
return server_response
|
346
303
|
|
347
304
|
## TODO: Should have a log endpoint, endpoint should also support batched payloads
|
348
305
|
def save_annotation(self, annotation: TraceAnnotation):
|
@@ -351,24 +308,24 @@ class TraceManagerClient:
|
|
351
308
|
"annotation": {
|
352
309
|
"text": annotation.text,
|
353
310
|
"label": annotation.label,
|
354
|
-
"score": annotation.score
|
355
|
-
}
|
356
|
-
}
|
311
|
+
"score": annotation.score,
|
312
|
+
},
|
313
|
+
}
|
357
314
|
|
358
315
|
response = requests.post(
|
359
316
|
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
360
317
|
json=json_data,
|
361
318
|
headers={
|
362
|
-
|
363
|
-
|
364
|
-
|
319
|
+
"Content-Type": "application/json",
|
320
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
321
|
+
"X-Organization-Id": self.organization_id,
|
365
322
|
},
|
366
|
-
verify=True
|
323
|
+
verify=True,
|
367
324
|
)
|
368
|
-
|
325
|
+
|
369
326
|
if response.status_code != HTTPStatus.OK:
|
370
327
|
raise ValueError(f"Failed to save annotation: {response.text}")
|
371
|
-
|
328
|
+
|
372
329
|
return response.json()
|
373
330
|
|
374
331
|
def delete_trace(self, trace_id: str):
|
@@ -383,15 +340,15 @@ class TraceManagerClient:
|
|
383
340
|
headers={
|
384
341
|
"Content-Type": "application/json",
|
385
342
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
386
|
-
"X-Organization-Id": self.organization_id
|
387
|
-
}
|
343
|
+
"X-Organization-Id": self.organization_id,
|
344
|
+
},
|
388
345
|
)
|
389
346
|
|
390
347
|
if response.status_code != HTTPStatus.OK:
|
391
348
|
raise ValueError(f"Failed to delete trace: {response.text}")
|
392
|
-
|
349
|
+
|
393
350
|
return response.json()
|
394
|
-
|
351
|
+
|
395
352
|
def delete_traces(self, trace_ids: List[str]):
|
396
353
|
"""
|
397
354
|
Delete a batch of traces from the database.
|
@@ -404,15 +361,15 @@ class TraceManagerClient:
|
|
404
361
|
headers={
|
405
362
|
"Content-Type": "application/json",
|
406
363
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
407
|
-
"X-Organization-Id": self.organization_id
|
408
|
-
}
|
364
|
+
"X-Organization-Id": self.organization_id,
|
365
|
+
},
|
409
366
|
)
|
410
367
|
|
411
368
|
if response.status_code != HTTPStatus.OK:
|
412
369
|
raise ValueError(f"Failed to delete trace: {response.text}")
|
413
|
-
|
370
|
+
|
414
371
|
return response.json()
|
415
|
-
|
372
|
+
|
416
373
|
def delete_project(self, project_name: str):
|
417
374
|
"""
|
418
375
|
Deletes a project from the server. Which also deletes all evaluations and traces associated with the project.
|
@@ -425,31 +382,31 @@ class TraceManagerClient:
|
|
425
382
|
headers={
|
426
383
|
"Content-Type": "application/json",
|
427
384
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
428
|
-
"X-Organization-Id": self.organization_id
|
429
|
-
}
|
385
|
+
"X-Organization-Id": self.organization_id,
|
386
|
+
},
|
430
387
|
)
|
431
388
|
|
432
389
|
if response.status_code != HTTPStatus.OK:
|
433
390
|
raise ValueError(f"Failed to delete traces: {response.text}")
|
434
|
-
|
391
|
+
|
435
392
|
return response.json()
|
436
393
|
|
437
394
|
|
438
395
|
class TraceClient:
|
439
396
|
"""Client for managing a single trace context"""
|
440
|
-
|
397
|
+
|
441
398
|
def __init__(
|
442
399
|
self,
|
443
|
-
tracer:
|
400
|
+
tracer: "Tracer",
|
444
401
|
trace_id: Optional[str] = None,
|
445
402
|
name: str = "default",
|
446
|
-
project_name: str = None,
|
403
|
+
project_name: str | None = None,
|
447
404
|
overwrite: bool = False,
|
448
405
|
rules: Optional[List[Rule]] = None,
|
449
406
|
enable_monitoring: bool = True,
|
450
407
|
enable_evaluations: bool = True,
|
451
408
|
parent_trace_id: Optional[str] = None,
|
452
|
-
parent_name: Optional[str] = None
|
409
|
+
parent_name: Optional[str] = None,
|
453
410
|
):
|
454
411
|
self.name = name
|
455
412
|
self.trace_id = trace_id or str(uuid.uuid4())
|
@@ -461,39 +418,48 @@ class TraceClient:
|
|
461
418
|
self.enable_evaluations = enable_evaluations
|
462
419
|
self.parent_trace_id = parent_trace_id
|
463
420
|
self.parent_name = parent_name
|
421
|
+
self.customer_id: Optional[str] = None # Added customer_id attribute
|
422
|
+
self.tags: List[Union[str, set, tuple]] = [] # Added tags attribute
|
423
|
+
self.metadata: Dict[str, Any] = {}
|
424
|
+
self.has_notification: Optional[bool] = False # Initialize has_notification
|
464
425
|
self.trace_spans: List[TraceSpan] = []
|
465
426
|
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
466
427
|
self.evaluation_runs: List[EvaluationRun] = []
|
467
428
|
self.annotations: List[TraceAnnotation] = []
|
468
|
-
self.start_time =
|
469
|
-
|
470
|
-
|
471
|
-
self.
|
472
|
-
|
473
|
-
|
474
|
-
|
429
|
+
self.start_time: Optional[float] = (
|
430
|
+
None # Will be set after first successful save
|
431
|
+
)
|
432
|
+
self.trace_manager_client = TraceManagerClient(
|
433
|
+
tracer.api_key, tracer.organization_id, tracer
|
434
|
+
)
|
435
|
+
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
436
|
+
|
475
437
|
# Get background span service from tracer
|
476
|
-
self.background_span_service =
|
438
|
+
self.background_span_service = (
|
439
|
+
tracer.get_background_span_service() if tracer else None
|
440
|
+
)
|
477
441
|
|
478
442
|
def get_current_span(self):
|
479
443
|
"""Get the current span from the context var"""
|
480
444
|
return self.tracer.get_current_span()
|
481
|
-
|
445
|
+
|
482
446
|
def set_current_span(self, span: Any):
|
483
447
|
"""Set the current span from the context var"""
|
484
448
|
return self.tracer.set_current_span(span)
|
485
|
-
|
449
|
+
|
486
450
|
def reset_current_span(self, token: Any):
|
487
451
|
"""Reset the current span from the context var"""
|
488
452
|
self.tracer.reset_current_span(token)
|
489
|
-
|
453
|
+
|
490
454
|
@contextmanager
|
491
455
|
def span(self, name: str, span_type: SpanType = "span"):
|
492
456
|
"""Context manager for creating a trace span, managing the current span via contextvars"""
|
493
457
|
is_first_span = len(self.trace_spans) == 0
|
494
458
|
if is_first_span:
|
495
459
|
try:
|
496
|
-
trace_id, server_response = self.
|
460
|
+
trace_id, server_response = self.save(
|
461
|
+
overwrite=self.overwrite, final_save=False
|
462
|
+
)
|
497
463
|
# Set start_time after first successful save
|
498
464
|
if self.start_time is None:
|
499
465
|
self.start_time = time.time()
|
@@ -501,19 +467,23 @@ class TraceClient:
|
|
501
467
|
except Exception as e:
|
502
468
|
warnings.warn(f"Failed to save initial trace for live tracking: {e}")
|
503
469
|
start_time = time.time()
|
504
|
-
|
470
|
+
|
505
471
|
# Generate a unique ID for *this specific span invocation*
|
506
472
|
span_id = str(uuid.uuid4())
|
507
|
-
|
508
|
-
parent_span_id =
|
509
|
-
|
510
|
-
|
473
|
+
|
474
|
+
parent_span_id = (
|
475
|
+
self.get_current_span()
|
476
|
+
) # Get ID of the parent span from context var
|
477
|
+
token = self.set_current_span(
|
478
|
+
span_id
|
479
|
+
) # Set *this* span's ID as the current one
|
480
|
+
|
511
481
|
current_depth = 0
|
512
482
|
if parent_span_id and parent_span_id in self._span_depths:
|
513
483
|
current_depth = self._span_depths[parent_span_id] + 1
|
514
|
-
|
515
|
-
self._span_depths[span_id] = current_depth
|
516
|
-
|
484
|
+
|
485
|
+
self._span_depths[span_id] = current_depth # Store depth by span_id
|
486
|
+
|
517
487
|
span = TraceSpan(
|
518
488
|
span_id=span_id,
|
519
489
|
trace_id=self.trace_id,
|
@@ -525,23 +495,21 @@ class TraceClient:
|
|
525
495
|
function=name,
|
526
496
|
)
|
527
497
|
self.add_span(span)
|
528
|
-
|
529
|
-
|
530
|
-
|
498
|
+
|
531
499
|
# Queue span with initial state (input phase)
|
532
500
|
if self.background_span_service:
|
533
501
|
self.background_span_service.queue_span(span, span_state="input")
|
534
|
-
|
502
|
+
|
535
503
|
try:
|
536
504
|
yield self
|
537
505
|
finally:
|
538
506
|
duration = time.time() - start_time
|
539
507
|
span.duration = duration
|
540
|
-
|
508
|
+
|
541
509
|
# Queue span with completed state (output phase)
|
542
510
|
if self.background_span_service:
|
543
511
|
self.background_span_service.queue_span(span, span_state="completed")
|
544
|
-
|
512
|
+
|
545
513
|
# Clean up depth tracking for this span_id
|
546
514
|
if span_id in self._span_depths:
|
547
515
|
del self._span_depths[span_id]
|
@@ -561,12 +529,11 @@ class TraceClient:
|
|
561
529
|
expected_tools: Optional[List[str]] = None,
|
562
530
|
additional_metadata: Optional[Dict[str, Any]] = None,
|
563
531
|
model: Optional[str] = None,
|
564
|
-
span_id: Optional[str] = None,
|
565
|
-
log_results: Optional[bool] = True
|
532
|
+
span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
|
566
533
|
):
|
567
534
|
if not self.enable_evaluations:
|
568
535
|
return
|
569
|
-
|
536
|
+
|
570
537
|
start_time = time.time() # Record start time
|
571
538
|
|
572
539
|
try:
|
@@ -574,21 +541,35 @@ class TraceClient:
|
|
574
541
|
if not scorers:
|
575
542
|
warnings.warn("No valid scorers available for evaluation")
|
576
543
|
return
|
577
|
-
|
544
|
+
|
578
545
|
# Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
|
579
|
-
if self.rules and any(
|
580
|
-
|
581
|
-
|
546
|
+
if self.rules and any(
|
547
|
+
isinstance(scorer, JudgevalScorer) for scorer in scorers
|
548
|
+
):
|
549
|
+
raise ValueError(
|
550
|
+
"Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types."
|
551
|
+
)
|
552
|
+
|
582
553
|
except Exception as e:
|
583
554
|
warnings.warn(f"Failed to load scorers: {str(e)}")
|
584
555
|
return
|
585
|
-
|
556
|
+
|
586
557
|
# If example is not provided, create one from the individual parameters
|
587
558
|
if example is None:
|
588
559
|
# Check if any of the individual parameters are provided
|
589
|
-
if any(
|
590
|
-
|
591
|
-
|
560
|
+
if any(
|
561
|
+
param is not None
|
562
|
+
for param in [
|
563
|
+
input,
|
564
|
+
actual_output,
|
565
|
+
expected_output,
|
566
|
+
context,
|
567
|
+
retrieval_context,
|
568
|
+
tools_called,
|
569
|
+
expected_tools,
|
570
|
+
additional_metadata,
|
571
|
+
]
|
572
|
+
):
|
592
573
|
example = Example(
|
593
574
|
input=input,
|
594
575
|
actual_output=actual_output,
|
@@ -600,12 +581,14 @@ class TraceClient:
|
|
600
581
|
additional_metadata=additional_metadata,
|
601
582
|
)
|
602
583
|
else:
|
603
|
-
raise ValueError(
|
604
|
-
|
584
|
+
raise ValueError(
|
585
|
+
"Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided"
|
586
|
+
)
|
587
|
+
|
605
588
|
# Check examples before creating evaluation run
|
606
|
-
|
589
|
+
|
607
590
|
# check_examples([example], scorers)
|
608
|
-
|
591
|
+
|
609
592
|
# --- Modification: Capture span_id immediately ---
|
610
593
|
# span_id_at_eval_call = current_span_var.get()
|
611
594
|
# print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
|
@@ -617,37 +600,32 @@ class TraceClient:
|
|
617
600
|
# Combine the trace-level rules with any evaluation-specific rules)
|
618
601
|
eval_run = EvaluationRun(
|
619
602
|
organization_id=self.tracer.organization_id,
|
620
|
-
log_results=log_results,
|
621
603
|
project_name=self.project_name,
|
622
604
|
eval_name=f"{self.name.capitalize()}-"
|
623
|
-
|
624
|
-
|
605
|
+
f"{span_id_to_use}-" # Keep original eval name format using context var if available
|
606
|
+
f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
|
625
607
|
examples=[example],
|
626
608
|
scorers=scorers,
|
627
609
|
model=model,
|
628
|
-
metadata={},
|
629
610
|
judgment_api_key=self.tracer.api_key,
|
630
611
|
override=self.overwrite,
|
631
|
-
trace_span_id=span_id_to_use,
|
632
|
-
rules=self.rules # Use the combined rules
|
612
|
+
trace_span_id=span_id_to_use,
|
633
613
|
)
|
634
|
-
|
614
|
+
|
635
615
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
636
|
-
|
616
|
+
|
637
617
|
# Queue evaluation run through background service
|
638
618
|
if self.background_span_service and span_id_to_use:
|
639
619
|
# Get the current span data to avoid race conditions
|
640
620
|
current_span = self.span_id_to_span.get(span_id_to_use)
|
641
621
|
if current_span:
|
642
622
|
self.background_span_service.queue_evaluation_run(
|
643
|
-
eval_run,
|
644
|
-
span_id=span_id_to_use,
|
645
|
-
span_data=current_span
|
623
|
+
eval_run, span_id=span_id_to_use, span_data=current_span
|
646
624
|
)
|
647
|
-
|
625
|
+
|
648
626
|
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
649
|
-
# --- Modification: Use span_id from eval_run ---
|
650
|
-
current_span_id = eval_run.trace_span_id
|
627
|
+
# --- Modification: Use span_id from eval_run ---
|
628
|
+
current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
|
651
629
|
# print(f"[TraceClient.add_eval_run] Using span_id from eval_run: {current_span_id}")
|
652
630
|
# --- End Modification ---
|
653
631
|
|
@@ -658,10 +636,10 @@ class TraceClient:
|
|
658
636
|
self.evaluation_runs.append(eval_run)
|
659
637
|
|
660
638
|
def add_annotation(self, annotation: TraceAnnotation):
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
639
|
+
"""Add an annotation to this trace context"""
|
640
|
+
self.annotations.append(annotation)
|
641
|
+
return self
|
642
|
+
|
665
643
|
def record_input(self, inputs: dict):
|
666
644
|
current_span_id = self.get_current_span()
|
667
645
|
if current_span_id:
|
@@ -670,17 +648,20 @@ class TraceClient:
|
|
670
648
|
if "self" in inputs:
|
671
649
|
del inputs["self"]
|
672
650
|
span.inputs = inputs
|
673
|
-
|
651
|
+
|
674
652
|
# Queue span with input data
|
675
|
-
|
676
|
-
self.background_span_service
|
677
|
-
|
653
|
+
try:
|
654
|
+
if self.background_span_service:
|
655
|
+
self.background_span_service.queue_span(span, span_state="input")
|
656
|
+
except Exception as e:
|
657
|
+
warnings.warn(f"Failed to queue span with input data: {e}")
|
658
|
+
|
678
659
|
def record_agent_name(self, agent_name: str):
|
679
660
|
current_span_id = self.get_current_span()
|
680
661
|
if current_span_id:
|
681
662
|
span = self.span_id_to_span[current_span_id]
|
682
663
|
span.agent_name = agent_name
|
683
|
-
|
664
|
+
|
684
665
|
# Queue span with agent_name data
|
685
666
|
if self.background_span_service:
|
686
667
|
self.background_span_service.queue_span(span, span_state="agent_name")
|
@@ -695,11 +676,11 @@ class TraceClient:
|
|
695
676
|
if current_span_id:
|
696
677
|
span = self.span_id_to_span[current_span_id]
|
697
678
|
span.state_before = state
|
698
|
-
|
679
|
+
|
699
680
|
# Queue span with state_before data
|
700
681
|
if self.background_span_service:
|
701
682
|
self.background_span_service.queue_span(span, span_state="state_before")
|
702
|
-
|
683
|
+
|
703
684
|
def record_state_after(self, state: dict):
|
704
685
|
"""Records the agent's state after a tool execution on the current span.
|
705
686
|
|
@@ -710,7 +691,7 @@ class TraceClient:
|
|
710
691
|
if current_span_id:
|
711
692
|
span = self.span_id_to_span[current_span_id]
|
712
693
|
span.state_after = state
|
713
|
-
|
694
|
+
|
714
695
|
# Queue span with state_after data
|
715
696
|
if self.background_span_service:
|
716
697
|
self.background_span_service.queue_span(span, span_state="state_after")
|
@@ -720,19 +701,19 @@ class TraceClient:
|
|
720
701
|
try:
|
721
702
|
result = await coroutine
|
722
703
|
setattr(span, field, result)
|
723
|
-
|
704
|
+
|
724
705
|
# Queue span with output data now that coroutine is complete
|
725
706
|
if self.background_span_service and field == "output":
|
726
707
|
self.background_span_service.queue_span(span, span_state="output")
|
727
|
-
|
708
|
+
|
728
709
|
return result
|
729
710
|
except Exception as e:
|
730
711
|
setattr(span, field, f"Error: {str(e)}")
|
731
|
-
|
712
|
+
|
732
713
|
# Queue span even if there was an error
|
733
714
|
if self.background_span_service and field == "output":
|
734
715
|
self.background_span_service.queue_span(span, span_state="output")
|
735
|
-
|
716
|
+
|
736
717
|
raise
|
737
718
|
|
738
719
|
def record_output(self, output: Any):
|
@@ -740,56 +721,56 @@ class TraceClient:
|
|
740
721
|
if current_span_id:
|
741
722
|
span = self.span_id_to_span[current_span_id]
|
742
723
|
span.output = "<pending>" if inspect.iscoroutine(output) else output
|
743
|
-
|
724
|
+
|
744
725
|
if inspect.iscoroutine(output):
|
745
726
|
asyncio.create_task(self._update_coroutine(span, output, "output"))
|
746
|
-
|
727
|
+
|
747
728
|
# # Queue span with output data (unless it's pending)
|
748
729
|
if self.background_span_service and not inspect.iscoroutine(output):
|
749
730
|
self.background_span_service.queue_span(span, span_state="output")
|
750
731
|
|
751
|
-
return span
|
732
|
+
return span # Return the created entry
|
752
733
|
# Removed else block - original didn't have one
|
753
|
-
return None
|
754
|
-
|
734
|
+
return None # Return None if no span_id found
|
735
|
+
|
755
736
|
def record_usage(self, usage: TraceUsage):
|
756
737
|
current_span_id = self.get_current_span()
|
757
738
|
if current_span_id:
|
758
739
|
span = self.span_id_to_span[current_span_id]
|
759
740
|
span.usage = usage
|
760
|
-
|
741
|
+
|
761
742
|
# Queue span with usage data
|
762
743
|
if self.background_span_service:
|
763
744
|
self.background_span_service.queue_span(span, span_state="usage")
|
764
|
-
|
765
|
-
return span
|
745
|
+
|
746
|
+
return span # Return the created entry
|
766
747
|
# Removed else block - original didn't have one
|
767
|
-
return None
|
768
|
-
|
748
|
+
return None # Return None if no span_id found
|
749
|
+
|
769
750
|
def record_error(self, error: Dict[str, Any]):
|
770
751
|
current_span_id = self.get_current_span()
|
771
752
|
if current_span_id:
|
772
753
|
span = self.span_id_to_span[current_span_id]
|
773
754
|
span.error = error
|
774
|
-
|
755
|
+
|
775
756
|
# Queue span with error data
|
776
757
|
if self.background_span_service:
|
777
758
|
self.background_span_service.queue_span(span, span_state="error")
|
778
|
-
|
759
|
+
|
779
760
|
return span
|
780
761
|
return None
|
781
|
-
|
762
|
+
|
782
763
|
def add_span(self, span: TraceSpan):
|
783
764
|
"""Add a trace span to this trace context"""
|
784
765
|
self.trace_spans.append(span)
|
785
766
|
self.span_id_to_span[span.span_id] = span
|
786
767
|
return self
|
787
|
-
|
768
|
+
|
788
769
|
def print(self):
|
789
770
|
"""Print the complete trace with proper visual structure"""
|
790
771
|
for span in self.trace_spans:
|
791
772
|
span.print_span()
|
792
|
-
|
773
|
+
|
793
774
|
def get_duration(self) -> float:
|
794
775
|
"""
|
795
776
|
Get the total duration of this trace
|
@@ -798,57 +779,23 @@ class TraceClient:
|
|
798
779
|
return 0.0 # No duration if trace hasn't been saved yet
|
799
780
|
return time.time() - self.start_time
|
800
781
|
|
801
|
-
def save(
|
802
|
-
|
803
|
-
|
804
|
-
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
805
|
-
"""
|
806
|
-
# Set start_time if this is the first save
|
807
|
-
if self.start_time is None:
|
808
|
-
self.start_time = time.time()
|
809
|
-
|
810
|
-
# Calculate total elapsed time
|
811
|
-
total_duration = self.get_duration()
|
812
|
-
# Create trace document - Always use standard keys for top-level counts
|
813
|
-
trace_data = {
|
814
|
-
"trace_id": self.trace_id,
|
815
|
-
"name": self.name,
|
816
|
-
"project_name": self.project_name,
|
817
|
-
"created_at": datetime.fromtimestamp(self.start_time, timezone.utc).isoformat(),
|
818
|
-
"duration": total_duration,
|
819
|
-
"trace_spans": [span.model_dump() for span in self.trace_spans],
|
820
|
-
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
821
|
-
"overwrite": overwrite,
|
822
|
-
"offline_mode": self.tracer.offline_mode,
|
823
|
-
"parent_trace_id": self.parent_trace_id,
|
824
|
-
"parent_name": self.parent_name
|
825
|
-
}
|
826
|
-
# --- Log trace data before saving ---
|
827
|
-
server_response = self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode, final_save=True)
|
828
|
-
|
829
|
-
# upload annotations
|
830
|
-
# TODO: batch to the log endpoint
|
831
|
-
for annotation in self.annotations:
|
832
|
-
self.trace_manager_client.save_annotation(annotation)
|
833
|
-
|
834
|
-
return self.trace_id, server_response
|
835
|
-
|
836
|
-
def save_with_rate_limiting(self, overwrite: bool = False, final_save: bool = False) -> Tuple[str, dict]:
|
782
|
+
def save(
|
783
|
+
self, overwrite: bool = False, final_save: bool = False
|
784
|
+
) -> Tuple[str, dict]:
|
837
785
|
"""
|
838
786
|
Save the current trace to the database with rate limiting checks.
|
839
787
|
First checks usage limits, then upserts the trace if allowed.
|
840
|
-
|
788
|
+
|
841
789
|
Args:
|
842
790
|
overwrite: Whether to overwrite existing traces
|
843
791
|
final_save: Whether this is the final save (updates usage counters)
|
844
|
-
|
792
|
+
|
845
793
|
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
846
794
|
"""
|
847
795
|
|
848
|
-
|
849
796
|
# Calculate total elapsed time
|
850
797
|
total_duration = self.get_duration()
|
851
|
-
|
798
|
+
|
852
799
|
# Create trace document
|
853
800
|
trace_data = {
|
854
801
|
"trace_id": self.trace_id,
|
@@ -861,35 +808,20 @@ class TraceClient:
|
|
861
808
|
"overwrite": overwrite,
|
862
809
|
"offline_mode": self.tracer.offline_mode,
|
863
810
|
"parent_trace_id": self.parent_trace_id,
|
864
|
-
"parent_name": self.parent_name
|
811
|
+
"parent_name": self.parent_name,
|
812
|
+
"customer_id": self.customer_id,
|
813
|
+
"tags": self.tags,
|
814
|
+
"metadata": self.metadata,
|
865
815
|
}
|
866
|
-
|
867
|
-
# Check usage limits first
|
868
|
-
try:
|
869
|
-
usage_check_result = self.trace_manager_client.check_usage_limits(count=1)
|
870
|
-
# Usage check passed silently - no need to show detailed info
|
871
|
-
except ValueError as e:
|
872
|
-
# Rate limit exceeded
|
873
|
-
warnings.warn(f"Rate limit check failed for live tracing: {e}")
|
874
|
-
raise e
|
875
|
-
|
816
|
+
|
876
817
|
# If usage check passes, upsert the trace
|
877
818
|
server_response = self.trace_manager_client.upsert_trace(
|
878
|
-
trace_data,
|
819
|
+
trace_data,
|
879
820
|
offline_mode=self.tracer.offline_mode,
|
880
821
|
show_link=not final_save, # Show link only on initial save, not final save
|
881
|
-
final_save=final_save # Pass final_save to control S3 saving
|
822
|
+
final_save=final_save, # Pass final_save to control S3 saving
|
882
823
|
)
|
883
824
|
|
884
|
-
# Update usage counters only on final save
|
885
|
-
if final_save:
|
886
|
-
try:
|
887
|
-
usage_update_result = self.trace_manager_client.update_usage_counters(count=1)
|
888
|
-
# Usage updated silently - no need to show detailed usage info
|
889
|
-
except ValueError as e:
|
890
|
-
# Log warning but don't fail the trace save since the trace was already saved
|
891
|
-
warnings.warn(f"Usage counter update failed (trace was still saved): {e}")
|
892
|
-
|
893
825
|
# Upload annotations
|
894
826
|
# TODO: batch to the log endpoint
|
895
827
|
for annotation in self.annotations:
|
@@ -900,8 +832,77 @@ class TraceClient:
|
|
900
832
|
|
901
833
|
def delete(self):
|
902
834
|
return self.trace_manager_client.delete_trace(self.trace_id)
|
903
|
-
|
904
|
-
def
|
835
|
+
|
836
|
+
def update_metadata(self, metadata: dict):
|
837
|
+
"""
|
838
|
+
Set metadata for this trace.
|
839
|
+
|
840
|
+
Args:
|
841
|
+
metadata: Metadata as a dictionary
|
842
|
+
|
843
|
+
Supported keys:
|
844
|
+
- customer_id: ID of the customer using this trace
|
845
|
+
- tags: List of tags for this trace
|
846
|
+
- has_notification: Whether this trace has a notification
|
847
|
+
- overwrite: Whether to overwrite existing traces
|
848
|
+
- name: Name of the trace
|
849
|
+
"""
|
850
|
+
for k, v in metadata.items():
|
851
|
+
if k == "customer_id":
|
852
|
+
if v is not None:
|
853
|
+
self.customer_id = str(v)
|
854
|
+
else:
|
855
|
+
self.customer_id = None
|
856
|
+
elif k == "tags":
|
857
|
+
if isinstance(v, list):
|
858
|
+
# Validate that all items in the list are of the expected types
|
859
|
+
for item in v:
|
860
|
+
if not isinstance(item, (str, set, tuple)):
|
861
|
+
raise ValueError(
|
862
|
+
f"Tags must be a list of strings, sets, or tuples, got item of type {type(item)}"
|
863
|
+
)
|
864
|
+
self.tags = v
|
865
|
+
else:
|
866
|
+
raise ValueError(
|
867
|
+
f"Tags must be a list of strings, sets, or tuples, got {type(v)}"
|
868
|
+
)
|
869
|
+
elif k == "has_notification":
|
870
|
+
if not isinstance(v, bool):
|
871
|
+
raise ValueError(
|
872
|
+
f"has_notification must be a boolean, got {type(v)}"
|
873
|
+
)
|
874
|
+
self.has_notification = v
|
875
|
+
elif k == "overwrite":
|
876
|
+
if not isinstance(v, bool):
|
877
|
+
raise ValueError(f"overwrite must be a boolean, got {type(v)}")
|
878
|
+
self.overwrite = v
|
879
|
+
elif k == "name":
|
880
|
+
self.name = v
|
881
|
+
else:
|
882
|
+
self.metadata[k] = v
|
883
|
+
|
884
|
+
def set_customer_id(self, customer_id: str):
|
885
|
+
"""
|
886
|
+
Set the customer ID for this trace.
|
887
|
+
|
888
|
+
Args:
|
889
|
+
customer_id: The customer ID to set
|
890
|
+
"""
|
891
|
+
self.update_metadata({"customer_id": customer_id})
|
892
|
+
|
893
|
+
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
894
|
+
"""
|
895
|
+
Set the tags for this trace.
|
896
|
+
|
897
|
+
Args:
|
898
|
+
tags: List of tags to set
|
899
|
+
"""
|
900
|
+
self.update_metadata({"tags": tags})
|
901
|
+
|
902
|
+
|
903
|
+
def _capture_exception_for_trace(
|
904
|
+
current_trace: Optional["TraceClient"], exc_info: ExcInfo
|
905
|
+
):
|
905
906
|
if not current_trace:
|
906
907
|
return
|
907
908
|
|
@@ -909,18 +910,20 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
|
|
909
910
|
formatted_exception = {
|
910
911
|
"type": exc_type.__name__ if exc_type else "UnknownExceptionType",
|
911
912
|
"message": str(exc_value) if exc_value else "No exception message",
|
912
|
-
"traceback": traceback.format_tb(exc_traceback_obj)
|
913
|
+
"traceback": traceback.format_tb(exc_traceback_obj)
|
914
|
+
if exc_traceback_obj
|
915
|
+
else [],
|
913
916
|
}
|
914
|
-
|
917
|
+
|
915
918
|
# This is where we specially handle exceptions that we might want to collect additional data for.
|
916
919
|
# When we do this, always try checking the module from sys.modules instead of importing. This will
|
917
920
|
# Let us support a wider range of exceptions without needing to import them for all clients.
|
918
|
-
|
919
|
-
# Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
|
921
|
+
|
922
|
+
# Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
|
920
923
|
# The alternative is to hand select libraries we want from sys.modules and check for them:
|
921
924
|
# As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
|
922
925
|
|
923
|
-
|
926
|
+
# General HTTP Like errors
|
924
927
|
try:
|
925
928
|
url = getattr(getattr(exc_value, "request", None), "url", None)
|
926
929
|
status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
|
@@ -929,34 +932,43 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
|
|
929
932
|
"url": url if url else "Unknown URL",
|
930
933
|
"status_code": status_code if status_code else None,
|
931
934
|
}
|
932
|
-
except Exception
|
935
|
+
except Exception:
|
933
936
|
pass
|
934
937
|
|
935
938
|
current_trace.record_error(formatted_exception)
|
936
|
-
|
939
|
+
|
937
940
|
# Queue the span with error state through background service
|
938
941
|
if current_trace.background_span_service:
|
939
942
|
current_span_id = current_trace.get_current_span()
|
940
943
|
if current_span_id and current_span_id in current_trace.span_id_to_span:
|
941
944
|
error_span = current_trace.span_id_to_span[current_span_id]
|
942
|
-
current_trace.background_span_service.queue_span(
|
945
|
+
current_trace.background_span_service.queue_span(
|
946
|
+
error_span, span_state="error"
|
947
|
+
)
|
948
|
+
|
943
949
|
|
944
950
|
class BackgroundSpanService:
|
945
951
|
"""
|
946
952
|
Background service for queueing and batching trace spans for efficient saving.
|
947
|
-
|
953
|
+
|
948
954
|
This service:
|
949
955
|
- Queues spans as they complete
|
950
956
|
- Batches them for efficient network usage
|
951
957
|
- Sends spans periodically or when batches reach a certain size
|
952
958
|
- Handles automatic flushing when the main event terminates
|
953
959
|
"""
|
954
|
-
|
955
|
-
def __init__(
|
956
|
-
|
960
|
+
|
961
|
+
def __init__(
|
962
|
+
self,
|
963
|
+
judgment_api_key: str,
|
964
|
+
organization_id: str,
|
965
|
+
batch_size: int = 10,
|
966
|
+
flush_interval: float = 5.0,
|
967
|
+
num_workers: int = 1,
|
968
|
+
):
|
957
969
|
"""
|
958
970
|
Initialize the background span service.
|
959
|
-
|
971
|
+
|
960
972
|
Args:
|
961
973
|
judgment_api_key: API key for Judgment service
|
962
974
|
organization_id: Organization ID
|
@@ -969,41 +981,41 @@ class BackgroundSpanService:
|
|
969
981
|
self.batch_size = batch_size
|
970
982
|
self.flush_interval = flush_interval
|
971
983
|
self.num_workers = max(1, num_workers) # Ensure at least 1 worker
|
972
|
-
|
984
|
+
|
973
985
|
# Queue for pending spans
|
974
|
-
self._span_queue = queue.Queue()
|
975
|
-
|
986
|
+
self._span_queue: queue.Queue[Dict[str, Any]] = queue.Queue()
|
987
|
+
|
976
988
|
# Background threads for processing spans
|
977
|
-
self._worker_threads = []
|
989
|
+
self._worker_threads: List[threading.Thread] = []
|
978
990
|
self._shutdown_event = threading.Event()
|
979
|
-
|
991
|
+
|
980
992
|
# Track spans that have been sent
|
981
|
-
self._sent_spans = set()
|
982
|
-
|
993
|
+
# self._sent_spans = set()
|
994
|
+
|
983
995
|
# Register cleanup on exit
|
984
996
|
atexit.register(self.shutdown)
|
985
|
-
|
997
|
+
|
986
998
|
# Start the background workers
|
987
999
|
self._start_workers()
|
988
|
-
|
1000
|
+
|
989
1001
|
def _start_workers(self):
|
990
1002
|
"""Start the background worker threads."""
|
991
1003
|
for i in range(self.num_workers):
|
992
1004
|
if len(self._worker_threads) < self.num_workers:
|
993
1005
|
worker_thread = threading.Thread(
|
994
|
-
target=self._worker_loop,
|
995
|
-
daemon=True,
|
996
|
-
name=f"SpanWorker-{i+1}"
|
1006
|
+
target=self._worker_loop, daemon=True, name=f"SpanWorker-{i + 1}"
|
997
1007
|
)
|
998
1008
|
worker_thread.start()
|
999
1009
|
self._worker_threads.append(worker_thread)
|
1000
|
-
|
1010
|
+
|
1001
1011
|
def _worker_loop(self):
|
1002
1012
|
"""Main worker loop that processes spans in batches."""
|
1003
1013
|
batch = []
|
1004
1014
|
last_flush_time = time.time()
|
1005
|
-
pending_task_count =
|
1006
|
-
|
1015
|
+
pending_task_count = (
|
1016
|
+
0 # Track how many tasks we've taken from queue but not marked done
|
1017
|
+
)
|
1018
|
+
|
1007
1019
|
while not self._shutdown_event.is_set() or self._span_queue.qsize() > 0:
|
1008
1020
|
try:
|
1009
1021
|
# First, do a blocking get to wait for at least one item
|
@@ -1015,7 +1027,7 @@ class BackgroundSpanService:
|
|
1015
1027
|
except queue.Empty:
|
1016
1028
|
# No new spans, continue to check for flush conditions
|
1017
1029
|
pass
|
1018
|
-
|
1030
|
+
|
1019
1031
|
# Then, do non-blocking gets to drain any additional available items
|
1020
1032
|
# up to our batch size limit
|
1021
1033
|
while len(batch) < self.batch_size:
|
@@ -1026,24 +1038,23 @@ class BackgroundSpanService:
|
|
1026
1038
|
except queue.Empty:
|
1027
1039
|
# No more items immediately available
|
1028
1040
|
break
|
1029
|
-
|
1041
|
+
|
1030
1042
|
current_time = time.time()
|
1031
|
-
should_flush = (
|
1032
|
-
|
1033
|
-
(batch and (current_time - last_flush_time) >= self.flush_interval)
|
1043
|
+
should_flush = len(batch) >= self.batch_size or (
|
1044
|
+
batch and (current_time - last_flush_time) >= self.flush_interval
|
1034
1045
|
)
|
1035
|
-
|
1046
|
+
|
1036
1047
|
if should_flush and batch:
|
1037
1048
|
self._send_batch(batch)
|
1038
|
-
|
1049
|
+
|
1039
1050
|
# Only mark tasks as done after successful sending
|
1040
1051
|
for _ in range(pending_task_count):
|
1041
1052
|
self._span_queue.task_done()
|
1042
1053
|
pending_task_count = 0 # Reset counter
|
1043
|
-
|
1054
|
+
|
1044
1055
|
batch.clear()
|
1045
1056
|
last_flush_time = current_time
|
1046
|
-
|
1057
|
+
|
1047
1058
|
except Exception as e:
|
1048
1059
|
warnings.warn(f"Error in span service worker loop: {e}")
|
1049
1060
|
# On error, still need to mark tasks as done to prevent hanging
|
@@ -1051,53 +1062,50 @@ class BackgroundSpanService:
|
|
1051
1062
|
self._span_queue.task_done()
|
1052
1063
|
pending_task_count = 0
|
1053
1064
|
batch.clear()
|
1054
|
-
|
1065
|
+
|
1055
1066
|
# Final flush on shutdown
|
1056
1067
|
if batch:
|
1057
1068
|
self._send_batch(batch)
|
1058
1069
|
# Mark remaining tasks as done
|
1059
1070
|
for _ in range(pending_task_count):
|
1060
1071
|
self._span_queue.task_done()
|
1061
|
-
|
1072
|
+
|
1062
1073
|
def _send_batch(self, batch: List[Dict[str, Any]]):
|
1063
1074
|
"""
|
1064
1075
|
Send a batch of spans to the server.
|
1065
|
-
|
1076
|
+
|
1066
1077
|
Args:
|
1067
1078
|
batch: List of span dictionaries to send
|
1068
1079
|
"""
|
1069
1080
|
if not batch:
|
1070
1081
|
return
|
1071
|
-
|
1082
|
+
|
1072
1083
|
try:
|
1073
|
-
# Group
|
1084
|
+
# Group items by type for different endpoints
|
1074
1085
|
spans_to_send = []
|
1075
1086
|
evaluation_runs_to_send = []
|
1076
|
-
|
1087
|
+
|
1077
1088
|
for item in batch:
|
1078
|
-
if item[
|
1079
|
-
spans_to_send.append(item[
|
1080
|
-
elif item[
|
1081
|
-
evaluation_runs_to_send.append(item[
|
1082
|
-
|
1089
|
+
if item["type"] == "span":
|
1090
|
+
spans_to_send.append(item["data"])
|
1091
|
+
elif item["type"] == "evaluation_run":
|
1092
|
+
evaluation_runs_to_send.append(item["data"])
|
1093
|
+
|
1083
1094
|
# Send spans if any
|
1084
1095
|
if spans_to_send:
|
1085
1096
|
self._send_spans_batch(spans_to_send)
|
1086
|
-
|
1097
|
+
|
1087
1098
|
# Send evaluation runs if any
|
1088
1099
|
if evaluation_runs_to_send:
|
1089
1100
|
self._send_evaluation_runs_batch(evaluation_runs_to_send)
|
1090
|
-
|
1101
|
+
|
1091
1102
|
except Exception as e:
|
1092
|
-
warnings.warn(f"Failed to send
|
1093
|
-
|
1103
|
+
warnings.warn(f"Failed to send batch: {e}")
|
1104
|
+
|
1094
1105
|
def _send_spans_batch(self, spans: List[Dict[str, Any]]):
|
1095
1106
|
"""Send a batch of spans to the spans endpoint."""
|
1096
|
-
payload = {
|
1097
|
-
|
1098
|
-
"organization_id": self.organization_id
|
1099
|
-
}
|
1100
|
-
|
1107
|
+
payload = {"spans": spans, "organization_id": self.organization_id}
|
1108
|
+
|
1101
1109
|
# Serialize with fallback encoder
|
1102
1110
|
def fallback_encoder(obj):
|
1103
1111
|
try:
|
@@ -1107,10 +1115,10 @@ class BackgroundSpanService:
|
|
1107
1115
|
return str(obj)
|
1108
1116
|
except Exception as e:
|
1109
1117
|
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
1110
|
-
|
1118
|
+
|
1111
1119
|
try:
|
1112
1120
|
serialized_data = json.dumps(payload, default=fallback_encoder)
|
1113
|
-
|
1121
|
+
|
1114
1122
|
# Send the actual HTTP request to the batch endpoint
|
1115
1123
|
response = requests.post(
|
1116
1124
|
JUDGMENT_TRACES_SPANS_BATCH_API_URL,
|
@@ -1118,21 +1126,22 @@ class BackgroundSpanService:
|
|
1118
1126
|
headers={
|
1119
1127
|
"Content-Type": "application/json",
|
1120
1128
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
1121
|
-
"X-Organization-Id": self.organization_id
|
1129
|
+
"X-Organization-Id": self.organization_id,
|
1122
1130
|
},
|
1123
1131
|
verify=True,
|
1124
|
-
timeout=30 # Add timeout to prevent hanging
|
1132
|
+
timeout=30, # Add timeout to prevent hanging
|
1125
1133
|
)
|
1126
|
-
|
1134
|
+
|
1127
1135
|
if response.status_code != HTTPStatus.OK:
|
1128
|
-
warnings.warn(
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1136
|
+
warnings.warn(
|
1137
|
+
f"Failed to send spans batch: HTTP {response.status_code} - {response.text}"
|
1138
|
+
)
|
1139
|
+
|
1140
|
+
except RequestException as e:
|
1132
1141
|
warnings.warn(f"Network error sending spans batch: {e}")
|
1133
1142
|
except Exception as e:
|
1134
1143
|
warnings.warn(f"Failed to serialize or send spans batch: {e}")
|
1135
|
-
|
1144
|
+
|
1136
1145
|
def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
|
1137
1146
|
"""Send a batch of evaluation runs with their associated span data to the endpoint."""
|
1138
1147
|
# Structure payload to include both evaluation run data and span data
|
@@ -1142,22 +1151,23 @@ class BackgroundSpanService:
|
|
1142
1151
|
entry = {
|
1143
1152
|
"evaluation_run": {
|
1144
1153
|
# Extract evaluation run fields (excluding span-specific fields)
|
1145
|
-
key: value
|
1146
|
-
|
1154
|
+
key: value
|
1155
|
+
for key, value in eval_data.items()
|
1156
|
+
if key not in ["associated_span_id", "span_data", "queued_at"]
|
1147
1157
|
},
|
1148
1158
|
"associated_span": {
|
1149
|
-
"span_id": eval_data.get(
|
1150
|
-
"span_data": eval_data.get(
|
1159
|
+
"span_id": eval_data.get("associated_span_id"),
|
1160
|
+
"span_data": eval_data.get("span_data"),
|
1151
1161
|
},
|
1152
|
-
"queued_at": eval_data.get(
|
1162
|
+
"queued_at": eval_data.get("queued_at"),
|
1153
1163
|
}
|
1154
1164
|
evaluation_entries.append(entry)
|
1155
|
-
|
1165
|
+
|
1156
1166
|
payload = {
|
1157
1167
|
"organization_id": self.organization_id,
|
1158
|
-
"evaluation_entries": evaluation_entries # Each entry contains both eval run + span data
|
1168
|
+
"evaluation_entries": evaluation_entries, # Each entry contains both eval run + span data
|
1159
1169
|
}
|
1160
|
-
|
1170
|
+
|
1161
1171
|
# Serialize with fallback encoder
|
1162
1172
|
def fallback_encoder(obj):
|
1163
1173
|
try:
|
@@ -1167,10 +1177,10 @@ class BackgroundSpanService:
|
|
1167
1177
|
return str(obj)
|
1168
1178
|
except Exception as e:
|
1169
1179
|
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
1170
|
-
|
1180
|
+
|
1171
1181
|
try:
|
1172
1182
|
serialized_data = json.dumps(payload, default=fallback_encoder)
|
1173
|
-
|
1183
|
+
|
1174
1184
|
# Send the actual HTTP request to the batch endpoint
|
1175
1185
|
response = requests.post(
|
1176
1186
|
JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
|
@@ -1178,25 +1188,26 @@ class BackgroundSpanService:
|
|
1178
1188
|
headers={
|
1179
1189
|
"Content-Type": "application/json",
|
1180
1190
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
1181
|
-
"X-Organization-Id": self.organization_id
|
1191
|
+
"X-Organization-Id": self.organization_id,
|
1182
1192
|
},
|
1183
1193
|
verify=True,
|
1184
|
-
timeout=30 # Add timeout to prevent hanging
|
1194
|
+
timeout=30, # Add timeout to prevent hanging
|
1185
1195
|
)
|
1186
|
-
|
1196
|
+
|
1187
1197
|
if response.status_code != HTTPStatus.OK:
|
1188
|
-
warnings.warn(
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1198
|
+
warnings.warn(
|
1199
|
+
f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}"
|
1200
|
+
)
|
1201
|
+
|
1202
|
+
except RequestException as e:
|
1192
1203
|
warnings.warn(f"Network error sending evaluation runs batch: {e}")
|
1193
1204
|
except Exception as e:
|
1194
1205
|
warnings.warn(f"Failed to send evaluation runs batch: {e}")
|
1195
|
-
|
1206
|
+
|
1196
1207
|
def queue_span(self, span: TraceSpan, span_state: str = "input"):
|
1197
1208
|
"""
|
1198
1209
|
Queue a span for background sending.
|
1199
|
-
|
1210
|
+
|
1200
1211
|
Args:
|
1201
1212
|
span: The TraceSpan object to queue
|
1202
1213
|
span_state: State of the span ("input", "output", "completed")
|
@@ -1207,15 +1218,17 @@ class BackgroundSpanService:
|
|
1207
1218
|
"data": {
|
1208
1219
|
**span.model_dump(),
|
1209
1220
|
"span_state": span_state,
|
1210
|
-
"queued_at": time.time()
|
1211
|
-
}
|
1221
|
+
"queued_at": time.time(),
|
1222
|
+
},
|
1212
1223
|
}
|
1213
1224
|
self._span_queue.put(span_data)
|
1214
|
-
|
1215
|
-
def queue_evaluation_run(
|
1225
|
+
|
1226
|
+
def queue_evaluation_run(
|
1227
|
+
self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan
|
1228
|
+
):
|
1216
1229
|
"""
|
1217
1230
|
Queue an evaluation run for background sending.
|
1218
|
-
|
1231
|
+
|
1219
1232
|
Args:
|
1220
1233
|
evaluation_run: The EvaluationRun object to queue
|
1221
1234
|
span_id: The span ID associated with this evaluation run
|
@@ -1228,11 +1241,11 @@ class BackgroundSpanService:
|
|
1228
1241
|
**evaluation_run.model_dump(),
|
1229
1242
|
"associated_span_id": span_id,
|
1230
1243
|
"span_data": span_data.model_dump(), # Include span data to avoid race conditions
|
1231
|
-
"queued_at": time.time()
|
1232
|
-
}
|
1244
|
+
"queued_at": time.time(),
|
1245
|
+
},
|
1233
1246
|
}
|
1234
1247
|
self._span_queue.put(eval_data)
|
1235
|
-
|
1248
|
+
|
1236
1249
|
def flush(self):
|
1237
1250
|
"""Force immediate sending of all queued spans."""
|
1238
1251
|
try:
|
@@ -1240,89 +1253,101 @@ class BackgroundSpanService:
|
|
1240
1253
|
self._span_queue.join()
|
1241
1254
|
except Exception as e:
|
1242
1255
|
warnings.warn(f"Error during flush: {e}")
|
1243
|
-
|
1256
|
+
|
1244
1257
|
def shutdown(self):
|
1245
1258
|
"""Shutdown the background service and flush remaining spans."""
|
1246
1259
|
if self._shutdown_event.is_set():
|
1247
1260
|
return
|
1248
|
-
|
1261
|
+
|
1249
1262
|
try:
|
1250
1263
|
# Signal shutdown to stop new items from being queued
|
1251
1264
|
self._shutdown_event.set()
|
1252
|
-
|
1265
|
+
|
1253
1266
|
# Try to flush any remaining spans
|
1254
1267
|
try:
|
1255
1268
|
self.flush()
|
1256
1269
|
except Exception as e:
|
1257
|
-
warnings.warn(f"Error during final flush: {e}")
|
1270
|
+
warnings.warn(f"Error during final flush: {e}")
|
1258
1271
|
except Exception as e:
|
1259
1272
|
warnings.warn(f"Error during BackgroundSpanService shutdown: {e}")
|
1260
1273
|
finally:
|
1261
1274
|
# Clear the worker threads list (daemon threads will be killed automatically)
|
1262
1275
|
self._worker_threads.clear()
|
1263
|
-
|
1276
|
+
|
1264
1277
|
def get_queue_size(self) -> int:
|
1265
1278
|
"""Get the current size of the span queue."""
|
1266
1279
|
return self._span_queue.qsize()
|
1267
1280
|
|
1281
|
+
|
1268
1282
|
class _DeepTracer:
|
1269
1283
|
_instance: Optional["_DeepTracer"] = None
|
1270
1284
|
_lock: threading.Lock = threading.Lock()
|
1271
1285
|
_refcount: int = 0
|
1272
|
-
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
|
1273
|
-
|
1286
|
+
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
|
1287
|
+
"_deep_profiler_span_stack", default=[]
|
1288
|
+
)
|
1289
|
+
_skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar(
|
1290
|
+
"_deep_profiler_skip_stack", default=[]
|
1291
|
+
)
|
1274
1292
|
_original_sys_trace: Optional[Callable] = None
|
1275
1293
|
_original_threading_trace: Optional[Callable] = None
|
1276
1294
|
|
1277
|
-
def __init__(self, tracer:
|
1295
|
+
def __init__(self, tracer: "Tracer"):
|
1278
1296
|
self._tracer = tracer
|
1279
1297
|
|
1280
1298
|
def _get_qual_name(self, frame) -> str:
|
1281
1299
|
func_name = frame.f_code.co_name
|
1282
1300
|
module_name = frame.f_globals.get("__name__", "unknown_module")
|
1283
|
-
|
1301
|
+
|
1284
1302
|
try:
|
1285
1303
|
func = frame.f_globals.get(func_name)
|
1286
1304
|
if func is None:
|
1287
1305
|
return f"{module_name}.{func_name}"
|
1288
1306
|
if hasattr(func, "__qualname__"):
|
1289
|
-
|
1307
|
+
return f"{module_name}.{func.__qualname__}"
|
1308
|
+
return f"{module_name}.{func_name}"
|
1290
1309
|
except Exception:
|
1291
1310
|
return f"{module_name}.{func_name}"
|
1292
|
-
|
1293
|
-
def __new__(cls, tracer:
|
1311
|
+
|
1312
|
+
def __new__(cls, tracer: "Tracer"):
|
1294
1313
|
with cls._lock:
|
1295
1314
|
if cls._instance is None:
|
1296
1315
|
cls._instance = super().__new__(cls)
|
1297
1316
|
return cls._instance
|
1298
|
-
|
1317
|
+
|
1299
1318
|
def _should_trace(self, frame):
|
1300
1319
|
# Skip stack is maintained by the tracer as an optimization to skip earlier
|
1301
1320
|
# frames in the call stack that we've already determined should be skipped
|
1302
1321
|
skip_stack = self._skip_stack.get()
|
1303
1322
|
if len(skip_stack) > 0:
|
1304
1323
|
return False
|
1305
|
-
|
1324
|
+
|
1306
1325
|
func_name = frame.f_code.co_name
|
1307
1326
|
module_name = frame.f_globals.get("__name__", None)
|
1308
|
-
|
1309
1327
|
func = frame.f_globals.get(func_name)
|
1310
|
-
if func and (
|
1328
|
+
if func and (
|
1329
|
+
hasattr(func, "_judgment_span_name") or hasattr(func, "_judgment_span_type")
|
1330
|
+
):
|
1311
1331
|
return False
|
1312
1332
|
|
1313
1333
|
if (
|
1314
1334
|
not module_name
|
1315
|
-
or func_name.startswith("<")
|
1316
|
-
or func_name.startswith("__")
|
1335
|
+
or func_name.startswith("<") # ex: <listcomp>
|
1336
|
+
or func_name.startswith("__")
|
1337
|
+
and func_name != "__call__" # dunders
|
1317
1338
|
or not self._is_user_code(frame.f_code.co_filename)
|
1318
1339
|
):
|
1319
1340
|
return False
|
1320
|
-
|
1341
|
+
|
1321
1342
|
return True
|
1322
|
-
|
1343
|
+
|
1323
1344
|
@functools.cache
|
1324
1345
|
def _is_user_code(self, filename: str):
|
1325
|
-
return
|
1346
|
+
return (
|
1347
|
+
bool(filename)
|
1348
|
+
and not filename.startswith("<")
|
1349
|
+
and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
|
1350
|
+
)
|
1326
1351
|
|
1327
1352
|
def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
|
1328
1353
|
"""Cooperative trace function for sys.settrace that chains with existing tracers."""
|
@@ -1334,18 +1359,20 @@ class _DeepTracer:
|
|
1334
1359
|
except Exception:
|
1335
1360
|
# If the original tracer fails, continue with our tracing
|
1336
1361
|
pass
|
1337
|
-
|
1362
|
+
|
1338
1363
|
# Then do our own tracing
|
1339
1364
|
our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
|
1340
|
-
|
1365
|
+
|
1341
1366
|
# Return our tracer to continue tracing, but respect the original's decision
|
1342
1367
|
# If the original tracer returned None (stop tracing), we should respect that
|
1343
1368
|
if original_result is None and self._original_sys_trace:
|
1344
1369
|
return None
|
1345
|
-
|
1370
|
+
|
1346
1371
|
return our_result or original_result
|
1347
|
-
|
1348
|
-
def _cooperative_threading_trace(
|
1372
|
+
|
1373
|
+
def _cooperative_threading_trace(
|
1374
|
+
self, frame: types.FrameType, event: str, arg: Any
|
1375
|
+
):
|
1349
1376
|
"""Cooperative trace function for threading.settrace that chains with existing tracers."""
|
1350
1377
|
# First, call the original threading trace function if it exists
|
1351
1378
|
original_result = None
|
@@ -1355,44 +1382,48 @@ class _DeepTracer:
|
|
1355
1382
|
except Exception:
|
1356
1383
|
# If the original tracer fails, continue with our tracing
|
1357
1384
|
pass
|
1358
|
-
|
1385
|
+
|
1359
1386
|
# Then do our own tracing
|
1360
1387
|
our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
|
1361
|
-
|
1388
|
+
|
1362
1389
|
# Return our tracer to continue tracing, but respect the original's decision
|
1363
1390
|
# If the original tracer returned None (stop tracing), we should respect that
|
1364
1391
|
if original_result is None and self._original_threading_trace:
|
1365
1392
|
return None
|
1366
|
-
|
1393
|
+
|
1367
1394
|
return our_result or original_result
|
1368
|
-
|
1369
|
-
def _trace(
|
1395
|
+
|
1396
|
+
def _trace(
|
1397
|
+
self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable
|
1398
|
+
):
|
1370
1399
|
frame.f_trace_lines = False
|
1371
1400
|
frame.f_trace_opcodes = False
|
1372
1401
|
|
1373
1402
|
if not self._should_trace(frame):
|
1374
1403
|
return
|
1375
|
-
|
1404
|
+
|
1376
1405
|
if event not in ("call", "return", "exception"):
|
1377
1406
|
return
|
1378
|
-
|
1407
|
+
|
1379
1408
|
current_trace = self._tracer.get_current_trace()
|
1380
1409
|
if not current_trace:
|
1381
1410
|
return
|
1382
|
-
|
1411
|
+
|
1383
1412
|
parent_span_id = self._tracer.get_current_span()
|
1384
1413
|
if not parent_span_id:
|
1385
1414
|
return
|
1386
1415
|
|
1387
1416
|
qual_name = self._get_qual_name(frame)
|
1388
1417
|
instance_name = None
|
1389
|
-
if
|
1390
|
-
instance = frame.f_locals[
|
1418
|
+
if "self" in frame.f_locals:
|
1419
|
+
instance = frame.f_locals["self"]
|
1391
1420
|
class_name = instance.__class__.__name__
|
1392
|
-
class_identifiers = getattr(
|
1393
|
-
instance_name = get_instance_prefixed_name(
|
1421
|
+
class_identifiers = getattr(self._tracer, "class_identifiers", {})
|
1422
|
+
instance_name = get_instance_prefixed_name(
|
1423
|
+
instance, class_name, class_identifiers
|
1424
|
+
)
|
1394
1425
|
skip_stack = self._skip_stack.get()
|
1395
|
-
|
1426
|
+
|
1396
1427
|
if event == "call":
|
1397
1428
|
# If we have entries in the skip stack and the current qual_name matches the top entry,
|
1398
1429
|
# push it again to track nesting depth and skip
|
@@ -1402,9 +1433,9 @@ class _DeepTracer:
|
|
1402
1433
|
skip_stack.append(qual_name)
|
1403
1434
|
self._skip_stack.set(skip_stack)
|
1404
1435
|
return
|
1405
|
-
|
1436
|
+
|
1406
1437
|
should_trace = self._should_trace(frame)
|
1407
|
-
|
1438
|
+
|
1408
1439
|
if not should_trace:
|
1409
1440
|
if not skip_stack:
|
1410
1441
|
self._skip_stack.set([qual_name])
|
@@ -1416,35 +1447,37 @@ class _DeepTracer:
|
|
1416
1447
|
skip_stack.pop()
|
1417
1448
|
self._skip_stack.set(skip_stack)
|
1418
1449
|
return
|
1419
|
-
|
1450
|
+
|
1420
1451
|
if skip_stack:
|
1421
1452
|
return
|
1422
|
-
|
1453
|
+
|
1423
1454
|
span_stack = self._span_stack.get()
|
1424
1455
|
if event == "call":
|
1425
1456
|
if not self._should_trace(frame):
|
1426
1457
|
return
|
1427
|
-
|
1458
|
+
|
1428
1459
|
span_id = str(uuid.uuid4())
|
1429
|
-
|
1460
|
+
|
1430
1461
|
parent_depth = current_trace._span_depths.get(parent_span_id, 0)
|
1431
1462
|
depth = parent_depth + 1
|
1432
|
-
|
1463
|
+
|
1433
1464
|
current_trace._span_depths[span_id] = depth
|
1434
|
-
|
1465
|
+
|
1435
1466
|
start_time = time.time()
|
1436
|
-
|
1437
|
-
span_stack.append(
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
|
1467
|
+
|
1468
|
+
span_stack.append(
|
1469
|
+
{
|
1470
|
+
"span_id": span_id,
|
1471
|
+
"parent_span_id": parent_span_id,
|
1472
|
+
"function": qual_name,
|
1473
|
+
"start_time": start_time,
|
1474
|
+
}
|
1475
|
+
)
|
1443
1476
|
self._span_stack.set(span_stack)
|
1444
|
-
|
1477
|
+
|
1445
1478
|
token = self._tracer.set_current_span(span_id)
|
1446
1479
|
frame.f_locals["_judgment_span_token"] = token
|
1447
|
-
|
1480
|
+
|
1448
1481
|
span = TraceSpan(
|
1449
1482
|
span_id=span_id,
|
1450
1483
|
trace_id=current_trace.trace_id,
|
@@ -1454,57 +1487,55 @@ class _DeepTracer:
|
|
1454
1487
|
span_type="span",
|
1455
1488
|
parent_span_id=parent_span_id,
|
1456
1489
|
function=qual_name,
|
1457
|
-
agent_name=instance_name
|
1490
|
+
agent_name=instance_name,
|
1458
1491
|
)
|
1459
1492
|
current_trace.add_span(span)
|
1460
|
-
|
1493
|
+
|
1461
1494
|
inputs = {}
|
1462
1495
|
try:
|
1463
1496
|
args_info = inspect.getargvalues(frame)
|
1464
1497
|
for arg in args_info.args:
|
1465
1498
|
try:
|
1466
1499
|
inputs[arg] = args_info.locals.get(arg)
|
1467
|
-
except:
|
1500
|
+
except Exception:
|
1468
1501
|
inputs[arg] = "<<Unserializable>>"
|
1469
1502
|
current_trace.record_input(inputs)
|
1470
1503
|
except Exception as e:
|
1471
|
-
current_trace.record_input({
|
1472
|
-
|
1473
|
-
})
|
1474
|
-
|
1504
|
+
current_trace.record_input({"error": str(e)})
|
1505
|
+
|
1475
1506
|
elif event == "return":
|
1476
1507
|
if not span_stack:
|
1477
1508
|
return
|
1478
|
-
|
1509
|
+
|
1479
1510
|
current_id = self._tracer.get_current_span()
|
1480
|
-
|
1511
|
+
|
1481
1512
|
span_data = None
|
1482
1513
|
for i, entry in enumerate(reversed(span_stack)):
|
1483
1514
|
if entry["span_id"] == current_id:
|
1484
|
-
span_data = span_stack.pop(-(i+1))
|
1515
|
+
span_data = span_stack.pop(-(i + 1))
|
1485
1516
|
self._span_stack.set(span_stack)
|
1486
1517
|
break
|
1487
|
-
|
1518
|
+
|
1488
1519
|
if not span_data:
|
1489
1520
|
return
|
1490
|
-
|
1521
|
+
|
1491
1522
|
start_time = span_data["start_time"]
|
1492
1523
|
duration = time.time() - start_time
|
1493
|
-
|
1524
|
+
|
1494
1525
|
current_trace.span_id_to_span[span_data["span_id"]].duration = duration
|
1495
1526
|
|
1496
1527
|
if arg is not None:
|
1497
|
-
# exception handling will take priority.
|
1528
|
+
# exception handling will take priority.
|
1498
1529
|
current_trace.record_output(arg)
|
1499
|
-
|
1530
|
+
|
1500
1531
|
if span_data["span_id"] in current_trace._span_depths:
|
1501
1532
|
del current_trace._span_depths[span_data["span_id"]]
|
1502
|
-
|
1533
|
+
|
1503
1534
|
if span_stack:
|
1504
1535
|
self._tracer.set_current_span(span_stack[-1]["span_id"])
|
1505
1536
|
else:
|
1506
1537
|
self._tracer.set_current_span(span_data["parent_span_id"])
|
1507
|
-
|
1538
|
+
|
1508
1539
|
if "_judgment_span_token" in frame.f_locals:
|
1509
1540
|
self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
|
1510
1541
|
|
@@ -1513,10 +1544,9 @@ class _DeepTracer:
|
|
1513
1544
|
if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
|
1514
1545
|
return
|
1515
1546
|
_capture_exception_for_trace(current_trace, arg)
|
1516
|
-
|
1517
|
-
|
1547
|
+
|
1518
1548
|
return continuation_func
|
1519
|
-
|
1549
|
+
|
1520
1550
|
def __enter__(self):
|
1521
1551
|
with self._lock:
|
1522
1552
|
self._refcount += 1
|
@@ -1524,14 +1554,14 @@ class _DeepTracer:
|
|
1524
1554
|
# Store the existing trace functions before setting ours
|
1525
1555
|
self._original_sys_trace = sys.gettrace()
|
1526
1556
|
self._original_threading_trace = threading.gettrace()
|
1527
|
-
|
1557
|
+
|
1528
1558
|
self._skip_stack.set([])
|
1529
1559
|
self._span_stack.set([])
|
1530
|
-
|
1560
|
+
|
1531
1561
|
sys.settrace(self._cooperative_sys_trace)
|
1532
1562
|
threading.settrace(self._cooperative_threading_trace)
|
1533
1563
|
return self
|
1534
|
-
|
1564
|
+
|
1535
1565
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
1536
1566
|
with self._lock:
|
1537
1567
|
self._refcount -= 1
|
@@ -1539,7 +1569,7 @@ class _DeepTracer:
|
|
1539
1569
|
# Restore the original trace functions instead of setting to None
|
1540
1570
|
sys.settrace(self._original_sys_trace)
|
1541
1571
|
threading.settrace(self._original_threading_trace)
|
1542
|
-
|
1572
|
+
|
1543
1573
|
# Clean up the references
|
1544
1574
|
self._original_sys_trace = None
|
1545
1575
|
self._original_threading_trace = None
|
@@ -1547,7 +1577,7 @@ class _DeepTracer:
|
|
1547
1577
|
|
1548
1578
|
# Below commented out function isn't used anymore?
|
1549
1579
|
|
1550
|
-
# def log(self, message: str, level: str = "info"):
|
1580
|
+
# def log(self, message: str, level: str = "info"):
|
1551
1581
|
# """ Log a message with the span context """
|
1552
1582
|
# current_trace = self._tracer.get_current_trace()
|
1553
1583
|
# if current_trace:
|
@@ -1555,31 +1585,29 @@ class _DeepTracer:
|
|
1555
1585
|
# else:
|
1556
1586
|
# print(f"[{level}] {message}")
|
1557
1587
|
# current_trace.record_output({"log": message})
|
1558
|
-
|
1559
|
-
class Tracer:
|
1560
|
-
_instance = None
|
1561
1588
|
|
1589
|
+
|
1590
|
+
class Tracer:
|
1562
1591
|
# Tracer.current_trace class variable is currently used in wrap()
|
1563
1592
|
# TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
|
1564
1593
|
# Should be fine to do so as long as we keep Tracer as a singleton
|
1565
1594
|
current_trace: Optional[TraceClient] = None
|
1566
1595
|
# current_span_id: Optional[str] = None
|
1567
1596
|
|
1568
|
-
trace_across_async_contexts: bool =
|
1569
|
-
|
1570
|
-
|
1571
|
-
if cls._instance is None:
|
1572
|
-
cls._instance = super(Tracer, cls).__new__(cls)
|
1573
|
-
return cls._instance
|
1597
|
+
trace_across_async_contexts: bool = (
|
1598
|
+
False # BY default, we don't trace across async contexts
|
1599
|
+
)
|
1574
1600
|
|
1575
1601
|
def __init__(
|
1576
|
-
self,
|
1577
|
-
api_key: str = os.getenv("JUDGMENT_API_KEY"),
|
1578
|
-
project_name: str = None,
|
1602
|
+
self,
|
1603
|
+
api_key: str | None = os.getenv("JUDGMENT_API_KEY"),
|
1604
|
+
project_name: str | None = None,
|
1579
1605
|
rules: Optional[List[Rule]] = None, # Added rules parameter
|
1580
|
-
organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
|
1581
|
-
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
|
1582
|
-
|
1606
|
+
organization_id: str | None = os.getenv("JUDGMENT_ORG_ID"),
|
1607
|
+
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
|
1608
|
+
== "true",
|
1609
|
+
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
|
1610
|
+
== "true",
|
1583
1611
|
# S3 configuration
|
1584
1612
|
use_s3: bool = False,
|
1585
1613
|
s3_bucket_name: Optional[str] = None,
|
@@ -1587,116 +1615,132 @@ class Tracer:
|
|
1587
1615
|
s3_aws_secret_access_key: Optional[str] = None,
|
1588
1616
|
s3_region_name: Optional[str] = None,
|
1589
1617
|
offline_mode: bool = False,
|
1590
|
-
deep_tracing: bool =
|
1591
|
-
trace_across_async_contexts: bool = False,
|
1618
|
+
deep_tracing: bool = False, # Deep tracing is disabled by default
|
1619
|
+
trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
|
1592
1620
|
# Background span service configuration
|
1593
1621
|
enable_background_spans: bool = True, # Enable background span service by default
|
1594
1622
|
span_batch_size: int = 50, # Number of spans to batch before sending
|
1595
1623
|
span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
|
1596
|
-
span_num_workers: int = 10 # Number of worker threads for span processing
|
1597
|
-
|
1598
|
-
if not
|
1599
|
-
|
1600
|
-
|
1601
|
-
|
1624
|
+
span_num_workers: int = 10, # Number of worker threads for span processing
|
1625
|
+
):
|
1626
|
+
if not api_key:
|
1627
|
+
raise ValueError("Tracer must be configured with a Judgment API key")
|
1628
|
+
|
1629
|
+
try:
|
1602
1630
|
result, response = validate_api_key(api_key)
|
1603
|
-
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
|
1619
|
-
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1623
|
-
|
1624
|
-
|
1625
|
-
|
1626
|
-
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1631
|
+
except Exception as e:
|
1632
|
+
print(f"Issue with verifying API key, disabling monitoring: {e}")
|
1633
|
+
enable_monitoring = False
|
1634
|
+
result = True
|
1635
|
+
|
1636
|
+
if not result:
|
1637
|
+
raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
|
1638
|
+
|
1639
|
+
if not organization_id:
|
1640
|
+
raise ValueError("Tracer must be configured with an Organization ID")
|
1641
|
+
if use_s3 and not s3_bucket_name:
|
1642
|
+
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
1643
|
+
|
1644
|
+
self.api_key: str = api_key
|
1645
|
+
self.project_name: str = project_name or str(uuid.uuid4())
|
1646
|
+
self.organization_id: str = organization_id
|
1647
|
+
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
1648
|
+
self.traces: List[Trace] = []
|
1649
|
+
self.enable_monitoring: bool = enable_monitoring
|
1650
|
+
self.enable_evaluations: bool = enable_evaluations
|
1651
|
+
self.class_identifiers: Dict[
|
1652
|
+
str, str
|
1653
|
+
] = {} # Dictionary to store class identifiers
|
1654
|
+
self.span_id_to_previous_span_id: Dict[str, str | None] = {}
|
1655
|
+
self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
|
1656
|
+
self.current_span_id: Optional[str] = None
|
1657
|
+
self.current_trace: Optional[TraceClient] = None
|
1658
|
+
self.trace_across_async_contexts: bool = trace_across_async_contexts
|
1659
|
+
Tracer.trace_across_async_contexts = trace_across_async_contexts
|
1660
|
+
|
1661
|
+
# Initialize S3 storage if enabled
|
1662
|
+
self.use_s3 = use_s3
|
1663
|
+
if use_s3:
|
1664
|
+
from judgeval.common.s3_storage import S3Storage
|
1665
|
+
|
1666
|
+
try:
|
1631
1667
|
self.s3_storage = S3Storage(
|
1632
1668
|
bucket_name=s3_bucket_name,
|
1633
1669
|
aws_access_key_id=s3_aws_access_key_id,
|
1634
1670
|
aws_secret_access_key=s3_aws_secret_access_key,
|
1635
|
-
region_name=s3_region_name
|
1671
|
+
region_name=s3_region_name,
|
1636
1672
|
)
|
1637
|
-
|
1638
|
-
|
1639
|
-
|
1640
|
-
|
1641
|
-
|
1642
|
-
|
1643
|
-
|
1644
|
-
|
1645
|
-
|
1646
|
-
|
1647
|
-
|
1648
|
-
|
1649
|
-
|
1650
|
-
|
1651
|
-
|
1652
|
-
|
1653
|
-
|
1654
|
-
f"Attempting to initialize Tracer with project_name='{project_name}' but it was already initialized with "
|
1655
|
-
f"project_name='{self.project_name}'. Due to the singleton pattern, the original project_name will be used. "
|
1656
|
-
"To use a different project name, ensure the first Tracer initialization uses the desired project name.",
|
1657
|
-
RuntimeWarning
|
1673
|
+
except Exception as e:
|
1674
|
+
print(f"Issue with initializing S3 storage, disabling S3: {e}")
|
1675
|
+
self.use_s3 = False
|
1676
|
+
|
1677
|
+
self.offline_mode: bool = offline_mode
|
1678
|
+
self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
|
1679
|
+
|
1680
|
+
# Initialize background span service
|
1681
|
+
self.enable_background_spans: bool = enable_background_spans
|
1682
|
+
self.background_span_service: Optional[BackgroundSpanService] = None
|
1683
|
+
if enable_background_spans and not offline_mode:
|
1684
|
+
self.background_span_service = BackgroundSpanService(
|
1685
|
+
judgment_api_key=api_key,
|
1686
|
+
organization_id=organization_id,
|
1687
|
+
batch_size=span_batch_size,
|
1688
|
+
flush_interval=span_flush_interval,
|
1689
|
+
num_workers=span_num_workers,
|
1658
1690
|
)
|
1659
1691
|
|
1660
|
-
def set_current_span(self, span_id: str):
|
1661
|
-
self.span_id_to_previous_span_id[span_id] =
|
1692
|
+
def set_current_span(self, span_id: str) -> Optional[contextvars.Token[str | None]]:
|
1693
|
+
self.span_id_to_previous_span_id[span_id] = self.current_span_id
|
1662
1694
|
self.current_span_id = span_id
|
1663
1695
|
Tracer.current_span_id = span_id
|
1664
1696
|
try:
|
1665
1697
|
token = current_span_var.set(span_id)
|
1666
|
-
except:
|
1698
|
+
except Exception:
|
1667
1699
|
token = None
|
1668
1700
|
return token
|
1669
|
-
|
1701
|
+
|
1670
1702
|
def get_current_span(self) -> Optional[str]:
|
1671
1703
|
try:
|
1672
1704
|
current_span_var_val = current_span_var.get()
|
1673
|
-
except:
|
1705
|
+
except Exception:
|
1674
1706
|
current_span_var_val = None
|
1675
|
-
return (
|
1676
|
-
|
1677
|
-
|
1678
|
-
|
1679
|
-
|
1707
|
+
return (
|
1708
|
+
(self.current_span_id or current_span_var_val)
|
1709
|
+
if self.trace_across_async_contexts
|
1710
|
+
else current_span_var_val
|
1711
|
+
)
|
1712
|
+
|
1713
|
+
def reset_current_span(
|
1714
|
+
self,
|
1715
|
+
token: Optional[contextvars.Token[str | None]] = None,
|
1716
|
+
span_id: Optional[str] = None,
|
1717
|
+
):
|
1680
1718
|
try:
|
1681
|
-
|
1682
|
-
|
1719
|
+
if token:
|
1720
|
+
current_span_var.reset(token)
|
1721
|
+
except Exception:
|
1683
1722
|
pass
|
1684
|
-
|
1685
|
-
|
1686
|
-
|
1687
|
-
|
1723
|
+
if not span_id:
|
1724
|
+
span_id = self.current_span_id
|
1725
|
+
if span_id:
|
1726
|
+
self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
|
1727
|
+
Tracer.current_span_id = self.current_span_id
|
1728
|
+
|
1729
|
+
def set_current_trace(
|
1730
|
+
self, trace: TraceClient
|
1731
|
+
) -> Optional[contextvars.Token[TraceClient | None]]:
|
1688
1732
|
"""
|
1689
1733
|
Set the current trace context in contextvars
|
1690
1734
|
"""
|
1691
|
-
self.trace_id_to_previous_trace[trace.trace_id] =
|
1735
|
+
self.trace_id_to_previous_trace[trace.trace_id] = self.current_trace
|
1692
1736
|
self.current_trace = trace
|
1693
1737
|
Tracer.current_trace = trace
|
1694
1738
|
try:
|
1695
1739
|
token = current_trace_var.set(trace)
|
1696
|
-
except:
|
1740
|
+
except Exception:
|
1697
1741
|
token = None
|
1698
1742
|
return token
|
1699
|
-
|
1743
|
+
|
1700
1744
|
def get_current_trace(self) -> Optional[TraceClient]:
|
1701
1745
|
"""
|
1702
1746
|
Get the current trace context.
|
@@ -1707,72 +1751,69 @@ class Tracer:
|
|
1707
1751
|
"""
|
1708
1752
|
try:
|
1709
1753
|
current_trace_var_val = current_trace_var.get()
|
1710
|
-
except:
|
1754
|
+
except Exception:
|
1711
1755
|
current_trace_var_val = None
|
1712
|
-
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
|
1718
|
-
|
1719
|
-
|
1720
|
-
|
1721
|
-
|
1722
|
-
|
1723
|
-
|
1724
|
-
# If neither is available, return None
|
1725
|
-
return None
|
1726
|
-
|
1727
|
-
def reset_current_trace(self, token: Optional[str] = None, trace_id: Optional[str] = None):
|
1728
|
-
if not trace_id and self.current_trace:
|
1729
|
-
trace_id = self.current_trace.trace_id
|
1756
|
+
return (
|
1757
|
+
(self.current_trace or current_trace_var_val)
|
1758
|
+
if self.trace_across_async_contexts
|
1759
|
+
else current_trace_var_val
|
1760
|
+
)
|
1761
|
+
|
1762
|
+
def reset_current_trace(
|
1763
|
+
self,
|
1764
|
+
token: Optional[contextvars.Token[TraceClient | None]] = None,
|
1765
|
+
trace_id: Optional[str] = None,
|
1766
|
+
):
|
1730
1767
|
try:
|
1731
|
-
|
1732
|
-
|
1768
|
+
if token:
|
1769
|
+
current_trace_var.reset(token)
|
1770
|
+
except Exception:
|
1733
1771
|
pass
|
1734
|
-
|
1735
|
-
|
1772
|
+
if not trace_id and self.current_trace:
|
1773
|
+
trace_id = self.current_trace.trace_id
|
1774
|
+
if trace_id:
|
1775
|
+
self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
|
1776
|
+
Tracer.current_trace = self.current_trace
|
1736
1777
|
|
1737
1778
|
@contextmanager
|
1738
1779
|
def trace(
|
1739
|
-
self,
|
1740
|
-
name: str,
|
1741
|
-
project_name: str = None,
|
1780
|
+
self,
|
1781
|
+
name: str,
|
1782
|
+
project_name: str | None = None,
|
1742
1783
|
overwrite: bool = False,
|
1743
|
-
rules: Optional[List[Rule]] = None # Added rules parameter
|
1784
|
+
rules: Optional[List[Rule]] = None, # Added rules parameter
|
1744
1785
|
) -> Generator[TraceClient, None, None]:
|
1745
1786
|
"""Start a new trace context using a context manager"""
|
1746
1787
|
trace_id = str(uuid.uuid4())
|
1747
1788
|
project = project_name if project_name is not None else self.project_name
|
1748
|
-
|
1789
|
+
|
1749
1790
|
# Get parent trace info from context
|
1750
1791
|
parent_trace = self.get_current_trace()
|
1751
1792
|
parent_trace_id = None
|
1752
1793
|
parent_name = None
|
1753
|
-
|
1794
|
+
|
1754
1795
|
if parent_trace:
|
1755
1796
|
parent_trace_id = parent_trace.trace_id
|
1756
1797
|
parent_name = parent_trace.name
|
1757
1798
|
|
1758
1799
|
trace = TraceClient(
|
1759
|
-
self,
|
1760
|
-
trace_id,
|
1761
|
-
name,
|
1762
|
-
project_name=project,
|
1800
|
+
self,
|
1801
|
+
trace_id,
|
1802
|
+
name,
|
1803
|
+
project_name=project,
|
1763
1804
|
overwrite=overwrite,
|
1764
1805
|
rules=self.rules, # Pass combined rules to the trace client
|
1765
1806
|
enable_monitoring=self.enable_monitoring,
|
1766
1807
|
enable_evaluations=self.enable_evaluations,
|
1767
1808
|
parent_trace_id=parent_trace_id,
|
1768
|
-
parent_name=parent_name
|
1809
|
+
parent_name=parent_name,
|
1769
1810
|
)
|
1770
|
-
|
1811
|
+
|
1771
1812
|
# Set the current trace in context variables
|
1772
1813
|
token = self.set_current_trace(trace)
|
1773
|
-
|
1814
|
+
|
1774
1815
|
# Automatically create top-level span
|
1775
|
-
with trace.span(name or "unnamed_trace")
|
1816
|
+
with trace.span(name or "unnamed_trace"):
|
1776
1817
|
try:
|
1777
1818
|
# Save the trace to the database to handle Evaluations' trace_id referential integrity
|
1778
1819
|
yield trace
|
@@ -1780,101 +1821,110 @@ class Tracer:
|
|
1780
1821
|
# Reset the context variable
|
1781
1822
|
self.reset_current_trace(token)
|
1782
1823
|
|
1783
|
-
|
1784
1824
|
def log(self, msg: str, label: str = "log", score: int = 1):
|
1785
1825
|
"""Log a message with the current span context"""
|
1786
1826
|
current_span_id = self.get_current_span()
|
1787
1827
|
current_trace = self.get_current_trace()
|
1788
|
-
if current_span_id:
|
1828
|
+
if current_span_id and current_trace:
|
1789
1829
|
annotation = TraceAnnotation(
|
1790
|
-
span_id=current_span_id,
|
1791
|
-
text=msg,
|
1792
|
-
label=label,
|
1793
|
-
score=score
|
1830
|
+
span_id=current_span_id, text=msg, label=label, score=score
|
1794
1831
|
)
|
1795
|
-
|
1796
1832
|
current_trace.add_annotation(annotation)
|
1797
1833
|
|
1798
1834
|
rprint(f"[bold]{label}:[/bold] {msg}")
|
1799
|
-
|
1800
|
-
def identify(
|
1835
|
+
|
1836
|
+
def identify(
|
1837
|
+
self,
|
1838
|
+
identifier: str,
|
1839
|
+
track_state: bool = False,
|
1840
|
+
track_attributes: Optional[List[str]] = None,
|
1841
|
+
field_mappings: Optional[Dict[str, str]] = None,
|
1842
|
+
):
|
1801
1843
|
"""
|
1802
1844
|
Class decorator that associates a class with a custom identifier and enables state tracking.
|
1803
|
-
|
1845
|
+
|
1804
1846
|
This decorator creates a mapping between the class name and the provided
|
1805
1847
|
identifier, which can be useful for tagging, grouping, or referencing
|
1806
1848
|
classes in a standardized way. It also enables automatic state capture
|
1807
1849
|
for instances of the decorated class when used with tracing.
|
1808
|
-
|
1850
|
+
|
1809
1851
|
Args:
|
1810
1852
|
identifier: The identifier to associate with the decorated class.
|
1811
1853
|
This will be used as the instance name in traces.
|
1812
|
-
track_state: Whether to automatically capture the state (attributes)
|
1854
|
+
track_state: Whether to automatically capture the state (attributes)
|
1813
1855
|
of instances before and after function execution. Defaults to False.
|
1814
1856
|
track_attributes: Optional list of specific attribute names to track.
|
1815
|
-
If None, all non-private attributes (not starting with '_')
|
1857
|
+
If None, all non-private attributes (not starting with '_')
|
1816
1858
|
will be tracked when track_state=True.
|
1817
|
-
field_mappings: Optional dictionary mapping internal attribute names to
|
1859
|
+
field_mappings: Optional dictionary mapping internal attribute names to
|
1818
1860
|
display names in the captured state. For example:
|
1819
|
-
{"system_prompt": "instructions"} will capture the
|
1861
|
+
{"system_prompt": "instructions"} will capture the
|
1820
1862
|
'instructions' attribute as 'system_prompt' in the state.
|
1821
|
-
|
1863
|
+
|
1822
1864
|
Example:
|
1823
1865
|
@tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
|
1824
1866
|
class User:
|
1825
1867
|
# Class implementation
|
1826
1868
|
"""
|
1869
|
+
|
1827
1870
|
def decorator(cls):
|
1828
1871
|
class_name = cls.__name__
|
1829
1872
|
self.class_identifiers[class_name] = {
|
1830
1873
|
"identifier": identifier,
|
1831
1874
|
"track_state": track_state,
|
1832
1875
|
"track_attributes": track_attributes,
|
1833
|
-
"field_mappings": field_mappings or {}
|
1876
|
+
"field_mappings": field_mappings or {},
|
1834
1877
|
}
|
1835
1878
|
return cls
|
1836
|
-
|
1879
|
+
|
1837
1880
|
return decorator
|
1838
|
-
|
1839
|
-
def _capture_instance_state(
|
1881
|
+
|
1882
|
+
def _capture_instance_state(
|
1883
|
+
self, instance: Any, class_config: Dict[str, Any]
|
1884
|
+
) -> Dict[str, Any]:
|
1840
1885
|
"""
|
1841
1886
|
Capture the state of an instance based on class configuration.
|
1842
1887
|
Args:
|
1843
1888
|
instance: The instance to capture the state of.
|
1844
|
-
class_config: Configuration dictionary for state capture,
|
1889
|
+
class_config: Configuration dictionary for state capture,
|
1845
1890
|
expected to contain 'track_attributes' and 'field_mappings'.
|
1846
1891
|
"""
|
1847
|
-
track_attributes = class_config.get(
|
1848
|
-
field_mappings = class_config.get(
|
1849
|
-
|
1892
|
+
track_attributes = class_config.get("track_attributes")
|
1893
|
+
field_mappings = class_config.get("field_mappings")
|
1894
|
+
|
1850
1895
|
if track_attributes:
|
1851
|
-
|
1852
1896
|
state = {attr: getattr(instance, attr, None) for attr in track_attributes}
|
1853
1897
|
else:
|
1854
|
-
|
1855
|
-
|
1898
|
+
state = {
|
1899
|
+
k: v for k, v in instance.__dict__.items() if not k.startswith("_")
|
1900
|
+
}
|
1856
1901
|
|
1857
1902
|
if field_mappings:
|
1858
|
-
state[
|
1903
|
+
state["field_mappings"] = field_mappings
|
1859
1904
|
|
1860
1905
|
return state
|
1861
|
-
|
1862
|
-
|
1906
|
+
|
1863
1907
|
def _get_instance_state_if_tracked(self, args):
|
1864
1908
|
"""
|
1865
1909
|
Extract instance state if the instance should be tracked.
|
1866
|
-
|
1910
|
+
|
1867
1911
|
Returns the captured state dict if tracking is enabled, None otherwise.
|
1868
1912
|
"""
|
1869
|
-
if args and hasattr(args[0],
|
1913
|
+
if args and hasattr(args[0], "__class__"):
|
1870
1914
|
instance = args[0]
|
1871
1915
|
class_name = instance.__class__.__name__
|
1872
|
-
if (
|
1873
|
-
|
1874
|
-
self.class_identifiers[class_name]
|
1875
|
-
|
1876
|
-
|
1877
|
-
|
1916
|
+
if (
|
1917
|
+
class_name in self.class_identifiers
|
1918
|
+
and isinstance(self.class_identifiers[class_name], dict)
|
1919
|
+
and self.class_identifiers[class_name].get("track_state", False)
|
1920
|
+
):
|
1921
|
+
return self._capture_instance_state(
|
1922
|
+
instance, self.class_identifiers[class_name]
|
1923
|
+
)
|
1924
|
+
|
1925
|
+
def _conditionally_capture_and_record_state(
|
1926
|
+
self, trace_client_instance: TraceClient, args: tuple, is_before: bool
|
1927
|
+
):
|
1878
1928
|
"""Captures instance state if tracked and records it via the trace_client."""
|
1879
1929
|
state = self._get_instance_state_if_tracked(args)
|
1880
1930
|
if state:
|
@@ -1882,11 +1932,20 @@ class Tracer:
|
|
1882
1932
|
trace_client_instance.record_state_before(state)
|
1883
1933
|
else:
|
1884
1934
|
trace_client_instance.record_state_after(state)
|
1885
|
-
|
1886
|
-
def observe(
|
1935
|
+
|
1936
|
+
def observe(
|
1937
|
+
self,
|
1938
|
+
func=None,
|
1939
|
+
*,
|
1940
|
+
name=None,
|
1941
|
+
span_type: SpanType = "span",
|
1942
|
+
project_name: str | None = None,
|
1943
|
+
overwrite: bool = False,
|
1944
|
+
deep_tracing: bool | None = None,
|
1945
|
+
):
|
1887
1946
|
"""
|
1888
1947
|
Decorator to trace function execution with detailed entry/exit information.
|
1889
|
-
|
1948
|
+
|
1890
1949
|
Args:
|
1891
1950
|
func: The function to decorate
|
1892
1951
|
name: Optional custom name for the span (defaults to function name)
|
@@ -1897,56 +1956,71 @@ class Tracer:
|
|
1897
1956
|
If None, uses the tracer's default setting.
|
1898
1957
|
"""
|
1899
1958
|
# If monitoring is disabled, return the function as is
|
1900
|
-
|
1901
|
-
|
1902
|
-
|
1903
|
-
|
1904
|
-
|
1905
|
-
|
1906
|
-
|
1907
|
-
|
1908
|
-
|
1909
|
-
|
1910
|
-
|
1911
|
-
|
1912
|
-
|
1913
|
-
|
1914
|
-
|
1915
|
-
|
1916
|
-
|
1959
|
+
try:
|
1960
|
+
if not self.enable_monitoring:
|
1961
|
+
return func if func else lambda f: f
|
1962
|
+
|
1963
|
+
if func is None:
|
1964
|
+
return lambda f: self.observe(
|
1965
|
+
f,
|
1966
|
+
name=name,
|
1967
|
+
span_type=span_type,
|
1968
|
+
project_name=project_name,
|
1969
|
+
overwrite=overwrite,
|
1970
|
+
deep_tracing=deep_tracing,
|
1971
|
+
)
|
1972
|
+
|
1973
|
+
# Use provided name or fall back to function name
|
1974
|
+
original_span_name = name or func.__name__
|
1975
|
+
|
1976
|
+
# Store custom attributes on the function object
|
1977
|
+
func._judgment_span_name = original_span_name
|
1978
|
+
func._judgment_span_type = span_type
|
1979
|
+
|
1980
|
+
# Use the provided deep_tracing value or fall back to the tracer's default
|
1981
|
+
use_deep_tracing = (
|
1982
|
+
deep_tracing if deep_tracing is not None else self.deep_tracing
|
1983
|
+
)
|
1984
|
+
except Exception:
|
1985
|
+
return func
|
1986
|
+
|
1917
1987
|
if asyncio.iscoroutinefunction(func):
|
1988
|
+
|
1918
1989
|
@functools.wraps(func)
|
1919
1990
|
async def async_wrapper(*args, **kwargs):
|
1920
1991
|
nonlocal original_span_name
|
1921
1992
|
class_name = None
|
1922
|
-
instance_name = None
|
1923
1993
|
span_name = original_span_name
|
1924
1994
|
agent_name = None
|
1925
1995
|
|
1926
|
-
if args and hasattr(args[0],
|
1996
|
+
if args and hasattr(args[0], "__class__"):
|
1927
1997
|
class_name = args[0].__class__.__name__
|
1928
|
-
agent_name = get_instance_prefixed_name(
|
1998
|
+
agent_name = get_instance_prefixed_name(
|
1999
|
+
args[0], class_name, self.class_identifiers
|
2000
|
+
)
|
1929
2001
|
|
1930
2002
|
# Get current trace from context
|
1931
2003
|
current_trace = self.get_current_trace()
|
1932
|
-
|
2004
|
+
|
1933
2005
|
# If there's no current trace, create a root trace
|
1934
2006
|
if not current_trace:
|
1935
2007
|
trace_id = str(uuid.uuid4())
|
1936
|
-
project =
|
1937
|
-
|
2008
|
+
project = (
|
2009
|
+
project_name if project_name is not None else self.project_name
|
2010
|
+
)
|
2011
|
+
|
1938
2012
|
# Create a new trace client to serve as the root
|
1939
2013
|
current_trace = TraceClient(
|
1940
2014
|
self,
|
1941
2015
|
trace_id,
|
1942
|
-
span_name,
|
2016
|
+
span_name, # MODIFIED: Use span_name directly
|
1943
2017
|
project_name=project,
|
1944
2018
|
overwrite=overwrite,
|
1945
2019
|
rules=self.rules,
|
1946
2020
|
enable_monitoring=self.enable_monitoring,
|
1947
|
-
enable_evaluations=self.enable_evaluations
|
2021
|
+
enable_evaluations=self.enable_evaluations,
|
1948
2022
|
)
|
1949
|
-
|
2023
|
+
|
1950
2024
|
# Save empty trace and set trace context
|
1951
2025
|
# current_trace.save(empty_save=True, overwrite=overwrite)
|
1952
2026
|
trace_token = self.set_current_trace(current_trace)
|
@@ -1954,7 +2028,9 @@ class Tracer:
|
|
1954
2028
|
try:
|
1955
2029
|
# Use span for the function execution within the root trace
|
1956
2030
|
# This sets the current_span_var
|
1957
|
-
with current_trace.span(
|
2031
|
+
with current_trace.span(
|
2032
|
+
span_name, span_type=span_type
|
2033
|
+
) as span: # MODIFIED: Use span_name directly
|
1958
2034
|
# Record inputs
|
1959
2035
|
inputs = combine_args_kwargs(func, args, kwargs)
|
1960
2036
|
span.record_input(inputs)
|
@@ -1962,50 +2038,66 @@ class Tracer:
|
|
1962
2038
|
span.record_agent_name(agent_name)
|
1963
2039
|
|
1964
2040
|
# Capture state before execution
|
1965
|
-
self._conditionally_capture_and_record_state(
|
1966
|
-
|
1967
|
-
|
1968
|
-
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
2041
|
+
self._conditionally_capture_and_record_state(
|
2042
|
+
span, args, is_before=True
|
2043
|
+
)
|
2044
|
+
|
2045
|
+
try:
|
2046
|
+
if use_deep_tracing:
|
2047
|
+
with _DeepTracer(self):
|
2048
|
+
result = await func(*args, **kwargs)
|
2049
|
+
else:
|
1972
2050
|
result = await func(*args, **kwargs)
|
1973
|
-
|
1974
|
-
|
1975
|
-
|
1976
|
-
|
1977
|
-
|
1978
|
-
|
1979
|
-
|
2051
|
+
except Exception as e:
|
2052
|
+
_capture_exception_for_trace(
|
2053
|
+
current_trace, sys.exc_info()
|
2054
|
+
)
|
2055
|
+
raise e
|
2056
|
+
|
2057
|
+
# Capture state after execution
|
2058
|
+
self._conditionally_capture_and_record_state(
|
2059
|
+
span, args, is_before=False
|
2060
|
+
)
|
2061
|
+
|
1980
2062
|
# Record output
|
1981
2063
|
span.record_output(result)
|
1982
2064
|
return result
|
1983
2065
|
finally:
|
1984
2066
|
# Flush background spans before saving the trace
|
1985
|
-
|
1986
|
-
|
1987
|
-
|
1988
|
-
|
1989
|
-
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
1993
|
-
|
1994
|
-
|
1995
|
-
|
1996
|
-
|
1997
|
-
|
1998
|
-
|
1999
|
-
|
2000
|
-
|
2001
|
-
|
2002
|
-
|
2003
|
-
|
2004
|
-
|
2005
|
-
|
2006
|
-
|
2007
|
-
|
2008
|
-
|
2067
|
+
try:
|
2068
|
+
complete_trace_data = {
|
2069
|
+
"trace_id": current_trace.trace_id,
|
2070
|
+
"name": current_trace.name,
|
2071
|
+
"created_at": datetime.utcfromtimestamp(
|
2072
|
+
current_trace.start_time
|
2073
|
+
).isoformat(),
|
2074
|
+
"duration": current_trace.get_duration(),
|
2075
|
+
"trace_spans": [
|
2076
|
+
span.model_dump()
|
2077
|
+
for span in current_trace.trace_spans
|
2078
|
+
],
|
2079
|
+
"overwrite": overwrite,
|
2080
|
+
"offline_mode": self.offline_mode,
|
2081
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
2082
|
+
"parent_name": current_trace.parent_name,
|
2083
|
+
}
|
2084
|
+
# Save the completed trace
|
2085
|
+
trace_id, server_response = current_trace.save(
|
2086
|
+
overwrite=overwrite, final_save=True
|
2087
|
+
)
|
2088
|
+
|
2089
|
+
# Store the complete trace data instead of just server response
|
2090
|
+
|
2091
|
+
self.traces.append(complete_trace_data)
|
2092
|
+
|
2093
|
+
# if self.background_span_service:
|
2094
|
+
# self.background_span_service.flush()
|
2095
|
+
|
2096
|
+
# Reset trace context (span context resets automatically)
|
2097
|
+
self.reset_current_trace(trace_token)
|
2098
|
+
except Exception as e:
|
2099
|
+
warnings.warn(f"Issue with async_wrapper: {e}")
|
2100
|
+
return
|
2009
2101
|
else:
|
2010
2102
|
with current_trace.span(span_name, span_type=span_type) as span:
|
2011
2103
|
inputs = combine_args_kwargs(func, args, kwargs)
|
@@ -2014,24 +2106,28 @@ class Tracer:
|
|
2014
2106
|
span.record_agent_name(agent_name)
|
2015
2107
|
|
2016
2108
|
# Capture state before execution
|
2017
|
-
self._conditionally_capture_and_record_state(
|
2018
|
-
|
2019
|
-
|
2020
|
-
|
2021
|
-
|
2022
|
-
|
2023
|
-
|
2109
|
+
self._conditionally_capture_and_record_state(
|
2110
|
+
span, args, is_before=True
|
2111
|
+
)
|
2112
|
+
|
2113
|
+
try:
|
2114
|
+
if use_deep_tracing:
|
2115
|
+
with _DeepTracer(self):
|
2116
|
+
result = await func(*args, **kwargs)
|
2117
|
+
else:
|
2024
2118
|
result = await func(*args, **kwargs)
|
2025
|
-
|
2026
|
-
|
2027
|
-
|
2028
|
-
|
2029
|
-
# Capture state after execution
|
2030
|
-
self._conditionally_capture_and_record_state(
|
2031
|
-
|
2119
|
+
except Exception as e:
|
2120
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
2121
|
+
raise e
|
2122
|
+
|
2123
|
+
# Capture state after execution
|
2124
|
+
self._conditionally_capture_and_record_state(
|
2125
|
+
span, args, is_before=False
|
2126
|
+
)
|
2127
|
+
|
2032
2128
|
span.record_output(result)
|
2033
2129
|
return result
|
2034
|
-
|
2130
|
+
|
2035
2131
|
return async_wrapper
|
2036
2132
|
else:
|
2037
2133
|
# Non-async function implementation with deep tracing
|
@@ -2039,118 +2135,146 @@ class Tracer:
|
|
2039
2135
|
def wrapper(*args, **kwargs):
|
2040
2136
|
nonlocal original_span_name
|
2041
2137
|
class_name = None
|
2042
|
-
instance_name = None
|
2043
2138
|
span_name = original_span_name
|
2044
2139
|
agent_name = None
|
2045
|
-
if args and hasattr(args[0],
|
2140
|
+
if args and hasattr(args[0], "__class__"):
|
2046
2141
|
class_name = args[0].__class__.__name__
|
2047
|
-
agent_name = get_instance_prefixed_name(
|
2142
|
+
agent_name = get_instance_prefixed_name(
|
2143
|
+
args[0], class_name, self.class_identifiers
|
2144
|
+
)
|
2048
2145
|
# Get current trace from context
|
2049
2146
|
current_trace = self.get_current_trace()
|
2050
2147
|
|
2051
2148
|
# If there's no current trace, create a root trace
|
2052
2149
|
if not current_trace:
|
2053
2150
|
trace_id = str(uuid.uuid4())
|
2054
|
-
project =
|
2055
|
-
|
2151
|
+
project = (
|
2152
|
+
project_name if project_name is not None else self.project_name
|
2153
|
+
)
|
2154
|
+
|
2056
2155
|
# Create a new trace client to serve as the root
|
2057
2156
|
current_trace = TraceClient(
|
2058
2157
|
self,
|
2059
2158
|
trace_id,
|
2060
|
-
span_name,
|
2159
|
+
span_name, # MODIFIED: Use span_name directly
|
2061
2160
|
project_name=project,
|
2062
2161
|
overwrite=overwrite,
|
2063
2162
|
rules=self.rules,
|
2064
2163
|
enable_monitoring=self.enable_monitoring,
|
2065
|
-
enable_evaluations=self.enable_evaluations
|
2164
|
+
enable_evaluations=self.enable_evaluations,
|
2066
2165
|
)
|
2067
|
-
|
2166
|
+
|
2068
2167
|
# Save empty trace and set trace context
|
2069
2168
|
# current_trace.save(empty_save=True, overwrite=overwrite)
|
2070
2169
|
trace_token = self.set_current_trace(current_trace)
|
2071
|
-
|
2170
|
+
|
2072
2171
|
try:
|
2073
2172
|
# Use span for the function execution within the root trace
|
2074
2173
|
# This sets the current_span_var
|
2075
|
-
with current_trace.span(
|
2174
|
+
with current_trace.span(
|
2175
|
+
span_name, span_type=span_type
|
2176
|
+
) as span: # MODIFIED: Use span_name directly
|
2076
2177
|
# Record inputs
|
2077
2178
|
inputs = combine_args_kwargs(func, args, kwargs)
|
2078
2179
|
span.record_input(inputs)
|
2079
2180
|
if agent_name:
|
2080
2181
|
span.record_agent_name(agent_name)
|
2081
2182
|
# Capture state before execution
|
2082
|
-
self._conditionally_capture_and_record_state(
|
2083
|
-
|
2084
|
-
|
2085
|
-
|
2086
|
-
|
2087
|
-
|
2088
|
-
|
2183
|
+
self._conditionally_capture_and_record_state(
|
2184
|
+
span, args, is_before=True
|
2185
|
+
)
|
2186
|
+
|
2187
|
+
try:
|
2188
|
+
if use_deep_tracing:
|
2189
|
+
with _DeepTracer(self):
|
2190
|
+
result = func(*args, **kwargs)
|
2191
|
+
else:
|
2089
2192
|
result = func(*args, **kwargs)
|
2090
|
-
|
2091
|
-
|
2092
|
-
|
2093
|
-
|
2193
|
+
except Exception as e:
|
2194
|
+
_capture_exception_for_trace(
|
2195
|
+
current_trace, sys.exc_info()
|
2196
|
+
)
|
2197
|
+
raise e
|
2198
|
+
|
2094
2199
|
# Capture state after execution
|
2095
|
-
self._conditionally_capture_and_record_state(
|
2200
|
+
self._conditionally_capture_and_record_state(
|
2201
|
+
span, args, is_before=False
|
2202
|
+
)
|
2096
2203
|
|
2097
|
-
|
2098
2204
|
# Record output
|
2099
2205
|
span.record_output(result)
|
2100
2206
|
return result
|
2101
2207
|
finally:
|
2102
2208
|
# Flush background spans before saving the trace
|
2103
|
-
|
2104
|
-
|
2105
|
-
|
2106
|
-
|
2107
|
-
|
2108
|
-
|
2109
|
-
|
2110
|
-
|
2111
|
-
|
2112
|
-
|
2113
|
-
|
2114
|
-
|
2115
|
-
|
2116
|
-
|
2117
|
-
|
2118
|
-
|
2119
|
-
|
2120
|
-
|
2121
|
-
|
2122
|
-
|
2209
|
+
try:
|
2210
|
+
# Save the completed trace
|
2211
|
+
trace_id, server_response = current_trace.save(
|
2212
|
+
overwrite=overwrite, final_save=True
|
2213
|
+
)
|
2214
|
+
|
2215
|
+
# Store the complete trace data instead of just server response
|
2216
|
+
complete_trace_data = {
|
2217
|
+
"trace_id": current_trace.trace_id,
|
2218
|
+
"name": current_trace.name,
|
2219
|
+
"created_at": datetime.utcfromtimestamp(
|
2220
|
+
current_trace.start_time
|
2221
|
+
).isoformat(),
|
2222
|
+
"duration": current_trace.get_duration(),
|
2223
|
+
"trace_spans": [
|
2224
|
+
span.model_dump()
|
2225
|
+
for span in current_trace.trace_spans
|
2226
|
+
],
|
2227
|
+
"overwrite": overwrite,
|
2228
|
+
"offline_mode": self.offline_mode,
|
2229
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
2230
|
+
"parent_name": current_trace.parent_name,
|
2231
|
+
}
|
2232
|
+
self.traces.append(complete_trace_data)
|
2233
|
+
# Reset trace context (span context resets automatically)
|
2234
|
+
self.reset_current_trace(trace_token)
|
2235
|
+
except Exception as e:
|
2236
|
+
warnings.warn(f"Issue with save: {e}")
|
2237
|
+
return
|
2123
2238
|
else:
|
2124
2239
|
with current_trace.span(span_name, span_type=span_type) as span:
|
2125
|
-
|
2126
2240
|
inputs = combine_args_kwargs(func, args, kwargs)
|
2127
2241
|
span.record_input(inputs)
|
2128
2242
|
if agent_name:
|
2129
2243
|
span.record_agent_name(agent_name)
|
2130
2244
|
|
2131
2245
|
# Capture state before execution
|
2132
|
-
self._conditionally_capture_and_record_state(
|
2246
|
+
self._conditionally_capture_and_record_state(
|
2247
|
+
span, args, is_before=True
|
2248
|
+
)
|
2133
2249
|
|
2134
|
-
|
2135
|
-
|
2136
|
-
|
2137
|
-
|
2138
|
-
|
2250
|
+
try:
|
2251
|
+
if use_deep_tracing:
|
2252
|
+
with _DeepTracer(self):
|
2253
|
+
result = func(*args, **kwargs)
|
2254
|
+
else:
|
2139
2255
|
result = func(*args, **kwargs)
|
2140
|
-
|
2141
|
-
|
2142
|
-
|
2143
|
-
|
2256
|
+
except Exception as e:
|
2257
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
2258
|
+
raise e
|
2259
|
+
|
2144
2260
|
# Capture state after execution
|
2145
|
-
self._conditionally_capture_and_record_state(
|
2146
|
-
|
2261
|
+
self._conditionally_capture_and_record_state(
|
2262
|
+
span, args, is_before=False
|
2263
|
+
)
|
2264
|
+
|
2147
2265
|
span.record_output(result)
|
2148
2266
|
return result
|
2149
|
-
|
2267
|
+
|
2150
2268
|
return wrapper
|
2151
|
-
|
2152
|
-
def observe_tools(
|
2153
|
-
|
2269
|
+
|
2270
|
+
def observe_tools(
|
2271
|
+
self,
|
2272
|
+
cls=None,
|
2273
|
+
*,
|
2274
|
+
exclude_methods: Optional[List[str]] = None,
|
2275
|
+
include_private: bool = False,
|
2276
|
+
warn_on_double_decoration: bool = True,
|
2277
|
+
):
|
2154
2278
|
"""
|
2155
2279
|
Automatically adds @observe(span_type="tool") to all methods in a class.
|
2156
2280
|
|
@@ -2162,28 +2286,32 @@ class Tracer:
|
|
2162
2286
|
"""
|
2163
2287
|
|
2164
2288
|
if exclude_methods is None:
|
2165
|
-
exclude_methods = [
|
2166
|
-
|
2289
|
+
exclude_methods = ["__init__", "__new__", "__del__", "__str__", "__repr__"]
|
2290
|
+
|
2167
2291
|
def decorate_class(cls):
|
2168
2292
|
if not self.enable_monitoring:
|
2169
2293
|
return cls
|
2170
|
-
|
2294
|
+
|
2171
2295
|
decorated = []
|
2172
2296
|
skipped = []
|
2173
|
-
|
2297
|
+
|
2174
2298
|
for name in dir(cls):
|
2175
2299
|
method = getattr(cls, name)
|
2176
|
-
|
2177
|
-
if (
|
2178
|
-
|
2179
|
-
|
2180
|
-
|
2300
|
+
|
2301
|
+
if (
|
2302
|
+
not callable(method)
|
2303
|
+
or name in exclude_methods
|
2304
|
+
or (name.startswith("_") and not include_private)
|
2305
|
+
or not hasattr(cls, name)
|
2306
|
+
):
|
2181
2307
|
continue
|
2182
|
-
|
2183
|
-
if hasattr(method,
|
2308
|
+
|
2309
|
+
if hasattr(method, "_judgment_span_name"):
|
2184
2310
|
skipped.append(name)
|
2185
2311
|
if warn_on_double_decoration:
|
2186
|
-
print(
|
2312
|
+
print(
|
2313
|
+
f"Warning: {cls.__name__}.{name} already decorated, skipping"
|
2314
|
+
)
|
2187
2315
|
continue
|
2188
2316
|
|
2189
2317
|
try:
|
@@ -2193,28 +2321,76 @@ class Tracer:
|
|
2193
2321
|
except Exception as e:
|
2194
2322
|
if warn_on_double_decoration:
|
2195
2323
|
print(f"Warning: Failed to decorate {cls.__name__}.{name}: {e}")
|
2196
|
-
|
2324
|
+
|
2197
2325
|
return cls
|
2198
|
-
|
2326
|
+
|
2199
2327
|
return decorate_class if cls is None else decorate_class(cls)
|
2200
2328
|
|
2201
2329
|
def async_evaluate(self, *args, **kwargs):
|
2202
|
-
|
2203
|
-
|
2330
|
+
try:
|
2331
|
+
if not self.enable_monitoring or not self.enable_evaluations:
|
2332
|
+
return
|
2333
|
+
|
2334
|
+
# --- Get trace_id passed explicitly (if any) ---
|
2335
|
+
passed_trace_id = kwargs.pop(
|
2336
|
+
"trace_id", None
|
2337
|
+
) # Get and remove trace_id from kwargs
|
2338
|
+
|
2339
|
+
current_trace = self.get_current_trace()
|
2340
|
+
|
2341
|
+
if current_trace:
|
2342
|
+
# Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
|
2343
|
+
# (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
|
2344
|
+
if passed_trace_id:
|
2345
|
+
kwargs["trace_id"] = (
|
2346
|
+
passed_trace_id # Re-add if needed by TraceClient.async_evaluate
|
2347
|
+
)
|
2348
|
+
current_trace.async_evaluate(*args, **kwargs)
|
2349
|
+
else:
|
2350
|
+
warnings.warn(
|
2351
|
+
"No trace found (context var or fallback), skipping evaluation"
|
2352
|
+
) # Modified warning
|
2353
|
+
except Exception as e:
|
2354
|
+
warnings.warn(f"Issue with async_evaluate: {e}")
|
2355
|
+
|
2356
|
+
def update_metadata(self, metadata: dict):
|
2357
|
+
"""
|
2358
|
+
Update metadata for the current trace.
|
2359
|
+
|
2360
|
+
Args:
|
2361
|
+
metadata: Metadata as a dictionary
|
2362
|
+
"""
|
2363
|
+
current_trace = self.get_current_trace()
|
2364
|
+
if current_trace:
|
2365
|
+
current_trace.update_metadata(metadata)
|
2366
|
+
else:
|
2367
|
+
warnings.warn("No current trace found, cannot set metadata")
|
2204
2368
|
|
2205
|
-
|
2206
|
-
|
2369
|
+
def set_customer_id(self, customer_id: str):
|
2370
|
+
"""
|
2371
|
+
Set the customer ID for the current trace.
|
2207
2372
|
|
2373
|
+
Args:
|
2374
|
+
customer_id: The customer ID to set
|
2375
|
+
"""
|
2208
2376
|
current_trace = self.get_current_trace()
|
2377
|
+
if current_trace:
|
2378
|
+
current_trace.set_customer_id(customer_id)
|
2379
|
+
else:
|
2380
|
+
warnings.warn("No current trace found, cannot set customer ID")
|
2381
|
+
|
2382
|
+
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
2383
|
+
"""
|
2384
|
+
Set the tags for the current trace.
|
2209
2385
|
|
2386
|
+
Args:
|
2387
|
+
tags: List of tags to set
|
2388
|
+
"""
|
2389
|
+
current_trace = self.get_current_trace()
|
2210
2390
|
if current_trace:
|
2211
|
-
|
2212
|
-
# (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
|
2213
|
-
if passed_trace_id:
|
2214
|
-
kwargs['trace_id'] = passed_trace_id # Re-add if needed by TraceClient.async_evaluate
|
2215
|
-
current_trace.async_evaluate(*args, **kwargs)
|
2391
|
+
current_trace.set_tags(tags)
|
2216
2392
|
else:
|
2217
|
-
warnings.warn("No trace found
|
2393
|
+
warnings.warn("No current trace found, cannot set tags")
|
2218
2394
|
|
2219
2395
|
def get_background_span_service(self) -> Optional[BackgroundSpanService]:
|
2220
2396
|
"""Get the background span service instance."""
|
@@ -2231,31 +2407,43 @@ class Tracer:
|
|
2231
2407
|
self.background_span_service.shutdown()
|
2232
2408
|
self.background_span_service = None
|
2233
2409
|
|
2234
|
-
|
2410
|
+
|
2411
|
+
def _get_current_trace(
|
2412
|
+
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
2413
|
+
):
|
2414
|
+
if trace_across_async_contexts:
|
2415
|
+
return Tracer.current_trace
|
2416
|
+
else:
|
2417
|
+
return current_trace_var.get()
|
2418
|
+
|
2419
|
+
|
2420
|
+
def wrap(
|
2421
|
+
client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts
|
2422
|
+
) -> Any:
|
2235
2423
|
"""
|
2236
2424
|
Wraps an API client to add tracing capabilities.
|
2237
2425
|
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
2238
2426
|
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
2239
2427
|
"""
|
2240
|
-
|
2428
|
+
(
|
2429
|
+
span_name,
|
2430
|
+
original_create,
|
2431
|
+
original_responses_create,
|
2432
|
+
original_stream,
|
2433
|
+
original_beta_parse,
|
2434
|
+
) = _get_client_config(client)
|
2241
2435
|
|
2242
|
-
def _get_current_trace():
|
2243
|
-
if trace_across_async_contexts:
|
2244
|
-
return Tracer.current_trace
|
2245
|
-
else:
|
2246
|
-
return current_trace_var.get()
|
2247
|
-
|
2248
2436
|
def _record_input_and_check_streaming(span, kwargs, is_responses=False):
|
2249
2437
|
"""Record input and check for streaming"""
|
2250
2438
|
is_streaming = kwargs.get("stream", False)
|
2251
2439
|
|
2252
|
-
|
2440
|
+
# Record input based on whether this is a responses endpoint
|
2253
2441
|
if is_responses:
|
2254
2442
|
span.record_input(kwargs)
|
2255
2443
|
else:
|
2256
2444
|
input_data = _format_input_data(client, **kwargs)
|
2257
2445
|
span.record_input(input_data)
|
2258
|
-
|
2446
|
+
|
2259
2447
|
# Warn about token counting limitations with streaming
|
2260
2448
|
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
2261
2449
|
if not kwargs.get("stream_options", {}).get("include_usage"):
|
@@ -2263,88 +2451,101 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
|
|
2263
2451
|
"OpenAI streaming calls don't include token counts by default. "
|
2264
2452
|
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
2265
2453
|
"in your API call arguments.",
|
2266
|
-
UserWarning
|
2454
|
+
UserWarning,
|
2267
2455
|
)
|
2268
|
-
|
2456
|
+
|
2269
2457
|
return is_streaming
|
2270
|
-
|
2458
|
+
|
2271
2459
|
def _format_and_record_output(span, response, is_streaming, is_async, is_responses):
|
2272
2460
|
"""Format and record the output in the span"""
|
2273
2461
|
if is_streaming:
|
2274
2462
|
output_entry = span.record_output("<pending stream>")
|
2275
2463
|
wrapper_func = _async_stream_wrapper if is_async else _sync_stream_wrapper
|
2276
|
-
return wrapper_func(
|
2464
|
+
return wrapper_func(
|
2465
|
+
response, client, output_entry, trace_across_async_contexts
|
2466
|
+
)
|
2277
2467
|
else:
|
2278
|
-
format_func =
|
2468
|
+
format_func = (
|
2469
|
+
_format_response_output_data if is_responses else _format_output_data
|
2470
|
+
)
|
2279
2471
|
output, usage = format_func(client, response)
|
2280
2472
|
span.record_output(output)
|
2281
2473
|
span.record_usage(usage)
|
2282
|
-
|
2474
|
+
|
2283
2475
|
# Queue the completed LLM span now that it has all data (input, output, usage)
|
2284
|
-
current_trace = _get_current_trace()
|
2476
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2285
2477
|
if current_trace and current_trace.background_span_service:
|
2286
2478
|
# Get the current span from the trace client
|
2287
2479
|
current_span_id = current_trace.get_current_span()
|
2288
2480
|
if current_span_id and current_span_id in current_trace.span_id_to_span:
|
2289
2481
|
completed_span = current_trace.span_id_to_span[current_span_id]
|
2290
|
-
current_trace.background_span_service.queue_span(
|
2291
|
-
|
2482
|
+
current_trace.background_span_service.queue_span(
|
2483
|
+
completed_span, span_state="completed"
|
2484
|
+
)
|
2485
|
+
|
2292
2486
|
return response
|
2293
|
-
|
2487
|
+
|
2294
2488
|
# --- Traced Async Functions ---
|
2295
2489
|
async def traced_create_async(*args, **kwargs):
|
2296
|
-
current_trace = _get_current_trace()
|
2490
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2297
2491
|
if not current_trace:
|
2298
2492
|
return await original_create(*args, **kwargs)
|
2299
|
-
|
2493
|
+
|
2300
2494
|
with current_trace.span(span_name, span_type="llm") as span:
|
2301
2495
|
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2302
|
-
|
2496
|
+
|
2303
2497
|
try:
|
2304
2498
|
response_or_iterator = await original_create(*args, **kwargs)
|
2305
|
-
return _format_and_record_output(
|
2499
|
+
return _format_and_record_output(
|
2500
|
+
span, response_or_iterator, is_streaming, True, False
|
2501
|
+
)
|
2306
2502
|
except Exception as e:
|
2307
2503
|
_capture_exception_for_trace(span, sys.exc_info())
|
2308
2504
|
raise e
|
2309
|
-
|
2505
|
+
|
2310
2506
|
async def traced_beta_parse_async(*args, **kwargs):
|
2311
|
-
current_trace = _get_current_trace()
|
2507
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2312
2508
|
if not current_trace:
|
2313
2509
|
return await original_beta_parse(*args, **kwargs)
|
2314
|
-
|
2510
|
+
|
2315
2511
|
with current_trace.span(span_name, span_type="llm") as span:
|
2316
2512
|
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2317
|
-
|
2513
|
+
|
2318
2514
|
try:
|
2319
2515
|
response_or_iterator = await original_beta_parse(*args, **kwargs)
|
2320
|
-
return _format_and_record_output(
|
2516
|
+
return _format_and_record_output(
|
2517
|
+
span, response_or_iterator, is_streaming, True, False
|
2518
|
+
)
|
2321
2519
|
except Exception as e:
|
2322
2520
|
_capture_exception_for_trace(span, sys.exc_info())
|
2323
2521
|
raise e
|
2324
|
-
|
2325
|
-
|
2522
|
+
|
2326
2523
|
# Async responses for OpenAI clients
|
2327
2524
|
async def traced_response_create_async(*args, **kwargs):
|
2328
|
-
current_trace = _get_current_trace()
|
2525
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2329
2526
|
if not current_trace:
|
2330
2527
|
return await original_responses_create(*args, **kwargs)
|
2331
|
-
|
2528
|
+
|
2332
2529
|
with current_trace.span(span_name, span_type="llm") as span:
|
2333
|
-
is_streaming = _record_input_and_check_streaming(
|
2334
|
-
|
2530
|
+
is_streaming = _record_input_and_check_streaming(
|
2531
|
+
span, kwargs, is_responses=True
|
2532
|
+
)
|
2533
|
+
|
2335
2534
|
try:
|
2336
2535
|
response_or_iterator = await original_responses_create(*args, **kwargs)
|
2337
|
-
return _format_and_record_output(
|
2536
|
+
return _format_and_record_output(
|
2537
|
+
span, response_or_iterator, is_streaming, True, True
|
2538
|
+
)
|
2338
2539
|
except Exception as e:
|
2339
2540
|
_capture_exception_for_trace(span, sys.exc_info())
|
2340
2541
|
raise e
|
2341
|
-
|
2542
|
+
|
2342
2543
|
# Function replacing .stream() for async clients
|
2343
2544
|
def traced_stream_async(*args, **kwargs):
|
2344
|
-
current_trace = _get_current_trace()
|
2545
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2345
2546
|
if not current_trace or not original_stream:
|
2346
2547
|
return original_stream(*args, **kwargs)
|
2347
|
-
|
2548
|
+
|
2348
2549
|
original_manager = original_stream(*args, **kwargs)
|
2349
2550
|
return _TracedAsyncStreamManagerWrapper(
|
2350
2551
|
original_manager=original_manager,
|
@@ -2352,61 +2553,70 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
|
|
2352
2553
|
span_name=span_name,
|
2353
2554
|
trace_client=current_trace,
|
2354
2555
|
stream_wrapper_func=_async_stream_wrapper,
|
2355
|
-
input_kwargs=kwargs
|
2556
|
+
input_kwargs=kwargs,
|
2557
|
+
trace_across_async_contexts=trace_across_async_contexts,
|
2356
2558
|
)
|
2357
|
-
|
2559
|
+
|
2358
2560
|
# --- Traced Sync Functions ---
|
2359
2561
|
def traced_create_sync(*args, **kwargs):
|
2360
|
-
current_trace = _get_current_trace()
|
2562
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2361
2563
|
if not current_trace:
|
2362
2564
|
return original_create(*args, **kwargs)
|
2363
|
-
|
2565
|
+
|
2364
2566
|
with current_trace.span(span_name, span_type="llm") as span:
|
2365
2567
|
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2366
|
-
|
2568
|
+
|
2367
2569
|
try:
|
2368
2570
|
response_or_iterator = original_create(*args, **kwargs)
|
2369
|
-
return _format_and_record_output(
|
2571
|
+
return _format_and_record_output(
|
2572
|
+
span, response_or_iterator, is_streaming, False, False
|
2573
|
+
)
|
2370
2574
|
except Exception as e:
|
2371
2575
|
_capture_exception_for_trace(span, sys.exc_info())
|
2372
2576
|
raise e
|
2373
|
-
|
2577
|
+
|
2374
2578
|
def traced_beta_parse_sync(*args, **kwargs):
|
2375
|
-
current_trace = _get_current_trace()
|
2579
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2376
2580
|
if not current_trace:
|
2377
2581
|
return original_beta_parse(*args, **kwargs)
|
2378
|
-
|
2582
|
+
|
2379
2583
|
with current_trace.span(span_name, span_type="llm") as span:
|
2380
2584
|
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2381
|
-
|
2585
|
+
|
2382
2586
|
try:
|
2383
2587
|
response_or_iterator = original_beta_parse(*args, **kwargs)
|
2384
|
-
return _format_and_record_output(
|
2588
|
+
return _format_and_record_output(
|
2589
|
+
span, response_or_iterator, is_streaming, False, False
|
2590
|
+
)
|
2385
2591
|
except Exception as e:
|
2386
2592
|
_capture_exception_for_trace(span, sys.exc_info())
|
2387
2593
|
raise e
|
2388
|
-
|
2594
|
+
|
2389
2595
|
def traced_response_create_sync(*args, **kwargs):
|
2390
|
-
current_trace = _get_current_trace()
|
2596
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2391
2597
|
if not current_trace:
|
2392
2598
|
return original_responses_create(*args, **kwargs)
|
2393
|
-
|
2599
|
+
|
2394
2600
|
with current_trace.span(span_name, span_type="llm") as span:
|
2395
|
-
is_streaming = _record_input_and_check_streaming(
|
2396
|
-
|
2601
|
+
is_streaming = _record_input_and_check_streaming(
|
2602
|
+
span, kwargs, is_responses=True
|
2603
|
+
)
|
2604
|
+
|
2397
2605
|
try:
|
2398
2606
|
response_or_iterator = original_responses_create(*args, **kwargs)
|
2399
|
-
return _format_and_record_output(
|
2607
|
+
return _format_and_record_output(
|
2608
|
+
span, response_or_iterator, is_streaming, False, True
|
2609
|
+
)
|
2400
2610
|
except Exception as e:
|
2401
2611
|
_capture_exception_for_trace(span, sys.exc_info())
|
2402
2612
|
raise e
|
2403
|
-
|
2613
|
+
|
2404
2614
|
# Function replacing sync .stream()
|
2405
2615
|
def traced_stream_sync(*args, **kwargs):
|
2406
|
-
current_trace = _get_current_trace()
|
2616
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2407
2617
|
if not current_trace or not original_stream:
|
2408
2618
|
return original_stream(*args, **kwargs)
|
2409
|
-
|
2619
|
+
|
2410
2620
|
original_manager = original_stream(*args, **kwargs)
|
2411
2621
|
return _TracedSyncStreamManagerWrapper(
|
2412
2622
|
original_manager=original_manager,
|
@@ -2414,15 +2624,21 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
|
|
2414
2624
|
span_name=span_name,
|
2415
2625
|
trace_client=current_trace,
|
2416
2626
|
stream_wrapper_func=_sync_stream_wrapper,
|
2417
|
-
input_kwargs=kwargs
|
2627
|
+
input_kwargs=kwargs,
|
2628
|
+
trace_across_async_contexts=trace_across_async_contexts,
|
2418
2629
|
)
|
2419
|
-
|
2630
|
+
|
2420
2631
|
# --- Assign Traced Methods to Client Instance ---
|
2421
2632
|
if isinstance(client, (AsyncOpenAI, AsyncTogether)):
|
2422
2633
|
client.chat.completions.create = traced_create_async
|
2423
2634
|
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
2424
2635
|
client.responses.create = traced_response_create_async
|
2425
|
-
if
|
2636
|
+
if (
|
2637
|
+
hasattr(client, "beta")
|
2638
|
+
and hasattr(client.beta, "chat")
|
2639
|
+
and hasattr(client.beta.chat, "completions")
|
2640
|
+
and hasattr(client.beta.chat.completions, "parse")
|
2641
|
+
):
|
2426
2642
|
client.beta.chat.completions.parse = traced_beta_parse_async
|
2427
2643
|
elif isinstance(client, AsyncAnthropic):
|
2428
2644
|
client.messages.create = traced_create_async
|
@@ -2434,7 +2650,12 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
|
|
2434
2650
|
client.chat.completions.create = traced_create_sync
|
2435
2651
|
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
2436
2652
|
client.responses.create = traced_response_create_sync
|
2437
|
-
if
|
2653
|
+
if (
|
2654
|
+
hasattr(client, "beta")
|
2655
|
+
and hasattr(client.beta, "chat")
|
2656
|
+
and hasattr(client.beta.chat, "completions")
|
2657
|
+
and hasattr(client.beta.chat.completions, "parse")
|
2658
|
+
):
|
2438
2659
|
client.beta.chat.completions.parse = traced_beta_parse_sync
|
2439
2660
|
elif isinstance(client, Anthropic):
|
2440
2661
|
client.messages.create = traced_create_sync
|
@@ -2442,17 +2663,21 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
|
|
2442
2663
|
client.messages.stream = traced_stream_sync
|
2443
2664
|
elif isinstance(client, genai.Client):
|
2444
2665
|
client.models.generate_content = traced_create_sync
|
2445
|
-
|
2666
|
+
|
2446
2667
|
return client
|
2447
2668
|
|
2669
|
+
|
2448
2670
|
# Helper functions for client-specific operations
|
2449
2671
|
|
2450
|
-
|
2672
|
+
|
2673
|
+
def _get_client_config(
|
2674
|
+
client: ApiClient,
|
2675
|
+
) -> tuple[str, Callable, Optional[Callable], Optional[Callable], Optional[Callable]]:
|
2451
2676
|
"""Returns configuration tuple for the given API client.
|
2452
|
-
|
2677
|
+
|
2453
2678
|
Args:
|
2454
2679
|
client: An instance of OpenAI, Together, or Anthropic client
|
2455
|
-
|
2680
|
+
|
2456
2681
|
Returns:
|
2457
2682
|
tuple: (span_name, create_method, responses_method, stream_method, beta_parse_method)
|
2458
2683
|
- span_name: String identifier for tracing
|
@@ -2460,23 +2685,36 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
|
|
2460
2685
|
- responses_method: Reference to the client's responses method (if applicable)
|
2461
2686
|
- stream_method: Reference to the client's stream method (if applicable)
|
2462
2687
|
- beta_parse_method: Reference to the client's beta parse method (if applicable)
|
2463
|
-
|
2688
|
+
|
2464
2689
|
Raises:
|
2465
2690
|
ValueError: If client type is not supported
|
2466
2691
|
"""
|
2467
2692
|
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
2468
|
-
return
|
2693
|
+
return (
|
2694
|
+
"OPENAI_API_CALL",
|
2695
|
+
client.chat.completions.create,
|
2696
|
+
client.responses.create,
|
2697
|
+
None,
|
2698
|
+
client.beta.chat.completions.parse,
|
2699
|
+
)
|
2469
2700
|
elif isinstance(client, (Together, AsyncTogether)):
|
2470
2701
|
return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
|
2471
2702
|
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
2472
|
-
return
|
2703
|
+
return (
|
2704
|
+
"ANTHROPIC_API_CALL",
|
2705
|
+
client.messages.create,
|
2706
|
+
None,
|
2707
|
+
client.messages.stream,
|
2708
|
+
None,
|
2709
|
+
)
|
2473
2710
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2474
2711
|
return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
|
2475
2712
|
raise ValueError(f"Unsupported client type: {type(client)}")
|
2476
2713
|
|
2714
|
+
|
2477
2715
|
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
2478
2716
|
"""Format input parameters based on client type.
|
2479
|
-
|
2717
|
+
|
2480
2718
|
Extracts relevant parameters from kwargs based on the client type
|
2481
2719
|
to ensure consistent tracing across different APIs.
|
2482
2720
|
"""
|
@@ -2489,25 +2727,23 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
|
2489
2727
|
input_data["response_format"] = kwargs.get("response_format")
|
2490
2728
|
return input_data
|
2491
2729
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2492
|
-
return {
|
2493
|
-
"model": kwargs.get("model"),
|
2494
|
-
"contents": kwargs.get("contents")
|
2495
|
-
}
|
2730
|
+
return {"model": kwargs.get("model"), "contents": kwargs.get("contents")}
|
2496
2731
|
# Anthropic requires additional max_tokens parameter
|
2497
2732
|
return {
|
2498
2733
|
"model": kwargs.get("model"),
|
2499
2734
|
"messages": kwargs.get("messages"),
|
2500
|
-
"max_tokens": kwargs.get("max_tokens")
|
2735
|
+
"max_tokens": kwargs.get("max_tokens"),
|
2501
2736
|
}
|
2502
2737
|
|
2503
|
-
|
2738
|
+
|
2739
|
+
def _format_response_output_data(client: ApiClient, response: Any) -> tuple:
|
2504
2740
|
"""Format API response data based on client type.
|
2505
|
-
|
2741
|
+
|
2506
2742
|
Normalizes different response formats into a consistent structure
|
2507
2743
|
for tracing purposes.
|
2508
2744
|
"""
|
2509
2745
|
message_content = None
|
2510
|
-
prompt_tokens = 0
|
2746
|
+
prompt_tokens = 0
|
2511
2747
|
completion_tokens = 0
|
2512
2748
|
model_name = None
|
2513
2749
|
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
@@ -2517,14 +2753,16 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
|
|
2517
2753
|
message_content = response.output
|
2518
2754
|
else:
|
2519
2755
|
warnings.warn(f"Unsupported client type: {type(client)}")
|
2520
|
-
return
|
2521
|
-
|
2522
|
-
prompt_cost, completion_cost = cost_per_token(
|
2756
|
+
return None, None
|
2757
|
+
|
2758
|
+
prompt_cost, completion_cost = cost_per_token(
|
2523
2759
|
model=model_name,
|
2524
2760
|
prompt_tokens=prompt_tokens,
|
2525
2761
|
completion_tokens=completion_tokens,
|
2526
2762
|
)
|
2527
|
-
total_cost_usd = (
|
2763
|
+
total_cost_usd = (
|
2764
|
+
(prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
2765
|
+
)
|
2528
2766
|
usage = TraceUsage(
|
2529
2767
|
prompt_tokens=prompt_tokens,
|
2530
2768
|
completion_tokens=completion_tokens,
|
@@ -2532,17 +2770,19 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
|
|
2532
2770
|
prompt_tokens_cost_usd=prompt_cost,
|
2533
2771
|
completion_tokens_cost_usd=completion_cost,
|
2534
2772
|
total_cost_usd=total_cost_usd,
|
2535
|
-
model_name=model_name
|
2773
|
+
model_name=model_name,
|
2536
2774
|
)
|
2537
2775
|
return message_content, usage
|
2538
2776
|
|
2539
2777
|
|
2540
|
-
def _format_output_data(
|
2778
|
+
def _format_output_data(
|
2779
|
+
client: ApiClient, response: Any
|
2780
|
+
) -> tuple[Optional[str], Optional[TraceUsage]]:
|
2541
2781
|
"""Format API response data based on client type.
|
2542
|
-
|
2782
|
+
|
2543
2783
|
Normalizes different response formats into a consistent structure
|
2544
2784
|
for tracing purposes.
|
2545
|
-
|
2785
|
+
|
2546
2786
|
Returns:
|
2547
2787
|
dict containing:
|
2548
2788
|
- content: The generated text
|
@@ -2557,7 +2797,10 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
2557
2797
|
model_name = response.model
|
2558
2798
|
prompt_tokens = response.usage.prompt_tokens
|
2559
2799
|
completion_tokens = response.usage.completion_tokens
|
2560
|
-
if
|
2800
|
+
if (
|
2801
|
+
hasattr(response.choices[0].message, "parsed")
|
2802
|
+
and response.choices[0].message.parsed
|
2803
|
+
):
|
2561
2804
|
message_content = response.choices[0].message.parsed
|
2562
2805
|
else:
|
2563
2806
|
message_content = response.choices[0].message.content
|
@@ -2574,13 +2817,15 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
2574
2817
|
else:
|
2575
2818
|
warnings.warn(f"Unsupported client type: {type(client)}")
|
2576
2819
|
return None, None
|
2577
|
-
|
2820
|
+
|
2578
2821
|
prompt_cost, completion_cost = cost_per_token(
|
2579
2822
|
model=model_name,
|
2580
2823
|
prompt_tokens=prompt_tokens,
|
2581
2824
|
completion_tokens=completion_tokens,
|
2582
2825
|
)
|
2583
|
-
total_cost_usd = (
|
2826
|
+
total_cost_usd = (
|
2827
|
+
(prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
2828
|
+
)
|
2584
2829
|
usage = TraceUsage(
|
2585
2830
|
prompt_tokens=prompt_tokens,
|
2586
2831
|
completion_tokens=completion_tokens,
|
@@ -2588,56 +2833,61 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
2588
2833
|
prompt_tokens_cost_usd=prompt_cost,
|
2589
2834
|
completion_tokens_cost_usd=completion_cost,
|
2590
2835
|
total_cost_usd=total_cost_usd,
|
2591
|
-
model_name=model_name
|
2836
|
+
model_name=model_name,
|
2592
2837
|
)
|
2593
2838
|
return message_content, usage
|
2594
2839
|
|
2840
|
+
|
2595
2841
|
def combine_args_kwargs(func, args, kwargs):
|
2596
2842
|
"""
|
2597
2843
|
Combine positional arguments and keyword arguments into a single dictionary.
|
2598
|
-
|
2844
|
+
|
2599
2845
|
Args:
|
2600
2846
|
func: The function being called
|
2601
2847
|
args: Tuple of positional arguments
|
2602
2848
|
kwargs: Dictionary of keyword arguments
|
2603
|
-
|
2849
|
+
|
2604
2850
|
Returns:
|
2605
2851
|
A dictionary combining both args and kwargs
|
2606
2852
|
"""
|
2607
2853
|
try:
|
2608
2854
|
import inspect
|
2855
|
+
|
2609
2856
|
sig = inspect.signature(func)
|
2610
2857
|
param_names = list(sig.parameters.keys())
|
2611
|
-
|
2858
|
+
|
2612
2859
|
args_dict = {}
|
2613
2860
|
for i, arg in enumerate(args):
|
2614
2861
|
if i < len(param_names):
|
2615
2862
|
args_dict[param_names[i]] = arg
|
2616
2863
|
else:
|
2617
2864
|
args_dict[f"arg{i}"] = arg
|
2618
|
-
|
2865
|
+
|
2619
2866
|
return {**args_dict, **kwargs}
|
2620
|
-
except Exception
|
2867
|
+
except Exception:
|
2621
2868
|
# Fallback if signature inspection fails
|
2622
2869
|
return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
|
2623
2870
|
|
2871
|
+
|
2624
2872
|
# NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
|
2625
2873
|
# @link https://docs.python.org/3.13/library/sysconfig.html
|
2626
2874
|
_TRACE_FILEPATH_BLOCKLIST = tuple(
|
2627
2875
|
os.path.realpath(p) + os.sep
|
2628
2876
|
for p in {
|
2629
|
-
sysconfig.get_paths()[
|
2630
|
-
sysconfig.get_paths().get(
|
2877
|
+
sysconfig.get_paths()["stdlib"],
|
2878
|
+
sysconfig.get_paths().get("platstdlib", ""),
|
2631
2879
|
*site.getsitepackages(),
|
2632
2880
|
site.getusersitepackages(),
|
2633
2881
|
*(
|
2634
|
-
[os.path.join(os.path.dirname(__file__),
|
2635
|
-
if os.environ.get(
|
2882
|
+
[os.path.join(os.path.dirname(__file__), "../../judgeval/")]
|
2883
|
+
if os.environ.get("JUDGMENT_DEV")
|
2636
2884
|
else []
|
2637
2885
|
),
|
2638
|
-
}
|
2886
|
+
}
|
2887
|
+
if p
|
2639
2888
|
)
|
2640
2889
|
|
2890
|
+
|
2641
2891
|
# Add the new TraceThreadPoolExecutor class
|
2642
2892
|
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
2643
2893
|
"""
|
@@ -2649,6 +2899,7 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
|
2649
2899
|
allowing the Tracer to maintain correct parent-child relationships across
|
2650
2900
|
thread boundaries.
|
2651
2901
|
"""
|
2902
|
+
|
2652
2903
|
def submit(self, fn, /, *args, **kwargs):
|
2653
2904
|
"""
|
2654
2905
|
Submit a callable to be executed with the captured context.
|
@@ -2669,9 +2920,11 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
|
2669
2920
|
# Note: The `map` method would also need to be overridden for full context
|
2670
2921
|
# propagation if users rely on it, but `submit` is the most common use case.
|
2671
2922
|
|
2923
|
+
|
2672
2924
|
# Helper functions for stream processing
|
2673
2925
|
# ---------------------------------------
|
2674
2926
|
|
2927
|
+
|
2675
2928
|
def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
|
2676
2929
|
"""Extracts the text content from a stream chunk based on the client type."""
|
2677
2930
|
try:
|
@@ -2683,34 +2936,49 @@ def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
|
|
2683
2936
|
return chunk.delta.text
|
2684
2937
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2685
2938
|
# Google streams Candidate objects
|
2686
|
-
if
|
2939
|
+
if (
|
2940
|
+
chunk.candidates
|
2941
|
+
and chunk.candidates[0].content
|
2942
|
+
and chunk.candidates[0].content.parts
|
2943
|
+
):
|
2687
2944
|
return chunk.candidates[0].content.parts[0].text
|
2688
2945
|
except (AttributeError, IndexError, KeyError):
|
2689
2946
|
# Handle cases where chunk structure is unexpected or doesn't contain content
|
2690
|
-
pass
|
2947
|
+
pass # Return None
|
2691
2948
|
return None
|
2692
2949
|
|
2693
|
-
|
2950
|
+
|
2951
|
+
def _extract_usage_from_final_chunk(
|
2952
|
+
client: ApiClient, chunk: Any
|
2953
|
+
) -> Optional[Dict[str, int]]:
|
2694
2954
|
"""Extracts usage data if present in the *final* chunk (client-specific)."""
|
2695
2955
|
try:
|
2696
2956
|
# OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
|
2697
2957
|
# This typically requires specific API versions or settings. Often usage is *not* streamed.
|
2698
2958
|
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
2699
2959
|
# Check if usage is directly on the chunk (some models might do this)
|
2700
|
-
if hasattr(chunk,
|
2960
|
+
if hasattr(chunk, "usage") and chunk.usage:
|
2701
2961
|
prompt_tokens = chunk.usage.prompt_tokens
|
2702
2962
|
completion_tokens = chunk.usage.completion_tokens
|
2703
2963
|
# Check if usage is nested within choices (less common for final chunk, but check)
|
2704
|
-
elif
|
2964
|
+
elif (
|
2965
|
+
chunk.choices
|
2966
|
+
and hasattr(chunk.choices[0], "usage")
|
2967
|
+
and chunk.choices[0].usage
|
2968
|
+
):
|
2705
2969
|
prompt_tokens = chunk.choices[0].usage.prompt_tokens
|
2706
2970
|
completion_tokens = chunk.choices[0].usage.completion_tokens
|
2707
|
-
|
2971
|
+
|
2708
2972
|
prompt_cost, completion_cost = cost_per_token(
|
2709
|
-
|
2710
|
-
|
2711
|
-
|
2712
|
-
|
2713
|
-
total_cost_usd = (
|
2973
|
+
model=chunk.model,
|
2974
|
+
prompt_tokens=prompt_tokens,
|
2975
|
+
completion_tokens=completion_tokens,
|
2976
|
+
)
|
2977
|
+
total_cost_usd = (
|
2978
|
+
(prompt_cost + completion_cost)
|
2979
|
+
if prompt_cost and completion_cost
|
2980
|
+
else None
|
2981
|
+
)
|
2714
2982
|
return TraceUsage(
|
2715
2983
|
prompt_tokens=chunk.usage.prompt_tokens,
|
2716
2984
|
completion_tokens=chunk.usage.completion_tokens,
|
@@ -2718,9 +2986,9 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
|
|
2718
2986
|
prompt_tokens_cost_usd=prompt_cost,
|
2719
2987
|
completion_tokens_cost_usd=completion_cost,
|
2720
2988
|
total_cost_usd=total_cost_usd,
|
2721
|
-
model_name=chunk.model
|
2989
|
+
model_name=chunk.model,
|
2722
2990
|
)
|
2723
|
-
|
2991
|
+
# Anthropic includes usage in the 'message_stop' event type
|
2724
2992
|
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
2725
2993
|
if chunk.type == "message_stop":
|
2726
2994
|
# Anthropic final usage is often attached to the *message* object, not the chunk directly
|
@@ -2729,18 +2997,18 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
|
|
2729
2997
|
# This is a placeholder - Anthropic usage typically needs a separate call or context.
|
2730
2998
|
pass
|
2731
2999
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2732
|
-
|
2733
|
-
|
2734
|
-
|
2735
|
-
|
2736
|
-
|
2737
|
-
|
2738
|
-
|
2739
|
-
|
3000
|
+
# Google provides usage metadata on the full response object, not typically streamed per chunk.
|
3001
|
+
# It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
|
3002
|
+
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
|
3003
|
+
return {
|
3004
|
+
"prompt_tokens": chunk.usage_metadata.prompt_token_count,
|
3005
|
+
"completion_tokens": chunk.usage_metadata.candidates_token_count,
|
3006
|
+
"total_tokens": chunk.usage_metadata.total_token_count,
|
3007
|
+
}
|
2740
3008
|
|
2741
3009
|
except (AttributeError, IndexError, KeyError, TypeError):
|
2742
3010
|
# Handle cases where usage data is missing or malformed
|
2743
|
-
|
3011
|
+
pass # Return None
|
2744
3012
|
return None
|
2745
3013
|
|
2746
3014
|
|
@@ -2748,7 +3016,8 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
|
|
2748
3016
|
def _sync_stream_wrapper(
|
2749
3017
|
original_stream: Iterator,
|
2750
3018
|
client: ApiClient,
|
2751
|
-
span: TraceSpan
|
3019
|
+
span: TraceSpan,
|
3020
|
+
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
2752
3021
|
) -> Generator[Any, None, None]:
|
2753
3022
|
"""Wraps a synchronous stream iterator to capture content and update the trace."""
|
2754
3023
|
content_parts = [] # Use a list instead of string concatenation
|
@@ -2758,9 +3027,11 @@ def _sync_stream_wrapper(
|
|
2758
3027
|
for chunk in original_stream:
|
2759
3028
|
content_part = _extract_content_from_chunk(client, chunk)
|
2760
3029
|
if content_part:
|
2761
|
-
content_parts.append(
|
2762
|
-
|
2763
|
-
|
3030
|
+
content_parts.append(
|
3031
|
+
content_part
|
3032
|
+
) # Append to list instead of concatenating
|
3033
|
+
last_chunk = chunk # Keep track of the last chunk for potential usage data
|
3034
|
+
yield chunk # Pass the chunk to the caller
|
2764
3035
|
finally:
|
2765
3036
|
# Attempt to extract usage from the last chunk received
|
2766
3037
|
if last_chunk:
|
@@ -2769,23 +3040,25 @@ def _sync_stream_wrapper(
|
|
2769
3040
|
# Update the trace entry with the accumulated content and usage
|
2770
3041
|
span.output = "".join(content_parts)
|
2771
3042
|
span.usage = final_usage
|
2772
|
-
|
3043
|
+
|
2773
3044
|
# Queue the completed LLM span now that streaming is done and all data is available
|
2774
|
-
|
2775
|
-
|
2776
|
-
|
2777
|
-
|
2778
|
-
|
2779
|
-
|
2780
|
-
|
3045
|
+
|
3046
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
3047
|
+
if current_trace and current_trace.background_span_service:
|
3048
|
+
current_trace.background_span_service.queue_span(
|
3049
|
+
span, span_state="completed"
|
3050
|
+
)
|
3051
|
+
|
2781
3052
|
# Note: We might need to adjust _serialize_output if this dict causes issues,
|
2782
3053
|
# but Pydantic's model_dump should handle dicts.
|
2783
3054
|
|
3055
|
+
|
2784
3056
|
# --- Async Stream Wrapper ---
|
2785
3057
|
async def _async_stream_wrapper(
|
2786
3058
|
original_stream: AsyncIterator,
|
2787
3059
|
client: ApiClient,
|
2788
|
-
span: TraceSpan
|
3060
|
+
span: TraceSpan,
|
3061
|
+
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
2789
3062
|
) -> AsyncGenerator[Any, None]:
|
2790
3063
|
# [Existing logic - unchanged]
|
2791
3064
|
content_parts = [] # Use a list instead of string concatenation
|
@@ -2794,56 +3067,70 @@ async def _async_stream_wrapper(
|
|
2794
3067
|
anthropic_input_tokens = 0
|
2795
3068
|
anthropic_output_tokens = 0
|
2796
3069
|
|
2797
|
-
target_span_id = span.span_id
|
2798
|
-
|
2799
3070
|
try:
|
2800
3071
|
model_name = ""
|
2801
3072
|
async for chunk in original_stream:
|
2802
3073
|
# Check for OpenAI's final usage chunk
|
2803
|
-
if
|
3074
|
+
if (
|
3075
|
+
isinstance(client, (AsyncOpenAI, OpenAI))
|
3076
|
+
and hasattr(chunk, "usage")
|
3077
|
+
and chunk.usage is not None
|
3078
|
+
):
|
2804
3079
|
final_usage_data = {
|
2805
3080
|
"prompt_tokens": chunk.usage.prompt_tokens,
|
2806
3081
|
"completion_tokens": chunk.usage.completion_tokens,
|
2807
|
-
"total_tokens": chunk.usage.total_tokens
|
3082
|
+
"total_tokens": chunk.usage.total_tokens,
|
2808
3083
|
}
|
2809
3084
|
model_name = chunk.model
|
2810
3085
|
yield chunk
|
2811
3086
|
continue
|
2812
3087
|
|
2813
|
-
if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(
|
3088
|
+
if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(
|
3089
|
+
chunk, "type"
|
3090
|
+
):
|
2814
3091
|
if chunk.type == "message_start":
|
2815
|
-
if
|
2816
|
-
|
2817
|
-
|
3092
|
+
if (
|
3093
|
+
hasattr(chunk, "message")
|
3094
|
+
and hasattr(chunk.message, "usage")
|
3095
|
+
and hasattr(chunk.message.usage, "input_tokens")
|
3096
|
+
):
|
3097
|
+
anthropic_input_tokens = chunk.message.usage.input_tokens
|
3098
|
+
model_name = chunk.message.model
|
2818
3099
|
elif chunk.type == "message_delta":
|
2819
|
-
if hasattr(chunk,
|
3100
|
+
if hasattr(chunk, "usage") and hasattr(
|
3101
|
+
chunk.usage, "output_tokens"
|
3102
|
+
):
|
2820
3103
|
anthropic_output_tokens = chunk.usage.output_tokens
|
2821
3104
|
|
2822
3105
|
content_part = _extract_content_from_chunk(client, chunk)
|
2823
3106
|
if content_part:
|
2824
|
-
content_parts.append(
|
3107
|
+
content_parts.append(
|
3108
|
+
content_part
|
3109
|
+
) # Append to list instead of concatenating
|
2825
3110
|
last_content_chunk = chunk
|
2826
3111
|
|
2827
3112
|
yield chunk
|
2828
3113
|
finally:
|
2829
3114
|
anthropic_final_usage = None
|
2830
|
-
if isinstance(client, (AsyncAnthropic, Anthropic)) and (
|
2831
|
-
|
2832
|
-
|
2833
|
-
|
2834
|
-
|
2835
|
-
|
3115
|
+
if isinstance(client, (AsyncAnthropic, Anthropic)) and (
|
3116
|
+
anthropic_input_tokens > 0 or anthropic_output_tokens > 0
|
3117
|
+
):
|
3118
|
+
anthropic_final_usage = {
|
3119
|
+
"prompt_tokens": anthropic_input_tokens,
|
3120
|
+
"completion_tokens": anthropic_output_tokens,
|
3121
|
+
"total_tokens": anthropic_input_tokens + anthropic_output_tokens,
|
3122
|
+
}
|
2836
3123
|
|
2837
3124
|
usage_info = None
|
2838
3125
|
if final_usage_data:
|
2839
|
-
|
3126
|
+
usage_info = final_usage_data
|
2840
3127
|
elif anthropic_final_usage:
|
2841
|
-
|
3128
|
+
usage_info = anthropic_final_usage
|
2842
3129
|
elif last_content_chunk:
|
2843
3130
|
usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
|
2844
3131
|
|
2845
3132
|
if usage_info and not isinstance(usage_info, TraceUsage):
|
2846
|
-
prompt_cost, completion_cost = cost_per_token(
|
3133
|
+
prompt_cost, completion_cost = cost_per_token(
|
2847
3134
|
model=model_name,
|
2848
3135
|
prompt_tokens=usage_info["prompt_tokens"],
|
2849
3136
|
completion_tokens=usage_info["completion_tokens"],
|
@@ -2855,21 +3142,23 @@ async def _async_stream_wrapper(
|
|
2855
3142
|
prompt_tokens_cost_usd=prompt_cost,
|
2856
3143
|
completion_tokens_cost_usd=completion_cost,
|
2857
3144
|
total_cost_usd=prompt_cost + completion_cost,
|
2858
|
-
model_name=model_name
|
3145
|
+
model_name=model_name,
|
2859
3146
|
)
|
2860
|
-
if span and hasattr(span,
|
2861
|
-
span.output =
|
3147
|
+
if span and hasattr(span, "output"):
|
3148
|
+
span.output = "".join(content_parts)
|
2862
3149
|
span.usage = usage_info
|
2863
|
-
start_ts = getattr(span,
|
3150
|
+
start_ts = getattr(span, "created_at", time.time())
|
2864
3151
|
span.duration = time.time() - start_ts
|
2865
|
-
|
3152
|
+
|
2866
3153
|
# Queue the completed LLM span now that async streaming is done and all data is available
|
2867
|
-
|
2868
|
-
|
2869
|
-
|
2870
|
-
|
3154
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
3155
|
+
if current_trace and current_trace.background_span_service:
|
3156
|
+
current_trace.background_span_service.queue_span(
|
3157
|
+
span, span_state="completed"
|
3158
|
+
)
|
2871
3159
|
# else: # Handle error case if necessary, but remove debug print
|
2872
3160
|
|
3161
|
+
|
2873
3162
|
def cost_per_token(*args, **kwargs):
|
2874
3163
|
try:
|
2875
3164
|
return _original_cost_per_token(*args, **kwargs)
|
@@ -2877,8 +3166,18 @@ def cost_per_token(*args, **kwargs):
|
|
2877
3166
|
warnings.warn(f"Error calculating cost per token: {e}")
|
2878
3167
|
return None, None
|
2879
3168
|
|
3169
|
+
|
2880
3170
|
class _BaseStreamManagerWrapper:
|
2881
|
-
def __init__(
|
3171
|
+
def __init__(
|
3172
|
+
self,
|
3173
|
+
original_manager,
|
3174
|
+
client,
|
3175
|
+
span_name,
|
3176
|
+
trace_client,
|
3177
|
+
stream_wrapper_func,
|
3178
|
+
input_kwargs,
|
3179
|
+
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
3180
|
+
):
|
2882
3181
|
self._original_manager = original_manager
|
2883
3182
|
self._client = client
|
2884
3183
|
self._span_name = span_name
|
@@ -2886,13 +3185,19 @@ class _BaseStreamManagerWrapper:
|
|
2886
3185
|
self._stream_wrapper_func = stream_wrapper_func
|
2887
3186
|
self._input_kwargs = input_kwargs
|
2888
3187
|
self._parent_span_id_at_entry = None
|
3188
|
+
self._trace_across_async_contexts = trace_across_async_contexts
|
2889
3189
|
|
2890
3190
|
def _create_span(self):
|
2891
3191
|
start_time = time.time()
|
2892
3192
|
span_id = str(uuid.uuid4())
|
2893
3193
|
current_depth = 0
|
2894
|
-
if
|
2895
|
-
|
3194
|
+
if (
|
3195
|
+
self._parent_span_id_at_entry
|
3196
|
+
and self._parent_span_id_at_entry in self._trace_client._span_depths
|
3197
|
+
):
|
3198
|
+
current_depth = (
|
3199
|
+
self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
3200
|
+
)
|
2896
3201
|
self._trace_client._span_depths[span_id] = current_depth
|
2897
3202
|
span = TraceSpan(
|
2898
3203
|
function=self._span_name,
|
@@ -2902,7 +3207,7 @@ class _BaseStreamManagerWrapper:
|
|
2902
3207
|
message=self._span_name,
|
2903
3208
|
created_at=start_time,
|
2904
3209
|
span_type="llm",
|
2905
|
-
parent_span_id=self._parent_span_id_at_entry
|
3210
|
+
parent_span_id=self._parent_span_id_at_entry,
|
2906
3211
|
)
|
2907
3212
|
self._trace_client.add_span(span)
|
2908
3213
|
return span_id, span
|
@@ -2914,7 +3219,10 @@ class _BaseStreamManagerWrapper:
|
|
2914
3219
|
if span_id in self._trace_client._span_depths:
|
2915
3220
|
del self._trace_client._span_depths[span_id]
|
2916
3221
|
|
2917
|
-
|
3222
|
+
|
3223
|
+
class _TracedAsyncStreamManagerWrapper(
|
3224
|
+
_BaseStreamManagerWrapper, AbstractAsyncContextManager
|
3225
|
+
):
|
2918
3226
|
async def __aenter__(self):
|
2919
3227
|
self._parent_span_id_at_entry = self._trace_client.get_current_span()
|
2920
3228
|
if not self._trace_client:
|
@@ -2927,17 +3235,22 @@ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncC
|
|
2927
3235
|
# Call the original __aenter__ and expect it to be an async generator
|
2928
3236
|
raw_iterator = await self._original_manager.__aenter__()
|
2929
3237
|
span.output = "<pending stream>"
|
2930
|
-
return self._stream_wrapper_func(
|
3238
|
+
return self._stream_wrapper_func(
|
3239
|
+
raw_iterator, self._client, span, self._trace_across_async_contexts
|
3240
|
+
)
|
2931
3241
|
|
2932
3242
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
2933
|
-
if hasattr(self,
|
3243
|
+
if hasattr(self, "_span_context_token"):
|
2934
3244
|
span_id = self._trace_client.get_current_span()
|
2935
3245
|
self._finalize_span(span_id)
|
2936
3246
|
self._trace_client.reset_current_span(self._span_context_token)
|
2937
|
-
delattr(self,
|
3247
|
+
delattr(self, "_span_context_token")
|
2938
3248
|
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2939
3249
|
|
2940
|
-
|
3250
|
+
|
3251
|
+
class _TracedSyncStreamManagerWrapper(
|
3252
|
+
_BaseStreamManagerWrapper, AbstractContextManager
|
3253
|
+
):
|
2941
3254
|
def __enter__(self):
|
2942
3255
|
self._parent_span_id_at_entry = self._trace_client.get_current_span()
|
2943
3256
|
if not self._trace_client:
|
@@ -2949,16 +3262,19 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
|
|
2949
3262
|
|
2950
3263
|
raw_iterator = self._original_manager.__enter__()
|
2951
3264
|
span.output = "<pending stream>"
|
2952
|
-
return self._stream_wrapper_func(
|
3265
|
+
return self._stream_wrapper_func(
|
3266
|
+
raw_iterator, self._client, span, self._trace_across_async_contexts
|
3267
|
+
)
|
2953
3268
|
|
2954
3269
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
2955
|
-
if hasattr(self,
|
3270
|
+
if hasattr(self, "_span_context_token"):
|
2956
3271
|
span_id = self._trace_client.get_current_span()
|
2957
3272
|
self._finalize_span(span_id)
|
2958
3273
|
self._trace_client.reset_current_span(self._span_context_token)
|
2959
|
-
delattr(self,
|
3274
|
+
delattr(self, "_span_context_token")
|
2960
3275
|
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
2961
3276
|
|
3277
|
+
|
2962
3278
|
# --- Helper function for instance-prefixed qual_name ---
|
2963
3279
|
def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
2964
3280
|
"""
|
@@ -2967,11 +3283,13 @@ def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
|
2967
3283
|
"""
|
2968
3284
|
if class_name in class_identifiers:
|
2969
3285
|
class_config = class_identifiers[class_name]
|
2970
|
-
attr = class_config[
|
2971
|
-
|
3286
|
+
attr = class_config["identifier"]
|
3287
|
+
|
2972
3288
|
if hasattr(instance, attr):
|
2973
3289
|
instance_name = getattr(instance, attr)
|
2974
3290
|
return instance_name
|
2975
3291
|
else:
|
2976
|
-
raise Exception(
|
2977
|
-
|
3292
|
+
raise Exception(
|
3293
|
+
f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator."
|
3294
|
+
)
|
3295
|
+
return None
|