judgeval 0.0.40__py3-none-any.whl → 0.0.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/s3_storage.py +3 -1
- judgeval/common/tracer.py +1079 -139
- judgeval/common/utils.py +6 -2
- judgeval/constants.py +5 -0
- judgeval/data/datasets/dataset.py +12 -6
- judgeval/data/datasets/eval_dataset_client.py +3 -1
- judgeval/data/trace.py +7 -2
- judgeval/integrations/langgraph.py +218 -34
- judgeval/judgment_client.py +9 -1
- judgeval/rules.py +60 -50
- judgeval/run_evaluation.py +53 -29
- judgeval/scorers/judgeval_scorer.py +4 -1
- judgeval/scorers/prompt_scorer.py +3 -0
- judgeval/utils/alerts.py +8 -0
- {judgeval-0.0.40.dist-info → judgeval-0.0.42.dist-info}/METADATA +48 -50
- {judgeval-0.0.40.dist-info → judgeval-0.0.42.dist-info}/RECORD +18 -18
- {judgeval-0.0.40.dist-info → judgeval-0.0.42.dist-info}/WHEEL +0 -0
- {judgeval-0.0.40.dist-info → judgeval-0.0.42.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -5,7 +5,6 @@ Tracing system for judgeval that allows for function tracing using decorators.
|
|
5
5
|
import asyncio
|
6
6
|
import functools
|
7
7
|
import inspect
|
8
|
-
import json
|
9
8
|
import os
|
10
9
|
import site
|
11
10
|
import sysconfig
|
@@ -16,9 +15,10 @@ import uuid
|
|
16
15
|
import warnings
|
17
16
|
import contextvars
|
18
17
|
import sys
|
18
|
+
import json
|
19
19
|
from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
|
20
20
|
from dataclasses import dataclass, field
|
21
|
-
from datetime import datetime
|
21
|
+
from datetime import datetime, timezone
|
22
22
|
from http import HTTPStatus
|
23
23
|
from typing import (
|
24
24
|
Any,
|
@@ -29,20 +29,16 @@ from typing import (
|
|
29
29
|
Literal,
|
30
30
|
Optional,
|
31
31
|
Tuple,
|
32
|
-
Type,
|
33
|
-
TypeVar,
|
34
32
|
Union,
|
35
33
|
AsyncGenerator,
|
36
34
|
TypeAlias,
|
37
|
-
Set
|
38
35
|
)
|
39
36
|
from rich import print as rprint
|
40
|
-
import types
|
37
|
+
import types
|
41
38
|
|
42
39
|
# Third-party imports
|
43
40
|
import requests
|
44
41
|
from litellm import cost_per_token as _original_cost_per_token
|
45
|
-
from pydantic import BaseModel
|
46
42
|
from rich import print as rprint
|
47
43
|
from openai import OpenAI, AsyncOpenAI
|
48
44
|
from together import Together, AsyncTogether
|
@@ -53,24 +49,30 @@ from google import genai
|
|
53
49
|
from judgeval.constants import (
|
54
50
|
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
55
51
|
JUDGMENT_TRACES_SAVE_API_URL,
|
52
|
+
JUDGMENT_TRACES_UPSERT_API_URL,
|
53
|
+
JUDGMENT_TRACES_USAGE_CHECK_API_URL,
|
54
|
+
JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
|
56
55
|
JUDGMENT_TRACES_FETCH_API_URL,
|
57
56
|
RABBITMQ_HOST,
|
58
57
|
RABBITMQ_PORT,
|
59
58
|
RABBITMQ_QUEUE,
|
60
59
|
JUDGMENT_TRACES_DELETE_API_URL,
|
61
60
|
JUDGMENT_PROJECT_DELETE_API_URL,
|
61
|
+
JUDGMENT_TRACES_SPANS_BATCH_API_URL,
|
62
|
+
JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
|
62
63
|
)
|
63
64
|
from judgeval.data import Example, Trace, TraceSpan, TraceUsage
|
64
65
|
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
65
66
|
from judgeval.rules import Rule
|
66
67
|
from judgeval.evaluation_run import EvaluationRun
|
67
|
-
from judgeval.
|
68
|
-
from judgeval.common.utils import validate_api_key
|
68
|
+
from judgeval.common.utils import ExcInfo, validate_api_key
|
69
69
|
from judgeval.common.exceptions import JudgmentAPIError
|
70
70
|
|
71
71
|
# Standard library imports needed for the new class
|
72
72
|
import concurrent.futures
|
73
73
|
from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
|
74
|
+
import queue
|
75
|
+
import atexit
|
74
76
|
|
75
77
|
# Define context variables for tracking the current trace and the current span within a trace
|
76
78
|
current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
|
@@ -147,13 +149,18 @@ class TraceManagerClient:
|
|
147
149
|
|
148
150
|
return response.json()
|
149
151
|
|
150
|
-
def save_trace(self, trace_data: dict, offline_mode: bool = False):
|
152
|
+
def save_trace(self, trace_data: dict, offline_mode: bool = False, final_save: bool = True):
|
151
153
|
"""
|
152
154
|
Saves a trace to the Judgment Supabase and optionally to S3 if configured.
|
153
155
|
|
154
156
|
Args:
|
155
157
|
trace_data: The trace data to save
|
158
|
+
offline_mode: Whether running in offline mode
|
159
|
+
final_save: Whether this is the final save (controls S3 saving)
|
156
160
|
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
dict: Server response containing UI URL and other metadata
|
157
164
|
"""
|
158
165
|
# Save to Judgment API
|
159
166
|
|
@@ -175,7 +182,6 @@ class TraceManagerClient:
|
|
175
182
|
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
176
183
|
|
177
184
|
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
178
|
-
|
179
185
|
response = requests.post(
|
180
186
|
JUDGMENT_TRACES_SAVE_API_URL,
|
181
187
|
data=serialized_trace_data,
|
@@ -192,8 +198,107 @@ class TraceManagerClient:
|
|
192
198
|
elif response.status_code != HTTPStatus.OK:
|
193
199
|
raise ValueError(f"Failed to save trace data: {response.text}")
|
194
200
|
|
195
|
-
#
|
196
|
-
|
201
|
+
# Parse server response
|
202
|
+
server_response = response.json()
|
203
|
+
|
204
|
+
# If S3 storage is enabled, save to S3 only on final save
|
205
|
+
if self.tracer and self.tracer.use_s3 and final_save:
|
206
|
+
try:
|
207
|
+
s3_key = self.tracer.s3_storage.save_trace(
|
208
|
+
trace_data=trace_data,
|
209
|
+
trace_id=trace_data["trace_id"],
|
210
|
+
project_name=trace_data["project_name"]
|
211
|
+
)
|
212
|
+
print(f"Trace also saved to S3 at key: {s3_key}")
|
213
|
+
except Exception as e:
|
214
|
+
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
215
|
+
|
216
|
+
if not offline_mode and "ui_results_url" in server_response:
|
217
|
+
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
218
|
+
rprint(pretty_str)
|
219
|
+
|
220
|
+
return server_response
|
221
|
+
|
222
|
+
def check_usage_limits(self, count: int = 1):
|
223
|
+
"""
|
224
|
+
Check if the organization can use the requested number of traces without exceeding limits.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
count: Number of traces to check for (default: 1)
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
dict: Server response with rate limit status and usage info
|
231
|
+
|
232
|
+
Raises:
|
233
|
+
ValueError: If rate limits would be exceeded or other errors occur
|
234
|
+
"""
|
235
|
+
response = requests.post(
|
236
|
+
JUDGMENT_TRACES_USAGE_CHECK_API_URL,
|
237
|
+
json={"count": count},
|
238
|
+
headers={
|
239
|
+
"Content-Type": "application/json",
|
240
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
241
|
+
"X-Organization-Id": self.organization_id
|
242
|
+
},
|
243
|
+
verify=True
|
244
|
+
)
|
245
|
+
|
246
|
+
if response.status_code == HTTPStatus.FORBIDDEN:
|
247
|
+
# Rate limits exceeded
|
248
|
+
error_data = response.json()
|
249
|
+
raise ValueError(f"Rate limit exceeded: {error_data.get('detail', 'Monthly trace limit reached')}")
|
250
|
+
elif response.status_code != HTTPStatus.OK:
|
251
|
+
raise ValueError(f"Failed to check usage limits: {response.text}")
|
252
|
+
|
253
|
+
return response.json()
|
254
|
+
|
255
|
+
def upsert_trace(self, trace_data: dict, offline_mode: bool = False, show_link: bool = True, final_save: bool = True):
|
256
|
+
"""
|
257
|
+
Upserts a trace to the Judgment API (always overwrites if exists).
|
258
|
+
|
259
|
+
Args:
|
260
|
+
trace_data: The trace data to upsert
|
261
|
+
offline_mode: Whether running in offline mode
|
262
|
+
show_link: Whether to show the UI link (for live tracing)
|
263
|
+
final_save: Whether this is the final save (controls S3 saving)
|
264
|
+
|
265
|
+
Returns:
|
266
|
+
dict: Server response containing UI URL and other metadata
|
267
|
+
"""
|
268
|
+
def fallback_encoder(obj):
|
269
|
+
"""
|
270
|
+
Custom JSON encoder fallback.
|
271
|
+
Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
|
272
|
+
"""
|
273
|
+
try:
|
274
|
+
return repr(obj)
|
275
|
+
except Exception:
|
276
|
+
try:
|
277
|
+
return str(obj)
|
278
|
+
except Exception as e:
|
279
|
+
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
280
|
+
|
281
|
+
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
282
|
+
|
283
|
+
response = requests.post(
|
284
|
+
JUDGMENT_TRACES_UPSERT_API_URL,
|
285
|
+
data=serialized_trace_data,
|
286
|
+
headers={
|
287
|
+
"Content-Type": "application/json",
|
288
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
289
|
+
"X-Organization-Id": self.organization_id
|
290
|
+
},
|
291
|
+
verify=True
|
292
|
+
)
|
293
|
+
|
294
|
+
if response.status_code != HTTPStatus.OK:
|
295
|
+
raise ValueError(f"Failed to upsert trace data: {response.text}")
|
296
|
+
|
297
|
+
# Parse server response
|
298
|
+
server_response = response.json()
|
299
|
+
|
300
|
+
# If S3 storage is enabled, save to S3 only on final save
|
301
|
+
if self.tracer and self.tracer.use_s3 and final_save:
|
197
302
|
try:
|
198
303
|
s3_key = self.tracer.s3_storage.save_trace(
|
199
304
|
trace_data=trace_data,
|
@@ -204,9 +309,40 @@ class TraceManagerClient:
|
|
204
309
|
except Exception as e:
|
205
310
|
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
206
311
|
|
207
|
-
if not offline_mode and "ui_results_url" in
|
208
|
-
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={
|
312
|
+
if not offline_mode and show_link and "ui_results_url" in server_response:
|
313
|
+
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
|
209
314
|
rprint(pretty_str)
|
315
|
+
|
316
|
+
return server_response
|
317
|
+
|
318
|
+
def update_usage_counters(self, count: int = 1):
|
319
|
+
"""
|
320
|
+
Update trace usage counters after successfully saving traces.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
count: Number of traces to count (default: 1)
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
dict: Server response with updated usage information
|
327
|
+
|
328
|
+
Raises:
|
329
|
+
ValueError: If the update fails
|
330
|
+
"""
|
331
|
+
response = requests.post(
|
332
|
+
JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
|
333
|
+
json={"count": count},
|
334
|
+
headers={
|
335
|
+
"Content-Type": "application/json",
|
336
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
337
|
+
"X-Organization-Id": self.organization_id
|
338
|
+
},
|
339
|
+
verify=True
|
340
|
+
)
|
341
|
+
|
342
|
+
if response.status_code != HTTPStatus.OK:
|
343
|
+
raise ValueError(f"Failed to update usage counters: {response.text}")
|
344
|
+
|
345
|
+
return response.json()
|
210
346
|
|
211
347
|
## TODO: Should have a log endpoint, endpoint should also support batched payloads
|
212
348
|
def save_annotation(self, annotation: TraceAnnotation):
|
@@ -307,7 +443,7 @@ class TraceClient:
|
|
307
443
|
tracer: Optional["Tracer"],
|
308
444
|
trace_id: Optional[str] = None,
|
309
445
|
name: str = "default",
|
310
|
-
project_name: str =
|
446
|
+
project_name: str = None,
|
311
447
|
overwrite: bool = False,
|
312
448
|
rules: Optional[List[Rule]] = None,
|
313
449
|
enable_monitoring: bool = True,
|
@@ -317,7 +453,7 @@ class TraceClient:
|
|
317
453
|
):
|
318
454
|
self.name = name
|
319
455
|
self.trace_id = trace_id or str(uuid.uuid4())
|
320
|
-
self.project_name = project_name
|
456
|
+
self.project_name = project_name or str(uuid.uuid4())
|
321
457
|
self.overwrite = overwrite
|
322
458
|
self.tracer = tracer
|
323
459
|
self.rules = rules or []
|
@@ -329,35 +465,48 @@ class TraceClient:
|
|
329
465
|
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
330
466
|
self.evaluation_runs: List[EvaluationRun] = []
|
331
467
|
self.annotations: List[TraceAnnotation] = []
|
332
|
-
self.start_time =
|
468
|
+
self.start_time = None # Will be set after first successful save
|
333
469
|
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
|
334
470
|
self.visited_nodes = []
|
335
471
|
self.executed_tools = []
|
336
472
|
self.executed_node_tools = []
|
337
473
|
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
474
|
+
|
475
|
+
# Get background span service from tracer
|
476
|
+
self.background_span_service = tracer.get_background_span_service() if tracer else None
|
338
477
|
|
339
478
|
def get_current_span(self):
|
340
479
|
"""Get the current span from the context var"""
|
341
|
-
return
|
480
|
+
return self.tracer.get_current_span()
|
342
481
|
|
343
482
|
def set_current_span(self, span: Any):
|
344
483
|
"""Set the current span from the context var"""
|
345
|
-
return
|
484
|
+
return self.tracer.set_current_span(span)
|
346
485
|
|
347
486
|
def reset_current_span(self, token: Any):
|
348
487
|
"""Reset the current span from the context var"""
|
349
|
-
|
488
|
+
self.tracer.reset_current_span(token)
|
350
489
|
|
351
490
|
@contextmanager
|
352
491
|
def span(self, name: str, span_type: SpanType = "span"):
|
353
492
|
"""Context manager for creating a trace span, managing the current span via contextvars"""
|
493
|
+
is_first_span = len(self.trace_spans) == 0
|
494
|
+
if is_first_span:
|
495
|
+
try:
|
496
|
+
trace_id, server_response = self.save_with_rate_limiting(overwrite=self.overwrite, final_save=False)
|
497
|
+
# Set start_time after first successful save
|
498
|
+
if self.start_time is None:
|
499
|
+
self.start_time = time.time()
|
500
|
+
# Link will be shown by upsert_trace method
|
501
|
+
except Exception as e:
|
502
|
+
warnings.warn(f"Failed to save initial trace for live tracking: {e}")
|
354
503
|
start_time = time.time()
|
355
504
|
|
356
505
|
# Generate a unique ID for *this specific span invocation*
|
357
506
|
span_id = str(uuid.uuid4())
|
358
507
|
|
359
|
-
parent_span_id =
|
360
|
-
token =
|
508
|
+
parent_span_id = self.get_current_span() # Get ID of the parent span from context var
|
509
|
+
token = self.set_current_span(span_id) # Set *this* span's ID as the current one
|
361
510
|
|
362
511
|
current_depth = 0
|
363
512
|
if parent_span_id and parent_span_id in self._span_depths:
|
@@ -377,16 +526,27 @@ class TraceClient:
|
|
377
526
|
)
|
378
527
|
self.add_span(span)
|
379
528
|
|
529
|
+
|
530
|
+
|
531
|
+
# Queue span with initial state (input phase)
|
532
|
+
if self.background_span_service:
|
533
|
+
self.background_span_service.queue_span(span, span_state="input")
|
534
|
+
|
380
535
|
try:
|
381
536
|
yield self
|
382
537
|
finally:
|
383
538
|
duration = time.time() - start_time
|
384
539
|
span.duration = duration
|
540
|
+
|
541
|
+
# Queue span with completed state (output phase)
|
542
|
+
if self.background_span_service:
|
543
|
+
self.background_span_service.queue_span(span, span_state="completed")
|
544
|
+
|
385
545
|
# Clean up depth tracking for this span_id
|
386
546
|
if span_id in self._span_depths:
|
387
547
|
del self._span_depths[span_id]
|
388
548
|
# Reset context var
|
389
|
-
|
549
|
+
self.reset_current_span(token)
|
390
550
|
|
391
551
|
def async_evaluate(
|
392
552
|
self,
|
@@ -450,8 +610,7 @@ class TraceClient:
|
|
450
610
|
# span_id_at_eval_call = current_span_var.get()
|
451
611
|
# print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
|
452
612
|
# Prioritize explicitly passed span_id, fallback to context var
|
453
|
-
|
454
|
-
span_id_to_use = span_id if span_id is not None else current_span_ctx_var if current_span_ctx_var is not None else self.tracer.get_current_span()
|
613
|
+
span_id_to_use = span_id if span_id is not None else self.get_current_span()
|
455
614
|
# print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
|
456
615
|
# --- End Modification ---
|
457
616
|
|
@@ -474,6 +633,17 @@ class TraceClient:
|
|
474
633
|
)
|
475
634
|
|
476
635
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
636
|
+
|
637
|
+
# Queue evaluation run through background service
|
638
|
+
if self.background_span_service and span_id_to_use:
|
639
|
+
# Get the current span data to avoid race conditions
|
640
|
+
current_span = self.span_id_to_span.get(span_id_to_use)
|
641
|
+
if current_span:
|
642
|
+
self.background_span_service.queue_evaluation_run(
|
643
|
+
eval_run,
|
644
|
+
span_id=span_id_to_use,
|
645
|
+
span_data=current_span
|
646
|
+
)
|
477
647
|
|
478
648
|
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
479
649
|
# --- Modification: Use span_id from eval_run ---
|
@@ -493,58 +663,119 @@ class TraceClient:
|
|
493
663
|
return self
|
494
664
|
|
495
665
|
def record_input(self, inputs: dict):
|
496
|
-
current_span_id =
|
666
|
+
current_span_id = self.get_current_span()
|
497
667
|
if current_span_id:
|
498
668
|
span = self.span_id_to_span[current_span_id]
|
499
669
|
# Ignore self parameter
|
500
670
|
if "self" in inputs:
|
501
671
|
del inputs["self"]
|
502
672
|
span.inputs = inputs
|
673
|
+
|
674
|
+
# Queue span with input data
|
675
|
+
if self.background_span_service:
|
676
|
+
self.background_span_service.queue_span(span, span_state="input")
|
503
677
|
|
504
678
|
def record_agent_name(self, agent_name: str):
|
505
|
-
current_span_id =
|
679
|
+
current_span_id = self.get_current_span()
|
506
680
|
if current_span_id:
|
507
681
|
span = self.span_id_to_span[current_span_id]
|
508
682
|
span.agent_name = agent_name
|
683
|
+
|
684
|
+
# Queue span with agent_name data
|
685
|
+
if self.background_span_service:
|
686
|
+
self.background_span_service.queue_span(span, span_state="agent_name")
|
687
|
+
|
688
|
+
def record_state_before(self, state: dict):
|
689
|
+
"""Records the agent's state before a tool execution on the current span.
|
690
|
+
|
691
|
+
Args:
|
692
|
+
state: A dictionary representing the agent's state.
|
693
|
+
"""
|
694
|
+
current_span_id = self.get_current_span()
|
695
|
+
if current_span_id:
|
696
|
+
span = self.span_id_to_span[current_span_id]
|
697
|
+
span.state_before = state
|
698
|
+
|
699
|
+
# Queue span with state_before data
|
700
|
+
if self.background_span_service:
|
701
|
+
self.background_span_service.queue_span(span, span_state="state_before")
|
702
|
+
|
703
|
+
def record_state_after(self, state: dict):
|
704
|
+
"""Records the agent's state after a tool execution on the current span.
|
705
|
+
|
706
|
+
Args:
|
707
|
+
state: A dictionary representing the agent's state.
|
708
|
+
"""
|
709
|
+
current_span_id = self.get_current_span()
|
710
|
+
if current_span_id:
|
711
|
+
span = self.span_id_to_span[current_span_id]
|
712
|
+
span.state_after = state
|
713
|
+
|
714
|
+
# Queue span with state_after data
|
715
|
+
if self.background_span_service:
|
716
|
+
self.background_span_service.queue_span(span, span_state="state_after")
|
509
717
|
|
510
718
|
async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
|
511
719
|
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
512
720
|
try:
|
513
721
|
result = await coroutine
|
514
722
|
setattr(span, field, result)
|
723
|
+
|
724
|
+
# Queue span with output data now that coroutine is complete
|
725
|
+
if self.background_span_service and field == "output":
|
726
|
+
self.background_span_service.queue_span(span, span_state="output")
|
727
|
+
|
515
728
|
return result
|
516
729
|
except Exception as e:
|
517
730
|
setattr(span, field, f"Error: {str(e)}")
|
731
|
+
|
732
|
+
# Queue span even if there was an error
|
733
|
+
if self.background_span_service and field == "output":
|
734
|
+
self.background_span_service.queue_span(span, span_state="output")
|
735
|
+
|
518
736
|
raise
|
519
737
|
|
520
738
|
def record_output(self, output: Any):
|
521
|
-
current_span_id =
|
739
|
+
current_span_id = self.get_current_span()
|
522
740
|
if current_span_id:
|
523
741
|
span = self.span_id_to_span[current_span_id]
|
524
742
|
span.output = "<pending>" if inspect.iscoroutine(output) else output
|
525
743
|
|
526
744
|
if inspect.iscoroutine(output):
|
527
745
|
asyncio.create_task(self._update_coroutine(span, output, "output"))
|
746
|
+
|
747
|
+
# # Queue span with output data (unless it's pending)
|
748
|
+
if self.background_span_service and not inspect.iscoroutine(output):
|
749
|
+
self.background_span_service.queue_span(span, span_state="output")
|
528
750
|
|
529
751
|
return span # Return the created entry
|
530
752
|
# Removed else block - original didn't have one
|
531
753
|
return None # Return None if no span_id found
|
532
754
|
|
533
755
|
def record_usage(self, usage: TraceUsage):
|
534
|
-
current_span_id =
|
756
|
+
current_span_id = self.get_current_span()
|
535
757
|
if current_span_id:
|
536
758
|
span = self.span_id_to_span[current_span_id]
|
537
759
|
span.usage = usage
|
538
760
|
|
761
|
+
# Queue span with usage data
|
762
|
+
if self.background_span_service:
|
763
|
+
self.background_span_service.queue_span(span, span_state="usage")
|
764
|
+
|
539
765
|
return span # Return the created entry
|
540
766
|
# Removed else block - original didn't have one
|
541
767
|
return None # Return None if no span_id found
|
542
768
|
|
543
|
-
def record_error(self, error: Any):
|
544
|
-
current_span_id =
|
769
|
+
def record_error(self, error: Dict[str, Any]):
|
770
|
+
current_span_id = self.get_current_span()
|
545
771
|
if current_span_id:
|
546
772
|
span = self.span_id_to_span[current_span_id]
|
547
773
|
span.error = error
|
774
|
+
|
775
|
+
# Queue span with error data
|
776
|
+
if self.background_span_service:
|
777
|
+
self.background_span_service.queue_span(span, span_state="error")
|
778
|
+
|
548
779
|
return span
|
549
780
|
return None
|
550
781
|
|
@@ -563,13 +794,19 @@ class TraceClient:
|
|
563
794
|
"""
|
564
795
|
Get the total duration of this trace
|
565
796
|
"""
|
797
|
+
if self.start_time is None:
|
798
|
+
return 0.0 # No duration if trace hasn't been saved yet
|
566
799
|
return time.time() - self.start_time
|
567
800
|
|
568
801
|
def save(self, overwrite: bool = False) -> Tuple[str, dict]:
|
569
802
|
"""
|
570
803
|
Save the current trace to the database.
|
571
|
-
Returns a tuple of (trace_id,
|
804
|
+
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
572
805
|
"""
|
806
|
+
# Set start_time if this is the first save
|
807
|
+
if self.start_time is None:
|
808
|
+
self.start_time = time.time()
|
809
|
+
|
573
810
|
# Calculate total elapsed time
|
574
811
|
total_duration = self.get_duration()
|
575
812
|
# Create trace document - Always use standard keys for top-level counts
|
@@ -577,9 +814,9 @@ class TraceClient:
|
|
577
814
|
"trace_id": self.trace_id,
|
578
815
|
"name": self.name,
|
579
816
|
"project_name": self.project_name,
|
580
|
-
"created_at": datetime.
|
817
|
+
"created_at": datetime.fromtimestamp(self.start_time, timezone.utc).isoformat(),
|
581
818
|
"duration": total_duration,
|
582
|
-
"
|
819
|
+
"trace_spans": [span.model_dump() for span in self.trace_spans],
|
583
820
|
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
584
821
|
"overwrite": overwrite,
|
585
822
|
"offline_mode": self.tracer.offline_mode,
|
@@ -587,19 +824,84 @@ class TraceClient:
|
|
587
824
|
"parent_name": self.parent_name
|
588
825
|
}
|
589
826
|
# --- Log trace data before saving ---
|
590
|
-
self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode)
|
827
|
+
server_response = self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode, final_save=True)
|
591
828
|
|
592
829
|
# upload annotations
|
593
830
|
# TODO: batch to the log endpoint
|
594
831
|
for annotation in self.annotations:
|
595
832
|
self.trace_manager_client.save_annotation(annotation)
|
596
833
|
|
597
|
-
return self.trace_id,
|
834
|
+
return self.trace_id, server_response
|
835
|
+
|
836
|
+
def save_with_rate_limiting(self, overwrite: bool = False, final_save: bool = False) -> Tuple[str, dict]:
|
837
|
+
"""
|
838
|
+
Save the current trace to the database with rate limiting checks.
|
839
|
+
First checks usage limits, then upserts the trace if allowed.
|
840
|
+
|
841
|
+
Args:
|
842
|
+
overwrite: Whether to overwrite existing traces
|
843
|
+
final_save: Whether this is the final save (updates usage counters)
|
844
|
+
|
845
|
+
Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
|
846
|
+
"""
|
847
|
+
|
848
|
+
|
849
|
+
# Calculate total elapsed time
|
850
|
+
total_duration = self.get_duration()
|
851
|
+
|
852
|
+
# Create trace document
|
853
|
+
trace_data = {
|
854
|
+
"trace_id": self.trace_id,
|
855
|
+
"name": self.name,
|
856
|
+
"project_name": self.project_name,
|
857
|
+
"created_at": datetime.utcfromtimestamp(time.time()).isoformat(),
|
858
|
+
"duration": total_duration,
|
859
|
+
"trace_spans": [span.model_dump() for span in self.trace_spans],
|
860
|
+
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
861
|
+
"overwrite": overwrite,
|
862
|
+
"offline_mode": self.tracer.offline_mode,
|
863
|
+
"parent_trace_id": self.parent_trace_id,
|
864
|
+
"parent_name": self.parent_name
|
865
|
+
}
|
866
|
+
|
867
|
+
# Check usage limits first
|
868
|
+
try:
|
869
|
+
usage_check_result = self.trace_manager_client.check_usage_limits(count=1)
|
870
|
+
# Usage check passed silently - no need to show detailed info
|
871
|
+
except ValueError as e:
|
872
|
+
# Rate limit exceeded
|
873
|
+
warnings.warn(f"Rate limit check failed for live tracing: {e}")
|
874
|
+
raise e
|
875
|
+
|
876
|
+
# If usage check passes, upsert the trace
|
877
|
+
server_response = self.trace_manager_client.upsert_trace(
|
878
|
+
trace_data,
|
879
|
+
offline_mode=self.tracer.offline_mode,
|
880
|
+
show_link=not final_save, # Show link only on initial save, not final save
|
881
|
+
final_save=final_save # Pass final_save to control S3 saving
|
882
|
+
)
|
883
|
+
|
884
|
+
# Update usage counters only on final save
|
885
|
+
if final_save:
|
886
|
+
try:
|
887
|
+
usage_update_result = self.trace_manager_client.update_usage_counters(count=1)
|
888
|
+
# Usage updated silently - no need to show detailed usage info
|
889
|
+
except ValueError as e:
|
890
|
+
# Log warning but don't fail the trace save since the trace was already saved
|
891
|
+
warnings.warn(f"Usage counter update failed (trace was still saved): {e}")
|
892
|
+
|
893
|
+
# Upload annotations
|
894
|
+
# TODO: batch to the log endpoint
|
895
|
+
for annotation in self.annotations:
|
896
|
+
self.trace_manager_client.save_annotation(annotation)
|
897
|
+
if self.start_time is None:
|
898
|
+
self.start_time = time.time()
|
899
|
+
return self.trace_id, server_response
|
598
900
|
|
599
901
|
def delete(self):
|
600
902
|
return self.trace_manager_client.delete_trace(self.trace_id)
|
601
903
|
|
602
|
-
def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_info:
|
904
|
+
def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_info: ExcInfo):
|
603
905
|
if not current_trace:
|
604
906
|
return
|
605
907
|
|
@@ -609,7 +911,360 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
|
|
609
911
|
"message": str(exc_value) if exc_value else "No exception message",
|
610
912
|
"traceback": traceback.format_tb(exc_traceback_obj) if exc_traceback_obj else []
|
611
913
|
}
|
914
|
+
|
915
|
+
# This is where we specially handle exceptions that we might want to collect additional data for.
|
916
|
+
# When we do this, always try checking the module from sys.modules instead of importing. This will
|
917
|
+
# Let us support a wider range of exceptions without needing to import them for all clients.
|
918
|
+
|
919
|
+
# Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
|
920
|
+
# The alternative is to hand select libraries we want from sys.modules and check for them:
|
921
|
+
# As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
|
922
|
+
|
923
|
+
# General HTTP Like errors
|
924
|
+
try:
|
925
|
+
url = getattr(getattr(exc_value, "request", None), "url", None)
|
926
|
+
status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
|
927
|
+
if status_code:
|
928
|
+
formatted_exception["http"] = {
|
929
|
+
"url": url if url else "Unknown URL",
|
930
|
+
"status_code": status_code if status_code else None,
|
931
|
+
}
|
932
|
+
except Exception as e:
|
933
|
+
pass
|
934
|
+
|
612
935
|
current_trace.record_error(formatted_exception)
|
936
|
+
|
937
|
+
# Queue the span with error state through background service
|
938
|
+
if current_trace.background_span_service:
|
939
|
+
current_span_id = current_trace.get_current_span()
|
940
|
+
if current_span_id and current_span_id in current_trace.span_id_to_span:
|
941
|
+
error_span = current_trace.span_id_to_span[current_span_id]
|
942
|
+
current_trace.background_span_service.queue_span(error_span, span_state="error")
|
943
|
+
|
944
|
+
class BackgroundSpanService:
|
945
|
+
"""
|
946
|
+
Background service for queueing and batching trace spans for efficient saving.
|
947
|
+
|
948
|
+
This service:
|
949
|
+
- Queues spans as they complete
|
950
|
+
- Batches them for efficient network usage
|
951
|
+
- Sends spans periodically or when batches reach a certain size
|
952
|
+
- Handles automatic flushing when the main event terminates
|
953
|
+
"""
|
954
|
+
|
955
|
+
def __init__(self, judgment_api_key: str, organization_id: str,
|
956
|
+
batch_size: int = 10, flush_interval: float = 5.0, num_workers: int = 1):
|
957
|
+
"""
|
958
|
+
Initialize the background span service.
|
959
|
+
|
960
|
+
Args:
|
961
|
+
judgment_api_key: API key for Judgment service
|
962
|
+
organization_id: Organization ID
|
963
|
+
batch_size: Number of spans to batch before sending (default: 10)
|
964
|
+
flush_interval: Time in seconds between automatic flushes (default: 5.0)
|
965
|
+
num_workers: Number of worker threads to process the queue (default: 1)
|
966
|
+
"""
|
967
|
+
self.judgment_api_key = judgment_api_key
|
968
|
+
self.organization_id = organization_id
|
969
|
+
self.batch_size = batch_size
|
970
|
+
self.flush_interval = flush_interval
|
971
|
+
self.num_workers = max(1, num_workers) # Ensure at least 1 worker
|
972
|
+
|
973
|
+
# Queue for pending spans
|
974
|
+
self._span_queue = queue.Queue()
|
975
|
+
|
976
|
+
# Background threads for processing spans
|
977
|
+
self._worker_threads = []
|
978
|
+
self._shutdown_event = threading.Event()
|
979
|
+
|
980
|
+
# Track spans that have been sent
|
981
|
+
self._sent_spans = set()
|
982
|
+
|
983
|
+
# Register cleanup on exit
|
984
|
+
atexit.register(self.shutdown)
|
985
|
+
|
986
|
+
# Start the background workers
|
987
|
+
self._start_workers()
|
988
|
+
|
989
|
+
def _start_workers(self):
|
990
|
+
"""Start the background worker threads."""
|
991
|
+
for i in range(self.num_workers):
|
992
|
+
if len(self._worker_threads) < self.num_workers:
|
993
|
+
worker_thread = threading.Thread(
|
994
|
+
target=self._worker_loop,
|
995
|
+
daemon=True,
|
996
|
+
name=f"SpanWorker-{i+1}"
|
997
|
+
)
|
998
|
+
worker_thread.start()
|
999
|
+
self._worker_threads.append(worker_thread)
|
1000
|
+
|
1001
|
+
def _worker_loop(self):
|
1002
|
+
"""Main worker loop that processes spans in batches."""
|
1003
|
+
batch = []
|
1004
|
+
last_flush_time = time.time()
|
1005
|
+
pending_task_count = 0 # Track how many tasks we've taken from queue but not marked done
|
1006
|
+
|
1007
|
+
while not self._shutdown_event.is_set() or self._span_queue.qsize() > 0:
|
1008
|
+
try:
|
1009
|
+
# First, do a blocking get to wait for at least one item
|
1010
|
+
if not batch: # Only block if we don't have items already
|
1011
|
+
try:
|
1012
|
+
span_data = self._span_queue.get(timeout=1.0)
|
1013
|
+
batch.append(span_data)
|
1014
|
+
pending_task_count += 1
|
1015
|
+
except queue.Empty:
|
1016
|
+
# No new spans, continue to check for flush conditions
|
1017
|
+
pass
|
1018
|
+
|
1019
|
+
# Then, do non-blocking gets to drain any additional available items
|
1020
|
+
# up to our batch size limit
|
1021
|
+
while len(batch) < self.batch_size:
|
1022
|
+
try:
|
1023
|
+
span_data = self._span_queue.get_nowait() # Non-blocking
|
1024
|
+
batch.append(span_data)
|
1025
|
+
pending_task_count += 1
|
1026
|
+
except queue.Empty:
|
1027
|
+
# No more items immediately available
|
1028
|
+
break
|
1029
|
+
|
1030
|
+
current_time = time.time()
|
1031
|
+
should_flush = (
|
1032
|
+
len(batch) >= self.batch_size or
|
1033
|
+
(batch and (current_time - last_flush_time) >= self.flush_interval)
|
1034
|
+
)
|
1035
|
+
|
1036
|
+
if should_flush and batch:
|
1037
|
+
self._send_batch(batch)
|
1038
|
+
|
1039
|
+
# Only mark tasks as done after successful sending
|
1040
|
+
for _ in range(pending_task_count):
|
1041
|
+
self._span_queue.task_done()
|
1042
|
+
pending_task_count = 0 # Reset counter
|
1043
|
+
|
1044
|
+
batch.clear()
|
1045
|
+
last_flush_time = current_time
|
1046
|
+
|
1047
|
+
except Exception as e:
|
1048
|
+
warnings.warn(f"Error in span service worker loop: {e}")
|
1049
|
+
# On error, still need to mark tasks as done to prevent hanging
|
1050
|
+
for _ in range(pending_task_count):
|
1051
|
+
self._span_queue.task_done()
|
1052
|
+
pending_task_count = 0
|
1053
|
+
batch.clear()
|
1054
|
+
|
1055
|
+
# Final flush on shutdown
|
1056
|
+
if batch:
|
1057
|
+
self._send_batch(batch)
|
1058
|
+
# Mark remaining tasks as done
|
1059
|
+
for _ in range(pending_task_count):
|
1060
|
+
self._span_queue.task_done()
|
1061
|
+
|
1062
|
+
def _send_batch(self, batch: List[Dict[str, Any]]):
|
1063
|
+
"""
|
1064
|
+
Send a batch of spans to the server.
|
1065
|
+
|
1066
|
+
Args:
|
1067
|
+
batch: List of span dictionaries to send
|
1068
|
+
"""
|
1069
|
+
if not batch:
|
1070
|
+
return
|
1071
|
+
|
1072
|
+
try:
|
1073
|
+
# Group spans by type for different endpoints
|
1074
|
+
spans_to_send = []
|
1075
|
+
evaluation_runs_to_send = []
|
1076
|
+
|
1077
|
+
for item in batch:
|
1078
|
+
if item['type'] == 'span':
|
1079
|
+
spans_to_send.append(item['data'])
|
1080
|
+
elif item['type'] == 'evaluation_run':
|
1081
|
+
evaluation_runs_to_send.append(item['data'])
|
1082
|
+
|
1083
|
+
# Send spans if any
|
1084
|
+
if spans_to_send:
|
1085
|
+
self._send_spans_batch(spans_to_send)
|
1086
|
+
|
1087
|
+
# Send evaluation runs if any
|
1088
|
+
if evaluation_runs_to_send:
|
1089
|
+
self._send_evaluation_runs_batch(evaluation_runs_to_send)
|
1090
|
+
|
1091
|
+
except Exception as e:
|
1092
|
+
warnings.warn(f"Failed to send span batch: {e}")
|
1093
|
+
|
1094
|
+
def _send_spans_batch(self, spans: List[Dict[str, Any]]):
|
1095
|
+
"""Send a batch of spans to the spans endpoint."""
|
1096
|
+
payload = {
|
1097
|
+
"spans": spans,
|
1098
|
+
"organization_id": self.organization_id
|
1099
|
+
}
|
1100
|
+
|
1101
|
+
# Serialize with fallback encoder
|
1102
|
+
def fallback_encoder(obj):
|
1103
|
+
try:
|
1104
|
+
return repr(obj)
|
1105
|
+
except Exception:
|
1106
|
+
try:
|
1107
|
+
return str(obj)
|
1108
|
+
except Exception as e:
|
1109
|
+
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
1110
|
+
|
1111
|
+
try:
|
1112
|
+
serialized_data = json.dumps(payload, default=fallback_encoder)
|
1113
|
+
|
1114
|
+
# Send the actual HTTP request to the batch endpoint
|
1115
|
+
response = requests.post(
|
1116
|
+
JUDGMENT_TRACES_SPANS_BATCH_API_URL,
|
1117
|
+
data=serialized_data,
|
1118
|
+
headers={
|
1119
|
+
"Content-Type": "application/json",
|
1120
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
1121
|
+
"X-Organization-Id": self.organization_id
|
1122
|
+
},
|
1123
|
+
verify=True,
|
1124
|
+
timeout=30 # Add timeout to prevent hanging
|
1125
|
+
)
|
1126
|
+
|
1127
|
+
if response.status_code != HTTPStatus.OK:
|
1128
|
+
warnings.warn(f"Failed to send spans batch: HTTP {response.status_code} - {response.text}")
|
1129
|
+
|
1130
|
+
|
1131
|
+
except requests.RequestException as e:
|
1132
|
+
warnings.warn(f"Network error sending spans batch: {e}")
|
1133
|
+
except Exception as e:
|
1134
|
+
warnings.warn(f"Failed to serialize or send spans batch: {e}")
|
1135
|
+
|
1136
|
+
def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
|
1137
|
+
"""Send a batch of evaluation runs with their associated span data to the endpoint."""
|
1138
|
+
# Structure payload to include both evaluation run data and span data
|
1139
|
+
evaluation_entries = []
|
1140
|
+
for eval_data in evaluation_runs:
|
1141
|
+
# eval_data already contains the evaluation run data (no need to access ['data'])
|
1142
|
+
entry = {
|
1143
|
+
"evaluation_run": {
|
1144
|
+
# Extract evaluation run fields (excluding span-specific fields)
|
1145
|
+
key: value for key, value in eval_data.items()
|
1146
|
+
if key not in ['associated_span_id', 'span_data', 'queued_at']
|
1147
|
+
},
|
1148
|
+
"associated_span": {
|
1149
|
+
"span_id": eval_data.get('associated_span_id'),
|
1150
|
+
"span_data": eval_data.get('span_data')
|
1151
|
+
},
|
1152
|
+
"queued_at": eval_data.get('queued_at')
|
1153
|
+
}
|
1154
|
+
evaluation_entries.append(entry)
|
1155
|
+
|
1156
|
+
payload = {
|
1157
|
+
"organization_id": self.organization_id,
|
1158
|
+
"evaluation_entries": evaluation_entries # Each entry contains both eval run + span data
|
1159
|
+
}
|
1160
|
+
|
1161
|
+
# Serialize with fallback encoder
|
1162
|
+
def fallback_encoder(obj):
|
1163
|
+
try:
|
1164
|
+
return repr(obj)
|
1165
|
+
except Exception:
|
1166
|
+
try:
|
1167
|
+
return str(obj)
|
1168
|
+
except Exception as e:
|
1169
|
+
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
1170
|
+
|
1171
|
+
try:
|
1172
|
+
serialized_data = json.dumps(payload, default=fallback_encoder)
|
1173
|
+
|
1174
|
+
# Send the actual HTTP request to the batch endpoint
|
1175
|
+
response = requests.post(
|
1176
|
+
JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
|
1177
|
+
data=serialized_data,
|
1178
|
+
headers={
|
1179
|
+
"Content-Type": "application/json",
|
1180
|
+
"Authorization": f"Bearer {self.judgment_api_key}",
|
1181
|
+
"X-Organization-Id": self.organization_id
|
1182
|
+
},
|
1183
|
+
verify=True,
|
1184
|
+
timeout=30 # Add timeout to prevent hanging
|
1185
|
+
)
|
1186
|
+
|
1187
|
+
if response.status_code != HTTPStatus.OK:
|
1188
|
+
warnings.warn(f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}")
|
1189
|
+
|
1190
|
+
|
1191
|
+
except requests.RequestException as e:
|
1192
|
+
warnings.warn(f"Network error sending evaluation runs batch: {e}")
|
1193
|
+
except Exception as e:
|
1194
|
+
warnings.warn(f"Failed to send evaluation runs batch: {e}")
|
1195
|
+
|
1196
|
+
def queue_span(self, span: TraceSpan, span_state: str = "input"):
|
1197
|
+
"""
|
1198
|
+
Queue a span for background sending.
|
1199
|
+
|
1200
|
+
Args:
|
1201
|
+
span: The TraceSpan object to queue
|
1202
|
+
span_state: State of the span ("input", "output", "completed")
|
1203
|
+
"""
|
1204
|
+
if not self._shutdown_event.is_set():
|
1205
|
+
span_data = {
|
1206
|
+
"type": "span",
|
1207
|
+
"data": {
|
1208
|
+
**span.model_dump(),
|
1209
|
+
"span_state": span_state,
|
1210
|
+
"queued_at": time.time()
|
1211
|
+
}
|
1212
|
+
}
|
1213
|
+
self._span_queue.put(span_data)
|
1214
|
+
|
1215
|
+
def queue_evaluation_run(self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan):
|
1216
|
+
"""
|
1217
|
+
Queue an evaluation run for background sending.
|
1218
|
+
|
1219
|
+
Args:
|
1220
|
+
evaluation_run: The EvaluationRun object to queue
|
1221
|
+
span_id: The span ID associated with this evaluation run
|
1222
|
+
span_data: The span data at the time of evaluation (to avoid race conditions)
|
1223
|
+
"""
|
1224
|
+
if not self._shutdown_event.is_set():
|
1225
|
+
eval_data = {
|
1226
|
+
"type": "evaluation_run",
|
1227
|
+
"data": {
|
1228
|
+
**evaluation_run.model_dump(),
|
1229
|
+
"associated_span_id": span_id,
|
1230
|
+
"span_data": span_data.model_dump(), # Include span data to avoid race conditions
|
1231
|
+
"queued_at": time.time()
|
1232
|
+
}
|
1233
|
+
}
|
1234
|
+
self._span_queue.put(eval_data)
|
1235
|
+
|
1236
|
+
def flush(self):
|
1237
|
+
"""Force immediate sending of all queued spans."""
|
1238
|
+
try:
|
1239
|
+
# Wait for the queue to be processed
|
1240
|
+
self._span_queue.join()
|
1241
|
+
except Exception as e:
|
1242
|
+
warnings.warn(f"Error during flush: {e}")
|
1243
|
+
|
1244
|
+
def shutdown(self):
|
1245
|
+
"""Shutdown the background service and flush remaining spans."""
|
1246
|
+
if self._shutdown_event.is_set():
|
1247
|
+
return
|
1248
|
+
|
1249
|
+
try:
|
1250
|
+
# Signal shutdown to stop new items from being queued
|
1251
|
+
self._shutdown_event.set()
|
1252
|
+
|
1253
|
+
# Try to flush any remaining spans
|
1254
|
+
try:
|
1255
|
+
self.flush()
|
1256
|
+
except Exception as e:
|
1257
|
+
warnings.warn(f"Error during final flush: {e}")
|
1258
|
+
except Exception as e:
|
1259
|
+
warnings.warn(f"Error during BackgroundSpanService shutdown: {e}")
|
1260
|
+
finally:
|
1261
|
+
# Clear the worker threads list (daemon threads will be killed automatically)
|
1262
|
+
self._worker_threads.clear()
|
1263
|
+
|
1264
|
+
def get_queue_size(self) -> int:
|
1265
|
+
"""Get the current size of the span queue."""
|
1266
|
+
return self._span_queue.qsize()
|
1267
|
+
|
613
1268
|
class _DeepTracer:
|
614
1269
|
_instance: Optional["_DeepTracer"] = None
|
615
1270
|
_lock: threading.Lock = threading.Lock()
|
@@ -619,6 +1274,9 @@ class _DeepTracer:
|
|
619
1274
|
_original_sys_trace: Optional[Callable] = None
|
620
1275
|
_original_threading_trace: Optional[Callable] = None
|
621
1276
|
|
1277
|
+
def __init__(self, tracer: 'Tracer'):
|
1278
|
+
self._tracer = tracer
|
1279
|
+
|
622
1280
|
def _get_qual_name(self, frame) -> str:
|
623
1281
|
func_name = frame.f_code.co_name
|
624
1282
|
module_name = frame.f_globals.get("__name__", "unknown_module")
|
@@ -632,7 +1290,7 @@ class _DeepTracer:
|
|
632
1290
|
except Exception:
|
633
1291
|
return f"{module_name}.{func_name}"
|
634
1292
|
|
635
|
-
def __new__(cls):
|
1293
|
+
def __new__(cls, tracer: 'Tracer' = None):
|
636
1294
|
with cls._lock:
|
637
1295
|
if cls._instance is None:
|
638
1296
|
cls._instance = super().__new__(cls)
|
@@ -718,11 +1376,11 @@ class _DeepTracer:
|
|
718
1376
|
if event not in ("call", "return", "exception"):
|
719
1377
|
return
|
720
1378
|
|
721
|
-
current_trace =
|
1379
|
+
current_trace = self._tracer.get_current_trace()
|
722
1380
|
if not current_trace:
|
723
1381
|
return
|
724
1382
|
|
725
|
-
parent_span_id =
|
1383
|
+
parent_span_id = self._tracer.get_current_span()
|
726
1384
|
if not parent_span_id:
|
727
1385
|
return
|
728
1386
|
|
@@ -784,7 +1442,7 @@ class _DeepTracer:
|
|
784
1442
|
})
|
785
1443
|
self._span_stack.set(span_stack)
|
786
1444
|
|
787
|
-
token =
|
1445
|
+
token = self._tracer.set_current_span(span_id)
|
788
1446
|
frame.f_locals["_judgment_span_token"] = token
|
789
1447
|
|
790
1448
|
span = TraceSpan(
|
@@ -818,7 +1476,7 @@ class _DeepTracer:
|
|
818
1476
|
if not span_stack:
|
819
1477
|
return
|
820
1478
|
|
821
|
-
current_id =
|
1479
|
+
current_id = self._tracer.get_current_span()
|
822
1480
|
|
823
1481
|
span_data = None
|
824
1482
|
for i, entry in enumerate(reversed(span_stack)):
|
@@ -843,12 +1501,12 @@ class _DeepTracer:
|
|
843
1501
|
del current_trace._span_depths[span_data["span_id"]]
|
844
1502
|
|
845
1503
|
if span_stack:
|
846
|
-
|
1504
|
+
self._tracer.set_current_span(span_stack[-1]["span_id"])
|
847
1505
|
else:
|
848
|
-
|
1506
|
+
self._tracer.set_current_span(span_data["parent_span_id"])
|
849
1507
|
|
850
1508
|
if "_judgment_span_token" in frame.f_locals:
|
851
|
-
|
1509
|
+
self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
|
852
1510
|
|
853
1511
|
elif event == "exception":
|
854
1512
|
exc_type = arg[0]
|
@@ -887,18 +1545,28 @@ class _DeepTracer:
|
|
887
1545
|
self._original_threading_trace = None
|
888
1546
|
|
889
1547
|
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
1548
|
+
# Below commented out function isn't used anymore?
|
1549
|
+
|
1550
|
+
# def log(self, message: str, level: str = "info"):
|
1551
|
+
# """ Log a message with the span context """
|
1552
|
+
# current_trace = self._tracer.get_current_trace()
|
1553
|
+
# if current_trace:
|
1554
|
+
# current_trace.log(message, level)
|
1555
|
+
# else:
|
1556
|
+
# print(f"[{level}] {message}")
|
1557
|
+
# current_trace.record_output({"log": message})
|
898
1558
|
|
899
1559
|
class Tracer:
|
900
1560
|
_instance = None
|
901
1561
|
|
1562
|
+
# Tracer.current_trace class variable is currently used in wrap()
|
1563
|
+
# TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
|
1564
|
+
# Should be fine to do so as long as we keep Tracer as a singleton
|
1565
|
+
current_trace: Optional[TraceClient] = None
|
1566
|
+
# current_span_id: Optional[str] = None
|
1567
|
+
|
1568
|
+
trace_across_async_contexts: bool = False # BY default, we don't trace across async contexts
|
1569
|
+
|
902
1570
|
def __new__(cls, *args, **kwargs):
|
903
1571
|
if cls._instance is None:
|
904
1572
|
cls._instance = super(Tracer, cls).__new__(cls)
|
@@ -907,7 +1575,7 @@ class Tracer:
|
|
907
1575
|
def __init__(
|
908
1576
|
self,
|
909
1577
|
api_key: str = os.getenv("JUDGMENT_API_KEY"),
|
910
|
-
project_name: str =
|
1578
|
+
project_name: str = None,
|
911
1579
|
rules: Optional[List[Rule]] = None, # Added rules parameter
|
912
1580
|
organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
|
913
1581
|
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
|
@@ -919,7 +1587,13 @@ class Tracer:
|
|
919
1587
|
s3_aws_secret_access_key: Optional[str] = None,
|
920
1588
|
s3_region_name: Optional[str] = None,
|
921
1589
|
offline_mode: bool = False,
|
922
|
-
deep_tracing: bool = True # Deep tracing is enabled by default
|
1590
|
+
deep_tracing: bool = True, # Deep tracing is enabled by default
|
1591
|
+
trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
|
1592
|
+
# Background span service configuration
|
1593
|
+
enable_background_spans: bool = True, # Enable background span service by default
|
1594
|
+
span_batch_size: int = 50, # Number of spans to batch before sending
|
1595
|
+
span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
|
1596
|
+
span_num_workers: int = 10 # Number of worker threads for span processing
|
923
1597
|
):
|
924
1598
|
if not hasattr(self, 'initialized'):
|
925
1599
|
if not api_key:
|
@@ -935,16 +1609,20 @@ class Tracer:
|
|
935
1609
|
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
936
1610
|
|
937
1611
|
self.api_key: str = api_key
|
938
|
-
self.project_name: str = project_name
|
1612
|
+
self.project_name: str = project_name or str(uuid.uuid4())
|
939
1613
|
self.organization_id: str = organization_id
|
940
|
-
self._current_trace: Optional[str] = None
|
941
|
-
self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
|
942
1614
|
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
943
1615
|
self.traces: List[Trace] = []
|
944
1616
|
self.initialized: bool = True
|
945
1617
|
self.enable_monitoring: bool = enable_monitoring
|
946
1618
|
self.enable_evaluations: bool = enable_evaluations
|
947
1619
|
self.class_identifiers: Dict[str, str] = {} # Dictionary to store class identifiers
|
1620
|
+
self.span_id_to_previous_span_id: Dict[str, str] = {}
|
1621
|
+
self.trace_id_to_previous_trace: Dict[str, TraceClient] = {}
|
1622
|
+
self.current_span_id: Optional[str] = None
|
1623
|
+
self.current_trace: Optional[TraceClient] = None
|
1624
|
+
self.trace_across_async_contexts: bool = trace_across_async_contexts
|
1625
|
+
Tracer.trace_across_async_contexts = trace_across_async_contexts
|
948
1626
|
|
949
1627
|
# Initialize S3 storage if enabled
|
950
1628
|
self.use_s3 = use_s3
|
@@ -958,6 +1636,18 @@ class Tracer:
|
|
958
1636
|
)
|
959
1637
|
self.offline_mode: bool = offline_mode
|
960
1638
|
self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
|
1639
|
+
|
1640
|
+
# Initialize background span service
|
1641
|
+
self.enable_background_spans: bool = enable_background_spans
|
1642
|
+
self.background_span_service: Optional[BackgroundSpanService] = None
|
1643
|
+
if enable_background_spans and not offline_mode:
|
1644
|
+
self.background_span_service = BackgroundSpanService(
|
1645
|
+
judgment_api_key=api_key,
|
1646
|
+
organization_id=organization_id,
|
1647
|
+
batch_size=span_batch_size,
|
1648
|
+
flush_interval=span_flush_interval,
|
1649
|
+
num_workers=span_num_workers
|
1650
|
+
)
|
961
1651
|
|
962
1652
|
elif hasattr(self, 'project_name') and self.project_name != project_name:
|
963
1653
|
warnings.warn(
|
@@ -968,16 +1658,44 @@ class Tracer:
|
|
968
1658
|
)
|
969
1659
|
|
970
1660
|
def set_current_span(self, span_id: str):
|
1661
|
+
self.span_id_to_previous_span_id[span_id] = getattr(self, 'current_span_id', None)
|
971
1662
|
self.current_span_id = span_id
|
1663
|
+
Tracer.current_span_id = span_id
|
1664
|
+
try:
|
1665
|
+
token = current_span_var.set(span_id)
|
1666
|
+
except:
|
1667
|
+
token = None
|
1668
|
+
return token
|
972
1669
|
|
973
1670
|
def get_current_span(self) -> Optional[str]:
|
974
|
-
|
1671
|
+
try:
|
1672
|
+
current_span_var_val = current_span_var.get()
|
1673
|
+
except:
|
1674
|
+
current_span_var_val = None
|
1675
|
+
return (self.current_span_id or current_span_var_val) if self.trace_across_async_contexts else current_span_var_val
|
1676
|
+
|
1677
|
+
def reset_current_span(self, token: Optional[str] = None, span_id: Optional[str] = None):
|
1678
|
+
if not span_id:
|
1679
|
+
span_id = self.current_span_id
|
1680
|
+
try:
|
1681
|
+
current_span_var.reset(token)
|
1682
|
+
except:
|
1683
|
+
pass
|
1684
|
+
self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
|
1685
|
+
Tracer.current_span_id = self.current_span_id
|
975
1686
|
|
976
1687
|
def set_current_trace(self, trace: TraceClient):
|
977
1688
|
"""
|
978
1689
|
Set the current trace context in contextvars
|
979
1690
|
"""
|
980
|
-
|
1691
|
+
self.trace_id_to_previous_trace[trace.trace_id] = getattr(self, 'current_trace', None)
|
1692
|
+
self.current_trace = trace
|
1693
|
+
Tracer.current_trace = trace
|
1694
|
+
try:
|
1695
|
+
token = current_trace_var.set(trace)
|
1696
|
+
except:
|
1697
|
+
token = None
|
1698
|
+
return token
|
981
1699
|
|
982
1700
|
def get_current_trace(self) -> Optional[TraceClient]:
|
983
1701
|
"""
|
@@ -987,23 +1705,34 @@ class Tracer:
|
|
987
1705
|
If not found (e.g., context lost across threads/tasks),
|
988
1706
|
it falls back to the active trace client managed by the callback handler.
|
989
1707
|
"""
|
990
|
-
|
991
|
-
|
992
|
-
|
1708
|
+
try:
|
1709
|
+
current_trace_var_val = current_trace_var.get()
|
1710
|
+
except:
|
1711
|
+
current_trace_var_val = None
|
1712
|
+
|
1713
|
+
# Use context variable or class variable based on trace_across_async_contexts setting
|
1714
|
+
context_trace = (self.current_trace or current_trace_var_val) if self.trace_across_async_contexts else current_trace_var_val
|
1715
|
+
|
1716
|
+
# If we found a trace from context, return it
|
1717
|
+
if context_trace:
|
1718
|
+
return context_trace
|
993
1719
|
|
994
|
-
# Fallback: Check the active client potentially set by a callback handler
|
1720
|
+
# Fallback: Check the active client potentially set by a callback handler (e.g., LangGraph)
|
995
1721
|
if hasattr(self, '_active_trace_client') and self._active_trace_client:
|
996
|
-
# warnings.warn("Falling back to _active_trace_client in get_current_trace. ContextVar might be lost.", RuntimeWarning)
|
997
1722
|
return self._active_trace_client
|
998
1723
|
|
999
|
-
# If neither is available
|
1000
|
-
# warnings.warn("No current trace found in context variable or active client fallback.", RuntimeWarning)
|
1724
|
+
# If neither is available, return None
|
1001
1725
|
return None
|
1002
|
-
|
1003
|
-
def
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1726
|
+
|
1727
|
+
def reset_current_trace(self, token: Optional[str] = None, trace_id: Optional[str] = None):
|
1728
|
+
if not trace_id and self.current_trace:
|
1729
|
+
trace_id = self.current_trace.trace_id
|
1730
|
+
try:
|
1731
|
+
current_trace_var.reset(token)
|
1732
|
+
except:
|
1733
|
+
pass
|
1734
|
+
self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
|
1735
|
+
Tracer.current_trace = self.current_trace
|
1007
1736
|
|
1008
1737
|
@contextmanager
|
1009
1738
|
def trace(
|
@@ -1018,7 +1747,7 @@ class Tracer:
|
|
1018
1747
|
project = project_name if project_name is not None else self.project_name
|
1019
1748
|
|
1020
1749
|
# Get parent trace info from context
|
1021
|
-
parent_trace =
|
1750
|
+
parent_trace = self.get_current_trace()
|
1022
1751
|
parent_trace_id = None
|
1023
1752
|
parent_name = None
|
1024
1753
|
|
@@ -1040,7 +1769,7 @@ class Tracer:
|
|
1040
1769
|
)
|
1041
1770
|
|
1042
1771
|
# Set the current trace in context variables
|
1043
|
-
token =
|
1772
|
+
token = self.set_current_trace(trace)
|
1044
1773
|
|
1045
1774
|
# Automatically create top-level span
|
1046
1775
|
with trace.span(name or "unnamed_trace") as span:
|
@@ -1049,13 +1778,13 @@ class Tracer:
|
|
1049
1778
|
yield trace
|
1050
1779
|
finally:
|
1051
1780
|
# Reset the context variable
|
1052
|
-
|
1781
|
+
self.reset_current_trace(token)
|
1053
1782
|
|
1054
1783
|
|
1055
1784
|
def log(self, msg: str, label: str = "log", score: int = 1):
|
1056
1785
|
"""Log a message with the current span context"""
|
1057
|
-
current_span_id =
|
1058
|
-
current_trace =
|
1786
|
+
current_span_id = self.get_current_span()
|
1787
|
+
current_trace = self.get_current_trace()
|
1059
1788
|
if current_span_id:
|
1060
1789
|
annotation = TraceAnnotation(
|
1061
1790
|
span_id=current_span_id,
|
@@ -1068,32 +1797,92 @@ class Tracer:
|
|
1068
1797
|
|
1069
1798
|
rprint(f"[bold]{label}:[/bold] {msg}")
|
1070
1799
|
|
1071
|
-
def identify(self, identifier: str):
|
1800
|
+
def identify(self, identifier: str, track_state: bool = False, track_attributes: Optional[List[str]] = None, field_mappings: Optional[Dict[str, str]] = None):
|
1072
1801
|
"""
|
1073
|
-
Class decorator that associates a class with a custom identifier.
|
1802
|
+
Class decorator that associates a class with a custom identifier and enables state tracking.
|
1074
1803
|
|
1075
1804
|
This decorator creates a mapping between the class name and the provided
|
1076
1805
|
identifier, which can be useful for tagging, grouping, or referencing
|
1077
|
-
classes in a standardized way.
|
1806
|
+
classes in a standardized way. It also enables automatic state capture
|
1807
|
+
for instances of the decorated class when used with tracing.
|
1078
1808
|
|
1079
1809
|
Args:
|
1080
|
-
identifier: The identifier to associate with the decorated class
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1810
|
+
identifier: The identifier to associate with the decorated class.
|
1811
|
+
This will be used as the instance name in traces.
|
1812
|
+
track_state: Whether to automatically capture the state (attributes)
|
1813
|
+
of instances before and after function execution. Defaults to False.
|
1814
|
+
track_attributes: Optional list of specific attribute names to track.
|
1815
|
+
If None, all non-private attributes (not starting with '_')
|
1816
|
+
will be tracked when track_state=True.
|
1817
|
+
field_mappings: Optional dictionary mapping internal attribute names to
|
1818
|
+
display names in the captured state. For example:
|
1819
|
+
{"system_prompt": "instructions"} will capture the
|
1820
|
+
'instructions' attribute as 'system_prompt' in the state.
|
1084
1821
|
|
1085
1822
|
Example:
|
1086
|
-
@tracer.identify(identifier="user_model")
|
1823
|
+
@tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
|
1087
1824
|
class User:
|
1088
1825
|
# Class implementation
|
1089
1826
|
"""
|
1090
1827
|
def decorator(cls):
|
1091
1828
|
class_name = cls.__name__
|
1092
|
-
self.class_identifiers[class_name] =
|
1829
|
+
self.class_identifiers[class_name] = {
|
1830
|
+
"identifier": identifier,
|
1831
|
+
"track_state": track_state,
|
1832
|
+
"track_attributes": track_attributes,
|
1833
|
+
"field_mappings": field_mappings or {}
|
1834
|
+
}
|
1093
1835
|
return cls
|
1094
1836
|
|
1095
1837
|
return decorator
|
1096
1838
|
|
1839
|
+
def _capture_instance_state(self, instance: Any, class_config: Dict[str, Any]) -> Dict[str, Any]:
|
1840
|
+
"""
|
1841
|
+
Capture the state of an instance based on class configuration.
|
1842
|
+
Args:
|
1843
|
+
instance: The instance to capture the state of.
|
1844
|
+
class_config: Configuration dictionary for state capture,
|
1845
|
+
expected to contain 'track_attributes' and 'field_mappings'.
|
1846
|
+
"""
|
1847
|
+
track_attributes = class_config.get('track_attributes')
|
1848
|
+
field_mappings = class_config.get('field_mappings')
|
1849
|
+
|
1850
|
+
if track_attributes:
|
1851
|
+
|
1852
|
+
state = {attr: getattr(instance, attr, None) for attr in track_attributes}
|
1853
|
+
else:
|
1854
|
+
|
1855
|
+
state = {k: v for k, v in instance.__dict__.items() if not k.startswith('_')}
|
1856
|
+
|
1857
|
+
if field_mappings:
|
1858
|
+
state['field_mappings'] = field_mappings
|
1859
|
+
|
1860
|
+
return state
|
1861
|
+
|
1862
|
+
|
1863
|
+
def _get_instance_state_if_tracked(self, args):
|
1864
|
+
"""
|
1865
|
+
Extract instance state if the instance should be tracked.
|
1866
|
+
|
1867
|
+
Returns the captured state dict if tracking is enabled, None otherwise.
|
1868
|
+
"""
|
1869
|
+
if args and hasattr(args[0], '__class__'):
|
1870
|
+
instance = args[0]
|
1871
|
+
class_name = instance.__class__.__name__
|
1872
|
+
if (class_name in self.class_identifiers and
|
1873
|
+
isinstance(self.class_identifiers[class_name], dict) and
|
1874
|
+
self.class_identifiers[class_name].get('track_state', False)):
|
1875
|
+
return self._capture_instance_state(instance, self.class_identifiers[class_name])
|
1876
|
+
|
1877
|
+
def _conditionally_capture_and_record_state(self, trace_client_instance: TraceClient, args: tuple, is_before: bool):
|
1878
|
+
"""Captures instance state if tracked and records it via the trace_client."""
|
1879
|
+
state = self._get_instance_state_if_tracked(args)
|
1880
|
+
if state:
|
1881
|
+
if is_before:
|
1882
|
+
trace_client_instance.record_state_before(state)
|
1883
|
+
else:
|
1884
|
+
trace_client_instance.record_state_after(state)
|
1885
|
+
|
1097
1886
|
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
|
1098
1887
|
"""
|
1099
1888
|
Decorator to trace function execution with detailed entry/exit information.
|
@@ -1139,7 +1928,7 @@ class Tracer:
|
|
1139
1928
|
agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
|
1140
1929
|
|
1141
1930
|
# Get current trace from context
|
1142
|
-
current_trace =
|
1931
|
+
current_trace = self.get_current_trace()
|
1143
1932
|
|
1144
1933
|
# If there's no current trace, create a root trace
|
1145
1934
|
if not current_trace:
|
@@ -1160,7 +1949,7 @@ class Tracer:
|
|
1160
1949
|
|
1161
1950
|
# Save empty trace and set trace context
|
1162
1951
|
# current_trace.save(empty_save=True, overwrite=overwrite)
|
1163
|
-
trace_token =
|
1952
|
+
trace_token = self.set_current_trace(current_trace)
|
1164
1953
|
|
1165
1954
|
try:
|
1166
1955
|
# Use span for the function execution within the root trace
|
@@ -1171,9 +1960,12 @@ class Tracer:
|
|
1171
1960
|
span.record_input(inputs)
|
1172
1961
|
if agent_name:
|
1173
1962
|
span.record_agent_name(agent_name)
|
1963
|
+
|
1964
|
+
# Capture state before execution
|
1965
|
+
self._conditionally_capture_and_record_state(span, args, is_before=True)
|
1174
1966
|
|
1175
1967
|
if use_deep_tracing:
|
1176
|
-
with _DeepTracer():
|
1968
|
+
with _DeepTracer(self):
|
1177
1969
|
result = await func(*args, **kwargs)
|
1178
1970
|
else:
|
1179
1971
|
try:
|
@@ -1181,17 +1973,39 @@ class Tracer:
|
|
1181
1973
|
except Exception as e:
|
1182
1974
|
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1183
1975
|
raise e
|
1184
|
-
|
1976
|
+
|
1977
|
+
# Capture state after execution
|
1978
|
+
self._conditionally_capture_and_record_state(span, args, is_before=False)
|
1979
|
+
|
1185
1980
|
# Record output
|
1186
1981
|
span.record_output(result)
|
1187
1982
|
return result
|
1188
1983
|
finally:
|
1984
|
+
# Flush background spans before saving the trace
|
1985
|
+
|
1986
|
+
complete_trace_data = {
|
1987
|
+
"trace_id": current_trace.trace_id,
|
1988
|
+
"name": current_trace.name,
|
1989
|
+
"created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
|
1990
|
+
"duration": current_trace.get_duration(),
|
1991
|
+
"trace_spans": [span.model_dump() for span in current_trace.trace_spans],
|
1992
|
+
"overwrite": overwrite,
|
1993
|
+
"offline_mode": self.offline_mode,
|
1994
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
1995
|
+
"parent_name": current_trace.parent_name
|
1996
|
+
}
|
1189
1997
|
# Save the completed trace
|
1190
|
-
trace_id,
|
1191
|
-
|
1998
|
+
trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
|
1999
|
+
|
2000
|
+
# Store the complete trace data instead of just server response
|
2001
|
+
|
2002
|
+
self.traces.append(complete_trace_data)
|
2003
|
+
|
2004
|
+
# if self.background_span_service:
|
2005
|
+
# self.background_span_service.flush()
|
1192
2006
|
|
1193
2007
|
# Reset trace context (span context resets automatically)
|
1194
|
-
|
2008
|
+
self.reset_current_trace(trace_token)
|
1195
2009
|
else:
|
1196
2010
|
with current_trace.span(span_name, span_type=span_type) as span:
|
1197
2011
|
inputs = combine_args_kwargs(func, args, kwargs)
|
@@ -1199,8 +2013,11 @@ class Tracer:
|
|
1199
2013
|
if agent_name:
|
1200
2014
|
span.record_agent_name(agent_name)
|
1201
2015
|
|
2016
|
+
# Capture state before execution
|
2017
|
+
self._conditionally_capture_and_record_state(span, args, is_before=True)
|
2018
|
+
|
1202
2019
|
if use_deep_tracing:
|
1203
|
-
with _DeepTracer():
|
2020
|
+
with _DeepTracer(self):
|
1204
2021
|
result = await func(*args, **kwargs)
|
1205
2022
|
else:
|
1206
2023
|
try:
|
@@ -1208,6 +2025,9 @@ class Tracer:
|
|
1208
2025
|
except Exception as e:
|
1209
2026
|
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1210
2027
|
raise e
|
2028
|
+
|
2029
|
+
# Capture state after execution
|
2030
|
+
self._conditionally_capture_and_record_state(span, args, is_before=False)
|
1211
2031
|
|
1212
2032
|
span.record_output(result)
|
1213
2033
|
return result
|
@@ -1226,7 +2046,7 @@ class Tracer:
|
|
1226
2046
|
class_name = args[0].__class__.__name__
|
1227
2047
|
agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
|
1228
2048
|
# Get current trace from context
|
1229
|
-
current_trace =
|
2049
|
+
current_trace = self.get_current_trace()
|
1230
2050
|
|
1231
2051
|
# If there's no current trace, create a root trace
|
1232
2052
|
if not current_trace:
|
@@ -1247,7 +2067,7 @@ class Tracer:
|
|
1247
2067
|
|
1248
2068
|
# Save empty trace and set trace context
|
1249
2069
|
# current_trace.save(empty_save=True, overwrite=overwrite)
|
1250
|
-
trace_token =
|
2070
|
+
trace_token = self.set_current_trace(current_trace)
|
1251
2071
|
|
1252
2072
|
try:
|
1253
2073
|
# Use span for the function execution within the root trace
|
@@ -1258,8 +2078,11 @@ class Tracer:
|
|
1258
2078
|
span.record_input(inputs)
|
1259
2079
|
if agent_name:
|
1260
2080
|
span.record_agent_name(agent_name)
|
2081
|
+
# Capture state before execution
|
2082
|
+
self._conditionally_capture_and_record_state(span, args, is_before=True)
|
2083
|
+
|
1261
2084
|
if use_deep_tracing:
|
1262
|
-
with _DeepTracer():
|
2085
|
+
with _DeepTracer(self):
|
1263
2086
|
result = func(*args, **kwargs)
|
1264
2087
|
else:
|
1265
2088
|
try:
|
@@ -1267,17 +2090,36 @@ class Tracer:
|
|
1267
2090
|
except Exception as e:
|
1268
2091
|
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1269
2092
|
raise e
|
2093
|
+
|
2094
|
+
# Capture state after execution
|
2095
|
+
self._conditionally_capture_and_record_state(span, args, is_before=False)
|
2096
|
+
|
1270
2097
|
|
1271
2098
|
# Record output
|
1272
2099
|
span.record_output(result)
|
1273
2100
|
return result
|
1274
2101
|
finally:
|
1275
|
-
#
|
1276
|
-
trace_id, trace = current_trace.save(overwrite=overwrite)
|
1277
|
-
self.traces.append(trace)
|
2102
|
+
# Flush background spans before saving the trace
|
1278
2103
|
|
2104
|
+
|
2105
|
+
# Save the completed trace
|
2106
|
+
trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
|
2107
|
+
|
2108
|
+
# Store the complete trace data instead of just server response
|
2109
|
+
complete_trace_data = {
|
2110
|
+
"trace_id": current_trace.trace_id,
|
2111
|
+
"name": current_trace.name,
|
2112
|
+
"created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
|
2113
|
+
"duration": current_trace.get_duration(),
|
2114
|
+
"trace_spans": [span.model_dump() for span in current_trace.trace_spans],
|
2115
|
+
"overwrite": overwrite,
|
2116
|
+
"offline_mode": self.offline_mode,
|
2117
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
2118
|
+
"parent_name": current_trace.parent_name
|
2119
|
+
}
|
2120
|
+
self.traces.append(complete_trace_data)
|
1279
2121
|
# Reset trace context (span context resets automatically)
|
1280
|
-
|
2122
|
+
self.reset_current_trace(trace_token)
|
1281
2123
|
else:
|
1282
2124
|
with current_trace.span(span_name, span_type=span_type) as span:
|
1283
2125
|
|
@@ -1286,8 +2128,11 @@ class Tracer:
|
|
1286
2128
|
if agent_name:
|
1287
2129
|
span.record_agent_name(agent_name)
|
1288
2130
|
|
2131
|
+
# Capture state before execution
|
2132
|
+
self._conditionally_capture_and_record_state(span, args, is_before=True)
|
2133
|
+
|
1289
2134
|
if use_deep_tracing:
|
1290
|
-
with _DeepTracer():
|
2135
|
+
with _DeepTracer(self):
|
1291
2136
|
result = func(*args, **kwargs)
|
1292
2137
|
else:
|
1293
2138
|
try:
|
@@ -1296,11 +2141,63 @@ class Tracer:
|
|
1296
2141
|
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1297
2142
|
raise e
|
1298
2143
|
|
2144
|
+
# Capture state after execution
|
2145
|
+
self._conditionally_capture_and_record_state(span, args, is_before=False)
|
2146
|
+
|
1299
2147
|
span.record_output(result)
|
1300
2148
|
return result
|
1301
2149
|
|
1302
2150
|
return wrapper
|
1303
2151
|
|
2152
|
+
def observe_tools(self, cls=None, *, exclude_methods: Optional[List[str]] = None,
|
2153
|
+
include_private: bool = False, warn_on_double_decoration: bool = True):
|
2154
|
+
"""
|
2155
|
+
Automatically adds @observe(span_type="tool") to all methods in a class.
|
2156
|
+
|
2157
|
+
Args:
|
2158
|
+
cls: The class to decorate (automatically provided when used as decorator)
|
2159
|
+
exclude_methods: List of method names to skip decorating. Defaults to common magic methods
|
2160
|
+
include_private: Whether to decorate methods starting with underscore. Defaults to False
|
2161
|
+
warn_on_double_decoration: Whether to print warnings when skipping already-decorated methods. Defaults to True
|
2162
|
+
"""
|
2163
|
+
|
2164
|
+
if exclude_methods is None:
|
2165
|
+
exclude_methods = ['__init__', '__new__', '__del__', '__str__', '__repr__']
|
2166
|
+
|
2167
|
+
def decorate_class(cls):
|
2168
|
+
if not self.enable_monitoring:
|
2169
|
+
return cls
|
2170
|
+
|
2171
|
+
decorated = []
|
2172
|
+
skipped = []
|
2173
|
+
|
2174
|
+
for name in dir(cls):
|
2175
|
+
method = getattr(cls, name)
|
2176
|
+
|
2177
|
+
if (not callable(method) or
|
2178
|
+
name in exclude_methods or
|
2179
|
+
(name.startswith('_') and not include_private) or
|
2180
|
+
not hasattr(cls, name)):
|
2181
|
+
continue
|
2182
|
+
|
2183
|
+
if hasattr(method, '_judgment_span_name'):
|
2184
|
+
skipped.append(name)
|
2185
|
+
if warn_on_double_decoration:
|
2186
|
+
print(f"Warning: {cls.__name__}.{name} already decorated, skipping")
|
2187
|
+
continue
|
2188
|
+
|
2189
|
+
try:
|
2190
|
+
decorated_method = self.observe(method, span_type="tool")
|
2191
|
+
setattr(cls, name, decorated_method)
|
2192
|
+
decorated.append(name)
|
2193
|
+
except Exception as e:
|
2194
|
+
if warn_on_double_decoration:
|
2195
|
+
print(f"Warning: Failed to decorate {cls.__name__}.{name}: {e}")
|
2196
|
+
|
2197
|
+
return cls
|
2198
|
+
|
2199
|
+
return decorate_class if cls is None else decorate_class(cls)
|
2200
|
+
|
1304
2201
|
def async_evaluate(self, *args, **kwargs):
|
1305
2202
|
if not self.enable_evaluations:
|
1306
2203
|
return
|
@@ -1308,13 +2205,7 @@ class Tracer:
|
|
1308
2205
|
# --- Get trace_id passed explicitly (if any) ---
|
1309
2206
|
passed_trace_id = kwargs.pop('trace_id', None) # Get and remove trace_id from kwargs
|
1310
2207
|
|
1311
|
-
|
1312
|
-
current_trace = current_trace_var.get()
|
1313
|
-
|
1314
|
-
# --- Fallback Logic: Use active client only if context var is empty ---
|
1315
|
-
if not current_trace:
|
1316
|
-
current_trace = self._active_trace_client # Use the fallback
|
1317
|
-
# --- End Fallback Logic ---
|
2208
|
+
current_trace = self.get_current_trace()
|
1318
2209
|
|
1319
2210
|
if current_trace:
|
1320
2211
|
# Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
|
@@ -1325,13 +2216,34 @@ class Tracer:
|
|
1325
2216
|
else:
|
1326
2217
|
warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
|
1327
2218
|
|
1328
|
-
def
|
2219
|
+
def get_background_span_service(self) -> Optional[BackgroundSpanService]:
|
2220
|
+
"""Get the background span service instance."""
|
2221
|
+
return self.background_span_service
|
2222
|
+
|
2223
|
+
def flush_background_spans(self):
|
2224
|
+
"""Flush all pending spans in the background service."""
|
2225
|
+
if self.background_span_service:
|
2226
|
+
self.background_span_service.flush()
|
2227
|
+
|
2228
|
+
def shutdown_background_service(self):
|
2229
|
+
"""Shutdown the background span service."""
|
2230
|
+
if self.background_span_service:
|
2231
|
+
self.background_span_service.shutdown()
|
2232
|
+
self.background_span_service = None
|
2233
|
+
|
2234
|
+
def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts) -> Any:
|
1329
2235
|
"""
|
1330
2236
|
Wraps an API client to add tracing capabilities.
|
1331
2237
|
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
1332
2238
|
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1333
2239
|
"""
|
1334
2240
|
span_name, original_create, original_responses_create, original_stream = _get_client_config(client)
|
2241
|
+
|
2242
|
+
def _get_current_trace():
|
2243
|
+
if trace_across_async_contexts:
|
2244
|
+
return Tracer.current_trace
|
2245
|
+
else:
|
2246
|
+
return current_trace_var.get()
|
1335
2247
|
|
1336
2248
|
def _record_input_and_check_streaming(span, kwargs, is_responses=False):
|
1337
2249
|
"""Record input and check for streaming"""
|
@@ -1367,18 +2279,21 @@ def wrap(client: Any) -> Any:
|
|
1367
2279
|
output, usage = format_func(client, response)
|
1368
2280
|
span.record_output(output)
|
1369
2281
|
span.record_usage(usage)
|
2282
|
+
|
2283
|
+
# Queue the completed LLM span now that it has all data (input, output, usage)
|
2284
|
+
current_trace = _get_current_trace()
|
2285
|
+
if current_trace and current_trace.background_span_service:
|
2286
|
+
# Get the current span from the trace client
|
2287
|
+
current_span_id = current_trace.get_current_span()
|
2288
|
+
if current_span_id and current_span_id in current_trace.span_id_to_span:
|
2289
|
+
completed_span = current_trace.span_id_to_span[current_span_id]
|
2290
|
+
current_trace.background_span_service.queue_span(completed_span, span_state="completed")
|
2291
|
+
|
1370
2292
|
return response
|
1371
2293
|
|
1372
|
-
def _handle_error(span, e, is_async):
|
1373
|
-
"""Handle and record errors"""
|
1374
|
-
call_type = "async" if is_async else "sync"
|
1375
|
-
print(f"Error during wrapped {call_type} API call ({span_name}): {e}")
|
1376
|
-
span.record_output({"error": str(e)})
|
1377
|
-
raise
|
1378
|
-
|
1379
2294
|
# --- Traced Async Functions ---
|
1380
2295
|
async def traced_create_async(*args, **kwargs):
|
1381
|
-
current_trace =
|
2296
|
+
current_trace = _get_current_trace()
|
1382
2297
|
if not current_trace:
|
1383
2298
|
return await original_create(*args, **kwargs)
|
1384
2299
|
|
@@ -1389,11 +2304,12 @@ def wrap(client: Any) -> Any:
|
|
1389
2304
|
response_or_iterator = await original_create(*args, **kwargs)
|
1390
2305
|
return _format_and_record_output(span, response_or_iterator, is_streaming, True, False)
|
1391
2306
|
except Exception as e:
|
1392
|
-
|
2307
|
+
_capture_exception_for_trace(span, sys.exc_info())
|
2308
|
+
raise e
|
1393
2309
|
|
1394
2310
|
# Async responses for OpenAI clients
|
1395
2311
|
async def traced_response_create_async(*args, **kwargs):
|
1396
|
-
current_trace =
|
2312
|
+
current_trace = _get_current_trace()
|
1397
2313
|
if not current_trace:
|
1398
2314
|
return await original_responses_create(*args, **kwargs)
|
1399
2315
|
|
@@ -1404,11 +2320,12 @@ def wrap(client: Any) -> Any:
|
|
1404
2320
|
response_or_iterator = await original_responses_create(*args, **kwargs)
|
1405
2321
|
return _format_and_record_output(span, response_or_iterator, is_streaming, True, True)
|
1406
2322
|
except Exception as e:
|
1407
|
-
|
2323
|
+
_capture_exception_for_trace(span, sys.exc_info())
|
2324
|
+
raise e
|
1408
2325
|
|
1409
2326
|
# Function replacing .stream() for async clients
|
1410
2327
|
def traced_stream_async(*args, **kwargs):
|
1411
|
-
current_trace =
|
2328
|
+
current_trace = _get_current_trace()
|
1412
2329
|
if not current_trace or not original_stream:
|
1413
2330
|
return original_stream(*args, **kwargs)
|
1414
2331
|
|
@@ -1424,7 +2341,7 @@ def wrap(client: Any) -> Any:
|
|
1424
2341
|
|
1425
2342
|
# --- Traced Sync Functions ---
|
1426
2343
|
def traced_create_sync(*args, **kwargs):
|
1427
|
-
current_trace =
|
2344
|
+
current_trace = _get_current_trace()
|
1428
2345
|
if not current_trace:
|
1429
2346
|
return original_create(*args, **kwargs)
|
1430
2347
|
|
@@ -1435,10 +2352,11 @@ def wrap(client: Any) -> Any:
|
|
1435
2352
|
response_or_iterator = original_create(*args, **kwargs)
|
1436
2353
|
return _format_and_record_output(span, response_or_iterator, is_streaming, False, False)
|
1437
2354
|
except Exception as e:
|
1438
|
-
|
2355
|
+
_capture_exception_for_trace(span, sys.exc_info())
|
2356
|
+
raise e
|
1439
2357
|
|
1440
2358
|
def traced_response_create_sync(*args, **kwargs):
|
1441
|
-
current_trace =
|
2359
|
+
current_trace = _get_current_trace()
|
1442
2360
|
if not current_trace:
|
1443
2361
|
return original_responses_create(*args, **kwargs)
|
1444
2362
|
|
@@ -1449,11 +2367,12 @@ def wrap(client: Any) -> Any:
|
|
1449
2367
|
response_or_iterator = original_responses_create(*args, **kwargs)
|
1450
2368
|
return _format_and_record_output(span, response_or_iterator, is_streaming, False, True)
|
1451
2369
|
except Exception as e:
|
1452
|
-
|
2370
|
+
_capture_exception_for_trace(span, sys.exc_info())
|
2371
|
+
raise e
|
1453
2372
|
|
1454
2373
|
# Function replacing sync .stream()
|
1455
2374
|
def traced_stream_sync(*args, **kwargs):
|
1456
|
-
current_trace =
|
2375
|
+
current_trace = _get_current_trace()
|
1457
2376
|
if not current_trace or not original_stream:
|
1458
2377
|
return original_stream(*args, **kwargs)
|
1459
2378
|
|
@@ -1472,6 +2391,8 @@ def wrap(client: Any) -> Any:
|
|
1472
2391
|
client.chat.completions.create = traced_create_async
|
1473
2392
|
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
1474
2393
|
client.responses.create = traced_response_create_async
|
2394
|
+
if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
|
2395
|
+
client.beta.chat.completions.parse = traced_create_async
|
1475
2396
|
elif isinstance(client, AsyncAnthropic):
|
1476
2397
|
client.messages.create = traced_create_async
|
1477
2398
|
if original_stream:
|
@@ -1482,6 +2403,8 @@ def wrap(client: Any) -> Any:
|
|
1482
2403
|
client.chat.completions.create = traced_create_sync
|
1483
2404
|
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
1484
2405
|
client.responses.create = traced_response_create_sync
|
2406
|
+
if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
|
2407
|
+
client.beta.chat.completions.parse = traced_create_sync
|
1485
2408
|
elif isinstance(client, Anthropic):
|
1486
2409
|
client.messages.create = traced_create_sync
|
1487
2410
|
if original_stream:
|
@@ -1808,6 +2731,15 @@ def _sync_stream_wrapper(
|
|
1808
2731
|
# Update the trace entry with the accumulated content and usage
|
1809
2732
|
span.output = "".join(content_parts)
|
1810
2733
|
span.usage = final_usage
|
2734
|
+
|
2735
|
+
# Queue the completed LLM span now that streaming is done and all data is available
|
2736
|
+
# Note: We need to get the TraceClient that owns this span to access the background service
|
2737
|
+
# We can find this through the tracer singleton since spans are associated with traces
|
2738
|
+
from judgeval.common.tracer import Tracer
|
2739
|
+
tracer_instance = Tracer._instance
|
2740
|
+
if tracer_instance and tracer_instance.background_span_service:
|
2741
|
+
tracer_instance.background_span_service.queue_span(span, span_state="completed")
|
2742
|
+
|
1811
2743
|
# Note: We might need to adjust _serialize_output if this dict causes issues,
|
1812
2744
|
# but Pydantic's model_dump should handle dicts.
|
1813
2745
|
|
@@ -1892,6 +2824,12 @@ async def _async_stream_wrapper(
|
|
1892
2824
|
span.usage = usage_info
|
1893
2825
|
start_ts = getattr(span, 'created_at', time.time())
|
1894
2826
|
span.duration = time.time() - start_ts
|
2827
|
+
|
2828
|
+
# Queue the completed LLM span now that async streaming is done and all data is available
|
2829
|
+
from judgeval.common.tracer import Tracer
|
2830
|
+
tracer_instance = Tracer._instance
|
2831
|
+
if tracer_instance and tracer_instance.background_span_service:
|
2832
|
+
tracer_instance.background_span_service.queue_span(span, span_state="completed")
|
1895
2833
|
# else: # Handle error case if necessary, but remove debug print
|
1896
2834
|
|
1897
2835
|
def cost_per_token(*args, **kwargs):
|
@@ -1940,12 +2878,12 @@ class _BaseStreamManagerWrapper:
|
|
1940
2878
|
|
1941
2879
|
class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
|
1942
2880
|
async def __aenter__(self):
|
1943
|
-
self._parent_span_id_at_entry =
|
2881
|
+
self._parent_span_id_at_entry = self._trace_client.get_current_span()
|
1944
2882
|
if not self._trace_client:
|
1945
2883
|
return await self._original_manager.__aenter__()
|
1946
2884
|
|
1947
2885
|
span_id, span = self._create_span()
|
1948
|
-
self._span_context_token =
|
2886
|
+
self._span_context_token = self._trace_client.set_current_span(span_id)
|
1949
2887
|
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
1950
2888
|
|
1951
2889
|
# Call the original __aenter__ and expect it to be an async generator
|
@@ -1955,20 +2893,20 @@ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncC
|
|
1955
2893
|
|
1956
2894
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
1957
2895
|
if hasattr(self, '_span_context_token'):
|
1958
|
-
span_id =
|
2896
|
+
span_id = self._trace_client.get_current_span()
|
1959
2897
|
self._finalize_span(span_id)
|
1960
|
-
|
2898
|
+
self._trace_client.reset_current_span(self._span_context_token)
|
1961
2899
|
delattr(self, '_span_context_token')
|
1962
2900
|
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
1963
2901
|
|
1964
2902
|
class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
|
1965
2903
|
def __enter__(self):
|
1966
|
-
self._parent_span_id_at_entry =
|
2904
|
+
self._parent_span_id_at_entry = self._trace_client.get_current_span()
|
1967
2905
|
if not self._trace_client:
|
1968
2906
|
return self._original_manager.__enter__()
|
1969
2907
|
|
1970
2908
|
span_id, span = self._create_span()
|
1971
|
-
self._span_context_token =
|
2909
|
+
self._span_context_token = self._trace_client.set_current_span(span_id)
|
1972
2910
|
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
1973
2911
|
|
1974
2912
|
raw_iterator = self._original_manager.__enter__()
|
@@ -1977,9 +2915,9 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
|
|
1977
2915
|
|
1978
2916
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
1979
2917
|
if hasattr(self, '_span_context_token'):
|
1980
|
-
span_id =
|
2918
|
+
span_id = self._trace_client.get_current_span()
|
1981
2919
|
self._finalize_span(span_id)
|
1982
|
-
|
2920
|
+
self._trace_client.reset_current_span(self._span_context_token)
|
1983
2921
|
delattr(self, '_span_context_token')
|
1984
2922
|
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
1985
2923
|
|
@@ -1990,10 +2928,12 @@ def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
|
1990
2928
|
Otherwise, returns None.
|
1991
2929
|
"""
|
1992
2930
|
if class_name in class_identifiers:
|
1993
|
-
|
2931
|
+
class_config = class_identifiers[class_name]
|
2932
|
+
attr = class_config['identifier']
|
2933
|
+
|
1994
2934
|
if hasattr(instance, attr):
|
1995
2935
|
instance_name = getattr(instance, attr)
|
1996
2936
|
return instance_name
|
1997
2937
|
else:
|
1998
|
-
raise Exception(f"Attribute {
|
1999
|
-
return None
|
2938
|
+
raise Exception(f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator.")
|
2939
|
+
return None
|