judgeval 0.0.44__py3-none-any.whl → 0.0.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +5 -4
- judgeval/clients.py +6 -6
- judgeval/common/__init__.py +7 -2
- judgeval/common/exceptions.py +2 -3
- judgeval/common/logger.py +74 -49
- judgeval/common/s3_storage.py +30 -23
- judgeval/common/tracer.py +1273 -939
- judgeval/common/utils.py +416 -244
- judgeval/constants.py +73 -61
- judgeval/data/__init__.py +1 -1
- judgeval/data/custom_example.py +3 -2
- judgeval/data/datasets/dataset.py +80 -54
- judgeval/data/datasets/eval_dataset_client.py +131 -181
- judgeval/data/example.py +67 -43
- judgeval/data/result.py +11 -9
- judgeval/data/scorer_data.py +4 -2
- judgeval/data/tool.py +25 -16
- judgeval/data/trace.py +57 -29
- judgeval/data/trace_run.py +5 -11
- judgeval/evaluation_run.py +22 -82
- judgeval/integrations/langgraph.py +546 -184
- judgeval/judges/base_judge.py +1 -2
- judgeval/judges/litellm_judge.py +33 -11
- judgeval/judges/mixture_of_judges.py +128 -78
- judgeval/judges/together_judge.py +22 -9
- judgeval/judges/utils.py +14 -5
- judgeval/judgment_client.py +259 -271
- judgeval/rules.py +169 -142
- judgeval/run_evaluation.py +462 -305
- judgeval/scorers/api_scorer.py +20 -11
- judgeval/scorers/exceptions.py +1 -0
- judgeval/scorers/judgeval_scorer.py +77 -58
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +46 -15
- judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +12 -11
- judgeval/scorers/judgeval_scorers/api_scorers/comparison.py +7 -5
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +5 -2
- judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +2 -1
- judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +17 -8
- judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/groundedness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +8 -9
- judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +4 -4
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +5 -5
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +5 -2
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +9 -10
- judgeval/scorers/prompt_scorer.py +48 -37
- judgeval/scorers/score.py +86 -53
- judgeval/scorers/utils.py +11 -7
- judgeval/tracer/__init__.py +1 -1
- judgeval/utils/alerts.py +23 -12
- judgeval/utils/{data_utils.py → file_utils.py} +5 -9
- judgeval/utils/requests.py +29 -0
- judgeval/version_check.py +5 -2
- {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/METADATA +79 -135
- judgeval-0.0.46.dist-info/RECORD +69 -0
- judgeval-0.0.44.dist-info/RECORD +0 -68
- {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/WHEEL +0 -0
- {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""
|
2
2
|
Tracing system for judgeval that allows for function tracing using decorators.
|
3
3
|
"""
|
4
|
+
|
4
5
|
# Standard library imports
|
5
6
|
import asyncio
|
6
7
|
import functools
|
@@ -16,9 +17,13 @@ import warnings
|
|
16
17
|
import contextvars
|
17
18
|
import sys
|
18
19
|
import json
|
19
|
-
from contextlib import
|
20
|
-
|
21
|
-
|
20
|
+
from contextlib import (
|
21
|
+
contextmanager,
|
22
|
+
AbstractAsyncContextManager,
|
23
|
+
AbstractContextManager,
|
24
|
+
) # Import context manager bases
|
25
|
+
from dataclasses import dataclass
|
26
|
+
from datetime import datetime
|
22
27
|
from http import HTTPStatus
|
23
28
|
from typing import (
|
24
29
|
Any,
|
@@ -37,9 +42,9 @@ from rich import print as rprint
|
|
37
42
|
import types
|
38
43
|
|
39
44
|
# Third-party imports
|
40
|
-
import
|
45
|
+
from requests import RequestException
|
46
|
+
from judgeval.utils.requests import requests
|
41
47
|
from litellm import cost_per_token as _original_cost_per_token
|
42
|
-
from rich import print as rprint
|
43
48
|
from openai import OpenAI, AsyncOpenAI
|
44
49
|
from together import Together, AsyncTogether
|
45
50
|
from anthropic import Anthropic, AsyncAnthropic
|
@@ -50,12 +55,7 @@ from judgeval.constants import (
|
|
50
55
|
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
51
56
|
JUDGMENT_TRACES_SAVE_API_URL,
|
52
57
|
JUDGMENT_TRACES_UPSERT_API_URL,
|
53
|
-
JUDGMENT_TRACES_USAGE_CHECK_API_URL,
|
54
|
-
JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
|
55
58
|
JUDGMENT_TRACES_FETCH_API_URL,
|
56
|
-
RABBITMQ_HOST,
|
57
|
-
RABBITMQ_PORT,
|
58
|
-
RABBITMQ_QUEUE,
|
59
59
|
JUDGMENT_TRACES_DELETE_API_URL,
|
60
60
|
JUDGMENT_PROJECT_DELETE_API_URL,
|
61
61
|
JUDGMENT_TRACES_SPANS_BATCH_API_URL,
|
@@ -70,32 +70,37 @@ from judgeval.common.exceptions import JudgmentAPIError
|
|
70
70
|
|
71
71
|
# Standard library imports needed for the new class
|
72
72
|
import concurrent.futures
|
73
|
-
from collections.abc import Iterator, AsyncIterator
|
73
|
+
from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
|
74
74
|
import queue
|
75
75
|
import atexit
|
76
76
|
|
77
77
|
# Define context variables for tracking the current trace and the current span within a trace
|
78
|
-
current_trace_var = contextvars.ContextVar[Optional[
|
79
|
-
|
78
|
+
current_trace_var = contextvars.ContextVar[Optional["TraceClient"]](
|
79
|
+
"current_trace", default=None
|
80
|
+
)
|
81
|
+
current_span_var = contextvars.ContextVar[Optional[str]](
|
82
|
+
"current_span", default=None
|
83
|
+
) # ContextVar for the active span id
|
80
84
|
|
81
85
|
# Define type aliases for better code readability and maintainability
|
82
|
-
ApiClient: TypeAlias = Union[
|
83
|
-
|
86
|
+
ApiClient: TypeAlias = Union[
|
87
|
+
OpenAI,
|
88
|
+
Together,
|
89
|
+
Anthropic,
|
90
|
+
AsyncOpenAI,
|
91
|
+
AsyncAnthropic,
|
92
|
+
AsyncTogether,
|
93
|
+
genai.Client,
|
94
|
+
genai.client.AsyncClient,
|
95
|
+
] # Supported API clients
|
96
|
+
SpanType = Literal["span", "tool", "llm", "evaluation", "chain"]
|
84
97
|
|
85
|
-
# --- Evaluation Config Dataclass (Moved from langgraph.py) ---
|
86
|
-
@dataclass
|
87
|
-
class EvaluationConfig:
|
88
|
-
"""Configuration for triggering an evaluation from the handler."""
|
89
|
-
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]]
|
90
|
-
example: Example
|
91
|
-
model: Optional[str] = None
|
92
|
-
log_results: Optional[bool] = True
|
93
|
-
# --- End Evaluation Config Dataclass ---
|
94
98
|
|
95
99
|
# Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
|
96
100
|
@dataclass
|
97
101
|
class TraceAnnotation:
|
98
102
|
"""Represents a single annotation for a trace span."""
|
103
|
+
|
99
104
|
span_id: str
|
100
105
|
text: str
|
101
106
|
label: str
|
@@ -105,24 +110,27 @@ class TraceAnnotation:
|
|
105
110
|
"""Convert the annotation to a dictionary format for storage/transmission."""
|
106
111
|
return {
|
107
112
|
"span_id": self.span_id,
|
108
|
-
"annotation": {
|
109
|
-
"text": self.text,
|
110
|
-
"label": self.label,
|
111
|
-
"score": self.score
|
112
|
-
}
|
113
|
+
"annotation": {"text": self.text, "label": self.label, "score": self.score},
|
113
114
|
}
|
114
|
-
|
115
|
+
|
116
|
+
|
115
117
|
class TraceManagerClient:
|
116
118
|
"""
|
117
119
|
Client for handling trace endpoints with the Judgment API
|
118
|
-
|
120
|
+
|
119
121
|
|
120
122
|
Operations include:
|
121
123
|
- Fetching a trace by id
|
122
124
|
- Saving a trace
|
123
125
|
- Deleting a trace
|
124
126
|
"""
|
125
|
-
|
127
|
+
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
judgment_api_key: str,
|
131
|
+
organization_id: str,
|
132
|
+
tracer: Optional["Tracer"] = None,
|
133
|
+
):
|
126
134
|
self.judgment_api_key = judgment_api_key
|
127
135
|
self.organization_id = organization_id
|
128
136
|
self.tracer = tracer
|
@@ -139,17 +147,19 @@ class TraceManagerClient:
|
|
139
147
|
headers={
|
140
148
|
"Content-Type": "application/json",
|
141
149
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
142
|
-
"X-Organization-Id": self.organization_id
|
150
|
+
"X-Organization-Id": self.organization_id,
|
143
151
|
},
|
144
|
-
verify=True
|
152
|
+
verify=True,
|
145
153
|
)
|
146
154
|
|
147
155
|
if response.status_code != HTTPStatus.OK:
|
148
156
|
raise ValueError(f"Failed to fetch traces: {response.text}")
|
149
|
-
|
157
|
+
|
150
158
|
return response.json()
|
151
159
|
|
152
|
-
def save_trace(
|
160
|
+
def save_trace(
|
161
|
+
self, trace_data: dict, offline_mode: bool = False, final_save: bool = True
|
162
|
+
):
|
153
163
|
"""
|
154
164
|
Saves a trace to the Judgment Supabase and optionally to S3 if configured.
|
155
165
|
|
@@ -158,12 +168,12 @@ class TraceManagerClient:
|
|
158
168
|
offline_mode: Whether running in offline mode
|
159
169
|
final_save: Whether this is the final save (controls S3 saving)
|
160
170
|
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
161
|
-
|
171
|
+
|
162
172
|
Returns:
|
163
173
|
dict: Server response containing UI URL and other metadata
|
164
174
|
"""
|
165
175
|
# Save to Judgment API
|
166
|
-
|
176
|
+
|
167
177
|
def fallback_encoder(obj):
|
168
178
|
"""
|
169
179
|
Custom JSON encoder fallback.
|
@@ -180,7 +190,7 @@ class TraceManagerClient:
|
|
180
190
|
except Exception as e:
|
181
191
|
# If both fail, you might return a placeholder or re-raise
|
182
192
|
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
183
|
-
|
193
|
+
|
184
194
|
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
185
195
|
response = requests.post(
|
186
196
|
JUDGMENT_TRACES_SAVE_API_URL,
|
@@ -188,71 +198,46 @@ class TraceManagerClient:
|
|
188
198
|
headers={
|
189
199
|
"Content-Type": "application/json",
|
190
200
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
191
|
-
"X-Organization-Id": self.organization_id
|
201
|
+
"X-Organization-Id": self.organization_id,
|
192
202
|
},
|
193
|
-
verify=True
|
203
|
+
verify=True,
|
194
204
|
)
|
195
|
-
|
205
|
+
|
196
206
|
if response.status_code == HTTPStatus.BAD_REQUEST:
|
197
|
-
raise ValueError(
|
207
|
+
raise ValueError(
|
208
|
+
f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}"
|
209
|
+
)
|
198
210
|
elif response.status_code != HTTPStatus.OK:
|
199
211
|
raise ValueError(f"Failed to save trace data: {response.text}")
|
200
|
-
|
212
|
+
|
201
213
|
# Parse server response
|
202
214
|
server_response = response.json()
|
203
|
-
|
215
|
+
|
204
216
|
# If S3 storage is enabled, save to S3 only on final save
|
205
217
|
if self.tracer and self.tracer.use_s3 and final_save:
|
206
218
|
try:
|
207
219
|
s3_key = self.tracer.s3_storage.save_trace(
|
208
220
|
trace_data=trace_data,
|
209
221
|
trace_id=trace_data["trace_id"],
|
210
|
-
project_name=trace_data["project_name"]
|
222
|
+
project_name=trace_data["project_name"],
|
211
223
|
)
|
212
224
|
print(f"Trace also saved to S3 at key: {s3_key}")
|
213
225
|
except Exception as e:
|
214
226
|
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
215
|
-
|
227
|
+
|
216
228
|
if not offline_mode and "ui_results_url" in server_response:
|
217
229
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
218
230
|
rprint(pretty_str)
|
219
|
-
|
220
|
-
return server_response
|
221
231
|
|
222
|
-
|
223
|
-
"""
|
224
|
-
Check if the organization can use the requested number of traces without exceeding limits.
|
225
|
-
|
226
|
-
Args:
|
227
|
-
count: Number of traces to check for (default: 1)
|
228
|
-
|
229
|
-
Returns:
|
230
|
-
dict: Server response with rate limit status and usage info
|
231
|
-
|
232
|
-
Raises:
|
233
|
-
ValueError: If rate limits would be exceeded or other errors occur
|
234
|
-
"""
|
235
|
-
response = requests.post(
|
236
|
-
JUDGMENT_TRACES_USAGE_CHECK_API_URL,
|
237
|
-
json={"count": count},
|
238
|
-
headers={
|
239
|
-
"Content-Type": "application/json",
|
240
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
241
|
-
"X-Organization-Id": self.organization_id
|
242
|
-
},
|
243
|
-
verify=True
|
244
|
-
)
|
245
|
-
|
246
|
-
if response.status_code == HTTPStatus.FORBIDDEN:
|
247
|
-
# Rate limits exceeded
|
248
|
-
error_data = response.json()
|
249
|
-
raise ValueError(f"Rate limit exceeded: {error_data.get('detail', 'Monthly trace limit reached')}")
|
250
|
-
elif response.status_code != HTTPStatus.OK:
|
251
|
-
raise ValueError(f"Failed to check usage limits: {response.text}")
|
252
|
-
|
253
|
-
return response.json()
|
232
|
+
return server_response
|
254
233
|
|
255
|
-
def upsert_trace(
|
234
|
+
def upsert_trace(
|
235
|
+
self,
|
236
|
+
trace_data: dict,
|
237
|
+
offline_mode: bool = False,
|
238
|
+
show_link: bool = True,
|
239
|
+
final_save: bool = True,
|
240
|
+
):
|
256
241
|
"""
|
257
242
|
Upserts a trace to the Judgment API (always overwrites if exists).
|
258
243
|
|
@@ -261,10 +246,11 @@ class TraceManagerClient:
|
|
261
246
|
offline_mode: Whether running in offline mode
|
262
247
|
show_link: Whether to show the UI link (for live tracing)
|
263
248
|
final_save: Whether this is the final save (controls S3 saving)
|
264
|
-
|
249
|
+
|
265
250
|
Returns:
|
266
251
|
dict: Server response containing UI URL and other metadata
|
267
252
|
"""
|
253
|
+
|
268
254
|
def fallback_encoder(obj):
|
269
255
|
"""
|
270
256
|
Custom JSON encoder fallback.
|
@@ -277,7 +263,7 @@ class TraceManagerClient:
|
|
277
263
|
return str(obj)
|
278
264
|
except Exception as e:
|
279
265
|
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
280
|
-
|
266
|
+
|
281
267
|
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
282
268
|
|
283
269
|
response = requests.post(
|
@@ -286,63 +272,34 @@ class TraceManagerClient:
|
|
286
272
|
headers={
|
287
273
|
"Content-Type": "application/json",
|
288
274
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
289
|
-
"X-Organization-Id": self.organization_id
|
275
|
+
"X-Organization-Id": self.organization_id,
|
290
276
|
},
|
291
|
-
verify=True
|
277
|
+
verify=True,
|
292
278
|
)
|
293
|
-
|
279
|
+
|
294
280
|
if response.status_code != HTTPStatus.OK:
|
295
281
|
raise ValueError(f"Failed to upsert trace data: {response.text}")
|
296
|
-
|
282
|
+
|
297
283
|
# Parse server response
|
298
284
|
server_response = response.json()
|
299
|
-
|
285
|
+
|
300
286
|
# If S3 storage is enabled, save to S3 only on final save
|
301
287
|
if self.tracer and self.tracer.use_s3 and final_save:
|
302
288
|
try:
|
303
289
|
s3_key = self.tracer.s3_storage.save_trace(
|
304
290
|
trace_data=trace_data,
|
305
291
|
trace_id=trace_data["trace_id"],
|
306
|
-
project_name=trace_data["project_name"]
|
292
|
+
project_name=trace_data["project_name"],
|
307
293
|
)
|
308
294
|
print(f"Trace also saved to S3 at key: {s3_key}")
|
309
295
|
except Exception as e:
|
310
296
|
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
311
|
-
|
297
|
+
|
312
298
|
if not offline_mode and show_link and "ui_results_url" in server_response:
|
313
299
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
314
300
|
rprint(pretty_str)
|
315
|
-
|
316
|
-
return server_response
|
317
301
|
|
318
|
-
|
319
|
-
"""
|
320
|
-
Update trace usage counters after successfully saving traces.
|
321
|
-
|
322
|
-
Args:
|
323
|
-
count: Number of traces to count (default: 1)
|
324
|
-
|
325
|
-
Returns:
|
326
|
-
dict: Server response with updated usage information
|
327
|
-
|
328
|
-
Raises:
|
329
|
-
ValueError: If the update fails
|
330
|
-
"""
|
331
|
-
response = requests.post(
|
332
|
-
JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
|
333
|
-
json={"count": count},
|
334
|
-
headers={
|
335
|
-
"Content-Type": "application/json",
|
336
|
-
"Authorization": f"Bearer {self.judgment_api_key}",
|
337
|
-
"X-Organization-Id": self.organization_id
|
338
|
-
},
|
339
|
-
verify=True
|
340
|
-
)
|
341
|
-
|
342
|
-
if response.status_code != HTTPStatus.OK:
|
343
|
-
raise ValueError(f"Failed to update usage counters: {response.text}")
|
344
|
-
|
345
|
-
return response.json()
|
302
|
+
return server_response
|
346
303
|
|
347
304
|
## TODO: Should have a log endpoint, endpoint should also support batched payloads
|
348
305
|
def save_annotation(self, annotation: TraceAnnotation):
|
@@ -351,24 +308,24 @@ class TraceManagerClient:
|
|
351
308
|
"annotation": {
|
352
309
|
"text": annotation.text,
|
353
310
|
"label": annotation.label,
|
354
|
-
"score": annotation.score
|
355
|
-
}
|
356
|
-
}
|
311
|
+
"score": annotation.score,
|
312
|
+
},
|
313
|
+
}
|
357
314
|
|
358
315
|
response = requests.post(
|
359
316
|
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
360
317
|
json=json_data,
|
361
318
|
headers={
|
362
|
-
|
363
|
-
|
364
|
-
|
319
|
+
"Content-Type": "application/json",
|
320
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
321
|
+
"X-Organization-Id": self.organization_id,
|
365
322
|
},
|
366
|
-
verify=True
|
323
|
+
verify=True,
|
367
324
|
)
|
368
|
-
|
325
|
+
|
369
326
|
if response.status_code != HTTPStatus.OK:
|
370
327
|
raise ValueError(f"Failed to save annotation: {response.text}")
|
371
|
-
|
328
|
+
|
372
329
|
return response.json()
|
373
330
|
|
374
331
|
def delete_trace(self, trace_id: str):
|
@@ -383,15 +340,15 @@ class TraceManagerClient:
|
|
383
340
|
headers={
|
384
341
|
"Content-Type": "application/json",
|
385
342
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
386
|
-
"X-Organization-Id": self.organization_id
|
387
|
-
}
|
343
|
+
"X-Organization-Id": self.organization_id,
|
344
|
+
},
|
388
345
|
)
|
389
346
|
|
390
347
|
if response.status_code != HTTPStatus.OK:
|
391
348
|
raise ValueError(f"Failed to delete trace: {response.text}")
|
392
|
-
|
349
|
+
|
393
350
|
return response.json()
|
394
|
-
|
351
|
+
|
395
352
|
def delete_traces(self, trace_ids: List[str]):
|
396
353
|
"""
|
397
354
|
Delete a batch of traces from the database.
|
@@ -404,15 +361,15 @@ class TraceManagerClient:
|
|
404
361
|
headers={
|
405
362
|
"Content-Type": "application/json",
|
406
363
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
407
|
-
"X-Organization-Id": self.organization_id
|
408
|
-
}
|
364
|
+
"X-Organization-Id": self.organization_id,
|
365
|
+
},
|
409
366
|
)
|
410
367
|
|
411
368
|
if response.status_code != HTTPStatus.OK:
|
412
369
|
raise ValueError(f"Failed to delete trace: {response.text}")
|
413
|
-
|
370
|
+
|
414
371
|
return response.json()
|
415
|
-
|
372
|
+
|
416
373
|
def delete_project(self, project_name: str):
|
417
374
|
"""
|
418
375
|
Deletes a project from the server. Which also deletes all evaluations and traces associated with the project.
|
@@ -425,31 +382,31 @@ class TraceManagerClient:
|
|
425
382
|
headers={
|
426
383
|
"Content-Type": "application/json",
|
427
384
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
428
|
-
"X-Organization-Id": self.organization_id
|
429
|
-
}
|
385
|
+
"X-Organization-Id": self.organization_id,
|
386
|
+
},
|
430
387
|
)
|
431
388
|
|
432
389
|
if response.status_code != HTTPStatus.OK:
|
433
390
|
raise ValueError(f"Failed to delete traces: {response.text}")
|
434
|
-
|
391
|
+
|
435
392
|
return response.json()
|
436
393
|
|
437
394
|
|
438
395
|
class TraceClient:
|
439
396
|
"""Client for managing a single trace context"""
|
440
|
-
|
397
|
+
|
441
398
|
def __init__(
|
442
399
|
self,
|
443
|
-
tracer:
|
400
|
+
tracer: "Tracer",
|
444
401
|
trace_id: Optional[str] = None,
|
445
402
|
name: str = "default",
|
446
|
-
project_name: str = None,
|
403
|
+
project_name: str | None = None,
|
447
404
|
overwrite: bool = False,
|
448
405
|
rules: Optional[List[Rule]] = None,
|
449
406
|
enable_monitoring: bool = True,
|
450
407
|
enable_evaluations: bool = True,
|
451
408
|
parent_trace_id: Optional[str] = None,
|
452
|
-
parent_name: Optional[str] = None
|
409
|
+
parent_name: Optional[str] = None,
|
453
410
|
):
|
454
411
|
self.name = name
|
455
412
|
self.trace_id = trace_id or str(uuid.uuid4())
|
@@ -461,39 +418,48 @@ class TraceClient:
|
|
461
418
|
self.enable_evaluations = enable_evaluations
|
462
419
|
self.parent_trace_id = parent_trace_id
|
463
420
|
self.parent_name = parent_name
|
421
|
+
self.customer_id: Optional[str] = None # Added customer_id attribute
|
422
|
+
self.tags: List[Union[str, set, tuple]] = [] # Added tags attribute
|
423
|
+
self.metadata: Dict[str, Any] = {}
|
424
|
+
self.has_notification: Optional[bool] = False # Initialize has_notification
|
464
425
|
self.trace_spans: List[TraceSpan] = []
|
465
426
|
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
466
427
|
self.evaluation_runs: List[EvaluationRun] = []
|
467
428
|
self.annotations: List[TraceAnnotation] = []
|
468
|
-
self.start_time =
|
469
|
-
|
470
|
-
|
471
|
-
self.
|
472
|
-
|
473
|
-
|
474
|
-
|
429
|
+
self.start_time: Optional[float] = (
|
430
|
+
None # Will be set after first successful save
|
431
|
+
)
|
432
|
+
self.trace_manager_client = TraceManagerClient(
|
433
|
+
tracer.api_key, tracer.organization_id, tracer
|
434
|
+
)
|
435
|
+
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
436
|
+
|
475
437
|
# Get background span service from tracer
|
476
|
-
self.background_span_service =
|
438
|
+
self.background_span_service = (
|
439
|
+
tracer.get_background_span_service() if tracer else None
|
440
|
+
)
|
477
441
|
|
478
442
|
def get_current_span(self):
|
479
443
|
"""Get the current span from the context var"""
|
480
444
|
return self.tracer.get_current_span()
|
481
|
-
|
445
|
+
|
482
446
|
def set_current_span(self, span: Any):
|
483
447
|
"""Set the current span from the context var"""
|
484
448
|
return self.tracer.set_current_span(span)
|
485
|
-
|
449
|
+
|
486
450
|
def reset_current_span(self, token: Any):
|
487
451
|
"""Reset the current span from the context var"""
|
488
452
|
self.tracer.reset_current_span(token)
|
489
|
-
|
453
|
+
|
490
454
|
@contextmanager
|
491
455
|
def span(self, name: str, span_type: SpanType = "span"):
|
492
456
|
"""Context manager for creating a trace span, managing the current span via contextvars"""
|
493
457
|
is_first_span = len(self.trace_spans) == 0
|
494
458
|
if is_first_span:
|
495
459
|
try:
|
496
|
-
trace_id, server_response = self.
|
460
|
+
trace_id, server_response = self.save(
|
461
|
+
overwrite=self.overwrite, final_save=False
|
462
|
+
)
|
497
463
|
# Set start_time after first successful save
|
498
464
|
if self.start_time is None:
|
499
465
|
self.start_time = time.time()
|
@@ -501,19 +467,23 @@ class TraceClient:
|
|
501
467
|
except Exception as e:
|
502
468
|
warnings.warn(f"Failed to save initial trace for live tracking: {e}")
|
503
469
|
start_time = time.time()
|
504
|
-
|
470
|
+
|
505
471
|
# Generate a unique ID for *this specific span invocation*
|
506
472
|
span_id = str(uuid.uuid4())
|
507
|
-
|
508
|
-
parent_span_id =
|
509
|
-
|
510
|
-
|
473
|
+
|
474
|
+
parent_span_id = (
|
475
|
+
self.get_current_span()
|
476
|
+
) # Get ID of the parent span from context var
|
477
|
+
token = self.set_current_span(
|
478
|
+
span_id
|
479
|
+
) # Set *this* span's ID as the current one
|
480
|
+
|
511
481
|
current_depth = 0
|
512
482
|
if parent_span_id and parent_span_id in self._span_depths:
|
513
483
|
current_depth = self._span_depths[parent_span_id] + 1
|
514
|
-
|
515
|
-
self._span_depths[span_id] = current_depth
|
516
|
-
|
484
|
+
|
485
|
+
self._span_depths[span_id] = current_depth # Store depth by span_id
|
486
|
+
|
517
487
|
span = TraceSpan(
|
518
488
|
span_id=span_id,
|
519
489
|
trace_id=self.trace_id,
|
@@ -525,23 +495,21 @@ class TraceClient:
|
|
525
495
|
function=name,
|
526
496
|
)
|
527
497
|
self.add_span(span)
|
528
|
-
|
529
|
-
|
530
|
-
|
498
|
+
|
531
499
|
# Queue span with initial state (input phase)
|
532
500
|
if self.background_span_service:
|
533
501
|
self.background_span_service.queue_span(span, span_state="input")
|
534
|
-
|
502
|
+
|
535
503
|
try:
|
536
504
|
yield self
|
537
505
|
finally:
|
538
506
|
duration = time.time() - start_time
|
539
507
|
span.duration = duration
|
540
|
-
|
508
|
+
|
541
509
|
# Queue span with completed state (output phase)
|
542
510
|
if self.background_span_service:
|
543
511
|
self.background_span_service.queue_span(span, span_state="completed")
|
544
|
-
|
512
|
+
|
545
513
|
# Clean up depth tracking for this span_id
|
546
514
|
if span_id in self._span_depths:
|
547
515
|
del self._span_depths[span_id]
|
@@ -561,12 +529,11 @@ class TraceClient:
|
|
561
529
|
expected_tools: Optional[List[str]] = None,
|
562
530
|
additional_metadata: Optional[Dict[str, Any]] = None,
|
563
531
|
model: Optional[str] = None,
|
564
|
-
span_id: Optional[str] = None,
|
565
|
-
log_results: Optional[bool] = True
|
532
|
+
span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
|
566
533
|
):
|
567
534
|
if not self.enable_evaluations:
|
568
535
|
return
|
569
|
-
|
536
|
+
|
570
537
|
start_time = time.time() # Record start time
|
571
538
|
|
572
539
|
try:
|
@@ -574,21 +541,35 @@ class TraceClient:
|
|
574
541
|
if not scorers:
|
575
542
|
warnings.warn("No valid scorers available for evaluation")
|
576
543
|
return
|
577
|
-
|
544
|
+
|
578
545
|
# Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
|
579
|
-
if self.rules and any(
|
580
|
-
|
581
|
-
|
546
|
+
if self.rules and any(
|
547
|
+
isinstance(scorer, JudgevalScorer) for scorer in scorers
|
548
|
+
):
|
549
|
+
raise ValueError(
|
550
|
+
"Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types."
|
551
|
+
)
|
552
|
+
|
582
553
|
except Exception as e:
|
583
554
|
warnings.warn(f"Failed to load scorers: {str(e)}")
|
584
555
|
return
|
585
|
-
|
556
|
+
|
586
557
|
# If example is not provided, create one from the individual parameters
|
587
558
|
if example is None:
|
588
559
|
# Check if any of the individual parameters are provided
|
589
|
-
if any(
|
590
|
-
|
591
|
-
|
560
|
+
if any(
|
561
|
+
param is not None
|
562
|
+
for param in [
|
563
|
+
input,
|
564
|
+
actual_output,
|
565
|
+
expected_output,
|
566
|
+
context,
|
567
|
+
retrieval_context,
|
568
|
+
tools_called,
|
569
|
+
expected_tools,
|
570
|
+
additional_metadata,
|
571
|
+
]
|
572
|
+
):
|
592
573
|
example = Example(
|
593
574
|
input=input,
|
594
575
|
actual_output=actual_output,
|
@@ -600,12 +581,14 @@ class TraceClient:
|
|
600
581
|
additional_metadata=additional_metadata,
|
601
582
|
)
|
602
583
|
else:
|
603
|
-
raise ValueError(
|
604
|
-
|
584
|
+
raise ValueError(
|
585
|
+
"Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided"
|
586
|
+
)
|
587
|
+
|
605
588
|
# Check examples before creating evaluation run
|
606
|
-
|
589
|
+
|
607
590
|
# check_examples([example], scorers)
|
608
|
-
|
591
|
+
|
609
592
|
# --- Modification: Capture span_id immediately ---
|
610
593
|
# span_id_at_eval_call = current_span_var.get()
|
611
594
|
# print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
|
@@ -617,37 +600,32 @@ class TraceClient:
|
|
617
600
|
# Combine the trace-level rules with any evaluation-specific rules)
|
618
601
|
eval_run = EvaluationRun(
|
619
602
|
organization_id=self.tracer.organization_id,
|
620
|
-
log_results=log_results,
|
621
603
|
project_name=self.project_name,
|
622
604
|
eval_name=f"{self.name.capitalize()}-"
|
623
|
-
|
624
|
-
|
605
|
+
f"{span_id_to_use}-" # Keep original eval name format using context var if available
|
606
|
+
f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
|
625
607
|
examples=[example],
|
626
608
|
scorers=scorers,
|
627
609
|
model=model,
|
628
|
-
metadata={},
|
629
610
|
judgment_api_key=self.tracer.api_key,
|
630
611
|
override=self.overwrite,
|
631
|
-
trace_span_id=span_id_to_use,
|
632
|
-
rules=self.rules # Use the combined rules
|
612
|
+
trace_span_id=span_id_to_use,
|
633
613
|
)
|
634
|
-
|
614
|
+
|
635
615
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
636
|
-
|
616
|
+
|
637
617
|
# Queue evaluation run through background service
|
638
618
|
if self.background_span_service and span_id_to_use:
|
639
619
|
# Get the current span data to avoid race conditions
|
640
620
|
current_span = self.span_id_to_span.get(span_id_to_use)
|
641
621
|
if current_span:
|
642
622
|
self.background_span_service.queue_evaluation_run(
|
643
|
-
eval_run,
|
644
|
-
span_id=span_id_to_use,
|
645
|
-
span_data=current_span
|
623
|
+
eval_run, span_id=span_id_to_use, span_data=current_span
|
646
624
|
)
|
647
|
-
|
625
|
+
|
648
626
|
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
649
|
-
# --- Modification: Use span_id from eval_run ---
|
650
|
-
current_span_id = eval_run.trace_span_id
|
627
|
+
# --- Modification: Use span_id from eval_run ---
|
628
|
+
current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
|
651
629
|
# print(f"[TraceClient.add_eval_run] Using span_id from eval_run: {current_span_id}")
|
652
630
|
# --- End Modification ---
|
653
631
|
|
@@ -658,10 +636,10 @@ class TraceClient:
|
|
658
636
|
self.evaluation_runs.append(eval_run)
|
659
637
|
|
660
638
|
def add_annotation(self, annotation: TraceAnnotation):
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
639
|
+
"""Add an annotation to this trace context"""
|
640
|
+
self.annotations.append(annotation)
|
641
|
+
return self
|
642
|
+
|
665
643
|
def record_input(self, inputs: dict):
|
666
644
|
current_span_id = self.get_current_span()
|
667
645
|
if current_span_id:
|
@@ -670,17 +648,20 @@ class TraceClient:
|
|
670
648
|
if "self" in inputs:
|
671
649
|
del inputs["self"]
|
672
650
|
span.inputs = inputs
|
673
|
-
|
651
|
+
|
674
652
|
# Queue span with input data
|
675
|
-
|
676
|
-
self.background_span_service
|
677
|
-
|
653
|
+
try:
|
654
|
+
if self.background_span_service:
|
655
|
+
self.background_span_service.queue_span(span, span_state="input")
|
656
|
+
except Exception as e:
|
657
|
+
warnings.warn(f"Failed to queue span with input data: {e}")
|
658
|
+
|
678
659
|
def record_agent_name(self, agent_name: str):
|
679
660
|
current_span_id = self.get_current_span()
|
680
661
|
if current_span_id:
|
681
662
|
span = self.span_id_to_span[current_span_id]
|
682
663
|
span.agent_name = agent_name
|
683
|
-
|
664
|
+
|
684
665
|
# Queue span with agent_name data
|
685
666
|
if self.background_span_service:
|
686
667
|
self.background_span_service.queue_span(span, span_state="agent_name")
|
@@ -695,11 +676,11 @@ class TraceClient:
|
|
695
676
|
if current_span_id:
|
696
677
|
span = self.span_id_to_span[current_span_id]
|
697
678
|
span.state_before = state
|
698
|
-
|
679
|
+
|
699
680
|
# Queue span with state_before data
|
700
681
|
if self.background_span_service:
|
701
682
|
self.background_span_service.queue_span(span, span_state="state_before")
|
702
|
-
|
683
|
+
|
703
684
|
def record_state_after(self, state: dict):
|
704
685
|
"""Records the agent's state after a tool execution on the current span.
|
705
686
|
|
@@ -710,7 +691,7 @@ class TraceClient:
|
|
710
691
|
if current_span_id:
|
711
692
|
span = self.span_id_to_span[current_span_id]
|
712
693
|
span.state_after = state
|
713
|
-
|
694
|
+
|
714
695
|
# Queue span with state_after data
|
715
696
|
if self.background_span_service:
|
716
697
|
self.background_span_service.queue_span(span, span_state="state_after")
|
@@ -720,19 +701,19 @@ class TraceClient:
|
|
720
701
|
try:
|
721
702
|
result = await coroutine
|
722
703
|
setattr(span, field, result)
|
723
|
-
|
704
|
+
|
724
705
|
# Queue span with output data now that coroutine is complete
|
725
706
|
if self.background_span_service and field == "output":
|
726
707
|
self.background_span_service.queue_span(span, span_state="output")
|
727
|
-
|
708
|
+
|
728
709
|
return result
|
729
710
|
except Exception as e:
|
730
711
|
setattr(span, field, f"Error: {str(e)}")
|
731
|
-
|
712
|
+
|
732
713
|
# Queue span even if there was an error
|
733
714
|
if self.background_span_service and field == "output":
|
734
715
|
self.background_span_service.queue_span(span, span_state="output")
|
735
|
-
|
716
|
+
|
736
717
|
raise
|
737
718
|
|
738
719
|
def record_output(self, output: Any):
|
@@ -740,56 +721,56 @@ class TraceClient:
|
|
740
721
|
if current_span_id:
|
741
722
|
span = self.span_id_to_span[current_span_id]
|
742
723
|
span.output = "<pending>" if inspect.iscoroutine(output) else output
|
743
|
-
|
724
|
+
|
744
725
|
if inspect.iscoroutine(output):
|
745
726
|
asyncio.create_task(self._update_coroutine(span, output, "output"))
|
746
|
-
|
727
|
+
|
747
728
|
# # Queue span with output data (unless it's pending)
|
748
729
|
if self.background_span_service and not inspect.iscoroutine(output):
|
749
730
|
self.background_span_service.queue_span(span, span_state="output")
|
750
731
|
|
751
|
-
return span
|
732
|
+
return span # Return the created entry
|
752
733
|
# Removed else block - original didn't have one
|
753
|
-
return None
|
754
|
-
|
734
|
+
return None # Return None if no span_id found
|
735
|
+
|
755
736
|
def record_usage(self, usage: TraceUsage):
|
756
737
|
current_span_id = self.get_current_span()
|
757
738
|
if current_span_id:
|
758
739
|
span = self.span_id_to_span[current_span_id]
|
759
740
|
span.usage = usage
|
760
|
-
|
741
|
+
|
761
742
|
# Queue span with usage data
|
762
743
|
if self.background_span_service:
|
763
744
|
self.background_span_service.queue_span(span, span_state="usage")
|
764
|
-
|
765
|
-
return span
|
745
|
+
|
746
|
+
return span # Return the created entry
|
766
747
|
# Removed else block - original didn't have one
|
767
|
-
return None
|
768
|
-
|
748
|
+
return None # Return None if no span_id found
|
749
|
+
|
769
750
|
def record_error(self, error: Dict[str, Any]):
|
770
751
|
current_span_id = self.get_current_span()
|
771
752
|
if current_span_id:
|
772
753
|
span = self.span_id_to_span[current_span_id]
|
773
754
|
span.error = error
|
774
|
-
|
755
|
+
|
775
756
|
# Queue span with error data
|
776
757
|
if self.background_span_service:
|
777
758
|
self.background_span_service.queue_span(span, span_state="error")
|
778
|
-
|
759
|
+
|
779
760
|
return span
|
780
761
|
return None
|
781
|
-
|
762
|
+
|
782
763
|
def add_span(self, span: TraceSpan):
|
783
764
|
"""Add a trace span to this trace context"""
|
784
765
|
self.trace_spans.append(span)
|
785
766
|
self.span_id_to_span[span.span_id] = span
|
786
767
|
return self
|
787
|
-
|
768
|
+
|
788
769
|
def print(self):
|
789
770
|
"""Print the complete trace with proper visual structure"""
|
790
771
|
for span in self.trace_spans:
|
791
772
|
span.print_span()
|
792
|
-
|
773
|
+
|
793
774
|
def get_duration(self) -> float:
|
794
775
|
"""
|
795
776
|
Get the total duration of this trace
|
@@ -798,57 +779,23 @@ class TraceClient:
|
|
798
779
|
return 0.0 # No duration if trace hasn't been saved yet
|
799
780
|
return time.time() - self.start_time
|
800
781
|
|
801
|
-
def save(
|
802
|
-
|
803
|
-
|
804
|
-
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
805
|
-
"""
|
806
|
-
# Set start_time if this is the first save
|
807
|
-
if self.start_time is None:
|
808
|
-
self.start_time = time.time()
|
809
|
-
|
810
|
-
# Calculate total elapsed time
|
811
|
-
total_duration = self.get_duration()
|
812
|
-
# Create trace document - Always use standard keys for top-level counts
|
813
|
-
trace_data = {
|
814
|
-
"trace_id": self.trace_id,
|
815
|
-
"name": self.name,
|
816
|
-
"project_name": self.project_name,
|
817
|
-
"created_at": datetime.fromtimestamp(self.start_time, timezone.utc).isoformat(),
|
818
|
-
"duration": total_duration,
|
819
|
-
"trace_spans": [span.model_dump() for span in self.trace_spans],
|
820
|
-
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
821
|
-
"overwrite": overwrite,
|
822
|
-
"offline_mode": self.tracer.offline_mode,
|
823
|
-
"parent_trace_id": self.parent_trace_id,
|
824
|
-
"parent_name": self.parent_name
|
825
|
-
}
|
826
|
-
# --- Log trace data before saving ---
|
827
|
-
server_response = self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode, final_save=True)
|
828
|
-
|
829
|
-
# upload annotations
|
830
|
-
# TODO: batch to the log endpoint
|
831
|
-
for annotation in self.annotations:
|
832
|
-
self.trace_manager_client.save_annotation(annotation)
|
833
|
-
|
834
|
-
return self.trace_id, server_response
|
835
|
-
|
836
|
-
def save_with_rate_limiting(self, overwrite: bool = False, final_save: bool = False) -> Tuple[str, dict]:
|
782
|
+
def save(
|
783
|
+
self, overwrite: bool = False, final_save: bool = False
|
784
|
+
) -> Tuple[str, dict]:
|
837
785
|
"""
|
838
786
|
Save the current trace to the database with rate limiting checks.
|
839
787
|
First checks usage limits, then upserts the trace if allowed.
|
840
|
-
|
788
|
+
|
841
789
|
Args:
|
842
790
|
overwrite: Whether to overwrite existing traces
|
843
791
|
final_save: Whether this is the final save (updates usage counters)
|
844
|
-
|
792
|
+
|
845
793
|
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
846
794
|
"""
|
847
795
|
|
848
|
-
|
849
796
|
# Calculate total elapsed time
|
850
797
|
total_duration = self.get_duration()
|
851
|
-
|
798
|
+
|
852
799
|
# Create trace document
|
853
800
|
trace_data = {
|
854
801
|
"trace_id": self.trace_id,
|
@@ -861,35 +808,20 @@ class TraceClient:
|
|
861
808
|
"overwrite": overwrite,
|
862
809
|
"offline_mode": self.tracer.offline_mode,
|
863
810
|
"parent_trace_id": self.parent_trace_id,
|
864
|
-
"parent_name": self.parent_name
|
811
|
+
"parent_name": self.parent_name,
|
812
|
+
"customer_id": self.customer_id,
|
813
|
+
"tags": self.tags,
|
814
|
+
"metadata": self.metadata,
|
865
815
|
}
|
866
|
-
|
867
|
-
# Check usage limits first
|
868
|
-
try:
|
869
|
-
usage_check_result = self.trace_manager_client.check_usage_limits(count=1)
|
870
|
-
# Usage check passed silently - no need to show detailed info
|
871
|
-
except ValueError as e:
|
872
|
-
# Rate limit exceeded
|
873
|
-
warnings.warn(f"Rate limit check failed for live tracing: {e}")
|
874
|
-
raise e
|
875
|
-
|
816
|
+
|
876
817
|
# If usage check passes, upsert the trace
|
877
818
|
server_response = self.trace_manager_client.upsert_trace(
|
878
|
-
trace_data,
|
819
|
+
trace_data,
|
879
820
|
offline_mode=self.tracer.offline_mode,
|
880
821
|
show_link=not final_save, # Show link only on initial save, not final save
|
881
|
-
final_save=final_save # Pass final_save to control S3 saving
|
822
|
+
final_save=final_save, # Pass final_save to control S3 saving
|
882
823
|
)
|
883
824
|
|
884
|
-
# Update usage counters only on final save
|
885
|
-
if final_save:
|
886
|
-
try:
|
887
|
-
usage_update_result = self.trace_manager_client.update_usage_counters(count=1)
|
888
|
-
# Usage updated silently - no need to show detailed usage info
|
889
|
-
except ValueError as e:
|
890
|
-
# Log warning but don't fail the trace save since the trace was already saved
|
891
|
-
warnings.warn(f"Usage counter update failed (trace was still saved): {e}")
|
892
|
-
|
893
825
|
# Upload annotations
|
894
826
|
# TODO: batch to the log endpoint
|
895
827
|
for annotation in self.annotations:
|
@@ -900,8 +832,77 @@ class TraceClient:
|
|
900
832
|
|
901
833
|
def delete(self):
|
902
834
|
return self.trace_manager_client.delete_trace(self.trace_id)
|
903
|
-
|
904
|
-
def
|
835
|
+
|
836
|
+
def update_metadata(self, metadata: dict):
|
837
|
+
"""
|
838
|
+
Set metadata for this trace.
|
839
|
+
|
840
|
+
Args:
|
841
|
+
metadata: Metadata as a dictionary
|
842
|
+
|
843
|
+
Supported keys:
|
844
|
+
- customer_id: ID of the customer using this trace
|
845
|
+
- tags: List of tags for this trace
|
846
|
+
- has_notification: Whether this trace has a notification
|
847
|
+
- overwrite: Whether to overwrite existing traces
|
848
|
+
- name: Name of the trace
|
849
|
+
"""
|
850
|
+
for k, v in metadata.items():
|
851
|
+
if k == "customer_id":
|
852
|
+
if v is not None:
|
853
|
+
self.customer_id = str(v)
|
854
|
+
else:
|
855
|
+
self.customer_id = None
|
856
|
+
elif k == "tags":
|
857
|
+
if isinstance(v, list):
|
858
|
+
# Validate that all items in the list are of the expected types
|
859
|
+
for item in v:
|
860
|
+
if not isinstance(item, (str, set, tuple)):
|
861
|
+
raise ValueError(
|
862
|
+
f"Tags must be a list of strings, sets, or tuples, got item of type {type(item)}"
|
863
|
+
)
|
864
|
+
self.tags = v
|
865
|
+
else:
|
866
|
+
raise ValueError(
|
867
|
+
f"Tags must be a list of strings, sets, or tuples, got {type(v)}"
|
868
|
+
)
|
869
|
+
elif k == "has_notification":
|
870
|
+
if not isinstance(v, bool):
|
871
|
+
raise ValueError(
|
872
|
+
f"has_notification must be a boolean, got {type(v)}"
|
873
|
+
)
|
874
|
+
self.has_notification = v
|
875
|
+
elif k == "overwrite":
|
876
|
+
if not isinstance(v, bool):
|
877
|
+
raise ValueError(f"overwrite must be a boolean, got {type(v)}")
|
878
|
+
self.overwrite = v
|
879
|
+
elif k == "name":
|
880
|
+
self.name = v
|
881
|
+
else:
|
882
|
+
self.metadata[k] = v
|
883
|
+
|
884
|
+
def set_customer_id(self, customer_id: str):
|
885
|
+
"""
|
886
|
+
Set the customer ID for this trace.
|
887
|
+
|
888
|
+
Args:
|
889
|
+
customer_id: The customer ID to set
|
890
|
+
"""
|
891
|
+
self.update_metadata({"customer_id": customer_id})
|
892
|
+
|
893
|
+
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
894
|
+
"""
|
895
|
+
Set the tags for this trace.
|
896
|
+
|
897
|
+
Args:
|
898
|
+
tags: List of tags to set
|
899
|
+
"""
|
900
|
+
self.update_metadata({"tags": tags})
|
901
|
+
|
902
|
+
|
903
|
+
def _capture_exception_for_trace(
|
904
|
+
current_trace: Optional["TraceClient"], exc_info: ExcInfo
|
905
|
+
):
|
905
906
|
if not current_trace:
|
906
907
|
return
|
907
908
|
|
@@ -909,18 +910,20 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
|
|
909
910
|
formatted_exception = {
|
910
911
|
"type": exc_type.__name__ if exc_type else "UnknownExceptionType",
|
911
912
|
"message": str(exc_value) if exc_value else "No exception message",
|
912
|
-
"traceback": traceback.format_tb(exc_traceback_obj)
|
913
|
+
"traceback": traceback.format_tb(exc_traceback_obj)
|
914
|
+
if exc_traceback_obj
|
915
|
+
else [],
|
913
916
|
}
|
914
|
-
|
917
|
+
|
915
918
|
# This is where we specially handle exceptions that we might want to collect additional data for.
|
916
919
|
# When we do this, always try checking the module from sys.modules instead of importing. This will
|
917
920
|
# Let us support a wider range of exceptions without needing to import them for all clients.
|
918
|
-
|
919
|
-
# Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
|
921
|
+
|
922
|
+
# Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
|
920
923
|
# The alternative is to hand select libraries we want from sys.modules and check for them:
|
921
924
|
# As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
|
922
925
|
|
923
|
-
|
926
|
+
# General HTTP Like errors
|
924
927
|
try:
|
925
928
|
url = getattr(getattr(exc_value, "request", None), "url", None)
|
926
929
|
status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
|
@@ -929,34 +932,43 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
|
|
929
932
|
"url": url if url else "Unknown URL",
|
930
933
|
"status_code": status_code if status_code else None,
|
931
934
|
}
|
932
|
-
except Exception
|
935
|
+
except Exception:
|
933
936
|
pass
|
934
937
|
|
935
938
|
current_trace.record_error(formatted_exception)
|
936
|
-
|
939
|
+
|
937
940
|
# Queue the span with error state through background service
|
938
941
|
if current_trace.background_span_service:
|
939
942
|
current_span_id = current_trace.get_current_span()
|
940
943
|
if current_span_id and current_span_id in current_trace.span_id_to_span:
|
941
944
|
error_span = current_trace.span_id_to_span[current_span_id]
|
942
|
-
current_trace.background_span_service.queue_span(
|
945
|
+
current_trace.background_span_service.queue_span(
|
946
|
+
error_span, span_state="error"
|
947
|
+
)
|
948
|
+
|
943
949
|
|
944
950
|
class BackgroundSpanService:
|
945
951
|
"""
|
946
952
|
Background service for queueing and batching trace spans for efficient saving.
|
947
|
-
|
953
|
+
|
948
954
|
This service:
|
949
955
|
- Queues spans as they complete
|
950
956
|
- Batches them for efficient network usage
|
951
957
|
- Sends spans periodically or when batches reach a certain size
|
952
958
|
- Handles automatic flushing when the main event terminates
|
953
959
|
"""
|
954
|
-
|
955
|
-
def __init__(
|
956
|
-
|
960
|
+
|
961
|
+
def __init__(
|
962
|
+
self,
|
963
|
+
judgment_api_key: str,
|
964
|
+
organization_id: str,
|
965
|
+
batch_size: int = 10,
|
966
|
+
flush_interval: float = 5.0,
|
967
|
+
num_workers: int = 1,
|
968
|
+
):
|
957
969
|
"""
|
958
970
|
Initialize the background span service.
|
959
|
-
|
971
|
+
|
960
972
|
Args:
|
961
973
|
judgment_api_key: API key for Judgment service
|
962
974
|
organization_id: Organization ID
|
@@ -969,41 +981,41 @@ class BackgroundSpanService:
|
|
969
981
|
self.batch_size = batch_size
|
970
982
|
self.flush_interval = flush_interval
|
971
983
|
self.num_workers = max(1, num_workers) # Ensure at least 1 worker
|
972
|
-
|
984
|
+
|
973
985
|
# Queue for pending spans
|
974
|
-
self._span_queue = queue.Queue()
|
975
|
-
|
986
|
+
self._span_queue: queue.Queue[Dict[str, Any]] = queue.Queue()
|
987
|
+
|
976
988
|
# Background threads for processing spans
|
977
|
-
self._worker_threads = []
|
989
|
+
self._worker_threads: List[threading.Thread] = []
|
978
990
|
self._shutdown_event = threading.Event()
|
979
|
-
|
991
|
+
|
980
992
|
# Track spans that have been sent
|
981
|
-
self._sent_spans = set()
|
982
|
-
|
993
|
+
# self._sent_spans = set()
|
994
|
+
|
983
995
|
# Register cleanup on exit
|
984
996
|
atexit.register(self.shutdown)
|
985
|
-
|
997
|
+
|
986
998
|
# Start the background workers
|
987
999
|
self._start_workers()
|
988
|
-
|
1000
|
+
|
989
1001
|
def _start_workers(self):
|
990
1002
|
"""Start the background worker threads."""
|
991
1003
|
for i in range(self.num_workers):
|
992
1004
|
if len(self._worker_threads) < self.num_workers:
|
993
1005
|
worker_thread = threading.Thread(
|
994
|
-
target=self._worker_loop,
|
995
|
-
daemon=True,
|
996
|
-
name=f"SpanWorker-{i+1}"
|
1006
|
+
target=self._worker_loop, daemon=True, name=f"SpanWorker-{i + 1}"
|
997
1007
|
)
|
998
1008
|
worker_thread.start()
|
999
1009
|
self._worker_threads.append(worker_thread)
|
1000
|
-
|
1010
|
+
|
1001
1011
|
def _worker_loop(self):
|
1002
1012
|
"""Main worker loop that processes spans in batches."""
|
1003
1013
|
batch = []
|
1004
1014
|
last_flush_time = time.time()
|
1005
|
-
pending_task_count =
|
1006
|
-
|
1015
|
+
pending_task_count = (
|
1016
|
+
0 # Track how many tasks we've taken from queue but not marked done
|
1017
|
+
)
|
1018
|
+
|
1007
1019
|
while not self._shutdown_event.is_set() or self._span_queue.qsize() > 0:
|
1008
1020
|
try:
|
1009
1021
|
# First, do a blocking get to wait for at least one item
|
@@ -1015,7 +1027,7 @@ class BackgroundSpanService:
|
|
1015
1027
|
except queue.Empty:
|
1016
1028
|
# No new spans, continue to check for flush conditions
|
1017
1029
|
pass
|
1018
|
-
|
1030
|
+
|
1019
1031
|
# Then, do non-blocking gets to drain any additional available items
|
1020
1032
|
# up to our batch size limit
|
1021
1033
|
while len(batch) < self.batch_size:
|
@@ -1026,24 +1038,23 @@ class BackgroundSpanService:
|
|
1026
1038
|
except queue.Empty:
|
1027
1039
|
# No more items immediately available
|
1028
1040
|
break
|
1029
|
-
|
1041
|
+
|
1030
1042
|
current_time = time.time()
|
1031
|
-
should_flush = (
|
1032
|
-
|
1033
|
-
(batch and (current_time - last_flush_time) >= self.flush_interval)
|
1043
|
+
should_flush = len(batch) >= self.batch_size or (
|
1044
|
+
batch and (current_time - last_flush_time) >= self.flush_interval
|
1034
1045
|
)
|
1035
|
-
|
1046
|
+
|
1036
1047
|
if should_flush and batch:
|
1037
1048
|
self._send_batch(batch)
|
1038
|
-
|
1049
|
+
|
1039
1050
|
# Only mark tasks as done after successful sending
|
1040
1051
|
for _ in range(pending_task_count):
|
1041
1052
|
self._span_queue.task_done()
|
1042
1053
|
pending_task_count = 0 # Reset counter
|
1043
|
-
|
1054
|
+
|
1044
1055
|
batch.clear()
|
1045
1056
|
last_flush_time = current_time
|
1046
|
-
|
1057
|
+
|
1047
1058
|
except Exception as e:
|
1048
1059
|
warnings.warn(f"Error in span service worker loop: {e}")
|
1049
1060
|
# On error, still need to mark tasks as done to prevent hanging
|
@@ -1051,53 +1062,50 @@ class BackgroundSpanService:
|
|
1051
1062
|
self._span_queue.task_done()
|
1052
1063
|
pending_task_count = 0
|
1053
1064
|
batch.clear()
|
1054
|
-
|
1065
|
+
|
1055
1066
|
# Final flush on shutdown
|
1056
1067
|
if batch:
|
1057
1068
|
self._send_batch(batch)
|
1058
1069
|
# Mark remaining tasks as done
|
1059
1070
|
for _ in range(pending_task_count):
|
1060
1071
|
self._span_queue.task_done()
|
1061
|
-
|
1072
|
+
|
1062
1073
|
def _send_batch(self, batch: List[Dict[str, Any]]):
|
1063
1074
|
"""
|
1064
1075
|
Send a batch of spans to the server.
|
1065
|
-
|
1076
|
+
|
1066
1077
|
Args:
|
1067
1078
|
batch: List of span dictionaries to send
|
1068
1079
|
"""
|
1069
1080
|
if not batch:
|
1070
1081
|
return
|
1071
|
-
|
1082
|
+
|
1072
1083
|
try:
|
1073
|
-
# Group
|
1084
|
+
# Group items by type for different endpoints
|
1074
1085
|
spans_to_send = []
|
1075
1086
|
evaluation_runs_to_send = []
|
1076
|
-
|
1087
|
+
|
1077
1088
|
for item in batch:
|
1078
|
-
if item[
|
1079
|
-
spans_to_send.append(item[
|
1080
|
-
elif item[
|
1081
|
-
evaluation_runs_to_send.append(item[
|
1082
|
-
|
1089
|
+
if item["type"] == "span":
|
1090
|
+
spans_to_send.append(item["data"])
|
1091
|
+
elif item["type"] == "evaluation_run":
|
1092
|
+
evaluation_runs_to_send.append(item["data"])
|
1093
|
+
|
1083
1094
|
# Send spans if any
|
1084
1095
|
if spans_to_send:
|
1085
1096
|
self._send_spans_batch(spans_to_send)
|
1086
|
-
|
1097
|
+
|
1087
1098
|
# Send evaluation runs if any
|
1088
1099
|
if evaluation_runs_to_send:
|
1089
1100
|
self._send_evaluation_runs_batch(evaluation_runs_to_send)
|
1090
|
-
|
1101
|
+
|
1091
1102
|
except Exception as e:
|
1092
|
-
warnings.warn(f"Failed to send
|
1093
|
-
|
1103
|
+
warnings.warn(f"Failed to send batch: {e}")
|
1104
|
+
|
1094
1105
|
def _send_spans_batch(self, spans: List[Dict[str, Any]]):
|
1095
1106
|
"""Send a batch of spans to the spans endpoint."""
|
1096
|
-
payload = {
|
1097
|
-
|
1098
|
-
"organization_id": self.organization_id
|
1099
|
-
}
|
1100
|
-
|
1107
|
+
payload = {"spans": spans, "organization_id": self.organization_id}
|
1108
|
+
|
1101
1109
|
# Serialize with fallback encoder
|
1102
1110
|
def fallback_encoder(obj):
|
1103
1111
|
try:
|
@@ -1107,10 +1115,10 @@ class BackgroundSpanService:
|
|
1107
1115
|
return str(obj)
|
1108
1116
|
except Exception as e:
|
1109
1117
|
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
1110
|
-
|
1118
|
+
|
1111
1119
|
try:
|
1112
1120
|
serialized_data = json.dumps(payload, default=fallback_encoder)
|
1113
|
-
|
1121
|
+
|
1114
1122
|
# Send the actual HTTP request to the batch endpoint
|
1115
1123
|
response = requests.post(
|
1116
1124
|
JUDGMENT_TRACES_SPANS_BATCH_API_URL,
|
@@ -1118,21 +1126,22 @@ class BackgroundSpanService:
|
|
1118
1126
|
headers={
|
1119
1127
|
"Content-Type": "application/json",
|
1120
1128
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
1121
|
-
"X-Organization-Id": self.organization_id
|
1129
|
+
"X-Organization-Id": self.organization_id,
|
1122
1130
|
},
|
1123
1131
|
verify=True,
|
1124
|
-
timeout=30 # Add timeout to prevent hanging
|
1132
|
+
timeout=30, # Add timeout to prevent hanging
|
1125
1133
|
)
|
1126
|
-
|
1134
|
+
|
1127
1135
|
if response.status_code != HTTPStatus.OK:
|
1128
|
-
warnings.warn(
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1136
|
+
warnings.warn(
|
1137
|
+
f"Failed to send spans batch: HTTP {response.status_code} - {response.text}"
|
1138
|
+
)
|
1139
|
+
|
1140
|
+
except RequestException as e:
|
1132
1141
|
warnings.warn(f"Network error sending spans batch: {e}")
|
1133
1142
|
except Exception as e:
|
1134
1143
|
warnings.warn(f"Failed to serialize or send spans batch: {e}")
|
1135
|
-
|
1144
|
+
|
1136
1145
|
def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
|
1137
1146
|
"""Send a batch of evaluation runs with their associated span data to the endpoint."""
|
1138
1147
|
# Structure payload to include both evaluation run data and span data
|
@@ -1142,22 +1151,23 @@ class BackgroundSpanService:
|
|
1142
1151
|
entry = {
|
1143
1152
|
"evaluation_run": {
|
1144
1153
|
# Extract evaluation run fields (excluding span-specific fields)
|
1145
|
-
key: value
|
1146
|
-
|
1154
|
+
key: value
|
1155
|
+
for key, value in eval_data.items()
|
1156
|
+
if key not in ["associated_span_id", "span_data", "queued_at"]
|
1147
1157
|
},
|
1148
1158
|
"associated_span": {
|
1149
|
-
"span_id": eval_data.get(
|
1150
|
-
"span_data": eval_data.get(
|
1159
|
+
"span_id": eval_data.get("associated_span_id"),
|
1160
|
+
"span_data": eval_data.get("span_data"),
|
1151
1161
|
},
|
1152
|
-
"queued_at": eval_data.get(
|
1162
|
+
"queued_at": eval_data.get("queued_at"),
|
1153
1163
|
}
|
1154
1164
|
evaluation_entries.append(entry)
|
1155
|
-
|
1165
|
+
|
1156
1166
|
payload = {
|
1157
1167
|
"organization_id": self.organization_id,
|
1158
|
-
"evaluation_entries": evaluation_entries # Each entry contains both eval run + span data
|
1168
|
+
"evaluation_entries": evaluation_entries, # Each entry contains both eval run + span data
|
1159
1169
|
}
|
1160
|
-
|
1170
|
+
|
1161
1171
|
# Serialize with fallback encoder
|
1162
1172
|
def fallback_encoder(obj):
|
1163
1173
|
try:
|
@@ -1167,10 +1177,10 @@ class BackgroundSpanService:
|
|
1167
1177
|
return str(obj)
|
1168
1178
|
except Exception as e:
|
1169
1179
|
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
1170
|
-
|
1180
|
+
|
1171
1181
|
try:
|
1172
1182
|
serialized_data = json.dumps(payload, default=fallback_encoder)
|
1173
|
-
|
1183
|
+
|
1174
1184
|
# Send the actual HTTP request to the batch endpoint
|
1175
1185
|
response = requests.post(
|
1176
1186
|
JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
|
@@ -1178,25 +1188,26 @@ class BackgroundSpanService:
|
|
1178
1188
|
headers={
|
1179
1189
|
"Content-Type": "application/json",
|
1180
1190
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
1181
|
-
"X-Organization-Id": self.organization_id
|
1191
|
+
"X-Organization-Id": self.organization_id,
|
1182
1192
|
},
|
1183
1193
|
verify=True,
|
1184
|
-
timeout=30 # Add timeout to prevent hanging
|
1194
|
+
timeout=30, # Add timeout to prevent hanging
|
1185
1195
|
)
|
1186
|
-
|
1196
|
+
|
1187
1197
|
if response.status_code != HTTPStatus.OK:
|
1188
|
-
warnings.warn(
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1198
|
+
warnings.warn(
|
1199
|
+
f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}"
|
1200
|
+
)
|
1201
|
+
|
1202
|
+
except RequestException as e:
|
1192
1203
|
warnings.warn(f"Network error sending evaluation runs batch: {e}")
|
1193
1204
|
except Exception as e:
|
1194
1205
|
warnings.warn(f"Failed to send evaluation runs batch: {e}")
|
1195
|
-
|
1206
|
+
|
1196
1207
|
def queue_span(self, span: TraceSpan, span_state: str = "input"):
|
1197
1208
|
"""
|
1198
1209
|
Queue a span for background sending.
|
1199
|
-
|
1210
|
+
|
1200
1211
|
Args:
|
1201
1212
|
span: The TraceSpan object to queue
|
1202
1213
|
span_state: State of the span ("input", "output", "completed")
|
@@ -1207,15 +1218,17 @@ class BackgroundSpanService:
|
|
1207
1218
|
"data": {
|
1208
1219
|
**span.model_dump(),
|
1209
1220
|
"span_state": span_state,
|
1210
|
-
"queued_at": time.time()
|
1211
|
-
}
|
1221
|
+
"queued_at": time.time(),
|
1222
|
+
},
|
1212
1223
|
}
|
1213
1224
|
self._span_queue.put(span_data)
|
1214
|
-
|
1215
|
-
def queue_evaluation_run(
|
1225
|
+
|
1226
|
+
def queue_evaluation_run(
|
1227
|
+
self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan
|
1228
|
+
):
|
1216
1229
|
"""
|
1217
1230
|
Queue an evaluation run for background sending.
|
1218
|
-
|
1231
|
+
|
1219
1232
|
Args:
|
1220
1233
|
evaluation_run: The EvaluationRun object to queue
|
1221
1234
|
span_id: The span ID associated with this evaluation run
|
@@ -1228,11 +1241,11 @@ class BackgroundSpanService:
|
|
1228
1241
|
**evaluation_run.model_dump(),
|
1229
1242
|
"associated_span_id": span_id,
|
1230
1243
|
"span_data": span_data.model_dump(), # Include span data to avoid race conditions
|
1231
|
-
"queued_at": time.time()
|
1232
|
-
}
|
1244
|
+
"queued_at": time.time(),
|
1245
|
+
},
|
1233
1246
|
}
|
1234
1247
|
self._span_queue.put(eval_data)
|
1235
|
-
|
1248
|
+
|
1236
1249
|
def flush(self):
|
1237
1250
|
"""Force immediate sending of all queued spans."""
|
1238
1251
|
try:
|
@@ -1240,89 +1253,101 @@ class BackgroundSpanService:
|
|
1240
1253
|
self._span_queue.join()
|
1241
1254
|
except Exception as e:
|
1242
1255
|
warnings.warn(f"Error during flush: {e}")
|
1243
|
-
|
1256
|
+
|
1244
1257
|
def shutdown(self):
|
1245
1258
|
"""Shutdown the background service and flush remaining spans."""
|
1246
1259
|
if self._shutdown_event.is_set():
|
1247
1260
|
return
|
1248
|
-
|
1261
|
+
|
1249
1262
|
try:
|
1250
1263
|
# Signal shutdown to stop new items from being queued
|
1251
1264
|
self._shutdown_event.set()
|
1252
|
-
|
1265
|
+
|
1253
1266
|
# Try to flush any remaining spans
|
1254
1267
|
try:
|
1255
1268
|
self.flush()
|
1256
1269
|
except Exception as e:
|
1257
|
-
warnings.warn(f"Error during final flush: {e}")
|
1270
|
+
warnings.warn(f"Error during final flush: {e}")
|
1258
1271
|
except Exception as e:
|
1259
1272
|
warnings.warn(f"Error during BackgroundSpanService shutdown: {e}")
|
1260
1273
|
finally:
|
1261
1274
|
# Clear the worker threads list (daemon threads will be killed automatically)
|
1262
1275
|
self._worker_threads.clear()
|
1263
|
-
|
1276
|
+
|
1264
1277
|
def get_queue_size(self) -> int:
|
1265
1278
|
"""Get the current size of the span queue."""
|
1266
1279
|
return self._span_queue.qsize()
|
1267
1280
|
|
1281
|
+
|
1268
1282
|
class _DeepTracer:
|
1269
1283
|
_instance: Optional["_DeepTracer"] = None
|
1270
1284
|
_lock: threading.Lock = threading.Lock()
|
1271
1285
|
_refcount: int = 0
|
1272
|
-
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
|
1273
|
-
|
1286
|
+
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
|
1287
|
+
"_deep_profiler_span_stack", default=[]
|
1288
|
+
)
|
1289
|
+
_skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar(
|
1290
|
+
"_deep_profiler_skip_stack", default=[]
|
1291
|
+
)
|
1274
1292
|
_original_sys_trace: Optional[Callable] = None
|
1275
1293
|
_original_threading_trace: Optional[Callable] = None
|
1276
1294
|
|
1277
|
-
def __init__(self, tracer:
|
1295
|
+
def __init__(self, tracer: "Tracer"):
|
1278
1296
|
self._tracer = tracer
|
1279
1297
|
|
1280
1298
|
def _get_qual_name(self, frame) -> str:
|
1281
1299
|
func_name = frame.f_code.co_name
|
1282
1300
|
module_name = frame.f_globals.get("__name__", "unknown_module")
|
1283
|
-
|
1301
|
+
|
1284
1302
|
try:
|
1285
1303
|
func = frame.f_globals.get(func_name)
|
1286
1304
|
if func is None:
|
1287
1305
|
return f"{module_name}.{func_name}"
|
1288
1306
|
if hasattr(func, "__qualname__"):
|
1289
|
-
|
1307
|
+
return f"{module_name}.{func.__qualname__}"
|
1308
|
+
return f"{module_name}.{func_name}"
|
1290
1309
|
except Exception:
|
1291
1310
|
return f"{module_name}.{func_name}"
|
1292
|
-
|
1293
|
-
def __new__(cls, tracer:
|
1311
|
+
|
1312
|
+
def __new__(cls, tracer: "Tracer"):
|
1294
1313
|
with cls._lock:
|
1295
1314
|
if cls._instance is None:
|
1296
1315
|
cls._instance = super().__new__(cls)
|
1297
1316
|
return cls._instance
|
1298
|
-
|
1317
|
+
|
1299
1318
|
def _should_trace(self, frame):
|
1300
1319
|
# Skip stack is maintained by the tracer as an optimization to skip earlier
|
1301
1320
|
# frames in the call stack that we've already determined should be skipped
|
1302
1321
|
skip_stack = self._skip_stack.get()
|
1303
1322
|
if len(skip_stack) > 0:
|
1304
1323
|
return False
|
1305
|
-
|
1324
|
+
|
1306
1325
|
func_name = frame.f_code.co_name
|
1307
1326
|
module_name = frame.f_globals.get("__name__", None)
|
1308
|
-
|
1309
1327
|
func = frame.f_globals.get(func_name)
|
1310
|
-
if func and (
|
1328
|
+
if func and (
|
1329
|
+
hasattr(func, "_judgment_span_name") or hasattr(func, "_judgment_span_type")
|
1330
|
+
):
|
1311
1331
|
return False
|
1312
1332
|
|
1313
1333
|
if (
|
1314
1334
|
not module_name
|
1315
|
-
or func_name.startswith("<")
|
1316
|
-
or func_name.startswith("__")
|
1335
|
+
or func_name.startswith("<") # ex: <listcomp>
|
1336
|
+
or func_name.startswith("__")
|
1337
|
+
and func_name != "__call__" # dunders
|
1317
1338
|
or not self._is_user_code(frame.f_code.co_filename)
|
1318
1339
|
):
|
1319
1340
|
return False
|
1320
|
-
|
1341
|
+
|
1321
1342
|
return True
|
1322
|
-
|
1343
|
+
|
1323
1344
|
@functools.cache
|
1324
1345
|
def _is_user_code(self, filename: str):
|
1325
|
-
return
|
1346
|
+
return (
|
1347
|
+
bool(filename)
|
1348
|
+
and not filename.startswith("<")
|
1349
|
+
and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
|
1350
|
+
)
|
1326
1351
|
|
1327
1352
|
def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
|
1328
1353
|
"""Cooperative trace function for sys.settrace that chains with existing tracers."""
|
@@ -1334,18 +1359,20 @@ class _DeepTracer:
|
|
1334
1359
|
except Exception:
|
1335
1360
|
# If the original tracer fails, continue with our tracing
|
1336
1361
|
pass
|
1337
|
-
|
1362
|
+
|
1338
1363
|
# Then do our own tracing
|
1339
1364
|
our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
|
1340
|
-
|
1365
|
+
|
1341
1366
|
# Return our tracer to continue tracing, but respect the original's decision
|
1342
1367
|
# If the original tracer returned None (stop tracing), we should respect that
|
1343
1368
|
if original_result is None and self._original_sys_trace:
|
1344
1369
|
return None
|
1345
|
-
|
1370
|
+
|
1346
1371
|
return our_result or original_result
|
1347
|
-
|
1348
|
-
def _cooperative_threading_trace(
|
1372
|
+
|
1373
|
+
def _cooperative_threading_trace(
|
1374
|
+
self, frame: types.FrameType, event: str, arg: Any
|
1375
|
+
):
|
1349
1376
|
"""Cooperative trace function for threading.settrace that chains with existing tracers."""
|
1350
1377
|
# First, call the original threading trace function if it exists
|
1351
1378
|
original_result = None
|
@@ -1355,44 +1382,48 @@ class _DeepTracer:
|
|
1355
1382
|
except Exception:
|
1356
1383
|
# If the original tracer fails, continue with our tracing
|
1357
1384
|
pass
|
1358
|
-
|
1385
|
+
|
1359
1386
|
# Then do our own tracing
|
1360
1387
|
our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
|
1361
|
-
|
1388
|
+
|
1362
1389
|
# Return our tracer to continue tracing, but respect the original's decision
|
1363
1390
|
# If the original tracer returned None (stop tracing), we should respect that
|
1364
1391
|
if original_result is None and self._original_threading_trace:
|
1365
1392
|
return None
|
1366
|
-
|
1393
|
+
|
1367
1394
|
return our_result or original_result
|
1368
|
-
|
1369
|
-
def _trace(
|
1395
|
+
|
1396
|
+
def _trace(
|
1397
|
+
self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable
|
1398
|
+
):
|
1370
1399
|
frame.f_trace_lines = False
|
1371
1400
|
frame.f_trace_opcodes = False
|
1372
1401
|
|
1373
1402
|
if not self._should_trace(frame):
|
1374
1403
|
return
|
1375
|
-
|
1404
|
+
|
1376
1405
|
if event not in ("call", "return", "exception"):
|
1377
1406
|
return
|
1378
|
-
|
1407
|
+
|
1379
1408
|
current_trace = self._tracer.get_current_trace()
|
1380
1409
|
if not current_trace:
|
1381
1410
|
return
|
1382
|
-
|
1411
|
+
|
1383
1412
|
parent_span_id = self._tracer.get_current_span()
|
1384
1413
|
if not parent_span_id:
|
1385
1414
|
return
|
1386
1415
|
|
1387
1416
|
qual_name = self._get_qual_name(frame)
|
1388
1417
|
instance_name = None
|
1389
|
-
if
|
1390
|
-
instance = frame.f_locals[
|
1418
|
+
if "self" in frame.f_locals:
|
1419
|
+
instance = frame.f_locals["self"]
|
1391
1420
|
class_name = instance.__class__.__name__
|
1392
|
-
class_identifiers = getattr(
|
1393
|
-
instance_name = get_instance_prefixed_name(
|
1421
|
+
class_identifiers = getattr(self._tracer, "class_identifiers", {})
|
1422
|
+
instance_name = get_instance_prefixed_name(
|
1423
|
+
instance, class_name, class_identifiers
|
1424
|
+
)
|
1394
1425
|
skip_stack = self._skip_stack.get()
|
1395
|
-
|
1426
|
+
|
1396
1427
|
if event == "call":
|
1397
1428
|
# If we have entries in the skip stack and the current qual_name matches the top entry,
|
1398
1429
|
# push it again to track nesting depth and skip
|
@@ -1402,9 +1433,9 @@ class _DeepTracer:
|
|
1402
1433
|
skip_stack.append(qual_name)
|
1403
1434
|
self._skip_stack.set(skip_stack)
|
1404
1435
|
return
|
1405
|
-
|
1436
|
+
|
1406
1437
|
should_trace = self._should_trace(frame)
|
1407
|
-
|
1438
|
+
|
1408
1439
|
if not should_trace:
|
1409
1440
|
if not skip_stack:
|
1410
1441
|
self._skip_stack.set([qual_name])
|
@@ -1416,35 +1447,37 @@ class _DeepTracer:
|
|
1416
1447
|
skip_stack.pop()
|
1417
1448
|
self._skip_stack.set(skip_stack)
|
1418
1449
|
return
|
1419
|
-
|
1450
|
+
|
1420
1451
|
if skip_stack:
|
1421
1452
|
return
|
1422
|
-
|
1453
|
+
|
1423
1454
|
span_stack = self._span_stack.get()
|
1424
1455
|
if event == "call":
|
1425
1456
|
if not self._should_trace(frame):
|
1426
1457
|
return
|
1427
|
-
|
1458
|
+
|
1428
1459
|
span_id = str(uuid.uuid4())
|
1429
|
-
|
1460
|
+
|
1430
1461
|
parent_depth = current_trace._span_depths.get(parent_span_id, 0)
|
1431
1462
|
depth = parent_depth + 1
|
1432
|
-
|
1463
|
+
|
1433
1464
|
current_trace._span_depths[span_id] = depth
|
1434
|
-
|
1465
|
+
|
1435
1466
|
start_time = time.time()
|
1436
|
-
|
1437
|
-
span_stack.append(
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
|
1467
|
+
|
1468
|
+
span_stack.append(
|
1469
|
+
{
|
1470
|
+
"span_id": span_id,
|
1471
|
+
"parent_span_id": parent_span_id,
|
1472
|
+
"function": qual_name,
|
1473
|
+
"start_time": start_time,
|
1474
|
+
}
|
1475
|
+
)
|
1443
1476
|
self._span_stack.set(span_stack)
|
1444
|
-
|
1477
|
+
|
1445
1478
|
token = self._tracer.set_current_span(span_id)
|
1446
1479
|
frame.f_locals["_judgment_span_token"] = token
|
1447
|
-
|
1480
|
+
|
1448
1481
|
span = TraceSpan(
|
1449
1482
|
span_id=span_id,
|
1450
1483
|
trace_id=current_trace.trace_id,
|
@@ -1454,57 +1487,55 @@ class _DeepTracer:
|
|
1454
1487
|
span_type="span",
|
1455
1488
|
parent_span_id=parent_span_id,
|
1456
1489
|
function=qual_name,
|
1457
|
-
agent_name=instance_name
|
1490
|
+
agent_name=instance_name,
|
1458
1491
|
)
|
1459
1492
|
current_trace.add_span(span)
|
1460
|
-
|
1493
|
+
|
1461
1494
|
inputs = {}
|
1462
1495
|
try:
|
1463
1496
|
args_info = inspect.getargvalues(frame)
|
1464
1497
|
for arg in args_info.args:
|
1465
1498
|
try:
|
1466
1499
|
inputs[arg] = args_info.locals.get(arg)
|
1467
|
-
except:
|
1500
|
+
except Exception:
|
1468
1501
|
inputs[arg] = "<<Unserializable>>"
|
1469
1502
|
current_trace.record_input(inputs)
|
1470
1503
|
except Exception as e:
|
1471
|
-
current_trace.record_input({
|
1472
|
-
|
1473
|
-
})
|
1474
|
-
|
1504
|
+
current_trace.record_input({"error": str(e)})
|
1505
|
+
|
1475
1506
|
elif event == "return":
|
1476
1507
|
if not span_stack:
|
1477
1508
|
return
|
1478
|
-
|
1509
|
+
|
1479
1510
|
current_id = self._tracer.get_current_span()
|
1480
|
-
|
1511
|
+
|
1481
1512
|
span_data = None
|
1482
1513
|
for i, entry in enumerate(reversed(span_stack)):
|
1483
1514
|
if entry["span_id"] == current_id:
|
1484
|
-
span_data = span_stack.pop(-(i+1))
|
1515
|
+
span_data = span_stack.pop(-(i + 1))
|
1485
1516
|
self._span_stack.set(span_stack)
|
1486
1517
|
break
|
1487
|
-
|
1518
|
+
|
1488
1519
|
if not span_data:
|
1489
1520
|
return
|
1490
|
-
|
1521
|
+
|
1491
1522
|
start_time = span_data["start_time"]
|
1492
1523
|
duration = time.time() - start_time
|
1493
|
-
|
1524
|
+
|
1494
1525
|
current_trace.span_id_to_span[span_data["span_id"]].duration = duration
|
1495
1526
|
|
1496
1527
|
if arg is not None:
|
1497
|
-
# exception handling will take priority.
|
1528
|
+
# exception handling will take priority.
|
1498
1529
|
current_trace.record_output(arg)
|
1499
|
-
|
1530
|
+
|
1500
1531
|
if span_data["span_id"] in current_trace._span_depths:
|
1501
1532
|
del current_trace._span_depths[span_data["span_id"]]
|
1502
|
-
|
1533
|
+
|
1503
1534
|
if span_stack:
|
1504
1535
|
self._tracer.set_current_span(span_stack[-1]["span_id"])
|
1505
1536
|
else:
|
1506
1537
|
self._tracer.set_current_span(span_data["parent_span_id"])
|
1507
|
-
|
1538
|
+
|
1508
1539
|
if "_judgment_span_token" in frame.f_locals:
|
1509
1540
|
self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
|
1510
1541
|
|
@@ -1513,10 +1544,9 @@ class _DeepTracer:
|
|
1513
1544
|
if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
|
1514
1545
|
return
|
1515
1546
|
_capture_exception_for_trace(current_trace, arg)
|
1516
|
-
|
1517
|
-
|
1547
|
+
|
1518
1548
|
return continuation_func
|
1519
|
-
|
1549
|
+
|
1520
1550
|
def __enter__(self):
|
1521
1551
|
with self._lock:
|
1522
1552
|
self._refcount += 1
|
@@ -1524,14 +1554,14 @@ class _DeepTracer:
|
|
1524
1554
|
# Store the existing trace functions before setting ours
|
1525
1555
|
self._original_sys_trace = sys.gettrace()
|
1526
1556
|
self._original_threading_trace = threading.gettrace()
|
1527
|
-
|
1557
|
+
|
1528
1558
|
self._skip_stack.set([])
|
1529
1559
|
self._span_stack.set([])
|
1530
|
-
|
1560
|
+
|
1531
1561
|
sys.settrace(self._cooperative_sys_trace)
|
1532
1562
|
threading.settrace(self._cooperative_threading_trace)
|
1533
1563
|
return self
|
1534
|
-
|
1564
|
+
|
1535
1565
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
1536
1566
|
with self._lock:
|
1537
1567
|
self._refcount -= 1
|
@@ -1539,7 +1569,7 @@ class _DeepTracer:
|
|
1539
1569
|
# Restore the original trace functions instead of setting to None
|
1540
1570
|
sys.settrace(self._original_sys_trace)
|
1541
1571
|
threading.settrace(self._original_threading_trace)
|
1542
|
-
|
1572
|
+
|
1543
1573
|
# Clean up the references
|
1544
1574
|
self._original_sys_trace = None
|
1545
1575
|
self._original_threading_trace = None
|
@@ -1547,7 +1577,7 @@ class _DeepTracer:
|
|
1547
1577
|
|
1548
1578
|
# Below commented out function isn't used anymore?
|
1549
1579
|
|
1550
|
-
# def log(self, message: str, level: str = "info"):
|
1580
|
+
# def log(self, message: str, level: str = "info"):
|
1551
1581
|
# """ Log a message with the span context """
|
1552
1582
|
# current_trace = self._tracer.get_current_trace()
|
1553
1583
|
# if current_trace:
|
@@ -1555,25 +1585,29 @@ class _DeepTracer:
|
|
1555
1585
|
# else:
|
1556
1586
|
# print(f"[{level}] {message}")
|
1557
1587
|
# current_trace.record_output({"log": message})
|
1558
|
-
|
1559
|
-
class Tracer:
|
1560
1588
|
|
1589
|
+
|
1590
|
+
class Tracer:
|
1561
1591
|
# Tracer.current_trace class variable is currently used in wrap()
|
1562
1592
|
# TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
|
1563
1593
|
# Should be fine to do so as long as we keep Tracer as a singleton
|
1564
1594
|
current_trace: Optional[TraceClient] = None
|
1565
1595
|
# current_span_id: Optional[str] = None
|
1566
1596
|
|
1567
|
-
trace_across_async_contexts: bool =
|
1597
|
+
trace_across_async_contexts: bool = (
|
1598
|
+
False # BY default, we don't trace across async contexts
|
1599
|
+
)
|
1568
1600
|
|
1569
1601
|
def __init__(
|
1570
|
-
self,
|
1571
|
-
api_key: str = os.getenv("JUDGMENT_API_KEY"),
|
1572
|
-
project_name: str = None,
|
1602
|
+
self,
|
1603
|
+
api_key: str | None = os.getenv("JUDGMENT_API_KEY"),
|
1604
|
+
project_name: str | None = None,
|
1573
1605
|
rules: Optional[List[Rule]] = None, # Added rules parameter
|
1574
|
-
organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
|
1575
|
-
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
|
1576
|
-
|
1606
|
+
organization_id: str | None = os.getenv("JUDGMENT_ORG_ID"),
|
1607
|
+
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
|
1608
|
+
== "true",
|
1609
|
+
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
|
1610
|
+
== "true",
|
1577
1611
|
# S3 configuration
|
1578
1612
|
use_s3: bool = False,
|
1579
1613
|
s3_bucket_name: Optional[str] = None,
|
@@ -1581,26 +1615,32 @@ class Tracer:
|
|
1581
1615
|
s3_aws_secret_access_key: Optional[str] = None,
|
1582
1616
|
s3_region_name: Optional[str] = None,
|
1583
1617
|
offline_mode: bool = False,
|
1584
|
-
deep_tracing: bool =
|
1585
|
-
trace_across_async_contexts: bool = False,
|
1618
|
+
deep_tracing: bool = False, # Deep tracing is disabled by default
|
1619
|
+
trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
|
1586
1620
|
# Background span service configuration
|
1587
1621
|
enable_background_spans: bool = True, # Enable background span service by default
|
1588
1622
|
span_batch_size: int = 50, # Number of spans to batch before sending
|
1589
1623
|
span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
|
1590
|
-
span_num_workers: int = 10 # Number of worker threads for span processing
|
1591
|
-
|
1624
|
+
span_num_workers: int = 10, # Number of worker threads for span processing
|
1625
|
+
):
|
1592
1626
|
if not api_key:
|
1593
1627
|
raise ValueError("Tracer must be configured with a Judgment API key")
|
1594
|
-
|
1595
|
-
|
1628
|
+
|
1629
|
+
try:
|
1630
|
+
result, response = validate_api_key(api_key)
|
1631
|
+
except Exception as e:
|
1632
|
+
print(f"Issue with verifying API key, disabling monitoring: {e}")
|
1633
|
+
enable_monitoring = False
|
1634
|
+
result = True
|
1635
|
+
|
1596
1636
|
if not result:
|
1597
1637
|
raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
|
1598
|
-
|
1638
|
+
|
1599
1639
|
if not organization_id:
|
1600
1640
|
raise ValueError("Tracer must be configured with an Organization ID")
|
1601
1641
|
if use_s3 and not s3_bucket_name:
|
1602
1642
|
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
1603
|
-
|
1643
|
+
|
1604
1644
|
self.api_key: str = api_key
|
1605
1645
|
self.project_name: str = project_name or str(uuid.uuid4())
|
1606
1646
|
self.organization_id: str = organization_id
|
@@ -1608,9 +1648,11 @@ class Tracer:
|
|
1608
1648
|
self.traces: List[Trace] = []
|
1609
1649
|
self.enable_monitoring: bool = enable_monitoring
|
1610
1650
|
self.enable_evaluations: bool = enable_evaluations
|
1611
|
-
self.class_identifiers: Dict[
|
1612
|
-
|
1613
|
-
|
1651
|
+
self.class_identifiers: Dict[
|
1652
|
+
str, str
|
1653
|
+
] = {} # Dictionary to store class identifiers
|
1654
|
+
self.span_id_to_previous_span_id: Dict[str, str | None] = {}
|
1655
|
+
self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
|
1614
1656
|
self.current_span_id: Optional[str] = None
|
1615
1657
|
self.current_trace: Optional[TraceClient] = None
|
1616
1658
|
self.trace_across_async_contexts: bool = trace_across_async_contexts
|
@@ -1620,15 +1662,21 @@ class Tracer:
|
|
1620
1662
|
self.use_s3 = use_s3
|
1621
1663
|
if use_s3:
|
1622
1664
|
from judgeval.common.s3_storage import S3Storage
|
1623
|
-
|
1624
|
-
|
1625
|
-
|
1626
|
-
|
1627
|
-
|
1628
|
-
|
1665
|
+
|
1666
|
+
try:
|
1667
|
+
self.s3_storage = S3Storage(
|
1668
|
+
bucket_name=s3_bucket_name,
|
1669
|
+
aws_access_key_id=s3_aws_access_key_id,
|
1670
|
+
aws_secret_access_key=s3_aws_secret_access_key,
|
1671
|
+
region_name=s3_region_name,
|
1672
|
+
)
|
1673
|
+
except Exception as e:
|
1674
|
+
print(f"Issue with initializing S3 storage, disabling S3: {e}")
|
1675
|
+
self.use_s3 = False
|
1676
|
+
|
1629
1677
|
self.offline_mode: bool = offline_mode
|
1630
1678
|
self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
|
1631
|
-
|
1679
|
+
|
1632
1680
|
# Initialize background span service
|
1633
1681
|
self.enable_background_spans: bool = enable_background_spans
|
1634
1682
|
self.background_span_service: Optional[BackgroundSpanService] = None
|
@@ -1638,49 +1686,61 @@ class Tracer:
|
|
1638
1686
|
organization_id=organization_id,
|
1639
1687
|
batch_size=span_batch_size,
|
1640
1688
|
flush_interval=span_flush_interval,
|
1641
|
-
num_workers=span_num_workers
|
1689
|
+
num_workers=span_num_workers,
|
1642
1690
|
)
|
1643
1691
|
|
1644
|
-
def set_current_span(self, span_id: str):
|
1645
|
-
self.span_id_to_previous_span_id[span_id] =
|
1692
|
+
def set_current_span(self, span_id: str) -> Optional[contextvars.Token[str | None]]:
|
1693
|
+
self.span_id_to_previous_span_id[span_id] = self.current_span_id
|
1646
1694
|
self.current_span_id = span_id
|
1647
1695
|
Tracer.current_span_id = span_id
|
1648
1696
|
try:
|
1649
1697
|
token = current_span_var.set(span_id)
|
1650
|
-
except:
|
1698
|
+
except Exception:
|
1651
1699
|
token = None
|
1652
1700
|
return token
|
1653
|
-
|
1701
|
+
|
1654
1702
|
def get_current_span(self) -> Optional[str]:
|
1655
1703
|
try:
|
1656
1704
|
current_span_var_val = current_span_var.get()
|
1657
|
-
except:
|
1705
|
+
except Exception:
|
1658
1706
|
current_span_var_val = None
|
1659
|
-
return (
|
1660
|
-
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1707
|
+
return (
|
1708
|
+
(self.current_span_id or current_span_var_val)
|
1709
|
+
if self.trace_across_async_contexts
|
1710
|
+
else current_span_var_val
|
1711
|
+
)
|
1712
|
+
|
1713
|
+
def reset_current_span(
|
1714
|
+
self,
|
1715
|
+
token: Optional[contextvars.Token[str | None]] = None,
|
1716
|
+
span_id: Optional[str] = None,
|
1717
|
+
):
|
1664
1718
|
try:
|
1665
|
-
|
1666
|
-
|
1719
|
+
if token:
|
1720
|
+
current_span_var.reset(token)
|
1721
|
+
except Exception:
|
1667
1722
|
pass
|
1668
|
-
|
1669
|
-
|
1670
|
-
|
1671
|
-
|
1723
|
+
if not span_id:
|
1724
|
+
span_id = self.current_span_id
|
1725
|
+
if span_id:
|
1726
|
+
self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
|
1727
|
+
Tracer.current_span_id = self.current_span_id
|
1728
|
+
|
1729
|
+
def set_current_trace(
|
1730
|
+
self, trace: TraceClient
|
1731
|
+
) -> Optional[contextvars.Token[TraceClient | None]]:
|
1672
1732
|
"""
|
1673
1733
|
Set the current trace context in contextvars
|
1674
1734
|
"""
|
1675
|
-
self.trace_id_to_previous_trace[trace.trace_id] =
|
1735
|
+
self.trace_id_to_previous_trace[trace.trace_id] = self.current_trace
|
1676
1736
|
self.current_trace = trace
|
1677
1737
|
Tracer.current_trace = trace
|
1678
1738
|
try:
|
1679
1739
|
token = current_trace_var.set(trace)
|
1680
|
-
except:
|
1740
|
+
except Exception:
|
1681
1741
|
token = None
|
1682
1742
|
return token
|
1683
|
-
|
1743
|
+
|
1684
1744
|
def get_current_trace(self) -> Optional[TraceClient]:
|
1685
1745
|
"""
|
1686
1746
|
Get the current trace context.
|
@@ -1691,72 +1751,69 @@ class Tracer:
|
|
1691
1751
|
"""
|
1692
1752
|
try:
|
1693
1753
|
current_trace_var_val = current_trace_var.get()
|
1694
|
-
except:
|
1754
|
+
except Exception:
|
1695
1755
|
current_trace_var_val = None
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
1699
|
-
|
1700
|
-
|
1701
|
-
|
1702
|
-
|
1703
|
-
|
1704
|
-
|
1705
|
-
|
1706
|
-
|
1707
|
-
|
1708
|
-
# If neither is available, return None
|
1709
|
-
return None
|
1710
|
-
|
1711
|
-
def reset_current_trace(self, token: Optional[str] = None, trace_id: Optional[str] = None):
|
1712
|
-
if not trace_id and self.current_trace:
|
1713
|
-
trace_id = self.current_trace.trace_id
|
1756
|
+
return (
|
1757
|
+
(self.current_trace or current_trace_var_val)
|
1758
|
+
if self.trace_across_async_contexts
|
1759
|
+
else current_trace_var_val
|
1760
|
+
)
|
1761
|
+
|
1762
|
+
def reset_current_trace(
|
1763
|
+
self,
|
1764
|
+
token: Optional[contextvars.Token[TraceClient | None]] = None,
|
1765
|
+
trace_id: Optional[str] = None,
|
1766
|
+
):
|
1714
1767
|
try:
|
1715
|
-
|
1716
|
-
|
1768
|
+
if token:
|
1769
|
+
current_trace_var.reset(token)
|
1770
|
+
except Exception:
|
1717
1771
|
pass
|
1718
|
-
|
1719
|
-
|
1772
|
+
if not trace_id and self.current_trace:
|
1773
|
+
trace_id = self.current_trace.trace_id
|
1774
|
+
if trace_id:
|
1775
|
+
self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
|
1776
|
+
Tracer.current_trace = self.current_trace
|
1720
1777
|
|
1721
1778
|
@contextmanager
|
1722
1779
|
def trace(
|
1723
|
-
self,
|
1724
|
-
name: str,
|
1725
|
-
project_name: str = None,
|
1780
|
+
self,
|
1781
|
+
name: str,
|
1782
|
+
project_name: str | None = None,
|
1726
1783
|
overwrite: bool = False,
|
1727
|
-
rules: Optional[List[Rule]] = None # Added rules parameter
|
1784
|
+
rules: Optional[List[Rule]] = None, # Added rules parameter
|
1728
1785
|
) -> Generator[TraceClient, None, None]:
|
1729
1786
|
"""Start a new trace context using a context manager"""
|
1730
1787
|
trace_id = str(uuid.uuid4())
|
1731
1788
|
project = project_name if project_name is not None else self.project_name
|
1732
|
-
|
1789
|
+
|
1733
1790
|
# Get parent trace info from context
|
1734
1791
|
parent_trace = self.get_current_trace()
|
1735
1792
|
parent_trace_id = None
|
1736
1793
|
parent_name = None
|
1737
|
-
|
1794
|
+
|
1738
1795
|
if parent_trace:
|
1739
1796
|
parent_trace_id = parent_trace.trace_id
|
1740
1797
|
parent_name = parent_trace.name
|
1741
1798
|
|
1742
1799
|
trace = TraceClient(
|
1743
|
-
self,
|
1744
|
-
trace_id,
|
1745
|
-
name,
|
1746
|
-
project_name=project,
|
1800
|
+
self,
|
1801
|
+
trace_id,
|
1802
|
+
name,
|
1803
|
+
project_name=project,
|
1747
1804
|
overwrite=overwrite,
|
1748
1805
|
rules=self.rules, # Pass combined rules to the trace client
|
1749
1806
|
enable_monitoring=self.enable_monitoring,
|
1750
1807
|
enable_evaluations=self.enable_evaluations,
|
1751
1808
|
parent_trace_id=parent_trace_id,
|
1752
|
-
parent_name=parent_name
|
1809
|
+
parent_name=parent_name,
|
1753
1810
|
)
|
1754
|
-
|
1811
|
+
|
1755
1812
|
# Set the current trace in context variables
|
1756
1813
|
token = self.set_current_trace(trace)
|
1757
|
-
|
1814
|
+
|
1758
1815
|
# Automatically create top-level span
|
1759
|
-
with trace.span(name or "unnamed_trace")
|
1816
|
+
with trace.span(name or "unnamed_trace"):
|
1760
1817
|
try:
|
1761
1818
|
# Save the trace to the database to handle Evaluations' trace_id referential integrity
|
1762
1819
|
yield trace
|
@@ -1764,101 +1821,110 @@ class Tracer:
|
|
1764
1821
|
# Reset the context variable
|
1765
1822
|
self.reset_current_trace(token)
|
1766
1823
|
|
1767
|
-
|
1768
1824
|
def log(self, msg: str, label: str = "log", score: int = 1):
|
1769
1825
|
"""Log a message with the current span context"""
|
1770
1826
|
current_span_id = self.get_current_span()
|
1771
1827
|
current_trace = self.get_current_trace()
|
1772
|
-
if current_span_id:
|
1828
|
+
if current_span_id and current_trace:
|
1773
1829
|
annotation = TraceAnnotation(
|
1774
|
-
span_id=current_span_id,
|
1775
|
-
text=msg,
|
1776
|
-
label=label,
|
1777
|
-
score=score
|
1830
|
+
span_id=current_span_id, text=msg, label=label, score=score
|
1778
1831
|
)
|
1779
|
-
|
1780
1832
|
current_trace.add_annotation(annotation)
|
1781
1833
|
|
1782
1834
|
rprint(f"[bold]{label}:[/bold] {msg}")
|
1783
|
-
|
1784
|
-
def identify(
|
1835
|
+
|
1836
|
+
def identify(
|
1837
|
+
self,
|
1838
|
+
identifier: str,
|
1839
|
+
track_state: bool = False,
|
1840
|
+
track_attributes: Optional[List[str]] = None,
|
1841
|
+
field_mappings: Optional[Dict[str, str]] = None,
|
1842
|
+
):
|
1785
1843
|
"""
|
1786
1844
|
Class decorator that associates a class with a custom identifier and enables state tracking.
|
1787
|
-
|
1845
|
+
|
1788
1846
|
This decorator creates a mapping between the class name and the provided
|
1789
1847
|
identifier, which can be useful for tagging, grouping, or referencing
|
1790
1848
|
classes in a standardized way. It also enables automatic state capture
|
1791
1849
|
for instances of the decorated class when used with tracing.
|
1792
|
-
|
1850
|
+
|
1793
1851
|
Args:
|
1794
1852
|
identifier: The identifier to associate with the decorated class.
|
1795
1853
|
This will be used as the instance name in traces.
|
1796
|
-
track_state: Whether to automatically capture the state (attributes)
|
1854
|
+
track_state: Whether to automatically capture the state (attributes)
|
1797
1855
|
of instances before and after function execution. Defaults to False.
|
1798
1856
|
track_attributes: Optional list of specific attribute names to track.
|
1799
|
-
If None, all non-private attributes (not starting with '_')
|
1857
|
+
If None, all non-private attributes (not starting with '_')
|
1800
1858
|
will be tracked when track_state=True.
|
1801
|
-
field_mappings: Optional dictionary mapping internal attribute names to
|
1859
|
+
field_mappings: Optional dictionary mapping internal attribute names to
|
1802
1860
|
display names in the captured state. For example:
|
1803
|
-
{"system_prompt": "instructions"} will capture the
|
1861
|
+
{"system_prompt": "instructions"} will capture the
|
1804
1862
|
'instructions' attribute as 'system_prompt' in the state.
|
1805
|
-
|
1863
|
+
|
1806
1864
|
Example:
|
1807
1865
|
@tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
|
1808
1866
|
class User:
|
1809
1867
|
# Class implementation
|
1810
1868
|
"""
|
1869
|
+
|
1811
1870
|
def decorator(cls):
|
1812
1871
|
class_name = cls.__name__
|
1813
1872
|
self.class_identifiers[class_name] = {
|
1814
1873
|
"identifier": identifier,
|
1815
1874
|
"track_state": track_state,
|
1816
1875
|
"track_attributes": track_attributes,
|
1817
|
-
"field_mappings": field_mappings or {}
|
1876
|
+
"field_mappings": field_mappings or {},
|
1818
1877
|
}
|
1819
1878
|
return cls
|
1820
|
-
|
1879
|
+
|
1821
1880
|
return decorator
|
1822
|
-
|
1823
|
-
def _capture_instance_state(
|
1881
|
+
|
1882
|
+
def _capture_instance_state(
|
1883
|
+
self, instance: Any, class_config: Dict[str, Any]
|
1884
|
+
) -> Dict[str, Any]:
|
1824
1885
|
"""
|
1825
1886
|
Capture the state of an instance based on class configuration.
|
1826
1887
|
Args:
|
1827
1888
|
instance: The instance to capture the state of.
|
1828
|
-
class_config: Configuration dictionary for state capture,
|
1889
|
+
class_config: Configuration dictionary for state capture,
|
1829
1890
|
expected to contain 'track_attributes' and 'field_mappings'.
|
1830
1891
|
"""
|
1831
|
-
track_attributes = class_config.get(
|
1832
|
-
field_mappings = class_config.get(
|
1833
|
-
|
1892
|
+
track_attributes = class_config.get("track_attributes")
|
1893
|
+
field_mappings = class_config.get("field_mappings")
|
1894
|
+
|
1834
1895
|
if track_attributes:
|
1835
|
-
|
1836
1896
|
state = {attr: getattr(instance, attr, None) for attr in track_attributes}
|
1837
1897
|
else:
|
1838
|
-
|
1839
|
-
|
1898
|
+
state = {
|
1899
|
+
k: v for k, v in instance.__dict__.items() if not k.startswith("_")
|
1900
|
+
}
|
1840
1901
|
|
1841
1902
|
if field_mappings:
|
1842
|
-
state[
|
1903
|
+
state["field_mappings"] = field_mappings
|
1843
1904
|
|
1844
1905
|
return state
|
1845
|
-
|
1846
|
-
|
1906
|
+
|
1847
1907
|
def _get_instance_state_if_tracked(self, args):
|
1848
1908
|
"""
|
1849
1909
|
Extract instance state if the instance should be tracked.
|
1850
|
-
|
1910
|
+
|
1851
1911
|
Returns the captured state dict if tracking is enabled, None otherwise.
|
1852
1912
|
"""
|
1853
|
-
if args and hasattr(args[0],
|
1913
|
+
if args and hasattr(args[0], "__class__"):
|
1854
1914
|
instance = args[0]
|
1855
1915
|
class_name = instance.__class__.__name__
|
1856
|
-
if (
|
1857
|
-
|
1858
|
-
self.class_identifiers[class_name]
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
1916
|
+
if (
|
1917
|
+
class_name in self.class_identifiers
|
1918
|
+
and isinstance(self.class_identifiers[class_name], dict)
|
1919
|
+
and self.class_identifiers[class_name].get("track_state", False)
|
1920
|
+
):
|
1921
|
+
return self._capture_instance_state(
|
1922
|
+
instance, self.class_identifiers[class_name]
|
1923
|
+
)
|
1924
|
+
|
1925
|
+
def _conditionally_capture_and_record_state(
|
1926
|
+
self, trace_client_instance: TraceClient, args: tuple, is_before: bool
|
1927
|
+
):
|
1862
1928
|
"""Captures instance state if tracked and records it via the trace_client."""
|
1863
1929
|
state = self._get_instance_state_if_tracked(args)
|
1864
1930
|
if state:
|
@@ -1866,11 +1932,20 @@ class Tracer:
|
|
1866
1932
|
trace_client_instance.record_state_before(state)
|
1867
1933
|
else:
|
1868
1934
|
trace_client_instance.record_state_after(state)
|
1869
|
-
|
1870
|
-
def observe(
|
1935
|
+
|
1936
|
+
def observe(
|
1937
|
+
self,
|
1938
|
+
func=None,
|
1939
|
+
*,
|
1940
|
+
name=None,
|
1941
|
+
span_type: SpanType = "span",
|
1942
|
+
project_name: str | None = None,
|
1943
|
+
overwrite: bool = False,
|
1944
|
+
deep_tracing: bool | None = None,
|
1945
|
+
):
|
1871
1946
|
"""
|
1872
1947
|
Decorator to trace function execution with detailed entry/exit information.
|
1873
|
-
|
1948
|
+
|
1874
1949
|
Args:
|
1875
1950
|
func: The function to decorate
|
1876
1951
|
name: Optional custom name for the span (defaults to function name)
|
@@ -1881,56 +1956,71 @@ class Tracer:
|
|
1881
1956
|
If None, uses the tracer's default setting.
|
1882
1957
|
"""
|
1883
1958
|
# If monitoring is disabled, return the function as is
|
1884
|
-
|
1885
|
-
|
1886
|
-
|
1887
|
-
|
1888
|
-
|
1889
|
-
|
1890
|
-
|
1891
|
-
|
1892
|
-
|
1893
|
-
|
1894
|
-
|
1895
|
-
|
1896
|
-
|
1897
|
-
|
1898
|
-
|
1899
|
-
|
1900
|
-
|
1959
|
+
try:
|
1960
|
+
if not self.enable_monitoring:
|
1961
|
+
return func if func else lambda f: f
|
1962
|
+
|
1963
|
+
if func is None:
|
1964
|
+
return lambda f: self.observe(
|
1965
|
+
f,
|
1966
|
+
name=name,
|
1967
|
+
span_type=span_type,
|
1968
|
+
project_name=project_name,
|
1969
|
+
overwrite=overwrite,
|
1970
|
+
deep_tracing=deep_tracing,
|
1971
|
+
)
|
1972
|
+
|
1973
|
+
# Use provided name or fall back to function name
|
1974
|
+
original_span_name = name or func.__name__
|
1975
|
+
|
1976
|
+
# Store custom attributes on the function object
|
1977
|
+
func._judgment_span_name = original_span_name
|
1978
|
+
func._judgment_span_type = span_type
|
1979
|
+
|
1980
|
+
# Use the provided deep_tracing value or fall back to the tracer's default
|
1981
|
+
use_deep_tracing = (
|
1982
|
+
deep_tracing if deep_tracing is not None else self.deep_tracing
|
1983
|
+
)
|
1984
|
+
except Exception:
|
1985
|
+
return func
|
1986
|
+
|
1901
1987
|
if asyncio.iscoroutinefunction(func):
|
1988
|
+
|
1902
1989
|
@functools.wraps(func)
|
1903
1990
|
async def async_wrapper(*args, **kwargs):
|
1904
1991
|
nonlocal original_span_name
|
1905
1992
|
class_name = None
|
1906
|
-
instance_name = None
|
1907
1993
|
span_name = original_span_name
|
1908
1994
|
agent_name = None
|
1909
1995
|
|
1910
|
-
if args and hasattr(args[0],
|
1996
|
+
if args and hasattr(args[0], "__class__"):
|
1911
1997
|
class_name = args[0].__class__.__name__
|
1912
|
-
agent_name = get_instance_prefixed_name(
|
1998
|
+
agent_name = get_instance_prefixed_name(
|
1999
|
+
args[0], class_name, self.class_identifiers
|
2000
|
+
)
|
1913
2001
|
|
1914
2002
|
# Get current trace from context
|
1915
2003
|
current_trace = self.get_current_trace()
|
1916
|
-
|
2004
|
+
|
1917
2005
|
# If there's no current trace, create a root trace
|
1918
2006
|
if not current_trace:
|
1919
2007
|
trace_id = str(uuid.uuid4())
|
1920
|
-
project =
|
1921
|
-
|
2008
|
+
project = (
|
2009
|
+
project_name if project_name is not None else self.project_name
|
2010
|
+
)
|
2011
|
+
|
1922
2012
|
# Create a new trace client to serve as the root
|
1923
2013
|
current_trace = TraceClient(
|
1924
2014
|
self,
|
1925
2015
|
trace_id,
|
1926
|
-
span_name,
|
2016
|
+
span_name, # MODIFIED: Use span_name directly
|
1927
2017
|
project_name=project,
|
1928
2018
|
overwrite=overwrite,
|
1929
2019
|
rules=self.rules,
|
1930
2020
|
enable_monitoring=self.enable_monitoring,
|
1931
|
-
enable_evaluations=self.enable_evaluations
|
2021
|
+
enable_evaluations=self.enable_evaluations,
|
1932
2022
|
)
|
1933
|
-
|
2023
|
+
|
1934
2024
|
# Save empty trace and set trace context
|
1935
2025
|
# current_trace.save(empty_save=True, overwrite=overwrite)
|
1936
2026
|
trace_token = self.set_current_trace(current_trace)
|
@@ -1938,7 +2028,9 @@ class Tracer:
|
|
1938
2028
|
try:
|
1939
2029
|
# Use span for the function execution within the root trace
|
1940
2030
|
# This sets the current_span_var
|
1941
|
-
with current_trace.span(
|
2031
|
+
with current_trace.span(
|
2032
|
+
span_name, span_type=span_type
|
2033
|
+
) as span: # MODIFIED: Use span_name directly
|
1942
2034
|
# Record inputs
|
1943
2035
|
inputs = combine_args_kwargs(func, args, kwargs)
|
1944
2036
|
span.record_input(inputs)
|
@@ -1946,50 +2038,66 @@ class Tracer:
|
|
1946
2038
|
span.record_agent_name(agent_name)
|
1947
2039
|
|
1948
2040
|
# Capture state before execution
|
1949
|
-
self._conditionally_capture_and_record_state(
|
1950
|
-
|
1951
|
-
|
1952
|
-
|
1953
|
-
|
1954
|
-
|
1955
|
-
|
2041
|
+
self._conditionally_capture_and_record_state(
|
2042
|
+
span, args, is_before=True
|
2043
|
+
)
|
2044
|
+
|
2045
|
+
try:
|
2046
|
+
if use_deep_tracing:
|
2047
|
+
with _DeepTracer(self):
|
2048
|
+
result = await func(*args, **kwargs)
|
2049
|
+
else:
|
1956
2050
|
result = await func(*args, **kwargs)
|
1957
|
-
|
1958
|
-
|
1959
|
-
|
1960
|
-
|
1961
|
-
|
1962
|
-
|
1963
|
-
|
2051
|
+
except Exception as e:
|
2052
|
+
_capture_exception_for_trace(
|
2053
|
+
current_trace, sys.exc_info()
|
2054
|
+
)
|
2055
|
+
raise e
|
2056
|
+
|
2057
|
+
# Capture state after execution
|
2058
|
+
self._conditionally_capture_and_record_state(
|
2059
|
+
span, args, is_before=False
|
2060
|
+
)
|
2061
|
+
|
1964
2062
|
# Record output
|
1965
2063
|
span.record_output(result)
|
1966
2064
|
return result
|
1967
2065
|
finally:
|
1968
2066
|
# Flush background spans before saving the trace
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
1972
|
-
|
1973
|
-
|
1974
|
-
|
1975
|
-
|
1976
|
-
|
1977
|
-
|
1978
|
-
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
|
1985
|
-
|
1986
|
-
|
1987
|
-
|
1988
|
-
|
1989
|
-
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
2067
|
+
try:
|
2068
|
+
complete_trace_data = {
|
2069
|
+
"trace_id": current_trace.trace_id,
|
2070
|
+
"name": current_trace.name,
|
2071
|
+
"created_at": datetime.utcfromtimestamp(
|
2072
|
+
current_trace.start_time
|
2073
|
+
).isoformat(),
|
2074
|
+
"duration": current_trace.get_duration(),
|
2075
|
+
"trace_spans": [
|
2076
|
+
span.model_dump()
|
2077
|
+
for span in current_trace.trace_spans
|
2078
|
+
],
|
2079
|
+
"overwrite": overwrite,
|
2080
|
+
"offline_mode": self.offline_mode,
|
2081
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
2082
|
+
"parent_name": current_trace.parent_name,
|
2083
|
+
}
|
2084
|
+
# Save the completed trace
|
2085
|
+
trace_id, server_response = current_trace.save(
|
2086
|
+
overwrite=overwrite, final_save=True
|
2087
|
+
)
|
2088
|
+
|
2089
|
+
# Store the complete trace data instead of just server response
|
2090
|
+
|
2091
|
+
self.traces.append(complete_trace_data)
|
2092
|
+
|
2093
|
+
# if self.background_span_service:
|
2094
|
+
# self.background_span_service.flush()
|
2095
|
+
|
2096
|
+
# Reset trace context (span context resets automatically)
|
2097
|
+
self.reset_current_trace(trace_token)
|
2098
|
+
except Exception as e:
|
2099
|
+
warnings.warn(f"Issue with async_wrapper: {e}")
|
2100
|
+
return
|
1993
2101
|
else:
|
1994
2102
|
with current_trace.span(span_name, span_type=span_type) as span:
|
1995
2103
|
inputs = combine_args_kwargs(func, args, kwargs)
|
@@ -1998,24 +2106,28 @@ class Tracer:
|
|
1998
2106
|
span.record_agent_name(agent_name)
|
1999
2107
|
|
2000
2108
|
# Capture state before execution
|
2001
|
-
self._conditionally_capture_and_record_state(
|
2002
|
-
|
2003
|
-
|
2004
|
-
|
2005
|
-
|
2006
|
-
|
2007
|
-
|
2109
|
+
self._conditionally_capture_and_record_state(
|
2110
|
+
span, args, is_before=True
|
2111
|
+
)
|
2112
|
+
|
2113
|
+
try:
|
2114
|
+
if use_deep_tracing:
|
2115
|
+
with _DeepTracer(self):
|
2116
|
+
result = await func(*args, **kwargs)
|
2117
|
+
else:
|
2008
2118
|
result = await func(*args, **kwargs)
|
2009
|
-
|
2010
|
-
|
2011
|
-
|
2012
|
-
|
2013
|
-
# Capture state after execution
|
2014
|
-
self._conditionally_capture_and_record_state(
|
2015
|
-
|
2119
|
+
except Exception as e:
|
2120
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
2121
|
+
raise e
|
2122
|
+
|
2123
|
+
# Capture state after execution
|
2124
|
+
self._conditionally_capture_and_record_state(
|
2125
|
+
span, args, is_before=False
|
2126
|
+
)
|
2127
|
+
|
2016
2128
|
span.record_output(result)
|
2017
2129
|
return result
|
2018
|
-
|
2130
|
+
|
2019
2131
|
return async_wrapper
|
2020
2132
|
else:
|
2021
2133
|
# Non-async function implementation with deep tracing
|
@@ -2023,118 +2135,146 @@ class Tracer:
|
|
2023
2135
|
def wrapper(*args, **kwargs):
|
2024
2136
|
nonlocal original_span_name
|
2025
2137
|
class_name = None
|
2026
|
-
instance_name = None
|
2027
2138
|
span_name = original_span_name
|
2028
2139
|
agent_name = None
|
2029
|
-
if args and hasattr(args[0],
|
2140
|
+
if args and hasattr(args[0], "__class__"):
|
2030
2141
|
class_name = args[0].__class__.__name__
|
2031
|
-
agent_name = get_instance_prefixed_name(
|
2142
|
+
agent_name = get_instance_prefixed_name(
|
2143
|
+
args[0], class_name, self.class_identifiers
|
2144
|
+
)
|
2032
2145
|
# Get current trace from context
|
2033
2146
|
current_trace = self.get_current_trace()
|
2034
2147
|
|
2035
2148
|
# If there's no current trace, create a root trace
|
2036
2149
|
if not current_trace:
|
2037
2150
|
trace_id = str(uuid.uuid4())
|
2038
|
-
project =
|
2039
|
-
|
2151
|
+
project = (
|
2152
|
+
project_name if project_name is not None else self.project_name
|
2153
|
+
)
|
2154
|
+
|
2040
2155
|
# Create a new trace client to serve as the root
|
2041
2156
|
current_trace = TraceClient(
|
2042
2157
|
self,
|
2043
2158
|
trace_id,
|
2044
|
-
span_name,
|
2159
|
+
span_name, # MODIFIED: Use span_name directly
|
2045
2160
|
project_name=project,
|
2046
2161
|
overwrite=overwrite,
|
2047
2162
|
rules=self.rules,
|
2048
2163
|
enable_monitoring=self.enable_monitoring,
|
2049
|
-
enable_evaluations=self.enable_evaluations
|
2164
|
+
enable_evaluations=self.enable_evaluations,
|
2050
2165
|
)
|
2051
|
-
|
2166
|
+
|
2052
2167
|
# Save empty trace and set trace context
|
2053
2168
|
# current_trace.save(empty_save=True, overwrite=overwrite)
|
2054
2169
|
trace_token = self.set_current_trace(current_trace)
|
2055
|
-
|
2170
|
+
|
2056
2171
|
try:
|
2057
2172
|
# Use span for the function execution within the root trace
|
2058
2173
|
# This sets the current_span_var
|
2059
|
-
with current_trace.span(
|
2174
|
+
with current_trace.span(
|
2175
|
+
span_name, span_type=span_type
|
2176
|
+
) as span: # MODIFIED: Use span_name directly
|
2060
2177
|
# Record inputs
|
2061
2178
|
inputs = combine_args_kwargs(func, args, kwargs)
|
2062
2179
|
span.record_input(inputs)
|
2063
2180
|
if agent_name:
|
2064
2181
|
span.record_agent_name(agent_name)
|
2065
2182
|
# Capture state before execution
|
2066
|
-
self._conditionally_capture_and_record_state(
|
2067
|
-
|
2068
|
-
|
2069
|
-
|
2070
|
-
|
2071
|
-
|
2072
|
-
|
2183
|
+
self._conditionally_capture_and_record_state(
|
2184
|
+
span, args, is_before=True
|
2185
|
+
)
|
2186
|
+
|
2187
|
+
try:
|
2188
|
+
if use_deep_tracing:
|
2189
|
+
with _DeepTracer(self):
|
2190
|
+
result = func(*args, **kwargs)
|
2191
|
+
else:
|
2073
2192
|
result = func(*args, **kwargs)
|
2074
|
-
|
2075
|
-
|
2076
|
-
|
2077
|
-
|
2193
|
+
except Exception as e:
|
2194
|
+
_capture_exception_for_trace(
|
2195
|
+
current_trace, sys.exc_info()
|
2196
|
+
)
|
2197
|
+
raise e
|
2198
|
+
|
2078
2199
|
# Capture state after execution
|
2079
|
-
self._conditionally_capture_and_record_state(
|
2200
|
+
self._conditionally_capture_and_record_state(
|
2201
|
+
span, args, is_before=False
|
2202
|
+
)
|
2080
2203
|
|
2081
|
-
|
2082
2204
|
# Record output
|
2083
2205
|
span.record_output(result)
|
2084
2206
|
return result
|
2085
2207
|
finally:
|
2086
2208
|
# Flush background spans before saving the trace
|
2087
|
-
|
2088
|
-
|
2089
|
-
|
2090
|
-
|
2091
|
-
|
2092
|
-
|
2093
|
-
|
2094
|
-
|
2095
|
-
|
2096
|
-
|
2097
|
-
|
2098
|
-
|
2099
|
-
|
2100
|
-
|
2101
|
-
|
2102
|
-
|
2103
|
-
|
2104
|
-
|
2105
|
-
|
2106
|
-
|
2209
|
+
try:
|
2210
|
+
# Save the completed trace
|
2211
|
+
trace_id, server_response = current_trace.save(
|
2212
|
+
overwrite=overwrite, final_save=True
|
2213
|
+
)
|
2214
|
+
|
2215
|
+
# Store the complete trace data instead of just server response
|
2216
|
+
complete_trace_data = {
|
2217
|
+
"trace_id": current_trace.trace_id,
|
2218
|
+
"name": current_trace.name,
|
2219
|
+
"created_at": datetime.utcfromtimestamp(
|
2220
|
+
current_trace.start_time
|
2221
|
+
).isoformat(),
|
2222
|
+
"duration": current_trace.get_duration(),
|
2223
|
+
"trace_spans": [
|
2224
|
+
span.model_dump()
|
2225
|
+
for span in current_trace.trace_spans
|
2226
|
+
],
|
2227
|
+
"overwrite": overwrite,
|
2228
|
+
"offline_mode": self.offline_mode,
|
2229
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
2230
|
+
"parent_name": current_trace.parent_name,
|
2231
|
+
}
|
2232
|
+
self.traces.append(complete_trace_data)
|
2233
|
+
# Reset trace context (span context resets automatically)
|
2234
|
+
self.reset_current_trace(trace_token)
|
2235
|
+
except Exception as e:
|
2236
|
+
warnings.warn(f"Issue with save: {e}")
|
2237
|
+
return
|
2107
2238
|
else:
|
2108
2239
|
with current_trace.span(span_name, span_type=span_type) as span:
|
2109
|
-
|
2110
2240
|
inputs = combine_args_kwargs(func, args, kwargs)
|
2111
2241
|
span.record_input(inputs)
|
2112
2242
|
if agent_name:
|
2113
2243
|
span.record_agent_name(agent_name)
|
2114
2244
|
|
2115
2245
|
# Capture state before execution
|
2116
|
-
self._conditionally_capture_and_record_state(
|
2246
|
+
self._conditionally_capture_and_record_state(
|
2247
|
+
span, args, is_before=True
|
2248
|
+
)
|
2117
2249
|
|
2118
|
-
|
2119
|
-
|
2120
|
-
|
2121
|
-
|
2122
|
-
|
2250
|
+
try:
|
2251
|
+
if use_deep_tracing:
|
2252
|
+
with _DeepTracer(self):
|
2253
|
+
result = func(*args, **kwargs)
|
2254
|
+
else:
|
2123
2255
|
result = func(*args, **kwargs)
|
2124
|
-
|
2125
|
-
|
2126
|
-
|
2127
|
-
|
2256
|
+
except Exception as e:
|
2257
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
2258
|
+
raise e
|
2259
|
+
|
2128
2260
|
# Capture state after execution
|
2129
|
-
self._conditionally_capture_and_record_state(
|
2130
|
-
|
2261
|
+
self._conditionally_capture_and_record_state(
|
2262
|
+
span, args, is_before=False
|
2263
|
+
)
|
2264
|
+
|
2131
2265
|
span.record_output(result)
|
2132
2266
|
return result
|
2133
|
-
|
2267
|
+
|
2134
2268
|
return wrapper
|
2135
|
-
|
2136
|
-
def observe_tools(
|
2137
|
-
|
2269
|
+
|
2270
|
+
def observe_tools(
|
2271
|
+
self,
|
2272
|
+
cls=None,
|
2273
|
+
*,
|
2274
|
+
exclude_methods: Optional[List[str]] = None,
|
2275
|
+
include_private: bool = False,
|
2276
|
+
warn_on_double_decoration: bool = True,
|
2277
|
+
):
|
2138
2278
|
"""
|
2139
2279
|
Automatically adds @observe(span_type="tool") to all methods in a class.
|
2140
2280
|
|
@@ -2146,28 +2286,32 @@ class Tracer:
|
|
2146
2286
|
"""
|
2147
2287
|
|
2148
2288
|
if exclude_methods is None:
|
2149
|
-
exclude_methods = [
|
2150
|
-
|
2289
|
+
exclude_methods = ["__init__", "__new__", "__del__", "__str__", "__repr__"]
|
2290
|
+
|
2151
2291
|
def decorate_class(cls):
|
2152
2292
|
if not self.enable_monitoring:
|
2153
2293
|
return cls
|
2154
|
-
|
2294
|
+
|
2155
2295
|
decorated = []
|
2156
2296
|
skipped = []
|
2157
|
-
|
2297
|
+
|
2158
2298
|
for name in dir(cls):
|
2159
2299
|
method = getattr(cls, name)
|
2160
|
-
|
2161
|
-
if (
|
2162
|
-
|
2163
|
-
|
2164
|
-
|
2300
|
+
|
2301
|
+
if (
|
2302
|
+
not callable(method)
|
2303
|
+
or name in exclude_methods
|
2304
|
+
or (name.startswith("_") and not include_private)
|
2305
|
+
or not hasattr(cls, name)
|
2306
|
+
):
|
2165
2307
|
continue
|
2166
|
-
|
2167
|
-
if hasattr(method,
|
2308
|
+
|
2309
|
+
if hasattr(method, "_judgment_span_name"):
|
2168
2310
|
skipped.append(name)
|
2169
2311
|
if warn_on_double_decoration:
|
2170
|
-
print(
|
2312
|
+
print(
|
2313
|
+
f"Warning: {cls.__name__}.{name} already decorated, skipping"
|
2314
|
+
)
|
2171
2315
|
continue
|
2172
2316
|
|
2173
2317
|
try:
|
@@ -2177,28 +2321,76 @@ class Tracer:
|
|
2177
2321
|
except Exception as e:
|
2178
2322
|
if warn_on_double_decoration:
|
2179
2323
|
print(f"Warning: Failed to decorate {cls.__name__}.{name}: {e}")
|
2180
|
-
|
2324
|
+
|
2181
2325
|
return cls
|
2182
|
-
|
2326
|
+
|
2183
2327
|
return decorate_class if cls is None else decorate_class(cls)
|
2184
2328
|
|
2185
2329
|
def async_evaluate(self, *args, **kwargs):
|
2186
|
-
|
2187
|
-
|
2330
|
+
try:
|
2331
|
+
if not self.enable_monitoring or not self.enable_evaluations:
|
2332
|
+
return
|
2188
2333
|
|
2189
|
-
|
2190
|
-
|
2334
|
+
# --- Get trace_id passed explicitly (if any) ---
|
2335
|
+
passed_trace_id = kwargs.pop(
|
2336
|
+
"trace_id", None
|
2337
|
+
) # Get and remove trace_id from kwargs
|
2191
2338
|
|
2339
|
+
current_trace = self.get_current_trace()
|
2340
|
+
|
2341
|
+
if current_trace:
|
2342
|
+
# Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
|
2343
|
+
# (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
|
2344
|
+
if passed_trace_id:
|
2345
|
+
kwargs["trace_id"] = (
|
2346
|
+
passed_trace_id # Re-add if needed by TraceClient.async_evaluate
|
2347
|
+
)
|
2348
|
+
current_trace.async_evaluate(*args, **kwargs)
|
2349
|
+
else:
|
2350
|
+
warnings.warn(
|
2351
|
+
"No trace found (context var or fallback), skipping evaluation"
|
2352
|
+
) # Modified warning
|
2353
|
+
except Exception as e:
|
2354
|
+
warnings.warn(f"Issue with async_evaluate: {e}")
|
2355
|
+
|
2356
|
+
def update_metadata(self, metadata: dict):
|
2357
|
+
"""
|
2358
|
+
Update metadata for the current trace.
|
2359
|
+
|
2360
|
+
Args:
|
2361
|
+
metadata: Metadata as a dictionary
|
2362
|
+
"""
|
2192
2363
|
current_trace = self.get_current_trace()
|
2364
|
+
if current_trace:
|
2365
|
+
current_trace.update_metadata(metadata)
|
2366
|
+
else:
|
2367
|
+
warnings.warn("No current trace found, cannot set metadata")
|
2368
|
+
|
2369
|
+
def set_customer_id(self, customer_id: str):
|
2370
|
+
"""
|
2371
|
+
Set the customer ID for the current trace.
|
2193
2372
|
|
2373
|
+
Args:
|
2374
|
+
customer_id: The customer ID to set
|
2375
|
+
"""
|
2376
|
+
current_trace = self.get_current_trace()
|
2194
2377
|
if current_trace:
|
2195
|
-
|
2196
|
-
# (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
|
2197
|
-
if passed_trace_id:
|
2198
|
-
kwargs['trace_id'] = passed_trace_id # Re-add if needed by TraceClient.async_evaluate
|
2199
|
-
current_trace.async_evaluate(*args, **kwargs)
|
2378
|
+
current_trace.set_customer_id(customer_id)
|
2200
2379
|
else:
|
2201
|
-
warnings.warn("No trace found
|
2380
|
+
warnings.warn("No current trace found, cannot set customer ID")
|
2381
|
+
|
2382
|
+
def set_tags(self, tags: List[Union[str, set, tuple]]):
|
2383
|
+
"""
|
2384
|
+
Set the tags for the current trace.
|
2385
|
+
|
2386
|
+
Args:
|
2387
|
+
tags: List of tags to set
|
2388
|
+
"""
|
2389
|
+
current_trace = self.get_current_trace()
|
2390
|
+
if current_trace:
|
2391
|
+
current_trace.set_tags(tags)
|
2392
|
+
else:
|
2393
|
+
warnings.warn("No current trace found, cannot set tags")
|
2202
2394
|
|
2203
2395
|
def get_background_span_service(self) -> Optional[BackgroundSpanService]:
|
2204
2396
|
"""Get the background span service instance."""
|
@@ -2215,31 +2407,43 @@ class Tracer:
|
|
2215
2407
|
self.background_span_service.shutdown()
|
2216
2408
|
self.background_span_service = None
|
2217
2409
|
|
2218
|
-
|
2410
|
+
|
2411
|
+
def _get_current_trace(
|
2412
|
+
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
2413
|
+
):
|
2414
|
+
if trace_across_async_contexts:
|
2415
|
+
return Tracer.current_trace
|
2416
|
+
else:
|
2417
|
+
return current_trace_var.get()
|
2418
|
+
|
2419
|
+
|
2420
|
+
def wrap(
|
2421
|
+
client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts
|
2422
|
+
) -> Any:
|
2219
2423
|
"""
|
2220
2424
|
Wraps an API client to add tracing capabilities.
|
2221
2425
|
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
2222
2426
|
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
2223
2427
|
"""
|
2224
|
-
|
2428
|
+
(
|
2429
|
+
span_name,
|
2430
|
+
original_create,
|
2431
|
+
original_responses_create,
|
2432
|
+
original_stream,
|
2433
|
+
original_beta_parse,
|
2434
|
+
) = _get_client_config(client)
|
2225
2435
|
|
2226
|
-
def _get_current_trace():
|
2227
|
-
if trace_across_async_contexts:
|
2228
|
-
return Tracer.current_trace
|
2229
|
-
else:
|
2230
|
-
return current_trace_var.get()
|
2231
|
-
|
2232
2436
|
def _record_input_and_check_streaming(span, kwargs, is_responses=False):
|
2233
2437
|
"""Record input and check for streaming"""
|
2234
2438
|
is_streaming = kwargs.get("stream", False)
|
2235
2439
|
|
2236
|
-
|
2440
|
+
# Record input based on whether this is a responses endpoint
|
2237
2441
|
if is_responses:
|
2238
2442
|
span.record_input(kwargs)
|
2239
2443
|
else:
|
2240
2444
|
input_data = _format_input_data(client, **kwargs)
|
2241
2445
|
span.record_input(input_data)
|
2242
|
-
|
2446
|
+
|
2243
2447
|
# Warn about token counting limitations with streaming
|
2244
2448
|
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
2245
2449
|
if not kwargs.get("stream_options", {}).get("include_usage"):
|
@@ -2247,88 +2451,101 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
|
|
2247
2451
|
"OpenAI streaming calls don't include token counts by default. "
|
2248
2452
|
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
2249
2453
|
"in your API call arguments.",
|
2250
|
-
UserWarning
|
2454
|
+
UserWarning,
|
2251
2455
|
)
|
2252
|
-
|
2456
|
+
|
2253
2457
|
return is_streaming
|
2254
|
-
|
2458
|
+
|
2255
2459
|
def _format_and_record_output(span, response, is_streaming, is_async, is_responses):
|
2256
2460
|
"""Format and record the output in the span"""
|
2257
2461
|
if is_streaming:
|
2258
2462
|
output_entry = span.record_output("<pending stream>")
|
2259
2463
|
wrapper_func = _async_stream_wrapper if is_async else _sync_stream_wrapper
|
2260
|
-
return wrapper_func(
|
2464
|
+
return wrapper_func(
|
2465
|
+
response, client, output_entry, trace_across_async_contexts
|
2466
|
+
)
|
2261
2467
|
else:
|
2262
|
-
format_func =
|
2468
|
+
format_func = (
|
2469
|
+
_format_response_output_data if is_responses else _format_output_data
|
2470
|
+
)
|
2263
2471
|
output, usage = format_func(client, response)
|
2264
2472
|
span.record_output(output)
|
2265
2473
|
span.record_usage(usage)
|
2266
|
-
|
2474
|
+
|
2267
2475
|
# Queue the completed LLM span now that it has all data (input, output, usage)
|
2268
|
-
current_trace = _get_current_trace()
|
2476
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2269
2477
|
if current_trace and current_trace.background_span_service:
|
2270
2478
|
# Get the current span from the trace client
|
2271
2479
|
current_span_id = current_trace.get_current_span()
|
2272
2480
|
if current_span_id and current_span_id in current_trace.span_id_to_span:
|
2273
2481
|
completed_span = current_trace.span_id_to_span[current_span_id]
|
2274
|
-
current_trace.background_span_service.queue_span(
|
2275
|
-
|
2482
|
+
current_trace.background_span_service.queue_span(
|
2483
|
+
completed_span, span_state="completed"
|
2484
|
+
)
|
2485
|
+
|
2276
2486
|
return response
|
2277
|
-
|
2487
|
+
|
2278
2488
|
# --- Traced Async Functions ---
|
2279
2489
|
async def traced_create_async(*args, **kwargs):
|
2280
|
-
current_trace = _get_current_trace()
|
2490
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2281
2491
|
if not current_trace:
|
2282
2492
|
return await original_create(*args, **kwargs)
|
2283
|
-
|
2493
|
+
|
2284
2494
|
with current_trace.span(span_name, span_type="llm") as span:
|
2285
2495
|
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2286
|
-
|
2496
|
+
|
2287
2497
|
try:
|
2288
2498
|
response_or_iterator = await original_create(*args, **kwargs)
|
2289
|
-
return _format_and_record_output(
|
2499
|
+
return _format_and_record_output(
|
2500
|
+
span, response_or_iterator, is_streaming, True, False
|
2501
|
+
)
|
2290
2502
|
except Exception as e:
|
2291
2503
|
_capture_exception_for_trace(span, sys.exc_info())
|
2292
2504
|
raise e
|
2293
|
-
|
2505
|
+
|
2294
2506
|
async def traced_beta_parse_async(*args, **kwargs):
|
2295
|
-
current_trace = _get_current_trace()
|
2507
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2296
2508
|
if not current_trace:
|
2297
2509
|
return await original_beta_parse(*args, **kwargs)
|
2298
|
-
|
2510
|
+
|
2299
2511
|
with current_trace.span(span_name, span_type="llm") as span:
|
2300
2512
|
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2301
|
-
|
2513
|
+
|
2302
2514
|
try:
|
2303
2515
|
response_or_iterator = await original_beta_parse(*args, **kwargs)
|
2304
|
-
return _format_and_record_output(
|
2516
|
+
return _format_and_record_output(
|
2517
|
+
span, response_or_iterator, is_streaming, True, False
|
2518
|
+
)
|
2305
2519
|
except Exception as e:
|
2306
2520
|
_capture_exception_for_trace(span, sys.exc_info())
|
2307
2521
|
raise e
|
2308
|
-
|
2309
|
-
|
2522
|
+
|
2310
2523
|
# Async responses for OpenAI clients
|
2311
2524
|
async def traced_response_create_async(*args, **kwargs):
|
2312
|
-
current_trace = _get_current_trace()
|
2525
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2313
2526
|
if not current_trace:
|
2314
2527
|
return await original_responses_create(*args, **kwargs)
|
2315
|
-
|
2528
|
+
|
2316
2529
|
with current_trace.span(span_name, span_type="llm") as span:
|
2317
|
-
is_streaming = _record_input_and_check_streaming(
|
2318
|
-
|
2530
|
+
is_streaming = _record_input_and_check_streaming(
|
2531
|
+
span, kwargs, is_responses=True
|
2532
|
+
)
|
2533
|
+
|
2319
2534
|
try:
|
2320
2535
|
response_or_iterator = await original_responses_create(*args, **kwargs)
|
2321
|
-
return _format_and_record_output(
|
2536
|
+
return _format_and_record_output(
|
2537
|
+
span, response_or_iterator, is_streaming, True, True
|
2538
|
+
)
|
2322
2539
|
except Exception as e:
|
2323
2540
|
_capture_exception_for_trace(span, sys.exc_info())
|
2324
2541
|
raise e
|
2325
|
-
|
2542
|
+
|
2326
2543
|
# Function replacing .stream() for async clients
|
2327
2544
|
def traced_stream_async(*args, **kwargs):
|
2328
|
-
current_trace = _get_current_trace()
|
2545
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2329
2546
|
if not current_trace or not original_stream:
|
2330
2547
|
return original_stream(*args, **kwargs)
|
2331
|
-
|
2548
|
+
|
2332
2549
|
original_manager = original_stream(*args, **kwargs)
|
2333
2550
|
return _TracedAsyncStreamManagerWrapper(
|
2334
2551
|
original_manager=original_manager,
|
@@ -2336,61 +2553,70 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
|
|
2336
2553
|
span_name=span_name,
|
2337
2554
|
trace_client=current_trace,
|
2338
2555
|
stream_wrapper_func=_async_stream_wrapper,
|
2339
|
-
input_kwargs=kwargs
|
2556
|
+
input_kwargs=kwargs,
|
2557
|
+
trace_across_async_contexts=trace_across_async_contexts,
|
2340
2558
|
)
|
2341
|
-
|
2559
|
+
|
2342
2560
|
# --- Traced Sync Functions ---
|
2343
2561
|
def traced_create_sync(*args, **kwargs):
|
2344
|
-
current_trace = _get_current_trace()
|
2562
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2345
2563
|
if not current_trace:
|
2346
2564
|
return original_create(*args, **kwargs)
|
2347
|
-
|
2565
|
+
|
2348
2566
|
with current_trace.span(span_name, span_type="llm") as span:
|
2349
2567
|
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2350
|
-
|
2568
|
+
|
2351
2569
|
try:
|
2352
2570
|
response_or_iterator = original_create(*args, **kwargs)
|
2353
|
-
return _format_and_record_output(
|
2571
|
+
return _format_and_record_output(
|
2572
|
+
span, response_or_iterator, is_streaming, False, False
|
2573
|
+
)
|
2354
2574
|
except Exception as e:
|
2355
2575
|
_capture_exception_for_trace(span, sys.exc_info())
|
2356
2576
|
raise e
|
2357
|
-
|
2577
|
+
|
2358
2578
|
def traced_beta_parse_sync(*args, **kwargs):
|
2359
|
-
current_trace = _get_current_trace()
|
2579
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2360
2580
|
if not current_trace:
|
2361
2581
|
return original_beta_parse(*args, **kwargs)
|
2362
|
-
|
2582
|
+
|
2363
2583
|
with current_trace.span(span_name, span_type="llm") as span:
|
2364
2584
|
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
2365
|
-
|
2585
|
+
|
2366
2586
|
try:
|
2367
2587
|
response_or_iterator = original_beta_parse(*args, **kwargs)
|
2368
|
-
return _format_and_record_output(
|
2588
|
+
return _format_and_record_output(
|
2589
|
+
span, response_or_iterator, is_streaming, False, False
|
2590
|
+
)
|
2369
2591
|
except Exception as e:
|
2370
2592
|
_capture_exception_for_trace(span, sys.exc_info())
|
2371
2593
|
raise e
|
2372
|
-
|
2594
|
+
|
2373
2595
|
def traced_response_create_sync(*args, **kwargs):
|
2374
|
-
current_trace = _get_current_trace()
|
2596
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2375
2597
|
if not current_trace:
|
2376
2598
|
return original_responses_create(*args, **kwargs)
|
2377
|
-
|
2599
|
+
|
2378
2600
|
with current_trace.span(span_name, span_type="llm") as span:
|
2379
|
-
is_streaming = _record_input_and_check_streaming(
|
2380
|
-
|
2601
|
+
is_streaming = _record_input_and_check_streaming(
|
2602
|
+
span, kwargs, is_responses=True
|
2603
|
+
)
|
2604
|
+
|
2381
2605
|
try:
|
2382
2606
|
response_or_iterator = original_responses_create(*args, **kwargs)
|
2383
|
-
return _format_and_record_output(
|
2607
|
+
return _format_and_record_output(
|
2608
|
+
span, response_or_iterator, is_streaming, False, True
|
2609
|
+
)
|
2384
2610
|
except Exception as e:
|
2385
2611
|
_capture_exception_for_trace(span, sys.exc_info())
|
2386
2612
|
raise e
|
2387
|
-
|
2613
|
+
|
2388
2614
|
# Function replacing sync .stream()
|
2389
2615
|
def traced_stream_sync(*args, **kwargs):
|
2390
|
-
current_trace = _get_current_trace()
|
2616
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
2391
2617
|
if not current_trace or not original_stream:
|
2392
2618
|
return original_stream(*args, **kwargs)
|
2393
|
-
|
2619
|
+
|
2394
2620
|
original_manager = original_stream(*args, **kwargs)
|
2395
2621
|
return _TracedSyncStreamManagerWrapper(
|
2396
2622
|
original_manager=original_manager,
|
@@ -2398,15 +2624,21 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
|
|
2398
2624
|
span_name=span_name,
|
2399
2625
|
trace_client=current_trace,
|
2400
2626
|
stream_wrapper_func=_sync_stream_wrapper,
|
2401
|
-
input_kwargs=kwargs
|
2627
|
+
input_kwargs=kwargs,
|
2628
|
+
trace_across_async_contexts=trace_across_async_contexts,
|
2402
2629
|
)
|
2403
|
-
|
2630
|
+
|
2404
2631
|
# --- Assign Traced Methods to Client Instance ---
|
2405
2632
|
if isinstance(client, (AsyncOpenAI, AsyncTogether)):
|
2406
2633
|
client.chat.completions.create = traced_create_async
|
2407
2634
|
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
2408
2635
|
client.responses.create = traced_response_create_async
|
2409
|
-
if
|
2636
|
+
if (
|
2637
|
+
hasattr(client, "beta")
|
2638
|
+
and hasattr(client.beta, "chat")
|
2639
|
+
and hasattr(client.beta.chat, "completions")
|
2640
|
+
and hasattr(client.beta.chat.completions, "parse")
|
2641
|
+
):
|
2410
2642
|
client.beta.chat.completions.parse = traced_beta_parse_async
|
2411
2643
|
elif isinstance(client, AsyncAnthropic):
|
2412
2644
|
client.messages.create = traced_create_async
|
@@ -2418,7 +2650,12 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
|
|
2418
2650
|
client.chat.completions.create = traced_create_sync
|
2419
2651
|
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
2420
2652
|
client.responses.create = traced_response_create_sync
|
2421
|
-
if
|
2653
|
+
if (
|
2654
|
+
hasattr(client, "beta")
|
2655
|
+
and hasattr(client.beta, "chat")
|
2656
|
+
and hasattr(client.beta.chat, "completions")
|
2657
|
+
and hasattr(client.beta.chat.completions, "parse")
|
2658
|
+
):
|
2422
2659
|
client.beta.chat.completions.parse = traced_beta_parse_sync
|
2423
2660
|
elif isinstance(client, Anthropic):
|
2424
2661
|
client.messages.create = traced_create_sync
|
@@ -2426,17 +2663,21 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
|
|
2426
2663
|
client.messages.stream = traced_stream_sync
|
2427
2664
|
elif isinstance(client, genai.Client):
|
2428
2665
|
client.models.generate_content = traced_create_sync
|
2429
|
-
|
2666
|
+
|
2430
2667
|
return client
|
2431
2668
|
|
2669
|
+
|
2432
2670
|
# Helper functions for client-specific operations
|
2433
2671
|
|
2434
|
-
|
2672
|
+
|
2673
|
+
def _get_client_config(
|
2674
|
+
client: ApiClient,
|
2675
|
+
) -> tuple[str, Callable, Optional[Callable], Optional[Callable], Optional[Callable]]:
|
2435
2676
|
"""Returns configuration tuple for the given API client.
|
2436
|
-
|
2677
|
+
|
2437
2678
|
Args:
|
2438
2679
|
client: An instance of OpenAI, Together, or Anthropic client
|
2439
|
-
|
2680
|
+
|
2440
2681
|
Returns:
|
2441
2682
|
tuple: (span_name, create_method, responses_method, stream_method, beta_parse_method)
|
2442
2683
|
- span_name: String identifier for tracing
|
@@ -2444,23 +2685,36 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
|
|
2444
2685
|
- responses_method: Reference to the client's responses method (if applicable)
|
2445
2686
|
- stream_method: Reference to the client's stream method (if applicable)
|
2446
2687
|
- beta_parse_method: Reference to the client's beta parse method (if applicable)
|
2447
|
-
|
2688
|
+
|
2448
2689
|
Raises:
|
2449
2690
|
ValueError: If client type is not supported
|
2450
2691
|
"""
|
2451
2692
|
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
2452
|
-
return
|
2693
|
+
return (
|
2694
|
+
"OPENAI_API_CALL",
|
2695
|
+
client.chat.completions.create,
|
2696
|
+
client.responses.create,
|
2697
|
+
None,
|
2698
|
+
client.beta.chat.completions.parse,
|
2699
|
+
)
|
2453
2700
|
elif isinstance(client, (Together, AsyncTogether)):
|
2454
2701
|
return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
|
2455
2702
|
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
2456
|
-
return
|
2703
|
+
return (
|
2704
|
+
"ANTHROPIC_API_CALL",
|
2705
|
+
client.messages.create,
|
2706
|
+
None,
|
2707
|
+
client.messages.stream,
|
2708
|
+
None,
|
2709
|
+
)
|
2457
2710
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2458
2711
|
return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
|
2459
2712
|
raise ValueError(f"Unsupported client type: {type(client)}")
|
2460
2713
|
|
2714
|
+
|
2461
2715
|
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
2462
2716
|
"""Format input parameters based on client type.
|
2463
|
-
|
2717
|
+
|
2464
2718
|
Extracts relevant parameters from kwargs based on the client type
|
2465
2719
|
to ensure consistent tracing across different APIs.
|
2466
2720
|
"""
|
@@ -2473,25 +2727,23 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
|
2473
2727
|
input_data["response_format"] = kwargs.get("response_format")
|
2474
2728
|
return input_data
|
2475
2729
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2476
|
-
return {
|
2477
|
-
"model": kwargs.get("model"),
|
2478
|
-
"contents": kwargs.get("contents")
|
2479
|
-
}
|
2730
|
+
return {"model": kwargs.get("model"), "contents": kwargs.get("contents")}
|
2480
2731
|
# Anthropic requires additional max_tokens parameter
|
2481
2732
|
return {
|
2482
2733
|
"model": kwargs.get("model"),
|
2483
2734
|
"messages": kwargs.get("messages"),
|
2484
|
-
"max_tokens": kwargs.get("max_tokens")
|
2735
|
+
"max_tokens": kwargs.get("max_tokens"),
|
2485
2736
|
}
|
2486
2737
|
|
2487
|
-
|
2738
|
+
|
2739
|
+
def _format_response_output_data(client: ApiClient, response: Any) -> tuple:
|
2488
2740
|
"""Format API response data based on client type.
|
2489
|
-
|
2741
|
+
|
2490
2742
|
Normalizes different response formats into a consistent structure
|
2491
2743
|
for tracing purposes.
|
2492
2744
|
"""
|
2493
2745
|
message_content = None
|
2494
|
-
prompt_tokens = 0
|
2746
|
+
prompt_tokens = 0
|
2495
2747
|
completion_tokens = 0
|
2496
2748
|
model_name = None
|
2497
2749
|
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
@@ -2501,14 +2753,16 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
|
|
2501
2753
|
message_content = response.output
|
2502
2754
|
else:
|
2503
2755
|
warnings.warn(f"Unsupported client type: {type(client)}")
|
2504
|
-
return
|
2505
|
-
|
2506
|
-
prompt_cost, completion_cost = cost_per_token(
|
2756
|
+
return None, None
|
2757
|
+
|
2758
|
+
prompt_cost, completion_cost = cost_per_token(
|
2507
2759
|
model=model_name,
|
2508
2760
|
prompt_tokens=prompt_tokens,
|
2509
2761
|
completion_tokens=completion_tokens,
|
2510
2762
|
)
|
2511
|
-
total_cost_usd = (
|
2763
|
+
total_cost_usd = (
|
2764
|
+
(prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
2765
|
+
)
|
2512
2766
|
usage = TraceUsage(
|
2513
2767
|
prompt_tokens=prompt_tokens,
|
2514
2768
|
completion_tokens=completion_tokens,
|
@@ -2516,17 +2770,19 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
|
|
2516
2770
|
prompt_tokens_cost_usd=prompt_cost,
|
2517
2771
|
completion_tokens_cost_usd=completion_cost,
|
2518
2772
|
total_cost_usd=total_cost_usd,
|
2519
|
-
model_name=model_name
|
2773
|
+
model_name=model_name,
|
2520
2774
|
)
|
2521
2775
|
return message_content, usage
|
2522
2776
|
|
2523
2777
|
|
2524
|
-
def _format_output_data(
|
2778
|
+
def _format_output_data(
|
2779
|
+
client: ApiClient, response: Any
|
2780
|
+
) -> tuple[Optional[str], Optional[TraceUsage]]:
|
2525
2781
|
"""Format API response data based on client type.
|
2526
|
-
|
2782
|
+
|
2527
2783
|
Normalizes different response formats into a consistent structure
|
2528
2784
|
for tracing purposes.
|
2529
|
-
|
2785
|
+
|
2530
2786
|
Returns:
|
2531
2787
|
dict containing:
|
2532
2788
|
- content: The generated text
|
@@ -2541,7 +2797,10 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
2541
2797
|
model_name = response.model
|
2542
2798
|
prompt_tokens = response.usage.prompt_tokens
|
2543
2799
|
completion_tokens = response.usage.completion_tokens
|
2544
|
-
if
|
2800
|
+
if (
|
2801
|
+
hasattr(response.choices[0].message, "parsed")
|
2802
|
+
and response.choices[0].message.parsed
|
2803
|
+
):
|
2545
2804
|
message_content = response.choices[0].message.parsed
|
2546
2805
|
else:
|
2547
2806
|
message_content = response.choices[0].message.content
|
@@ -2558,13 +2817,15 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
2558
2817
|
else:
|
2559
2818
|
warnings.warn(f"Unsupported client type: {type(client)}")
|
2560
2819
|
return None, None
|
2561
|
-
|
2820
|
+
|
2562
2821
|
prompt_cost, completion_cost = cost_per_token(
|
2563
2822
|
model=model_name,
|
2564
2823
|
prompt_tokens=prompt_tokens,
|
2565
2824
|
completion_tokens=completion_tokens,
|
2566
2825
|
)
|
2567
|
-
total_cost_usd = (
|
2826
|
+
total_cost_usd = (
|
2827
|
+
(prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
2828
|
+
)
|
2568
2829
|
usage = TraceUsage(
|
2569
2830
|
prompt_tokens=prompt_tokens,
|
2570
2831
|
completion_tokens=completion_tokens,
|
@@ -2572,56 +2833,61 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
2572
2833
|
prompt_tokens_cost_usd=prompt_cost,
|
2573
2834
|
completion_tokens_cost_usd=completion_cost,
|
2574
2835
|
total_cost_usd=total_cost_usd,
|
2575
|
-
model_name=model_name
|
2836
|
+
model_name=model_name,
|
2576
2837
|
)
|
2577
2838
|
return message_content, usage
|
2578
2839
|
|
2840
|
+
|
2579
2841
|
def combine_args_kwargs(func, args, kwargs):
|
2580
2842
|
"""
|
2581
2843
|
Combine positional arguments and keyword arguments into a single dictionary.
|
2582
|
-
|
2844
|
+
|
2583
2845
|
Args:
|
2584
2846
|
func: The function being called
|
2585
2847
|
args: Tuple of positional arguments
|
2586
2848
|
kwargs: Dictionary of keyword arguments
|
2587
|
-
|
2849
|
+
|
2588
2850
|
Returns:
|
2589
2851
|
A dictionary combining both args and kwargs
|
2590
2852
|
"""
|
2591
2853
|
try:
|
2592
2854
|
import inspect
|
2855
|
+
|
2593
2856
|
sig = inspect.signature(func)
|
2594
2857
|
param_names = list(sig.parameters.keys())
|
2595
|
-
|
2858
|
+
|
2596
2859
|
args_dict = {}
|
2597
2860
|
for i, arg in enumerate(args):
|
2598
2861
|
if i < len(param_names):
|
2599
2862
|
args_dict[param_names[i]] = arg
|
2600
2863
|
else:
|
2601
2864
|
args_dict[f"arg{i}"] = arg
|
2602
|
-
|
2865
|
+
|
2603
2866
|
return {**args_dict, **kwargs}
|
2604
|
-
except Exception
|
2867
|
+
except Exception:
|
2605
2868
|
# Fallback if signature inspection fails
|
2606
2869
|
return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
|
2607
2870
|
|
2871
|
+
|
2608
2872
|
# NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
|
2609
2873
|
# @link https://docs.python.org/3.13/library/sysconfig.html
|
2610
2874
|
_TRACE_FILEPATH_BLOCKLIST = tuple(
|
2611
2875
|
os.path.realpath(p) + os.sep
|
2612
2876
|
for p in {
|
2613
|
-
sysconfig.get_paths()[
|
2614
|
-
sysconfig.get_paths().get(
|
2877
|
+
sysconfig.get_paths()["stdlib"],
|
2878
|
+
sysconfig.get_paths().get("platstdlib", ""),
|
2615
2879
|
*site.getsitepackages(),
|
2616
2880
|
site.getusersitepackages(),
|
2617
2881
|
*(
|
2618
|
-
[os.path.join(os.path.dirname(__file__),
|
2619
|
-
if os.environ.get(
|
2882
|
+
[os.path.join(os.path.dirname(__file__), "../../judgeval/")]
|
2883
|
+
if os.environ.get("JUDGMENT_DEV")
|
2620
2884
|
else []
|
2621
2885
|
),
|
2622
|
-
}
|
2886
|
+
}
|
2887
|
+
if p
|
2623
2888
|
)
|
2624
2889
|
|
2890
|
+
|
2625
2891
|
# Add the new TraceThreadPoolExecutor class
|
2626
2892
|
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
2627
2893
|
"""
|
@@ -2633,6 +2899,7 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
|
2633
2899
|
allowing the Tracer to maintain correct parent-child relationships across
|
2634
2900
|
thread boundaries.
|
2635
2901
|
"""
|
2902
|
+
|
2636
2903
|
def submit(self, fn, /, *args, **kwargs):
|
2637
2904
|
"""
|
2638
2905
|
Submit a callable to be executed with the captured context.
|
@@ -2653,9 +2920,11 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
|
2653
2920
|
# Note: The `map` method would also need to be overridden for full context
|
2654
2921
|
# propagation if users rely on it, but `submit` is the most common use case.
|
2655
2922
|
|
2923
|
+
|
2656
2924
|
# Helper functions for stream processing
|
2657
2925
|
# ---------------------------------------
|
2658
2926
|
|
2927
|
+
|
2659
2928
|
def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
|
2660
2929
|
"""Extracts the text content from a stream chunk based on the client type."""
|
2661
2930
|
try:
|
@@ -2667,34 +2936,49 @@ def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
|
|
2667
2936
|
return chunk.delta.text
|
2668
2937
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2669
2938
|
# Google streams Candidate objects
|
2670
|
-
if
|
2939
|
+
if (
|
2940
|
+
chunk.candidates
|
2941
|
+
and chunk.candidates[0].content
|
2942
|
+
and chunk.candidates[0].content.parts
|
2943
|
+
):
|
2671
2944
|
return chunk.candidates[0].content.parts[0].text
|
2672
2945
|
except (AttributeError, IndexError, KeyError):
|
2673
2946
|
# Handle cases where chunk structure is unexpected or doesn't contain content
|
2674
|
-
pass
|
2947
|
+
pass # Return None
|
2675
2948
|
return None
|
2676
2949
|
|
2677
|
-
|
2950
|
+
|
2951
|
+
def _extract_usage_from_final_chunk(
|
2952
|
+
client: ApiClient, chunk: Any
|
2953
|
+
) -> Optional[Dict[str, int]]:
|
2678
2954
|
"""Extracts usage data if present in the *final* chunk (client-specific)."""
|
2679
2955
|
try:
|
2680
2956
|
# OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
|
2681
2957
|
# This typically requires specific API versions or settings. Often usage is *not* streamed.
|
2682
2958
|
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
2683
2959
|
# Check if usage is directly on the chunk (some models might do this)
|
2684
|
-
if hasattr(chunk,
|
2960
|
+
if hasattr(chunk, "usage") and chunk.usage:
|
2685
2961
|
prompt_tokens = chunk.usage.prompt_tokens
|
2686
2962
|
completion_tokens = chunk.usage.completion_tokens
|
2687
2963
|
# Check if usage is nested within choices (less common for final chunk, but check)
|
2688
|
-
elif
|
2964
|
+
elif (
|
2965
|
+
chunk.choices
|
2966
|
+
and hasattr(chunk.choices[0], "usage")
|
2967
|
+
and chunk.choices[0].usage
|
2968
|
+
):
|
2689
2969
|
prompt_tokens = chunk.choices[0].usage.prompt_tokens
|
2690
2970
|
completion_tokens = chunk.choices[0].usage.completion_tokens
|
2691
|
-
|
2971
|
+
|
2692
2972
|
prompt_cost, completion_cost = cost_per_token(
|
2693
|
-
|
2694
|
-
|
2695
|
-
|
2696
|
-
|
2697
|
-
total_cost_usd = (
|
2973
|
+
model=chunk.model,
|
2974
|
+
prompt_tokens=prompt_tokens,
|
2975
|
+
completion_tokens=completion_tokens,
|
2976
|
+
)
|
2977
|
+
total_cost_usd = (
|
2978
|
+
(prompt_cost + completion_cost)
|
2979
|
+
if prompt_cost and completion_cost
|
2980
|
+
else None
|
2981
|
+
)
|
2698
2982
|
return TraceUsage(
|
2699
2983
|
prompt_tokens=chunk.usage.prompt_tokens,
|
2700
2984
|
completion_tokens=chunk.usage.completion_tokens,
|
@@ -2702,9 +2986,9 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
|
|
2702
2986
|
prompt_tokens_cost_usd=prompt_cost,
|
2703
2987
|
completion_tokens_cost_usd=completion_cost,
|
2704
2988
|
total_cost_usd=total_cost_usd,
|
2705
|
-
model_name=chunk.model
|
2989
|
+
model_name=chunk.model,
|
2706
2990
|
)
|
2707
|
-
|
2991
|
+
# Anthropic includes usage in the 'message_stop' event type
|
2708
2992
|
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
2709
2993
|
if chunk.type == "message_stop":
|
2710
2994
|
# Anthropic final usage is often attached to the *message* object, not the chunk directly
|
@@ -2713,18 +2997,18 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
|
|
2713
2997
|
# This is a placeholder - Anthropic usage typically needs a separate call or context.
|
2714
2998
|
pass
|
2715
2999
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
2716
|
-
|
2717
|
-
|
2718
|
-
|
2719
|
-
|
2720
|
-
|
2721
|
-
|
2722
|
-
|
2723
|
-
|
3000
|
+
# Google provides usage metadata on the full response object, not typically streamed per chunk.
|
3001
|
+
# It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
|
3002
|
+
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
|
3003
|
+
return {
|
3004
|
+
"prompt_tokens": chunk.usage_metadata.prompt_token_count,
|
3005
|
+
"completion_tokens": chunk.usage_metadata.candidates_token_count,
|
3006
|
+
"total_tokens": chunk.usage_metadata.total_token_count,
|
3007
|
+
}
|
2724
3008
|
|
2725
3009
|
except (AttributeError, IndexError, KeyError, TypeError):
|
2726
3010
|
# Handle cases where usage data is missing or malformed
|
2727
|
-
|
3011
|
+
pass # Return None
|
2728
3012
|
return None
|
2729
3013
|
|
2730
3014
|
|
@@ -2732,7 +3016,8 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
|
|
2732
3016
|
def _sync_stream_wrapper(
|
2733
3017
|
original_stream: Iterator,
|
2734
3018
|
client: ApiClient,
|
2735
|
-
span: TraceSpan
|
3019
|
+
span: TraceSpan,
|
3020
|
+
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
2736
3021
|
) -> Generator[Any, None, None]:
|
2737
3022
|
"""Wraps a synchronous stream iterator to capture content and update the trace."""
|
2738
3023
|
content_parts = [] # Use a list instead of string concatenation
|
@@ -2742,9 +3027,11 @@ def _sync_stream_wrapper(
|
|
2742
3027
|
for chunk in original_stream:
|
2743
3028
|
content_part = _extract_content_from_chunk(client, chunk)
|
2744
3029
|
if content_part:
|
2745
|
-
content_parts.append(
|
2746
|
-
|
2747
|
-
|
3030
|
+
content_parts.append(
|
3031
|
+
content_part
|
3032
|
+
) # Append to list instead of concatenating
|
3033
|
+
last_chunk = chunk # Keep track of the last chunk for potential usage data
|
3034
|
+
yield chunk # Pass the chunk to the caller
|
2748
3035
|
finally:
|
2749
3036
|
# Attempt to extract usage from the last chunk received
|
2750
3037
|
if last_chunk:
|
@@ -2753,23 +3040,25 @@ def _sync_stream_wrapper(
|
|
2753
3040
|
# Update the trace entry with the accumulated content and usage
|
2754
3041
|
span.output = "".join(content_parts)
|
2755
3042
|
span.usage = final_usage
|
2756
|
-
|
3043
|
+
|
2757
3044
|
# Queue the completed LLM span now that streaming is done and all data is available
|
2758
|
-
|
2759
|
-
|
2760
|
-
|
2761
|
-
|
2762
|
-
|
2763
|
-
|
2764
|
-
|
3045
|
+
|
3046
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
3047
|
+
if current_trace and current_trace.background_span_service:
|
3048
|
+
current_trace.background_span_service.queue_span(
|
3049
|
+
span, span_state="completed"
|
3050
|
+
)
|
3051
|
+
|
2765
3052
|
# Note: We might need to adjust _serialize_output if this dict causes issues,
|
2766
3053
|
# but Pydantic's model_dump should handle dicts.
|
2767
3054
|
|
3055
|
+
|
2768
3056
|
# --- Async Stream Wrapper ---
|
2769
3057
|
async def _async_stream_wrapper(
|
2770
3058
|
original_stream: AsyncIterator,
|
2771
3059
|
client: ApiClient,
|
2772
|
-
span: TraceSpan
|
3060
|
+
span: TraceSpan,
|
3061
|
+
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
2773
3062
|
) -> AsyncGenerator[Any, None]:
|
2774
3063
|
# [Existing logic - unchanged]
|
2775
3064
|
content_parts = [] # Use a list instead of string concatenation
|
@@ -2778,56 +3067,70 @@ async def _async_stream_wrapper(
|
|
2778
3067
|
anthropic_input_tokens = 0
|
2779
3068
|
anthropic_output_tokens = 0
|
2780
3069
|
|
2781
|
-
target_span_id = span.span_id
|
2782
|
-
|
2783
3070
|
try:
|
2784
3071
|
model_name = ""
|
2785
3072
|
async for chunk in original_stream:
|
2786
3073
|
# Check for OpenAI's final usage chunk
|
2787
|
-
if
|
3074
|
+
if (
|
3075
|
+
isinstance(client, (AsyncOpenAI, OpenAI))
|
3076
|
+
and hasattr(chunk, "usage")
|
3077
|
+
and chunk.usage is not None
|
3078
|
+
):
|
2788
3079
|
final_usage_data = {
|
2789
3080
|
"prompt_tokens": chunk.usage.prompt_tokens,
|
2790
3081
|
"completion_tokens": chunk.usage.completion_tokens,
|
2791
|
-
"total_tokens": chunk.usage.total_tokens
|
3082
|
+
"total_tokens": chunk.usage.total_tokens,
|
2792
3083
|
}
|
2793
3084
|
model_name = chunk.model
|
2794
3085
|
yield chunk
|
2795
3086
|
continue
|
2796
3087
|
|
2797
|
-
if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(
|
3088
|
+
if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(
|
3089
|
+
chunk, "type"
|
3090
|
+
):
|
2798
3091
|
if chunk.type == "message_start":
|
2799
|
-
if
|
2800
|
-
|
2801
|
-
|
3092
|
+
if (
|
3093
|
+
hasattr(chunk, "message")
|
3094
|
+
and hasattr(chunk.message, "usage")
|
3095
|
+
and hasattr(chunk.message.usage, "input_tokens")
|
3096
|
+
):
|
3097
|
+
anthropic_input_tokens = chunk.message.usage.input_tokens
|
3098
|
+
model_name = chunk.message.model
|
2802
3099
|
elif chunk.type == "message_delta":
|
2803
|
-
if hasattr(chunk,
|
3100
|
+
if hasattr(chunk, "usage") and hasattr(
|
3101
|
+
chunk.usage, "output_tokens"
|
3102
|
+
):
|
2804
3103
|
anthropic_output_tokens = chunk.usage.output_tokens
|
2805
3104
|
|
2806
3105
|
content_part = _extract_content_from_chunk(client, chunk)
|
2807
3106
|
if content_part:
|
2808
|
-
content_parts.append(
|
3107
|
+
content_parts.append(
|
3108
|
+
content_part
|
3109
|
+
) # Append to list instead of concatenating
|
2809
3110
|
last_content_chunk = chunk
|
2810
3111
|
|
2811
3112
|
yield chunk
|
2812
3113
|
finally:
|
2813
3114
|
anthropic_final_usage = None
|
2814
|
-
if isinstance(client, (AsyncAnthropic, Anthropic)) and (
|
2815
|
-
|
2816
|
-
|
2817
|
-
|
2818
|
-
|
2819
|
-
|
3115
|
+
if isinstance(client, (AsyncAnthropic, Anthropic)) and (
|
3116
|
+
anthropic_input_tokens > 0 or anthropic_output_tokens > 0
|
3117
|
+
):
|
3118
|
+
anthropic_final_usage = {
|
3119
|
+
"prompt_tokens": anthropic_input_tokens,
|
3120
|
+
"completion_tokens": anthropic_output_tokens,
|
3121
|
+
"total_tokens": anthropic_input_tokens + anthropic_output_tokens,
|
3122
|
+
}
|
2820
3123
|
|
2821
3124
|
usage_info = None
|
2822
3125
|
if final_usage_data:
|
2823
|
-
|
3126
|
+
usage_info = final_usage_data
|
2824
3127
|
elif anthropic_final_usage:
|
2825
|
-
|
3128
|
+
usage_info = anthropic_final_usage
|
2826
3129
|
elif last_content_chunk:
|
2827
3130
|
usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
|
2828
3131
|
|
2829
3132
|
if usage_info and not isinstance(usage_info, TraceUsage):
|
2830
|
-
prompt_cost, completion_cost = cost_per_token(
|
3133
|
+
prompt_cost, completion_cost = cost_per_token(
|
2831
3134
|
model=model_name,
|
2832
3135
|
prompt_tokens=usage_info["prompt_tokens"],
|
2833
3136
|
completion_tokens=usage_info["completion_tokens"],
|
@@ -2839,21 +3142,23 @@ async def _async_stream_wrapper(
|
|
2839
3142
|
prompt_tokens_cost_usd=prompt_cost,
|
2840
3143
|
completion_tokens_cost_usd=completion_cost,
|
2841
3144
|
total_cost_usd=prompt_cost + completion_cost,
|
2842
|
-
model_name=model_name
|
3145
|
+
model_name=model_name,
|
2843
3146
|
)
|
2844
|
-
if span and hasattr(span,
|
2845
|
-
span.output =
|
3147
|
+
if span and hasattr(span, "output"):
|
3148
|
+
span.output = "".join(content_parts)
|
2846
3149
|
span.usage = usage_info
|
2847
|
-
start_ts = getattr(span,
|
3150
|
+
start_ts = getattr(span, "created_at", time.time())
|
2848
3151
|
span.duration = time.time() - start_ts
|
2849
|
-
|
3152
|
+
|
2850
3153
|
# Queue the completed LLM span now that async streaming is done and all data is available
|
2851
|
-
|
2852
|
-
|
2853
|
-
|
2854
|
-
|
3154
|
+
current_trace = _get_current_trace(trace_across_async_contexts)
|
3155
|
+
if current_trace and current_trace.background_span_service:
|
3156
|
+
current_trace.background_span_service.queue_span(
|
3157
|
+
span, span_state="completed"
|
3158
|
+
)
|
2855
3159
|
# else: # Handle error case if necessary, but remove debug print
|
2856
3160
|
|
3161
|
+
|
2857
3162
|
def cost_per_token(*args, **kwargs):
|
2858
3163
|
try:
|
2859
3164
|
return _original_cost_per_token(*args, **kwargs)
|
@@ -2861,8 +3166,18 @@ def cost_per_token(*args, **kwargs):
|
|
2861
3166
|
warnings.warn(f"Error calculating cost per token: {e}")
|
2862
3167
|
return None, None
|
2863
3168
|
|
3169
|
+
|
2864
3170
|
class _BaseStreamManagerWrapper:
|
2865
|
-
def __init__(
|
3171
|
+
def __init__(
|
3172
|
+
self,
|
3173
|
+
original_manager,
|
3174
|
+
client,
|
3175
|
+
span_name,
|
3176
|
+
trace_client,
|
3177
|
+
stream_wrapper_func,
|
3178
|
+
input_kwargs,
|
3179
|
+
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
3180
|
+
):
|
2866
3181
|
self._original_manager = original_manager
|
2867
3182
|
self._client = client
|
2868
3183
|
self._span_name = span_name
|
@@ -2870,13 +3185,19 @@ class _BaseStreamManagerWrapper:
|
|
2870
3185
|
self._stream_wrapper_func = stream_wrapper_func
|
2871
3186
|
self._input_kwargs = input_kwargs
|
2872
3187
|
self._parent_span_id_at_entry = None
|
3188
|
+
self._trace_across_async_contexts = trace_across_async_contexts
|
2873
3189
|
|
2874
3190
|
def _create_span(self):
|
2875
3191
|
start_time = time.time()
|
2876
3192
|
span_id = str(uuid.uuid4())
|
2877
3193
|
current_depth = 0
|
2878
|
-
if
|
2879
|
-
|
3194
|
+
if (
|
3195
|
+
self._parent_span_id_at_entry
|
3196
|
+
and self._parent_span_id_at_entry in self._trace_client._span_depths
|
3197
|
+
):
|
3198
|
+
current_depth = (
|
3199
|
+
self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
3200
|
+
)
|
2880
3201
|
self._trace_client._span_depths[span_id] = current_depth
|
2881
3202
|
span = TraceSpan(
|
2882
3203
|
function=self._span_name,
|
@@ -2886,7 +3207,7 @@ class _BaseStreamManagerWrapper:
|
|
2886
3207
|
message=self._span_name,
|
2887
3208
|
created_at=start_time,
|
2888
3209
|
span_type="llm",
|
2889
|
-
parent_span_id=self._parent_span_id_at_entry
|
3210
|
+
parent_span_id=self._parent_span_id_at_entry,
|
2890
3211
|
)
|
2891
3212
|
self._trace_client.add_span(span)
|
2892
3213
|
return span_id, span
|
@@ -2898,7 +3219,10 @@ class _BaseStreamManagerWrapper:
|
|
2898
3219
|
if span_id in self._trace_client._span_depths:
|
2899
3220
|
del self._trace_client._span_depths[span_id]
|
2900
3221
|
|
2901
|
-
|
3222
|
+
|
3223
|
+
class _TracedAsyncStreamManagerWrapper(
|
3224
|
+
_BaseStreamManagerWrapper, AbstractAsyncContextManager
|
3225
|
+
):
|
2902
3226
|
async def __aenter__(self):
|
2903
3227
|
self._parent_span_id_at_entry = self._trace_client.get_current_span()
|
2904
3228
|
if not self._trace_client:
|
@@ -2911,17 +3235,22 @@ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncC
|
|
2911
3235
|
# Call the original __aenter__ and expect it to be an async generator
|
2912
3236
|
raw_iterator = await self._original_manager.__aenter__()
|
2913
3237
|
span.output = "<pending stream>"
|
2914
|
-
return self._stream_wrapper_func(
|
3238
|
+
return self._stream_wrapper_func(
|
3239
|
+
raw_iterator, self._client, span, self._trace_across_async_contexts
|
3240
|
+
)
|
2915
3241
|
|
2916
3242
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
2917
|
-
if hasattr(self,
|
3243
|
+
if hasattr(self, "_span_context_token"):
|
2918
3244
|
span_id = self._trace_client.get_current_span()
|
2919
3245
|
self._finalize_span(span_id)
|
2920
3246
|
self._trace_client.reset_current_span(self._span_context_token)
|
2921
|
-
delattr(self,
|
3247
|
+
delattr(self, "_span_context_token")
|
2922
3248
|
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2923
3249
|
|
2924
|
-
|
3250
|
+
|
3251
|
+
class _TracedSyncStreamManagerWrapper(
|
3252
|
+
_BaseStreamManagerWrapper, AbstractContextManager
|
3253
|
+
):
|
2925
3254
|
def __enter__(self):
|
2926
3255
|
self._parent_span_id_at_entry = self._trace_client.get_current_span()
|
2927
3256
|
if not self._trace_client:
|
@@ -2933,16 +3262,19 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
|
|
2933
3262
|
|
2934
3263
|
raw_iterator = self._original_manager.__enter__()
|
2935
3264
|
span.output = "<pending stream>"
|
2936
|
-
return self._stream_wrapper_func(
|
3265
|
+
return self._stream_wrapper_func(
|
3266
|
+
raw_iterator, self._client, span, self._trace_across_async_contexts
|
3267
|
+
)
|
2937
3268
|
|
2938
3269
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
2939
|
-
if hasattr(self,
|
3270
|
+
if hasattr(self, "_span_context_token"):
|
2940
3271
|
span_id = self._trace_client.get_current_span()
|
2941
3272
|
self._finalize_span(span_id)
|
2942
3273
|
self._trace_client.reset_current_span(self._span_context_token)
|
2943
|
-
delattr(self,
|
3274
|
+
delattr(self, "_span_context_token")
|
2944
3275
|
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
2945
3276
|
|
3277
|
+
|
2946
3278
|
# --- Helper function for instance-prefixed qual_name ---
|
2947
3279
|
def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
2948
3280
|
"""
|
@@ -2951,11 +3283,13 @@ def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
|
2951
3283
|
"""
|
2952
3284
|
if class_name in class_identifiers:
|
2953
3285
|
class_config = class_identifiers[class_name]
|
2954
|
-
attr = class_config[
|
2955
|
-
|
3286
|
+
attr = class_config["identifier"]
|
3287
|
+
|
2956
3288
|
if hasattr(instance, attr):
|
2957
3289
|
instance_name = getattr(instance, attr)
|
2958
3290
|
return instance_name
|
2959
3291
|
else:
|
2960
|
-
raise Exception(
|
2961
|
-
|
3292
|
+
raise Exception(
|
3293
|
+
f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator."
|
3294
|
+
)
|
3295
|
+
return None
|