judgeval 0.0.41__py3-none-any.whl → 0.0.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/s3_storage.py +3 -1
- judgeval/common/tracer.py +921 -103
- judgeval/common/utils.py +1 -1
- judgeval/constants.py +5 -0
- judgeval/data/trace.py +2 -1
- judgeval/integrations/langgraph.py +218 -34
- judgeval/rules.py +60 -50
- judgeval/run_evaluation.py +36 -26
- judgeval/utils/alerts.py +8 -0
- {judgeval-0.0.41.dist-info → judgeval-0.0.42.dist-info}/METADATA +35 -46
- {judgeval-0.0.41.dist-info → judgeval-0.0.42.dist-info}/RECORD +13 -13
- {judgeval-0.0.41.dist-info → judgeval-0.0.42.dist-info}/WHEEL +0 -0
- {judgeval-0.0.41.dist-info → judgeval-0.0.42.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -18,7 +18,7 @@ import sys
|
|
18
18
|
import json
|
19
19
|
from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
|
20
20
|
from dataclasses import dataclass, field
|
21
|
-
from datetime import datetime
|
21
|
+
from datetime import datetime, timezone
|
22
22
|
from http import HTTPStatus
|
23
23
|
from typing import (
|
24
24
|
Any,
|
@@ -49,12 +49,17 @@ from google import genai
|
|
49
49
|
from judgeval.constants import (
|
50
50
|
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
51
51
|
JUDGMENT_TRACES_SAVE_API_URL,
|
52
|
+
JUDGMENT_TRACES_UPSERT_API_URL,
|
53
|
+
JUDGMENT_TRACES_USAGE_CHECK_API_URL,
|
54
|
+
JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
|
52
55
|
JUDGMENT_TRACES_FETCH_API_URL,
|
53
56
|
RABBITMQ_HOST,
|
54
57
|
RABBITMQ_PORT,
|
55
58
|
RABBITMQ_QUEUE,
|
56
59
|
JUDGMENT_TRACES_DELETE_API_URL,
|
57
60
|
JUDGMENT_PROJECT_DELETE_API_URL,
|
61
|
+
JUDGMENT_TRACES_SPANS_BATCH_API_URL,
|
62
|
+
JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
|
58
63
|
)
|
59
64
|
from judgeval.data import Example, Trace, TraceSpan, TraceUsage
|
60
65
|
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
@@ -66,6 +71,8 @@ from judgeval.common.exceptions import JudgmentAPIError
|
|
66
71
|
# Standard library imports needed for the new class
|
67
72
|
import concurrent.futures
|
68
73
|
from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
|
74
|
+
import queue
|
75
|
+
import atexit
|
69
76
|
|
70
77
|
# Define context variables for tracking the current trace and the current span within a trace
|
71
78
|
current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
|
@@ -142,13 +149,18 @@ class TraceManagerClient:
|
|
142
149
|
|
143
150
|
return response.json()
|
144
151
|
|
145
|
-
def save_trace(self, trace_data: dict, offline_mode: bool = False):
|
152
|
+
def save_trace(self, trace_data: dict, offline_mode: bool = False, final_save: bool = True):
|
146
153
|
"""
|
147
154
|
Saves a trace to the Judgment Supabase and optionally to S3 if configured.
|
148
155
|
|
149
156
|
Args:
|
150
157
|
trace_data: The trace data to save
|
158
|
+
offline_mode: Whether running in offline mode
|
159
|
+
final_save: Whether this is the final save (controls S3 saving)
|
151
160
|
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
dict: Server response containing UI URL and other metadata
|
152
164
|
"""
|
153
165
|
# Save to Judgment API
|
154
166
|
|
@@ -170,7 +182,6 @@ class TraceManagerClient:
|
|
170
182
|
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
171
183
|
|
172
184
|
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
173
|
-
|
174
185
|
response = requests.post(
|
175
186
|
JUDGMENT_TRACES_SAVE_API_URL,
|
176
187
|
data=serialized_trace_data,
|
@@ -187,8 +198,11 @@ class TraceManagerClient:
|
|
187
198
|
elif response.status_code != HTTPStatus.OK:
|
188
199
|
raise ValueError(f"Failed to save trace data: {response.text}")
|
189
200
|
|
190
|
-
#
|
191
|
-
|
201
|
+
# Parse server response
|
202
|
+
server_response = response.json()
|
203
|
+
|
204
|
+
# If S3 storage is enabled, save to S3 only on final save
|
205
|
+
if self.tracer and self.tracer.use_s3 and final_save:
|
192
206
|
try:
|
193
207
|
s3_key = self.tracer.s3_storage.save_trace(
|
194
208
|
trace_data=trace_data,
|
@@ -199,9 +213,136 @@ class TraceManagerClient:
|
|
199
213
|
except Exception as e:
|
200
214
|
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
201
215
|
|
202
|
-
if not offline_mode and "ui_results_url" in
|
203
|
-
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={
|
216
|
+
if not offline_mode and "ui_results_url" in server_response:
|
217
|
+
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
204
218
|
rprint(pretty_str)
|
219
|
+
|
220
|
+
return server_response
|
221
|
+
|
222
|
+
def check_usage_limits(self, count: int = 1):
|
223
|
+
"""
|
224
|
+
Check if the organization can use the requested number of traces without exceeding limits.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
count: Number of traces to check for (default: 1)
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
dict: Server response with rate limit status and usage info
|
231
|
+
|
232
|
+
Raises:
|
233
|
+
ValueError: If rate limits would be exceeded or other errors occur
|
234
|
+
"""
|
235
|
+
response = requests.post(
|
236
|
+
JUDGMENT_TRACES_USAGE_CHECK_API_URL,
|
237
|
+
json={"count": count},
|
238
|
+
headers={
|
239
|
+
"Content-Type": "application/json",
|
240
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
241
|
+
"X-Organization-Id": self.organization_id
|
242
|
+
},
|
243
|
+
verify=True
|
244
|
+
)
|
245
|
+
|
246
|
+
if response.status_code == HTTPStatus.FORBIDDEN:
|
247
|
+
# Rate limits exceeded
|
248
|
+
error_data = response.json()
|
249
|
+
raise ValueError(f"Rate limit exceeded: {error_data.get('detail', 'Monthly trace limit reached')}")
|
250
|
+
elif response.status_code != HTTPStatus.OK:
|
251
|
+
raise ValueError(f"Failed to check usage limits: {response.text}")
|
252
|
+
|
253
|
+
return response.json()
|
254
|
+
|
255
|
+
def upsert_trace(self, trace_data: dict, offline_mode: bool = False, show_link: bool = True, final_save: bool = True):
|
256
|
+
"""
|
257
|
+
Upserts a trace to the Judgment API (always overwrites if exists).
|
258
|
+
|
259
|
+
Args:
|
260
|
+
trace_data: The trace data to upsert
|
261
|
+
offline_mode: Whether running in offline mode
|
262
|
+
show_link: Whether to show the UI link (for live tracing)
|
263
|
+
final_save: Whether this is the final save (controls S3 saving)
|
264
|
+
|
265
|
+
Returns:
|
266
|
+
dict: Server response containing UI URL and other metadata
|
267
|
+
"""
|
268
|
+
def fallback_encoder(obj):
|
269
|
+
"""
|
270
|
+
Custom JSON encoder fallback.
|
271
|
+
Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
|
272
|
+
"""
|
273
|
+
try:
|
274
|
+
return repr(obj)
|
275
|
+
except Exception:
|
276
|
+
try:
|
277
|
+
return str(obj)
|
278
|
+
except Exception as e:
|
279
|
+
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
280
|
+
|
281
|
+
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
282
|
+
|
283
|
+
response = requests.post(
|
284
|
+
JUDGMENT_TRACES_UPSERT_API_URL,
|
285
|
+
data=serialized_trace_data,
|
286
|
+
headers={
|
287
|
+
"Content-Type": "application/json",
|
288
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
289
|
+
"X-Organization-Id": self.organization_id
|
290
|
+
},
|
291
|
+
verify=True
|
292
|
+
)
|
293
|
+
|
294
|
+
if response.status_code != HTTPStatus.OK:
|
295
|
+
raise ValueError(f"Failed to upsert trace data: {response.text}")
|
296
|
+
|
297
|
+
# Parse server response
|
298
|
+
server_response = response.json()
|
299
|
+
|
300
|
+
# If S3 storage is enabled, save to S3 only on final save
|
301
|
+
if self.tracer and self.tracer.use_s3 and final_save:
|
302
|
+
try:
|
303
|
+
s3_key = self.tracer.s3_storage.save_trace(
|
304
|
+
trace_data=trace_data,
|
305
|
+
trace_id=trace_data["trace_id"],
|
306
|
+
project_name=trace_data["project_name"]
|
307
|
+
)
|
308
|
+
print(f"Trace also saved to S3 at key: {s3_key}")
|
309
|
+
except Exception as e:
|
310
|
+
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
311
|
+
|
312
|
+
if not offline_mode and show_link and "ui_results_url" in server_response:
|
313
|
+
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
314
|
+
rprint(pretty_str)
|
315
|
+
|
316
|
+
return server_response
|
317
|
+
|
318
|
+
def update_usage_counters(self, count: int = 1):
|
319
|
+
"""
|
320
|
+
Update trace usage counters after successfully saving traces.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
count: Number of traces to count (default: 1)
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
dict: Server response with updated usage information
|
327
|
+
|
328
|
+
Raises:
|
329
|
+
ValueError: If the update fails
|
330
|
+
"""
|
331
|
+
response = requests.post(
|
332
|
+
JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
|
333
|
+
json={"count": count},
|
334
|
+
headers={
|
335
|
+
"Content-Type": "application/json",
|
336
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
337
|
+
"X-Organization-Id": self.organization_id
|
338
|
+
},
|
339
|
+
verify=True
|
340
|
+
)
|
341
|
+
|
342
|
+
if response.status_code != HTTPStatus.OK:
|
343
|
+
raise ValueError(f"Failed to update usage counters: {response.text}")
|
344
|
+
|
345
|
+
return response.json()
|
205
346
|
|
206
347
|
## TODO: Should have a log endpoint, endpoint should also support batched payloads
|
207
348
|
def save_annotation(self, annotation: TraceAnnotation):
|
@@ -324,35 +465,48 @@ class TraceClient:
|
|
324
465
|
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
325
466
|
self.evaluation_runs: List[EvaluationRun] = []
|
326
467
|
self.annotations: List[TraceAnnotation] = []
|
327
|
-
self.start_time =
|
468
|
+
self.start_time = None # Will be set after first successful save
|
328
469
|
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
|
329
470
|
self.visited_nodes = []
|
330
471
|
self.executed_tools = []
|
331
472
|
self.executed_node_tools = []
|
332
473
|
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
474
|
+
|
475
|
+
# Get background span service from tracer
|
476
|
+
self.background_span_service = tracer.get_background_span_service() if tracer else None
|
333
477
|
|
334
478
|
def get_current_span(self):
|
335
479
|
"""Get the current span from the context var"""
|
336
|
-
return
|
480
|
+
return self.tracer.get_current_span()
|
337
481
|
|
338
482
|
def set_current_span(self, span: Any):
|
339
483
|
"""Set the current span from the context var"""
|
340
|
-
return
|
484
|
+
return self.tracer.set_current_span(span)
|
341
485
|
|
342
486
|
def reset_current_span(self, token: Any):
|
343
487
|
"""Reset the current span from the context var"""
|
344
|
-
|
488
|
+
self.tracer.reset_current_span(token)
|
345
489
|
|
346
490
|
@contextmanager
|
347
491
|
def span(self, name: str, span_type: SpanType = "span"):
|
348
492
|
"""Context manager for creating a trace span, managing the current span via contextvars"""
|
493
|
+
is_first_span = len(self.trace_spans) == 0
|
494
|
+
if is_first_span:
|
495
|
+
try:
|
496
|
+
trace_id, server_response = self.save_with_rate_limiting(overwrite=self.overwrite, final_save=False)
|
497
|
+
# Set start_time after first successful save
|
498
|
+
if self.start_time is None:
|
499
|
+
self.start_time = time.time()
|
500
|
+
# Link will be shown by upsert_trace method
|
501
|
+
except Exception as e:
|
502
|
+
warnings.warn(f"Failed to save initial trace for live tracking: {e}")
|
349
503
|
start_time = time.time()
|
350
504
|
|
351
505
|
# Generate a unique ID for *this specific span invocation*
|
352
506
|
span_id = str(uuid.uuid4())
|
353
507
|
|
354
|
-
parent_span_id =
|
355
|
-
token =
|
508
|
+
parent_span_id = self.get_current_span() # Get ID of the parent span from context var
|
509
|
+
token = self.set_current_span(span_id) # Set *this* span's ID as the current one
|
356
510
|
|
357
511
|
current_depth = 0
|
358
512
|
if parent_span_id and parent_span_id in self._span_depths:
|
@@ -372,16 +526,27 @@ class TraceClient:
|
|
372
526
|
)
|
373
527
|
self.add_span(span)
|
374
528
|
|
529
|
+
|
530
|
+
|
531
|
+
# Queue span with initial state (input phase)
|
532
|
+
if self.background_span_service:
|
533
|
+
self.background_span_service.queue_span(span, span_state="input")
|
534
|
+
|
375
535
|
try:
|
376
536
|
yield self
|
377
537
|
finally:
|
378
538
|
duration = time.time() - start_time
|
379
539
|
span.duration = duration
|
540
|
+
|
541
|
+
# Queue span with completed state (output phase)
|
542
|
+
if self.background_span_service:
|
543
|
+
self.background_span_service.queue_span(span, span_state="completed")
|
544
|
+
|
380
545
|
# Clean up depth tracking for this span_id
|
381
546
|
if span_id in self._span_depths:
|
382
547
|
del self._span_depths[span_id]
|
383
548
|
# Reset context var
|
384
|
-
|
549
|
+
self.reset_current_span(token)
|
385
550
|
|
386
551
|
def async_evaluate(
|
387
552
|
self,
|
@@ -445,8 +610,7 @@ class TraceClient:
|
|
445
610
|
# span_id_at_eval_call = current_span_var.get()
|
446
611
|
# print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
|
447
612
|
# Prioritize explicitly passed span_id, fallback to context var
|
448
|
-
|
449
|
-
span_id_to_use = span_id if span_id is not None else current_span_ctx_var if current_span_ctx_var is not None else self.tracer.get_current_span()
|
613
|
+
span_id_to_use = span_id if span_id is not None else self.get_current_span()
|
450
614
|
# print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
|
451
615
|
# --- End Modification ---
|
452
616
|
|
@@ -469,6 +633,17 @@ class TraceClient:
|
|
469
633
|
)
|
470
634
|
|
471
635
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
636
|
+
|
637
|
+
# Queue evaluation run through background service
|
638
|
+
if self.background_span_service and span_id_to_use:
|
639
|
+
# Get the current span data to avoid race conditions
|
640
|
+
current_span = self.span_id_to_span.get(span_id_to_use)
|
641
|
+
if current_span:
|
642
|
+
self.background_span_service.queue_evaluation_run(
|
643
|
+
eval_run,
|
644
|
+
span_id=span_id_to_use,
|
645
|
+
span_data=current_span
|
646
|
+
)
|
472
647
|
|
473
648
|
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
474
649
|
# --- Modification: Use span_id from eval_run ---
|
@@ -488,19 +663,27 @@ class TraceClient:
|
|
488
663
|
return self
|
489
664
|
|
490
665
|
def record_input(self, inputs: dict):
|
491
|
-
current_span_id =
|
666
|
+
current_span_id = self.get_current_span()
|
492
667
|
if current_span_id:
|
493
668
|
span = self.span_id_to_span[current_span_id]
|
494
669
|
# Ignore self parameter
|
495
670
|
if "self" in inputs:
|
496
671
|
del inputs["self"]
|
497
672
|
span.inputs = inputs
|
673
|
+
|
674
|
+
# Queue span with input data
|
675
|
+
if self.background_span_service:
|
676
|
+
self.background_span_service.queue_span(span, span_state="input")
|
498
677
|
|
499
678
|
def record_agent_name(self, agent_name: str):
|
500
|
-
current_span_id =
|
679
|
+
current_span_id = self.get_current_span()
|
501
680
|
if current_span_id:
|
502
681
|
span = self.span_id_to_span[current_span_id]
|
503
682
|
span.agent_name = agent_name
|
683
|
+
|
684
|
+
# Queue span with agent_name data
|
685
|
+
if self.background_span_service:
|
686
|
+
self.background_span_service.queue_span(span, span_state="agent_name")
|
504
687
|
|
505
688
|
def record_state_before(self, state: dict):
|
506
689
|
"""Records the agent's state before a tool execution on the current span.
|
@@ -508,60 +691,91 @@ class TraceClient:
|
|
508
691
|
Args:
|
509
692
|
state: A dictionary representing the agent's state.
|
510
693
|
"""
|
511
|
-
current_span_id =
|
694
|
+
current_span_id = self.get_current_span()
|
512
695
|
if current_span_id:
|
513
696
|
span = self.span_id_to_span[current_span_id]
|
514
697
|
span.state_before = state
|
515
698
|
|
699
|
+
# Queue span with state_before data
|
700
|
+
if self.background_span_service:
|
701
|
+
self.background_span_service.queue_span(span, span_state="state_before")
|
702
|
+
|
516
703
|
def record_state_after(self, state: dict):
|
517
704
|
"""Records the agent's state after a tool execution on the current span.
|
518
705
|
|
519
706
|
Args:
|
520
707
|
state: A dictionary representing the agent's state.
|
521
708
|
"""
|
522
|
-
current_span_id =
|
709
|
+
current_span_id = self.get_current_span()
|
523
710
|
if current_span_id:
|
524
711
|
span = self.span_id_to_span[current_span_id]
|
525
712
|
span.state_after = state
|
713
|
+
|
714
|
+
# Queue span with state_after data
|
715
|
+
if self.background_span_service:
|
716
|
+
self.background_span_service.queue_span(span, span_state="state_after")
|
526
717
|
|
527
718
|
async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
|
528
719
|
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
529
720
|
try:
|
530
721
|
result = await coroutine
|
531
722
|
setattr(span, field, result)
|
723
|
+
|
724
|
+
# Queue span with output data now that coroutine is complete
|
725
|
+
if self.background_span_service and field == "output":
|
726
|
+
self.background_span_service.queue_span(span, span_state="output")
|
727
|
+
|
532
728
|
return result
|
533
729
|
except Exception as e:
|
534
730
|
setattr(span, field, f"Error: {str(e)}")
|
731
|
+
|
732
|
+
# Queue span even if there was an error
|
733
|
+
if self.background_span_service and field == "output":
|
734
|
+
self.background_span_service.queue_span(span, span_state="output")
|
735
|
+
|
535
736
|
raise
|
536
737
|
|
537
738
|
def record_output(self, output: Any):
|
538
|
-
current_span_id =
|
739
|
+
current_span_id = self.get_current_span()
|
539
740
|
if current_span_id:
|
540
741
|
span = self.span_id_to_span[current_span_id]
|
541
742
|
span.output = "<pending>" if inspect.iscoroutine(output) else output
|
542
743
|
|
543
744
|
if inspect.iscoroutine(output):
|
544
745
|
asyncio.create_task(self._update_coroutine(span, output, "output"))
|
746
|
+
|
747
|
+
# # Queue span with output data (unless it's pending)
|
748
|
+
if self.background_span_service and not inspect.iscoroutine(output):
|
749
|
+
self.background_span_service.queue_span(span, span_state="output")
|
545
750
|
|
546
751
|
return span # Return the created entry
|
547
752
|
# Removed else block - original didn't have one
|
548
753
|
return None # Return None if no span_id found
|
549
754
|
|
550
755
|
def record_usage(self, usage: TraceUsage):
|
551
|
-
current_span_id =
|
756
|
+
current_span_id = self.get_current_span()
|
552
757
|
if current_span_id:
|
553
758
|
span = self.span_id_to_span[current_span_id]
|
554
759
|
span.usage = usage
|
555
760
|
|
761
|
+
# Queue span with usage data
|
762
|
+
if self.background_span_service:
|
763
|
+
self.background_span_service.queue_span(span, span_state="usage")
|
764
|
+
|
556
765
|
return span # Return the created entry
|
557
766
|
# Removed else block - original didn't have one
|
558
767
|
return None # Return None if no span_id found
|
559
768
|
|
560
769
|
def record_error(self, error: Dict[str, Any]):
|
561
|
-
current_span_id =
|
770
|
+
current_span_id = self.get_current_span()
|
562
771
|
if current_span_id:
|
563
772
|
span = self.span_id_to_span[current_span_id]
|
564
773
|
span.error = error
|
774
|
+
|
775
|
+
# Queue span with error data
|
776
|
+
if self.background_span_service:
|
777
|
+
self.background_span_service.queue_span(span, span_state="error")
|
778
|
+
|
565
779
|
return span
|
566
780
|
return None
|
567
781
|
|
@@ -580,13 +794,19 @@ class TraceClient:
|
|
580
794
|
"""
|
581
795
|
Get the total duration of this trace
|
582
796
|
"""
|
797
|
+
if self.start_time is None:
|
798
|
+
return 0.0 # No duration if trace hasn't been saved yet
|
583
799
|
return time.time() - self.start_time
|
584
800
|
|
585
801
|
def save(self, overwrite: bool = False) -> Tuple[str, dict]:
|
586
802
|
"""
|
587
803
|
Save the current trace to the database.
|
588
|
-
Returns a tuple of (trace_id,
|
804
|
+
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
589
805
|
"""
|
806
|
+
# Set start_time if this is the first save
|
807
|
+
if self.start_time is None:
|
808
|
+
self.start_time = time.time()
|
809
|
+
|
590
810
|
# Calculate total elapsed time
|
591
811
|
total_duration = self.get_duration()
|
592
812
|
# Create trace document - Always use standard keys for top-level counts
|
@@ -594,7 +814,7 @@ class TraceClient:
|
|
594
814
|
"trace_id": self.trace_id,
|
595
815
|
"name": self.name,
|
596
816
|
"project_name": self.project_name,
|
597
|
-
"created_at": datetime.
|
817
|
+
"created_at": datetime.fromtimestamp(self.start_time, timezone.utc).isoformat(),
|
598
818
|
"duration": total_duration,
|
599
819
|
"trace_spans": [span.model_dump() for span in self.trace_spans],
|
600
820
|
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
@@ -604,14 +824,79 @@ class TraceClient:
|
|
604
824
|
"parent_name": self.parent_name
|
605
825
|
}
|
606
826
|
# --- Log trace data before saving ---
|
607
|
-
self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode)
|
827
|
+
server_response = self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode, final_save=True)
|
608
828
|
|
609
829
|
# upload annotations
|
610
830
|
# TODO: batch to the log endpoint
|
611
831
|
for annotation in self.annotations:
|
612
832
|
self.trace_manager_client.save_annotation(annotation)
|
613
833
|
|
614
|
-
return self.trace_id,
|
834
|
+
return self.trace_id, server_response
|
835
|
+
|
836
|
+
def save_with_rate_limiting(self, overwrite: bool = False, final_save: bool = False) -> Tuple[str, dict]:
|
837
|
+
"""
|
838
|
+
Save the current trace to the database with rate limiting checks.
|
839
|
+
First checks usage limits, then upserts the trace if allowed.
|
840
|
+
|
841
|
+
Args:
|
842
|
+
overwrite: Whether to overwrite existing traces
|
843
|
+
final_save: Whether this is the final save (updates usage counters)
|
844
|
+
|
845
|
+
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
846
|
+
"""
|
847
|
+
|
848
|
+
|
849
|
+
# Calculate total elapsed time
|
850
|
+
total_duration = self.get_duration()
|
851
|
+
|
852
|
+
# Create trace document
|
853
|
+
trace_data = {
|
854
|
+
"trace_id": self.trace_id,
|
855
|
+
"name": self.name,
|
856
|
+
"project_name": self.project_name,
|
857
|
+
"created_at": datetime.utcfromtimestamp(time.time()).isoformat(),
|
858
|
+
"duration": total_duration,
|
859
|
+
"trace_spans": [span.model_dump() for span in self.trace_spans],
|
860
|
+
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
861
|
+
"overwrite": overwrite,
|
862
|
+
"offline_mode": self.tracer.offline_mode,
|
863
|
+
"parent_trace_id": self.parent_trace_id,
|
864
|
+
"parent_name": self.parent_name
|
865
|
+
}
|
866
|
+
|
867
|
+
# Check usage limits first
|
868
|
+
try:
|
869
|
+
usage_check_result = self.trace_manager_client.check_usage_limits(count=1)
|
870
|
+
# Usage check passed silently - no need to show detailed info
|
871
|
+
except ValueError as e:
|
872
|
+
# Rate limit exceeded
|
873
|
+
warnings.warn(f"Rate limit check failed for live tracing: {e}")
|
874
|
+
raise e
|
875
|
+
|
876
|
+
# If usage check passes, upsert the trace
|
877
|
+
server_response = self.trace_manager_client.upsert_trace(
|
878
|
+
trace_data,
|
879
|
+
offline_mode=self.tracer.offline_mode,
|
880
|
+
show_link=not final_save, # Show link only on initial save, not final save
|
881
|
+
final_save=final_save # Pass final_save to control S3 saving
|
882
|
+
)
|
883
|
+
|
884
|
+
# Update usage counters only on final save
|
885
|
+
if final_save:
|
886
|
+
try:
|
887
|
+
usage_update_result = self.trace_manager_client.update_usage_counters(count=1)
|
888
|
+
# Usage updated silently - no need to show detailed usage info
|
889
|
+
except ValueError as e:
|
890
|
+
# Log warning but don't fail the trace save since the trace was already saved
|
891
|
+
warnings.warn(f"Usage counter update failed (trace was still saved): {e}")
|
892
|
+
|
893
|
+
# Upload annotations
|
894
|
+
# TODO: batch to the log endpoint
|
895
|
+
for annotation in self.annotations:
|
896
|
+
self.trace_manager_client.save_annotation(annotation)
|
897
|
+
if self.start_time is None:
|
898
|
+
self.start_time = time.time()
|
899
|
+
return self.trace_id, server_response
|
615
900
|
|
616
901
|
def delete(self):
|
617
902
|
return self.trace_manager_client.delete_trace(self.trace_id)
|
@@ -648,6 +933,338 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
|
|
648
933
|
pass
|
649
934
|
|
650
935
|
current_trace.record_error(formatted_exception)
|
936
|
+
|
937
|
+
# Queue the span with error state through background service
|
938
|
+
if current_trace.background_span_service:
|
939
|
+
current_span_id = current_trace.get_current_span()
|
940
|
+
if current_span_id and current_span_id in current_trace.span_id_to_span:
|
941
|
+
error_span = current_trace.span_id_to_span[current_span_id]
|
942
|
+
current_trace.background_span_service.queue_span(error_span, span_state="error")
|
943
|
+
|
944
|
+
class BackgroundSpanService:
|
945
|
+
"""
|
946
|
+
Background service for queueing and batching trace spans for efficient saving.
|
947
|
+
|
948
|
+
This service:
|
949
|
+
- Queues spans as they complete
|
950
|
+
- Batches them for efficient network usage
|
951
|
+
- Sends spans periodically or when batches reach a certain size
|
952
|
+
- Handles automatic flushing when the main event terminates
|
953
|
+
"""
|
954
|
+
|
955
|
+
def __init__(self, judgment_api_key: str, organization_id: str,
|
956
|
+
batch_size: int = 10, flush_interval: float = 5.0, num_workers: int = 1):
|
957
|
+
"""
|
958
|
+
Initialize the background span service.
|
959
|
+
|
960
|
+
Args:
|
961
|
+
judgment_api_key: API key for Judgment service
|
962
|
+
organization_id: Organization ID
|
963
|
+
batch_size: Number of spans to batch before sending (default: 10)
|
964
|
+
flush_interval: Time in seconds between automatic flushes (default: 5.0)
|
965
|
+
num_workers: Number of worker threads to process the queue (default: 1)
|
966
|
+
"""
|
967
|
+
self.judgment_api_key = judgment_api_key
|
968
|
+
self.organization_id = organization_id
|
969
|
+
self.batch_size = batch_size
|
970
|
+
self.flush_interval = flush_interval
|
971
|
+
self.num_workers = max(1, num_workers) # Ensure at least 1 worker
|
972
|
+
|
973
|
+
# Queue for pending spans
|
974
|
+
self._span_queue = queue.Queue()
|
975
|
+
|
976
|
+
# Background threads for processing spans
|
977
|
+
self._worker_threads = []
|
978
|
+
self._shutdown_event = threading.Event()
|
979
|
+
|
980
|
+
# Track spans that have been sent
|
981
|
+
self._sent_spans = set()
|
982
|
+
|
983
|
+
# Register cleanup on exit
|
984
|
+
atexit.register(self.shutdown)
|
985
|
+
|
986
|
+
# Start the background workers
|
987
|
+
self._start_workers()
|
988
|
+
|
989
|
+
def _start_workers(self):
|
990
|
+
"""Start the background worker threads."""
|
991
|
+
for i in range(self.num_workers):
|
992
|
+
if len(self._worker_threads) < self.num_workers:
|
993
|
+
worker_thread = threading.Thread(
|
994
|
+
target=self._worker_loop,
|
995
|
+
daemon=True,
|
996
|
+
name=f"SpanWorker-{i+1}"
|
997
|
+
)
|
998
|
+
worker_thread.start()
|
999
|
+
self._worker_threads.append(worker_thread)
|
1000
|
+
|
1001
|
+
def _worker_loop(self):
|
1002
|
+
"""Main worker loop that processes spans in batches."""
|
1003
|
+
batch = []
|
1004
|
+
last_flush_time = time.time()
|
1005
|
+
pending_task_count = 0 # Track how many tasks we've taken from queue but not marked done
|
1006
|
+
|
1007
|
+
while not self._shutdown_event.is_set() or self._span_queue.qsize() > 0:
|
1008
|
+
try:
|
1009
|
+
# First, do a blocking get to wait for at least one item
|
1010
|
+
if not batch: # Only block if we don't have items already
|
1011
|
+
try:
|
1012
|
+
span_data = self._span_queue.get(timeout=1.0)
|
1013
|
+
batch.append(span_data)
|
1014
|
+
pending_task_count += 1
|
1015
|
+
except queue.Empty:
|
1016
|
+
# No new spans, continue to check for flush conditions
|
1017
|
+
pass
|
1018
|
+
|
1019
|
+
# Then, do non-blocking gets to drain any additional available items
|
1020
|
+
# up to our batch size limit
|
1021
|
+
while len(batch) < self.batch_size:
|
1022
|
+
try:
|
1023
|
+
span_data = self._span_queue.get_nowait() # Non-blocking
|
1024
|
+
batch.append(span_data)
|
1025
|
+
pending_task_count += 1
|
1026
|
+
except queue.Empty:
|
1027
|
+
# No more items immediately available
|
1028
|
+
break
|
1029
|
+
|
1030
|
+
current_time = time.time()
|
1031
|
+
should_flush = (
|
1032
|
+
len(batch) >= self.batch_size or
|
1033
|
+
(batch and (current_time - last_flush_time) >= self.flush_interval)
|
1034
|
+
)
|
1035
|
+
|
1036
|
+
if should_flush and batch:
|
1037
|
+
self._send_batch(batch)
|
1038
|
+
|
1039
|
+
# Only mark tasks as done after successful sending
|
1040
|
+
for _ in range(pending_task_count):
|
1041
|
+
self._span_queue.task_done()
|
1042
|
+
pending_task_count = 0 # Reset counter
|
1043
|
+
|
1044
|
+
batch.clear()
|
1045
|
+
last_flush_time = current_time
|
1046
|
+
|
1047
|
+
except Exception as e:
|
1048
|
+
warnings.warn(f"Error in span service worker loop: {e}")
|
1049
|
+
# On error, still need to mark tasks as done to prevent hanging
|
1050
|
+
for _ in range(pending_task_count):
|
1051
|
+
self._span_queue.task_done()
|
1052
|
+
pending_task_count = 0
|
1053
|
+
batch.clear()
|
1054
|
+
|
1055
|
+
# Final flush on shutdown
|
1056
|
+
if batch:
|
1057
|
+
self._send_batch(batch)
|
1058
|
+
# Mark remaining tasks as done
|
1059
|
+
for _ in range(pending_task_count):
|
1060
|
+
self._span_queue.task_done()
|
1061
|
+
|
1062
|
+
def _send_batch(self, batch: List[Dict[str, Any]]):
|
1063
|
+
"""
|
1064
|
+
Send a batch of spans to the server.
|
1065
|
+
|
1066
|
+
Args:
|
1067
|
+
batch: List of span dictionaries to send
|
1068
|
+
"""
|
1069
|
+
if not batch:
|
1070
|
+
return
|
1071
|
+
|
1072
|
+
try:
|
1073
|
+
# Group spans by type for different endpoints
|
1074
|
+
spans_to_send = []
|
1075
|
+
evaluation_runs_to_send = []
|
1076
|
+
|
1077
|
+
for item in batch:
|
1078
|
+
if item['type'] == 'span':
|
1079
|
+
spans_to_send.append(item['data'])
|
1080
|
+
elif item['type'] == 'evaluation_run':
|
1081
|
+
evaluation_runs_to_send.append(item['data'])
|
1082
|
+
|
1083
|
+
# Send spans if any
|
1084
|
+
if spans_to_send:
|
1085
|
+
self._send_spans_batch(spans_to_send)
|
1086
|
+
|
1087
|
+
# Send evaluation runs if any
|
1088
|
+
if evaluation_runs_to_send:
|
1089
|
+
self._send_evaluation_runs_batch(evaluation_runs_to_send)
|
1090
|
+
|
1091
|
+
except Exception as e:
|
1092
|
+
warnings.warn(f"Failed to send span batch: {e}")
|
1093
|
+
|
1094
|
+
def _send_spans_batch(self, spans: List[Dict[str, Any]]):
|
1095
|
+
"""Send a batch of spans to the spans endpoint."""
|
1096
|
+
payload = {
|
1097
|
+
"spans": spans,
|
1098
|
+
"organization_id": self.organization_id
|
1099
|
+
}
|
1100
|
+
|
1101
|
+
# Serialize with fallback encoder
|
1102
|
+
def fallback_encoder(obj):
|
1103
|
+
try:
|
1104
|
+
return repr(obj)
|
1105
|
+
except Exception:
|
1106
|
+
try:
|
1107
|
+
return str(obj)
|
1108
|
+
except Exception as e:
|
1109
|
+
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
1110
|
+
|
1111
|
+
try:
|
1112
|
+
serialized_data = json.dumps(payload, default=fallback_encoder)
|
1113
|
+
|
1114
|
+
# Send the actual HTTP request to the batch endpoint
|
1115
|
+
response = requests.post(
|
1116
|
+
JUDGMENT_TRACES_SPANS_BATCH_API_URL,
|
1117
|
+
data=serialized_data,
|
1118
|
+
headers={
|
1119
|
+
"Content-Type": "application/json",
|
1120
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
1121
|
+
"X-Organization-Id": self.organization_id
|
1122
|
+
},
|
1123
|
+
verify=True,
|
1124
|
+
timeout=30 # Add timeout to prevent hanging
|
1125
|
+
)
|
1126
|
+
|
1127
|
+
if response.status_code != HTTPStatus.OK:
|
1128
|
+
warnings.warn(f"Failed to send spans batch: HTTP {response.status_code} - {response.text}")
|
1129
|
+
|
1130
|
+
|
1131
|
+
except requests.RequestException as e:
|
1132
|
+
warnings.warn(f"Network error sending spans batch: {e}")
|
1133
|
+
except Exception as e:
|
1134
|
+
warnings.warn(f"Failed to serialize or send spans batch: {e}")
|
1135
|
+
|
1136
|
+
def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
|
1137
|
+
"""Send a batch of evaluation runs with their associated span data to the endpoint."""
|
1138
|
+
# Structure payload to include both evaluation run data and span data
|
1139
|
+
evaluation_entries = []
|
1140
|
+
for eval_data in evaluation_runs:
|
1141
|
+
# eval_data already contains the evaluation run data (no need to access ['data'])
|
1142
|
+
entry = {
|
1143
|
+
"evaluation_run": {
|
1144
|
+
# Extract evaluation run fields (excluding span-specific fields)
|
1145
|
+
key: value for key, value in eval_data.items()
|
1146
|
+
if key not in ['associated_span_id', 'span_data', 'queued_at']
|
1147
|
+
},
|
1148
|
+
"associated_span": {
|
1149
|
+
"span_id": eval_data.get('associated_span_id'),
|
1150
|
+
"span_data": eval_data.get('span_data')
|
1151
|
+
},
|
1152
|
+
"queued_at": eval_data.get('queued_at')
|
1153
|
+
}
|
1154
|
+
evaluation_entries.append(entry)
|
1155
|
+
|
1156
|
+
payload = {
|
1157
|
+
"organization_id": self.organization_id,
|
1158
|
+
"evaluation_entries": evaluation_entries # Each entry contains both eval run + span data
|
1159
|
+
}
|
1160
|
+
|
1161
|
+
# Serialize with fallback encoder
|
1162
|
+
def fallback_encoder(obj):
|
1163
|
+
try:
|
1164
|
+
return repr(obj)
|
1165
|
+
except Exception:
|
1166
|
+
try:
|
1167
|
+
return str(obj)
|
1168
|
+
except Exception as e:
|
1169
|
+
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
1170
|
+
|
1171
|
+
try:
|
1172
|
+
serialized_data = json.dumps(payload, default=fallback_encoder)
|
1173
|
+
|
1174
|
+
# Send the actual HTTP request to the batch endpoint
|
1175
|
+
response = requests.post(
|
1176
|
+
JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
|
1177
|
+
data=serialized_data,
|
1178
|
+
headers={
|
1179
|
+
"Content-Type": "application/json",
|
1180
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
1181
|
+
"X-Organization-Id": self.organization_id
|
1182
|
+
},
|
1183
|
+
verify=True,
|
1184
|
+
timeout=30 # Add timeout to prevent hanging
|
1185
|
+
)
|
1186
|
+
|
1187
|
+
if response.status_code != HTTPStatus.OK:
|
1188
|
+
warnings.warn(f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}")
|
1189
|
+
|
1190
|
+
|
1191
|
+
except requests.RequestException as e:
|
1192
|
+
warnings.warn(f"Network error sending evaluation runs batch: {e}")
|
1193
|
+
except Exception as e:
|
1194
|
+
warnings.warn(f"Failed to send evaluation runs batch: {e}")
|
1195
|
+
|
1196
|
+
def queue_span(self, span: TraceSpan, span_state: str = "input"):
|
1197
|
+
"""
|
1198
|
+
Queue a span for background sending.
|
1199
|
+
|
1200
|
+
Args:
|
1201
|
+
span: The TraceSpan object to queue
|
1202
|
+
span_state: State of the span ("input", "output", "completed")
|
1203
|
+
"""
|
1204
|
+
if not self._shutdown_event.is_set():
|
1205
|
+
span_data = {
|
1206
|
+
"type": "span",
|
1207
|
+
"data": {
|
1208
|
+
**span.model_dump(),
|
1209
|
+
"span_state": span_state,
|
1210
|
+
"queued_at": time.time()
|
1211
|
+
}
|
1212
|
+
}
|
1213
|
+
self._span_queue.put(span_data)
|
1214
|
+
|
1215
|
+
def queue_evaluation_run(self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan):
|
1216
|
+
"""
|
1217
|
+
Queue an evaluation run for background sending.
|
1218
|
+
|
1219
|
+
Args:
|
1220
|
+
evaluation_run: The EvaluationRun object to queue
|
1221
|
+
span_id: The span ID associated with this evaluation run
|
1222
|
+
span_data: The span data at the time of evaluation (to avoid race conditions)
|
1223
|
+
"""
|
1224
|
+
if not self._shutdown_event.is_set():
|
1225
|
+
eval_data = {
|
1226
|
+
"type": "evaluation_run",
|
1227
|
+
"data": {
|
1228
|
+
**evaluation_run.model_dump(),
|
1229
|
+
"associated_span_id": span_id,
|
1230
|
+
"span_data": span_data.model_dump(), # Include span data to avoid race conditions
|
1231
|
+
"queued_at": time.time()
|
1232
|
+
}
|
1233
|
+
}
|
1234
|
+
self._span_queue.put(eval_data)
|
1235
|
+
|
1236
|
+
def flush(self):
|
1237
|
+
"""Force immediate sending of all queued spans."""
|
1238
|
+
try:
|
1239
|
+
# Wait for the queue to be processed
|
1240
|
+
self._span_queue.join()
|
1241
|
+
except Exception as e:
|
1242
|
+
warnings.warn(f"Error during flush: {e}")
|
1243
|
+
|
1244
|
+
def shutdown(self):
|
1245
|
+
"""Shutdown the background service and flush remaining spans."""
|
1246
|
+
if self._shutdown_event.is_set():
|
1247
|
+
return
|
1248
|
+
|
1249
|
+
try:
|
1250
|
+
# Signal shutdown to stop new items from being queued
|
1251
|
+
self._shutdown_event.set()
|
1252
|
+
|
1253
|
+
# Try to flush any remaining spans
|
1254
|
+
try:
|
1255
|
+
self.flush()
|
1256
|
+
except Exception as e:
|
1257
|
+
warnings.warn(f"Error during final flush: {e}")
|
1258
|
+
except Exception as e:
|
1259
|
+
warnings.warn(f"Error during BackgroundSpanService shutdown: {e}")
|
1260
|
+
finally:
|
1261
|
+
# Clear the worker threads list (daemon threads will be killed automatically)
|
1262
|
+
self._worker_threads.clear()
|
1263
|
+
|
1264
|
+
def get_queue_size(self) -> int:
|
1265
|
+
"""Get the current size of the span queue."""
|
1266
|
+
return self._span_queue.qsize()
|
1267
|
+
|
651
1268
|
class _DeepTracer:
|
652
1269
|
_instance: Optional["_DeepTracer"] = None
|
653
1270
|
_lock: threading.Lock = threading.Lock()
|
@@ -657,6 +1274,9 @@ class _DeepTracer:
|
|
657
1274
|
_original_sys_trace: Optional[Callable] = None
|
658
1275
|
_original_threading_trace: Optional[Callable] = None
|
659
1276
|
|
1277
|
+
def __init__(self, tracer: 'Tracer'):
|
1278
|
+
self._tracer = tracer
|
1279
|
+
|
660
1280
|
def _get_qual_name(self, frame) -> str:
|
661
1281
|
func_name = frame.f_code.co_name
|
662
1282
|
module_name = frame.f_globals.get("__name__", "unknown_module")
|
@@ -670,7 +1290,7 @@ class _DeepTracer:
|
|
670
1290
|
except Exception:
|
671
1291
|
return f"{module_name}.{func_name}"
|
672
1292
|
|
673
|
-
def __new__(cls):
|
1293
|
+
def __new__(cls, tracer: 'Tracer' = None):
|
674
1294
|
with cls._lock:
|
675
1295
|
if cls._instance is None:
|
676
1296
|
cls._instance = super().__new__(cls)
|
@@ -756,11 +1376,11 @@ class _DeepTracer:
|
|
756
1376
|
if event not in ("call", "return", "exception"):
|
757
1377
|
return
|
758
1378
|
|
759
|
-
current_trace =
|
1379
|
+
current_trace = self._tracer.get_current_trace()
|
760
1380
|
if not current_trace:
|
761
1381
|
return
|
762
1382
|
|
763
|
-
parent_span_id =
|
1383
|
+
parent_span_id = self._tracer.get_current_span()
|
764
1384
|
if not parent_span_id:
|
765
1385
|
return
|
766
1386
|
|
@@ -822,7 +1442,7 @@ class _DeepTracer:
|
|
822
1442
|
})
|
823
1443
|
self._span_stack.set(span_stack)
|
824
1444
|
|
825
|
-
token =
|
1445
|
+
token = self._tracer.set_current_span(span_id)
|
826
1446
|
frame.f_locals["_judgment_span_token"] = token
|
827
1447
|
|
828
1448
|
span = TraceSpan(
|
@@ -856,7 +1476,7 @@ class _DeepTracer:
|
|
856
1476
|
if not span_stack:
|
857
1477
|
return
|
858
1478
|
|
859
|
-
current_id =
|
1479
|
+
current_id = self._tracer.get_current_span()
|
860
1480
|
|
861
1481
|
span_data = None
|
862
1482
|
for i, entry in enumerate(reversed(span_stack)):
|
@@ -881,12 +1501,12 @@ class _DeepTracer:
|
|
881
1501
|
del current_trace._span_depths[span_data["span_id"]]
|
882
1502
|
|
883
1503
|
if span_stack:
|
884
|
-
|
1504
|
+
self._tracer.set_current_span(span_stack[-1]["span_id"])
|
885
1505
|
else:
|
886
|
-
|
1506
|
+
self._tracer.set_current_span(span_data["parent_span_id"])
|
887
1507
|
|
888
1508
|
if "_judgment_span_token" in frame.f_locals:
|
889
|
-
|
1509
|
+
self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
|
890
1510
|
|
891
1511
|
elif event == "exception":
|
892
1512
|
exc_type = arg[0]
|
@@ -925,18 +1545,28 @@ class _DeepTracer:
|
|
925
1545
|
self._original_threading_trace = None
|
926
1546
|
|
927
1547
|
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
1548
|
+
# Below commented out function isn't used anymore?
|
1549
|
+
|
1550
|
+
# def log(self, message: str, level: str = "info"):
|
1551
|
+
# """ Log a message with the span context """
|
1552
|
+
# current_trace = self._tracer.get_current_trace()
|
1553
|
+
# if current_trace:
|
1554
|
+
# current_trace.log(message, level)
|
1555
|
+
# else:
|
1556
|
+
# print(f"[{level}] {message}")
|
1557
|
+
# current_trace.record_output({"log": message})
|
936
1558
|
|
937
1559
|
class Tracer:
|
938
1560
|
_instance = None
|
939
1561
|
|
1562
|
+
# Tracer.current_trace class variable is currently used in wrap()
|
1563
|
+
# TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
|
1564
|
+
# Should be fine to do so as long as we keep Tracer as a singleton
|
1565
|
+
current_trace: Optional[TraceClient] = None
|
1566
|
+
# current_span_id: Optional[str] = None
|
1567
|
+
|
1568
|
+
trace_across_async_contexts: bool = False # BY default, we don't trace across async contexts
|
1569
|
+
|
940
1570
|
def __new__(cls, *args, **kwargs):
|
941
1571
|
if cls._instance is None:
|
942
1572
|
cls._instance = super(Tracer, cls).__new__(cls)
|
@@ -957,7 +1587,13 @@ class Tracer:
|
|
957
1587
|
s3_aws_secret_access_key: Optional[str] = None,
|
958
1588
|
s3_region_name: Optional[str] = None,
|
959
1589
|
offline_mode: bool = False,
|
960
|
-
deep_tracing: bool = True # Deep tracing is enabled by default
|
1590
|
+
deep_tracing: bool = True, # Deep tracing is enabled by default
|
1591
|
+
trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
|
1592
|
+
# Background span service configuration
|
1593
|
+
enable_background_spans: bool = True, # Enable background span service by default
|
1594
|
+
span_batch_size: int = 50, # Number of spans to batch before sending
|
1595
|
+
span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
|
1596
|
+
span_num_workers: int = 10 # Number of worker threads for span processing
|
961
1597
|
):
|
962
1598
|
if not hasattr(self, 'initialized'):
|
963
1599
|
if not api_key:
|
@@ -975,14 +1611,18 @@ class Tracer:
|
|
975
1611
|
self.api_key: str = api_key
|
976
1612
|
self.project_name: str = project_name or str(uuid.uuid4())
|
977
1613
|
self.organization_id: str = organization_id
|
978
|
-
self._current_trace: Optional[str] = None
|
979
|
-
self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
|
980
1614
|
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
981
1615
|
self.traces: List[Trace] = []
|
982
1616
|
self.initialized: bool = True
|
983
1617
|
self.enable_monitoring: bool = enable_monitoring
|
984
1618
|
self.enable_evaluations: bool = enable_evaluations
|
985
1619
|
self.class_identifiers: Dict[str, str] = {} # Dictionary to store class identifiers
|
1620
|
+
self.span_id_to_previous_span_id: Dict[str, str] = {}
|
1621
|
+
self.trace_id_to_previous_trace: Dict[str, TraceClient] = {}
|
1622
|
+
self.current_span_id: Optional[str] = None
|
1623
|
+
self.current_trace: Optional[TraceClient] = None
|
1624
|
+
self.trace_across_async_contexts: bool = trace_across_async_contexts
|
1625
|
+
Tracer.trace_across_async_contexts = trace_across_async_contexts
|
986
1626
|
|
987
1627
|
# Initialize S3 storage if enabled
|
988
1628
|
self.use_s3 = use_s3
|
@@ -996,6 +1636,18 @@ class Tracer:
|
|
996
1636
|
)
|
997
1637
|
self.offline_mode: bool = offline_mode
|
998
1638
|
self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
|
1639
|
+
|
1640
|
+
# Initialize background span service
|
1641
|
+
self.enable_background_spans: bool = enable_background_spans
|
1642
|
+
self.background_span_service: Optional[BackgroundSpanService] = None
|
1643
|
+
if enable_background_spans and not offline_mode:
|
1644
|
+
self.background_span_service = BackgroundSpanService(
|
1645
|
+
judgment_api_key=api_key,
|
1646
|
+
organization_id=organization_id,
|
1647
|
+
batch_size=span_batch_size,
|
1648
|
+
flush_interval=span_flush_interval,
|
1649
|
+
num_workers=span_num_workers
|
1650
|
+
)
|
999
1651
|
|
1000
1652
|
elif hasattr(self, 'project_name') and self.project_name != project_name:
|
1001
1653
|
warnings.warn(
|
@@ -1006,16 +1658,44 @@ class Tracer:
|
|
1006
1658
|
)
|
1007
1659
|
|
1008
1660
|
def set_current_span(self, span_id: str):
|
1661
|
+
self.span_id_to_previous_span_id[span_id] = getattr(self, 'current_span_id', None)
|
1009
1662
|
self.current_span_id = span_id
|
1663
|
+
Tracer.current_span_id = span_id
|
1664
|
+
try:
|
1665
|
+
token = current_span_var.set(span_id)
|
1666
|
+
except:
|
1667
|
+
token = None
|
1668
|
+
return token
|
1010
1669
|
|
1011
1670
|
def get_current_span(self) -> Optional[str]:
|
1012
|
-
|
1671
|
+
try:
|
1672
|
+
current_span_var_val = current_span_var.get()
|
1673
|
+
except:
|
1674
|
+
current_span_var_val = None
|
1675
|
+
return (self.current_span_id or current_span_var_val) if self.trace_across_async_contexts else current_span_var_val
|
1676
|
+
|
1677
|
+
def reset_current_span(self, token: Optional[str] = None, span_id: Optional[str] = None):
|
1678
|
+
if not span_id:
|
1679
|
+
span_id = self.current_span_id
|
1680
|
+
try:
|
1681
|
+
current_span_var.reset(token)
|
1682
|
+
except:
|
1683
|
+
pass
|
1684
|
+
self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
|
1685
|
+
Tracer.current_span_id = self.current_span_id
|
1013
1686
|
|
1014
1687
|
def set_current_trace(self, trace: TraceClient):
|
1015
1688
|
"""
|
1016
1689
|
Set the current trace context in contextvars
|
1017
1690
|
"""
|
1018
|
-
|
1691
|
+
self.trace_id_to_previous_trace[trace.trace_id] = getattr(self, 'current_trace', None)
|
1692
|
+
self.current_trace = trace
|
1693
|
+
Tracer.current_trace = trace
|
1694
|
+
try:
|
1695
|
+
token = current_trace_var.set(trace)
|
1696
|
+
except:
|
1697
|
+
token = None
|
1698
|
+
return token
|
1019
1699
|
|
1020
1700
|
def get_current_trace(self) -> Optional[TraceClient]:
|
1021
1701
|
"""
|
@@ -1025,23 +1705,34 @@ class Tracer:
|
|
1025
1705
|
If not found (e.g., context lost across threads/tasks),
|
1026
1706
|
it falls back to the active trace client managed by the callback handler.
|
1027
1707
|
"""
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1708
|
+
try:
|
1709
|
+
current_trace_var_val = current_trace_var.get()
|
1710
|
+
except:
|
1711
|
+
current_trace_var_val = None
|
1712
|
+
|
1713
|
+
# Use context variable or class variable based on trace_across_async_contexts setting
|
1714
|
+
context_trace = (self.current_trace or current_trace_var_val) if self.trace_across_async_contexts else current_trace_var_val
|
1031
1715
|
|
1032
|
-
#
|
1716
|
+
# If we found a trace from context, return it
|
1717
|
+
if context_trace:
|
1718
|
+
return context_trace
|
1719
|
+
|
1720
|
+
# Fallback: Check the active client potentially set by a callback handler (e.g., LangGraph)
|
1033
1721
|
if hasattr(self, '_active_trace_client') and self._active_trace_client:
|
1034
|
-
# warnings.warn("Falling back to _active_trace_client in get_current_trace. ContextVar might be lost.", RuntimeWarning)
|
1035
1722
|
return self._active_trace_client
|
1036
1723
|
|
1037
|
-
# If neither is available
|
1038
|
-
# warnings.warn("No current trace found in context variable or active client fallback.", RuntimeWarning)
|
1724
|
+
# If neither is available, return None
|
1039
1725
|
return None
|
1040
|
-
|
1041
|
-
def
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1726
|
+
|
1727
|
+
def reset_current_trace(self, token: Optional[str] = None, trace_id: Optional[str] = None):
|
1728
|
+
if not trace_id and self.current_trace:
|
1729
|
+
trace_id = self.current_trace.trace_id
|
1730
|
+
try:
|
1731
|
+
current_trace_var.reset(token)
|
1732
|
+
except:
|
1733
|
+
pass
|
1734
|
+
self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
|
1735
|
+
Tracer.current_trace = self.current_trace
|
1045
1736
|
|
1046
1737
|
@contextmanager
|
1047
1738
|
def trace(
|
@@ -1056,7 +1747,7 @@ class Tracer:
|
|
1056
1747
|
project = project_name if project_name is not None else self.project_name
|
1057
1748
|
|
1058
1749
|
# Get parent trace info from context
|
1059
|
-
parent_trace =
|
1750
|
+
parent_trace = self.get_current_trace()
|
1060
1751
|
parent_trace_id = None
|
1061
1752
|
parent_name = None
|
1062
1753
|
|
@@ -1078,7 +1769,7 @@ class Tracer:
|
|
1078
1769
|
)
|
1079
1770
|
|
1080
1771
|
# Set the current trace in context variables
|
1081
|
-
token =
|
1772
|
+
token = self.set_current_trace(trace)
|
1082
1773
|
|
1083
1774
|
# Automatically create top-level span
|
1084
1775
|
with trace.span(name or "unnamed_trace") as span:
|
@@ -1087,13 +1778,13 @@ class Tracer:
|
|
1087
1778
|
yield trace
|
1088
1779
|
finally:
|
1089
1780
|
# Reset the context variable
|
1090
|
-
|
1781
|
+
self.reset_current_trace(token)
|
1091
1782
|
|
1092
1783
|
|
1093
1784
|
def log(self, msg: str, label: str = "log", score: int = 1):
|
1094
1785
|
"""Log a message with the current span context"""
|
1095
|
-
current_span_id =
|
1096
|
-
current_trace =
|
1786
|
+
current_span_id = self.get_current_span()
|
1787
|
+
current_trace = self.get_current_trace()
|
1097
1788
|
if current_span_id:
|
1098
1789
|
annotation = TraceAnnotation(
|
1099
1790
|
span_id=current_span_id,
|
@@ -1237,7 +1928,7 @@ class Tracer:
|
|
1237
1928
|
agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
|
1238
1929
|
|
1239
1930
|
# Get current trace from context
|
1240
|
-
current_trace =
|
1931
|
+
current_trace = self.get_current_trace()
|
1241
1932
|
|
1242
1933
|
# If there's no current trace, create a root trace
|
1243
1934
|
if not current_trace:
|
@@ -1258,7 +1949,7 @@ class Tracer:
|
|
1258
1949
|
|
1259
1950
|
# Save empty trace and set trace context
|
1260
1951
|
# current_trace.save(empty_save=True, overwrite=overwrite)
|
1261
|
-
trace_token =
|
1952
|
+
trace_token = self.set_current_trace(current_trace)
|
1262
1953
|
|
1263
1954
|
try:
|
1264
1955
|
# Use span for the function execution within the root trace
|
@@ -1274,7 +1965,7 @@ class Tracer:
|
|
1274
1965
|
self._conditionally_capture_and_record_state(span, args, is_before=True)
|
1275
1966
|
|
1276
1967
|
if use_deep_tracing:
|
1277
|
-
with _DeepTracer():
|
1968
|
+
with _DeepTracer(self):
|
1278
1969
|
result = await func(*args, **kwargs)
|
1279
1970
|
else:
|
1280
1971
|
try:
|
@@ -1290,12 +1981,31 @@ class Tracer:
|
|
1290
1981
|
span.record_output(result)
|
1291
1982
|
return result
|
1292
1983
|
finally:
|
1984
|
+
# Flush background spans before saving the trace
|
1985
|
+
|
1986
|
+
complete_trace_data = {
|
1987
|
+
"trace_id": current_trace.trace_id,
|
1988
|
+
"name": current_trace.name,
|
1989
|
+
"created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
|
1990
|
+
"duration": current_trace.get_duration(),
|
1991
|
+
"trace_spans": [span.model_dump() for span in current_trace.trace_spans],
|
1992
|
+
"overwrite": overwrite,
|
1993
|
+
"offline_mode": self.offline_mode,
|
1994
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
1995
|
+
"parent_name": current_trace.parent_name
|
1996
|
+
}
|
1293
1997
|
# Save the completed trace
|
1294
|
-
trace_id,
|
1295
|
-
|
1998
|
+
trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
|
1999
|
+
|
2000
|
+
# Store the complete trace data instead of just server response
|
2001
|
+
|
2002
|
+
self.traces.append(complete_trace_data)
|
2003
|
+
|
2004
|
+
# if self.background_span_service:
|
2005
|
+
# self.background_span_service.flush()
|
1296
2006
|
|
1297
2007
|
# Reset trace context (span context resets automatically)
|
1298
|
-
|
2008
|
+
self.reset_current_trace(trace_token)
|
1299
2009
|
else:
|
1300
2010
|
with current_trace.span(span_name, span_type=span_type) as span:
|
1301
2011
|
inputs = combine_args_kwargs(func, args, kwargs)
|
@@ -1307,7 +2017,7 @@ class Tracer:
|
|
1307
2017
|
self._conditionally_capture_and_record_state(span, args, is_before=True)
|
1308
2018
|
|
1309
2019
|
if use_deep_tracing:
|
1310
|
-
with _DeepTracer():
|
2020
|
+
with _DeepTracer(self):
|
1311
2021
|
result = await func(*args, **kwargs)
|
1312
2022
|
else:
|
1313
2023
|
try:
|
@@ -1336,7 +2046,7 @@ class Tracer:
|
|
1336
2046
|
class_name = args[0].__class__.__name__
|
1337
2047
|
agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
|
1338
2048
|
# Get current trace from context
|
1339
|
-
current_trace =
|
2049
|
+
current_trace = self.get_current_trace()
|
1340
2050
|
|
1341
2051
|
# If there's no current trace, create a root trace
|
1342
2052
|
if not current_trace:
|
@@ -1357,7 +2067,7 @@ class Tracer:
|
|
1357
2067
|
|
1358
2068
|
# Save empty trace and set trace context
|
1359
2069
|
# current_trace.save(empty_save=True, overwrite=overwrite)
|
1360
|
-
trace_token =
|
2070
|
+
trace_token = self.set_current_trace(current_trace)
|
1361
2071
|
|
1362
2072
|
try:
|
1363
2073
|
# Use span for the function execution within the root trace
|
@@ -1372,7 +2082,7 @@ class Tracer:
|
|
1372
2082
|
self._conditionally_capture_and_record_state(span, args, is_before=True)
|
1373
2083
|
|
1374
2084
|
if use_deep_tracing:
|
1375
|
-
with _DeepTracer():
|
2085
|
+
with _DeepTracer(self):
|
1376
2086
|
result = func(*args, **kwargs)
|
1377
2087
|
else:
|
1378
2088
|
try:
|
@@ -1389,12 +2099,27 @@ class Tracer:
|
|
1389
2099
|
span.record_output(result)
|
1390
2100
|
return result
|
1391
2101
|
finally:
|
1392
|
-
#
|
1393
|
-
trace_id, trace = current_trace.save(overwrite=overwrite)
|
1394
|
-
self.traces.append(trace)
|
2102
|
+
# Flush background spans before saving the trace
|
1395
2103
|
|
2104
|
+
|
2105
|
+
# Save the completed trace
|
2106
|
+
trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
|
2107
|
+
|
2108
|
+
# Store the complete trace data instead of just server response
|
2109
|
+
complete_trace_data = {
|
2110
|
+
"trace_id": current_trace.trace_id,
|
2111
|
+
"name": current_trace.name,
|
2112
|
+
"created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
|
2113
|
+
"duration": current_trace.get_duration(),
|
2114
|
+
"trace_spans": [span.model_dump() for span in current_trace.trace_spans],
|
2115
|
+
"overwrite": overwrite,
|
2116
|
+
"offline_mode": self.offline_mode,
|
2117
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
2118
|
+
"parent_name": current_trace.parent_name
|
2119
|
+
}
|
2120
|
+
self.traces.append(complete_trace_data)
|
1396
2121
|
# Reset trace context (span context resets automatically)
|
1397
|
-
|
2122
|
+
self.reset_current_trace(trace_token)
|
1398
2123
|
else:
|
1399
2124
|
with current_trace.span(span_name, span_type=span_type) as span:
|
1400
2125
|
|
@@ -1407,7 +2132,7 @@ class Tracer:
|
|
1407
2132
|
self._conditionally_capture_and_record_state(span, args, is_before=True)
|
1408
2133
|
|
1409
2134
|
if use_deep_tracing:
|
1410
|
-
with _DeepTracer():
|
2135
|
+
with _DeepTracer(self):
|
1411
2136
|
result = func(*args, **kwargs)
|
1412
2137
|
else:
|
1413
2138
|
try:
|
@@ -1424,6 +2149,55 @@ class Tracer:
|
|
1424
2149
|
|
1425
2150
|
return wrapper
|
1426
2151
|
|
2152
|
+
def observe_tools(self, cls=None, *, exclude_methods: Optional[List[str]] = None,
|
2153
|
+
include_private: bool = False, warn_on_double_decoration: bool = True):
|
2154
|
+
"""
|
2155
|
+
Automatically adds @observe(span_type="tool") to all methods in a class.
|
2156
|
+
|
2157
|
+
Args:
|
2158
|
+
cls: The class to decorate (automatically provided when used as decorator)
|
2159
|
+
exclude_methods: List of method names to skip decorating. Defaults to common magic methods
|
2160
|
+
include_private: Whether to decorate methods starting with underscore. Defaults to False
|
2161
|
+
warn_on_double_decoration: Whether to print warnings when skipping already-decorated methods. Defaults to True
|
2162
|
+
"""
|
2163
|
+
|
2164
|
+
if exclude_methods is None:
|
2165
|
+
exclude_methods = ['__init__', '__new__', '__del__', '__str__', '__repr__']
|
2166
|
+
|
2167
|
+
def decorate_class(cls):
|
2168
|
+
if not self.enable_monitoring:
|
2169
|
+
return cls
|
2170
|
+
|
2171
|
+
decorated = []
|
2172
|
+
skipped = []
|
2173
|
+
|
2174
|
+
for name in dir(cls):
|
2175
|
+
method = getattr(cls, name)
|
2176
|
+
|
2177
|
+
if (not callable(method) or
|
2178
|
+
name in exclude_methods or
|
2179
|
+
(name.startswith('_') and not include_private) or
|
2180
|
+
not hasattr(cls, name)):
|
2181
|
+
continue
|
2182
|
+
|
2183
|
+
if hasattr(method, '_judgment_span_name'):
|
2184
|
+
skipped.append(name)
|
2185
|
+
if warn_on_double_decoration:
|
2186
|
+
print(f"Warning: {cls.__name__}.{name} already decorated, skipping")
|
2187
|
+
continue
|
2188
|
+
|
2189
|
+
try:
|
2190
|
+
decorated_method = self.observe(method, span_type="tool")
|
2191
|
+
setattr(cls, name, decorated_method)
|
2192
|
+
decorated.append(name)
|
2193
|
+
except Exception as e:
|
2194
|
+
if warn_on_double_decoration:
|
2195
|
+
print(f"Warning: Failed to decorate {cls.__name__}.{name}: {e}")
|
2196
|
+
|
2197
|
+
return cls
|
2198
|
+
|
2199
|
+
return decorate_class if cls is None else decorate_class(cls)
|
2200
|
+
|
1427
2201
|
def async_evaluate(self, *args, **kwargs):
|
1428
2202
|
if not self.enable_evaluations:
|
1429
2203
|
return
|
@@ -1431,13 +2205,7 @@ class Tracer:
|
|
1431
2205
|
# --- Get trace_id passed explicitly (if any) ---
|
1432
2206
|
passed_trace_id = kwargs.pop('trace_id', None) # Get and remove trace_id from kwargs
|
1433
2207
|
|
1434
|
-
|
1435
|
-
current_trace = current_trace_var.get()
|
1436
|
-
|
1437
|
-
# --- Fallback Logic: Use active client only if context var is empty ---
|
1438
|
-
if not current_trace:
|
1439
|
-
current_trace = self._active_trace_client # Use the fallback
|
1440
|
-
# --- End Fallback Logic ---
|
2208
|
+
current_trace = self.get_current_trace()
|
1441
2209
|
|
1442
2210
|
if current_trace:
|
1443
2211
|
# Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
|
@@ -1448,13 +2216,34 @@ class Tracer:
|
|
1448
2216
|
else:
|
1449
2217
|
warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
|
1450
2218
|
|
1451
|
-
def
|
2219
|
+
def get_background_span_service(self) -> Optional[BackgroundSpanService]:
|
2220
|
+
"""Get the background span service instance."""
|
2221
|
+
return self.background_span_service
|
2222
|
+
|
2223
|
+
def flush_background_spans(self):
|
2224
|
+
"""Flush all pending spans in the background service."""
|
2225
|
+
if self.background_span_service:
|
2226
|
+
self.background_span_service.flush()
|
2227
|
+
|
2228
|
+
def shutdown_background_service(self):
|
2229
|
+
"""Shutdown the background span service."""
|
2230
|
+
if self.background_span_service:
|
2231
|
+
self.background_span_service.shutdown()
|
2232
|
+
self.background_span_service = None
|
2233
|
+
|
2234
|
+
def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts) -> Any:
|
1452
2235
|
"""
|
1453
2236
|
Wraps an API client to add tracing capabilities.
|
1454
2237
|
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
1455
2238
|
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1456
2239
|
"""
|
1457
2240
|
span_name, original_create, original_responses_create, original_stream = _get_client_config(client)
|
2241
|
+
|
2242
|
+
def _get_current_trace():
|
2243
|
+
if trace_across_async_contexts:
|
2244
|
+
return Tracer.current_trace
|
2245
|
+
else:
|
2246
|
+
return current_trace_var.get()
|
1458
2247
|
|
1459
2248
|
def _record_input_and_check_streaming(span, kwargs, is_responses=False):
|
1460
2249
|
"""Record input and check for streaming"""
|
@@ -1490,11 +2279,21 @@ def wrap(client: Any) -> Any:
|
|
1490
2279
|
output, usage = format_func(client, response)
|
1491
2280
|
span.record_output(output)
|
1492
2281
|
span.record_usage(usage)
|
2282
|
+
|
2283
|
+
# Queue the completed LLM span now that it has all data (input, output, usage)
|
2284
|
+
current_trace = _get_current_trace()
|
2285
|
+
if current_trace and current_trace.background_span_service:
|
2286
|
+
# Get the current span from the trace client
|
2287
|
+
current_span_id = current_trace.get_current_span()
|
2288
|
+
if current_span_id and current_span_id in current_trace.span_id_to_span:
|
2289
|
+
completed_span = current_trace.span_id_to_span[current_span_id]
|
2290
|
+
current_trace.background_span_service.queue_span(completed_span, span_state="completed")
|
2291
|
+
|
1493
2292
|
return response
|
1494
2293
|
|
1495
2294
|
# --- Traced Async Functions ---
|
1496
2295
|
async def traced_create_async(*args, **kwargs):
|
1497
|
-
current_trace =
|
2296
|
+
current_trace = _get_current_trace()
|
1498
2297
|
if not current_trace:
|
1499
2298
|
return await original_create(*args, **kwargs)
|
1500
2299
|
|
@@ -1510,7 +2309,7 @@ def wrap(client: Any) -> Any:
|
|
1510
2309
|
|
1511
2310
|
# Async responses for OpenAI clients
|
1512
2311
|
async def traced_response_create_async(*args, **kwargs):
|
1513
|
-
current_trace =
|
2312
|
+
current_trace = _get_current_trace()
|
1514
2313
|
if not current_trace:
|
1515
2314
|
return await original_responses_create(*args, **kwargs)
|
1516
2315
|
|
@@ -1526,7 +2325,7 @@ def wrap(client: Any) -> Any:
|
|
1526
2325
|
|
1527
2326
|
# Function replacing .stream() for async clients
|
1528
2327
|
def traced_stream_async(*args, **kwargs):
|
1529
|
-
current_trace =
|
2328
|
+
current_trace = _get_current_trace()
|
1530
2329
|
if not current_trace or not original_stream:
|
1531
2330
|
return original_stream(*args, **kwargs)
|
1532
2331
|
|
@@ -1542,7 +2341,7 @@ def wrap(client: Any) -> Any:
|
|
1542
2341
|
|
1543
2342
|
# --- Traced Sync Functions ---
|
1544
2343
|
def traced_create_sync(*args, **kwargs):
|
1545
|
-
current_trace =
|
2344
|
+
current_trace = _get_current_trace()
|
1546
2345
|
if not current_trace:
|
1547
2346
|
return original_create(*args, **kwargs)
|
1548
2347
|
|
@@ -1557,7 +2356,7 @@ def wrap(client: Any) -> Any:
|
|
1557
2356
|
raise e
|
1558
2357
|
|
1559
2358
|
def traced_response_create_sync(*args, **kwargs):
|
1560
|
-
current_trace =
|
2359
|
+
current_trace = _get_current_trace()
|
1561
2360
|
if not current_trace:
|
1562
2361
|
return original_responses_create(*args, **kwargs)
|
1563
2362
|
|
@@ -1573,7 +2372,7 @@ def wrap(client: Any) -> Any:
|
|
1573
2372
|
|
1574
2373
|
# Function replacing sync .stream()
|
1575
2374
|
def traced_stream_sync(*args, **kwargs):
|
1576
|
-
current_trace =
|
2375
|
+
current_trace = _get_current_trace()
|
1577
2376
|
if not current_trace or not original_stream:
|
1578
2377
|
return original_stream(*args, **kwargs)
|
1579
2378
|
|
@@ -1592,6 +2391,8 @@ def wrap(client: Any) -> Any:
|
|
1592
2391
|
client.chat.completions.create = traced_create_async
|
1593
2392
|
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
1594
2393
|
client.responses.create = traced_response_create_async
|
2394
|
+
if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
|
2395
|
+
client.beta.chat.completions.parse = traced_create_async
|
1595
2396
|
elif isinstance(client, AsyncAnthropic):
|
1596
2397
|
client.messages.create = traced_create_async
|
1597
2398
|
if original_stream:
|
@@ -1602,6 +2403,8 @@ def wrap(client: Any) -> Any:
|
|
1602
2403
|
client.chat.completions.create = traced_create_sync
|
1603
2404
|
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
1604
2405
|
client.responses.create = traced_response_create_sync
|
2406
|
+
if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
|
2407
|
+
client.beta.chat.completions.parse = traced_create_sync
|
1605
2408
|
elif isinstance(client, Anthropic):
|
1606
2409
|
client.messages.create = traced_create_sync
|
1607
2410
|
if original_stream:
|
@@ -1928,6 +2731,15 @@ def _sync_stream_wrapper(
|
|
1928
2731
|
# Update the trace entry with the accumulated content and usage
|
1929
2732
|
span.output = "".join(content_parts)
|
1930
2733
|
span.usage = final_usage
|
2734
|
+
|
2735
|
+
# Queue the completed LLM span now that streaming is done and all data is available
|
2736
|
+
# Note: We need to get the TraceClient that owns this span to access the background service
|
2737
|
+
# We can find this through the tracer singleton since spans are associated with traces
|
2738
|
+
from judgeval.common.tracer import Tracer
|
2739
|
+
tracer_instance = Tracer._instance
|
2740
|
+
if tracer_instance and tracer_instance.background_span_service:
|
2741
|
+
tracer_instance.background_span_service.queue_span(span, span_state="completed")
|
2742
|
+
|
1931
2743
|
# Note: We might need to adjust _serialize_output if this dict causes issues,
|
1932
2744
|
# but Pydantic's model_dump should handle dicts.
|
1933
2745
|
|
@@ -2012,6 +2824,12 @@ async def _async_stream_wrapper(
|
|
2012
2824
|
span.usage = usage_info
|
2013
2825
|
start_ts = getattr(span, 'created_at', time.time())
|
2014
2826
|
span.duration = time.time() - start_ts
|
2827
|
+
|
2828
|
+
# Queue the completed LLM span now that async streaming is done and all data is available
|
2829
|
+
from judgeval.common.tracer import Tracer
|
2830
|
+
tracer_instance = Tracer._instance
|
2831
|
+
if tracer_instance and tracer_instance.background_span_service:
|
2832
|
+
tracer_instance.background_span_service.queue_span(span, span_state="completed")
|
2015
2833
|
# else: # Handle error case if necessary, but remove debug print
|
2016
2834
|
|
2017
2835
|
def cost_per_token(*args, **kwargs):
|
@@ -2060,12 +2878,12 @@ class _BaseStreamManagerWrapper:
|
|
2060
2878
|
|
2061
2879
|
class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
|
2062
2880
|
async def __aenter__(self):
|
2063
|
-
self._parent_span_id_at_entry =
|
2881
|
+
self._parent_span_id_at_entry = self._trace_client.get_current_span()
|
2064
2882
|
if not self._trace_client:
|
2065
2883
|
return await self._original_manager.__aenter__()
|
2066
2884
|
|
2067
2885
|
span_id, span = self._create_span()
|
2068
|
-
self._span_context_token =
|
2886
|
+
self._span_context_token = self._trace_client.set_current_span(span_id)
|
2069
2887
|
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
2070
2888
|
|
2071
2889
|
# Call the original __aenter__ and expect it to be an async generator
|
@@ -2075,20 +2893,20 @@ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncC
|
|
2075
2893
|
|
2076
2894
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
2077
2895
|
if hasattr(self, '_span_context_token'):
|
2078
|
-
span_id =
|
2896
|
+
span_id = self._trace_client.get_current_span()
|
2079
2897
|
self._finalize_span(span_id)
|
2080
|
-
|
2898
|
+
self._trace_client.reset_current_span(self._span_context_token)
|
2081
2899
|
delattr(self, '_span_context_token')
|
2082
2900
|
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2083
2901
|
|
2084
2902
|
class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
|
2085
2903
|
def __enter__(self):
|
2086
|
-
self._parent_span_id_at_entry =
|
2904
|
+
self._parent_span_id_at_entry = self._trace_client.get_current_span()
|
2087
2905
|
if not self._trace_client:
|
2088
2906
|
return self._original_manager.__enter__()
|
2089
2907
|
|
2090
2908
|
span_id, span = self._create_span()
|
2091
|
-
self._span_context_token =
|
2909
|
+
self._span_context_token = self._trace_client.set_current_span(span_id)
|
2092
2910
|
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
2093
2911
|
|
2094
2912
|
raw_iterator = self._original_manager.__enter__()
|
@@ -2097,9 +2915,9 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
|
|
2097
2915
|
|
2098
2916
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
2099
2917
|
if hasattr(self, '_span_context_token'):
|
2100
|
-
span_id =
|
2918
|
+
span_id = self._trace_client.get_current_span()
|
2101
2919
|
self._finalize_span(span_id)
|
2102
|
-
|
2920
|
+
self._trace_client.reset_current_span(self._span_context_token)
|
2103
2921
|
delattr(self, '_span_context_token')
|
2104
2922
|
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
2105
2923
|
|
@@ -2118,4 +2936,4 @@ def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
|
2118
2936
|
return instance_name
|
2119
2937
|
else:
|
2120
2938
|
raise Exception(f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator.")
|
2121
|
-
return None
|
2939
|
+
return None
|