judgeval 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/cli.py +65 -0
- judgeval/common/api/api.py +44 -38
- judgeval/common/api/constants.py +18 -5
- judgeval/common/api/json_encoder.py +8 -9
- judgeval/common/tracer/core.py +448 -256
- judgeval/common/tracer/otel_span_processor.py +1 -1
- judgeval/common/tracer/span_processor.py +1 -1
- judgeval/common/tracer/span_transformer.py +2 -1
- judgeval/common/tracer/trace_manager.py +6 -1
- judgeval/common/trainer/__init__.py +5 -0
- judgeval/common/trainer/config.py +125 -0
- judgeval/common/trainer/console.py +151 -0
- judgeval/common/trainer/trainable_model.py +238 -0
- judgeval/common/trainer/trainer.py +301 -0
- judgeval/data/evaluation_run.py +104 -0
- judgeval/data/judgment_types.py +37 -8
- judgeval/data/trace.py +1 -0
- judgeval/data/trace_run.py +0 -2
- judgeval/integrations/langgraph.py +2 -1
- judgeval/judgment_client.py +90 -135
- judgeval/local_eval_queue.py +3 -5
- judgeval/run_evaluation.py +43 -299
- judgeval/scorers/base_scorer.py +9 -10
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +17 -3
- {judgeval-0.5.0.dist-info → judgeval-0.7.0.dist-info}/METADATA +10 -47
- {judgeval-0.5.0.dist-info → judgeval-0.7.0.dist-info}/RECORD +29 -22
- judgeval-0.7.0.dist-info/entry_points.txt +2 -0
- judgeval/evaluation_run.py +0 -80
- {judgeval-0.5.0.dist-info → judgeval-0.7.0.dist-info}/WHEEL +0 -0
- {judgeval-0.5.0.dist-info → judgeval-0.7.0.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer/core.py
CHANGED
@@ -45,7 +45,7 @@ from judgeval.common.tracer.trace_manager import TraceManagerClient
|
|
45
45
|
|
46
46
|
from judgeval.data import Example, Trace, TraceSpan, TraceUsage
|
47
47
|
from judgeval.scorers import APIScorerConfig, BaseScorer
|
48
|
-
from judgeval.evaluation_run import EvaluationRun
|
48
|
+
from judgeval.data.evaluation_run import EvaluationRun
|
49
49
|
from judgeval.local_eval_queue import LocalEvaluationQueue
|
50
50
|
from judgeval.common.api import JudgmentApiClient
|
51
51
|
from judgeval.common.utils import OptExcInfo, validate_api_key
|
@@ -183,8 +183,10 @@ class TraceClient:
|
|
183
183
|
eval_run_name = (
|
184
184
|
f"{self.name.capitalize()}-{span_id}-{scorer.score_type.capitalize()}"
|
185
185
|
)
|
186
|
-
|
187
|
-
|
186
|
+
hosted_scoring = isinstance(scorer, APIScorerConfig) or (
|
187
|
+
isinstance(scorer, BaseScorer) and scorer.server_hosted
|
188
|
+
)
|
189
|
+
if hosted_scoring:
|
188
190
|
eval_run = EvaluationRun(
|
189
191
|
organization_id=self.tracer.organization_id,
|
190
192
|
project_name=self.project_name,
|
@@ -203,7 +205,7 @@ class TraceClient:
|
|
203
205
|
self.otel_span_processor.queue_evaluation_run(
|
204
206
|
eval_run, span_id=span_id, span_data=current_span
|
205
207
|
)
|
206
|
-
|
208
|
+
else:
|
207
209
|
# Handle custom scorers using local evaluation queue
|
208
210
|
eval_run = EvaluationRun(
|
209
211
|
organization_id=self.tracer.organization_id,
|
@@ -212,9 +214,7 @@ class TraceClient:
|
|
212
214
|
examples=[example],
|
213
215
|
scorers=[scorer],
|
214
216
|
model=model,
|
215
|
-
judgment_api_key=self.tracer.api_key,
|
216
217
|
trace_span_id=span_id,
|
217
|
-
trace_id=self.trace_id,
|
218
218
|
)
|
219
219
|
|
220
220
|
self.add_eval_run(eval_run, start_time)
|
@@ -251,6 +251,14 @@ class TraceClient:
|
|
251
251
|
|
252
252
|
self.otel_span_processor.queue_span_update(span, span_state="agent_name")
|
253
253
|
|
254
|
+
def record_class_name(self, class_name: str):
|
255
|
+
current_span_id = self.get_current_span()
|
256
|
+
if current_span_id:
|
257
|
+
span = self.span_id_to_span[current_span_id]
|
258
|
+
span.class_name = class_name
|
259
|
+
|
260
|
+
self.otel_span_processor.queue_span_update(span, span_state="class_name")
|
261
|
+
|
254
262
|
def record_state_before(self, state: dict):
|
255
263
|
"""Records the agent's state before a tool execution on the current span.
|
256
264
|
|
@@ -277,35 +285,13 @@ class TraceClient:
|
|
277
285
|
|
278
286
|
self.otel_span_processor.queue_span_update(span, span_state="state_after")
|
279
287
|
|
280
|
-
async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
|
281
|
-
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
282
|
-
try:
|
283
|
-
result = await coroutine
|
284
|
-
setattr(span, field, result)
|
285
|
-
|
286
|
-
if field == "output":
|
287
|
-
self.otel_span_processor.queue_span_update(span, span_state="output")
|
288
|
-
|
289
|
-
return result
|
290
|
-
except Exception as e:
|
291
|
-
setattr(span, field, f"Error: {str(e)}")
|
292
|
-
|
293
|
-
if field == "output":
|
294
|
-
self.otel_span_processor.queue_span_update(span, span_state="output")
|
295
|
-
|
296
|
-
raise
|
297
|
-
|
298
288
|
def record_output(self, output: Any):
|
299
289
|
current_span_id = self.get_current_span()
|
300
290
|
if current_span_id:
|
301
291
|
span = self.span_id_to_span[current_span_id]
|
302
|
-
span.output =
|
303
|
-
|
304
|
-
if inspect.iscoroutine(output):
|
305
|
-
asyncio.create_task(self._update_coroutine(span, output, "output"))
|
292
|
+
span.output = output
|
306
293
|
|
307
|
-
|
308
|
-
self.otel_span_processor.queue_span_update(span, span_state="output")
|
294
|
+
self.otel_span_processor.queue_span_update(span, span_state="output")
|
309
295
|
|
310
296
|
return span
|
311
297
|
return None
|
@@ -642,6 +628,7 @@ class _DeepTracer:
|
|
642
628
|
|
643
629
|
qual_name = self._get_qual_name(frame)
|
644
630
|
instance_name = None
|
631
|
+
class_name = None
|
645
632
|
if "self" in frame.f_locals:
|
646
633
|
instance = frame.f_locals["self"]
|
647
634
|
class_name = instance.__class__.__name__
|
@@ -715,6 +702,7 @@ class _DeepTracer:
|
|
715
702
|
parent_span_id=parent_span_id,
|
716
703
|
function=qual_name,
|
717
704
|
agent_name=instance_name,
|
705
|
+
class_name=class_name,
|
718
706
|
)
|
719
707
|
current_trace.add_span(span)
|
720
708
|
|
@@ -827,6 +815,8 @@ class Tracer:
|
|
827
815
|
== "true",
|
828
816
|
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
|
829
817
|
== "true",
|
818
|
+
show_trace_urls: bool = os.getenv("JUDGMENT_SHOW_TRACE_URLS", "true").lower()
|
819
|
+
== "true",
|
830
820
|
# S3 configuration
|
831
821
|
use_s3: bool = False,
|
832
822
|
s3_bucket_name: Optional[str] = None,
|
@@ -871,6 +861,7 @@ class Tracer:
|
|
871
861
|
self.traces: List[Trace] = []
|
872
862
|
self.enable_monitoring: bool = enable_monitoring
|
873
863
|
self.enable_evaluations: bool = enable_evaluations
|
864
|
+
self.show_trace_urls: bool = show_trace_urls
|
874
865
|
self.class_identifiers: Dict[
|
875
866
|
str, str
|
876
867
|
] = {} # Dictionary to store class identifiers
|
@@ -1063,10 +1054,10 @@ class Tracer:
|
|
1063
1054
|
# Reset the context variable
|
1064
1055
|
self.reset_current_trace(token)
|
1065
1056
|
|
1066
|
-
def
|
1057
|
+
def agent(
|
1067
1058
|
self,
|
1068
|
-
identifier: str,
|
1069
|
-
track_state: bool = False,
|
1059
|
+
identifier: Optional[str] = None,
|
1060
|
+
track_state: Optional[bool] = False,
|
1070
1061
|
track_attributes: Optional[List[str]] = None,
|
1071
1062
|
field_mappings: Optional[Dict[str, str]] = None,
|
1072
1063
|
):
|
@@ -1104,11 +1095,18 @@ class Tracer:
|
|
1104
1095
|
"track_state": track_state,
|
1105
1096
|
"track_attributes": track_attributes,
|
1106
1097
|
"field_mappings": field_mappings or {},
|
1098
|
+
"class_name": class_name,
|
1107
1099
|
}
|
1108
1100
|
return cls
|
1109
1101
|
|
1110
1102
|
return decorator
|
1111
1103
|
|
1104
|
+
def identify(self, *args, **kwargs):
|
1105
|
+
judgeval_logger.warning(
|
1106
|
+
"identify() is deprecated and may not be supported in future versions of judgeval. Use the agent() decorator instead."
|
1107
|
+
)
|
1108
|
+
return self.agent(*args, **kwargs)
|
1109
|
+
|
1112
1110
|
def _capture_instance_state(
|
1113
1111
|
self, instance: Any, class_config: Dict[str, Any]
|
1114
1112
|
) -> Dict[str, Any]:
|
@@ -1213,125 +1211,256 @@ class Tracer:
|
|
1213
1211
|
except Exception:
|
1214
1212
|
return func
|
1215
1213
|
|
1216
|
-
|
1214
|
+
def _record_span_data(span, args, kwargs):
|
1215
|
+
"""Helper function to record inputs, agent info, and state on a span."""
|
1216
|
+
# Get class and agent info
|
1217
|
+
class_name = None
|
1218
|
+
agent_name = None
|
1219
|
+
if args and hasattr(args[0], "__class__"):
|
1220
|
+
class_name = args[0].__class__.__name__
|
1221
|
+
agent_name = get_instance_prefixed_name(
|
1222
|
+
args[0], class_name, self.class_identifiers
|
1223
|
+
)
|
1217
1224
|
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1225
|
+
# Record inputs, agent name, class name
|
1226
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1227
|
+
span.record_input(inputs)
|
1228
|
+
if agent_name:
|
1229
|
+
span.record_agent_name(agent_name)
|
1230
|
+
if class_name and class_name in self.class_identifiers:
|
1231
|
+
span.record_class_name(class_name)
|
1232
|
+
|
1233
|
+
# Capture state before execution
|
1234
|
+
self._conditionally_capture_and_record_state(span, args, is_before=True)
|
1235
|
+
|
1236
|
+
return class_name, agent_name
|
1237
|
+
|
1238
|
+
def _finalize_span_data(span, result, args):
|
1239
|
+
"""Helper function to record outputs and final state on a span."""
|
1240
|
+
# Record output
|
1241
|
+
span.record_output(result)
|
1242
|
+
|
1243
|
+
# Capture state after execution
|
1244
|
+
self._conditionally_capture_and_record_state(span, args, is_before=False)
|
1245
|
+
|
1246
|
+
def _cleanup_trace(current_trace, trace_token, wrapper_type="function"):
|
1247
|
+
"""Helper function to handle trace cleanup in finally blocks."""
|
1248
|
+
try:
|
1249
|
+
trace_id, server_response = current_trace.save(final_save=True)
|
1250
|
+
|
1251
|
+
complete_trace_data = {
|
1252
|
+
"trace_id": current_trace.trace_id,
|
1253
|
+
"name": current_trace.name,
|
1254
|
+
"project_name": current_trace.project_name,
|
1255
|
+
"created_at": datetime.fromtimestamp(
|
1256
|
+
current_trace.start_time or time.time(),
|
1257
|
+
timezone.utc,
|
1258
|
+
).isoformat(),
|
1259
|
+
"duration": current_trace.get_duration(),
|
1260
|
+
"trace_spans": [
|
1261
|
+
span.model_dump() for span in current_trace.trace_spans
|
1262
|
+
],
|
1263
|
+
"evaluation_runs": [
|
1264
|
+
run.model_dump() for run in current_trace.evaluation_runs
|
1265
|
+
],
|
1266
|
+
"offline_mode": self.offline_mode,
|
1267
|
+
"parent_trace_id": current_trace.parent_trace_id,
|
1268
|
+
"parent_name": current_trace.parent_name,
|
1269
|
+
"customer_id": current_trace.customer_id,
|
1270
|
+
"tags": current_trace.tags,
|
1271
|
+
"metadata": current_trace.metadata,
|
1272
|
+
"update_id": current_trace.update_id,
|
1273
|
+
}
|
1274
|
+
self.traces.append(complete_trace_data)
|
1275
|
+
self.reset_current_trace(trace_token)
|
1276
|
+
except Exception as e:
|
1277
|
+
judgeval_logger.warning(f"Issue with {wrapper_type} cleanup: {e}")
|
1278
|
+
|
1279
|
+
def _execute_in_span(
|
1280
|
+
current_trace, span_name, span_type, execution_func, args, kwargs
|
1281
|
+
):
|
1282
|
+
"""Helper function to execute code within a span context."""
|
1283
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1284
|
+
_record_span_data(span, args, kwargs)
|
1285
|
+
|
1286
|
+
try:
|
1287
|
+
result = execution_func()
|
1288
|
+
_finalize_span_data(span, result, args)
|
1289
|
+
return result
|
1290
|
+
except Exception as e:
|
1291
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1292
|
+
raise e
|
1293
|
+
|
1294
|
+
async def _execute_in_span_async(
|
1295
|
+
current_trace, span_name, span_type, async_execution_func, args, kwargs
|
1296
|
+
):
|
1297
|
+
"""Helper function to execute async code within a span context."""
|
1298
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1299
|
+
_record_span_data(span, args, kwargs)
|
1300
|
+
|
1301
|
+
try:
|
1302
|
+
result = await async_execution_func()
|
1303
|
+
_finalize_span_data(span, result, args)
|
1304
|
+
return result
|
1305
|
+
except Exception as e:
|
1306
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1307
|
+
raise e
|
1308
|
+
|
1309
|
+
def _create_new_trace(self, span_name):
|
1310
|
+
"""Helper function to create a new trace and set it as current."""
|
1311
|
+
trace_id = str(uuid.uuid4())
|
1312
|
+
project = self.project_name
|
1313
|
+
|
1314
|
+
current_trace = TraceClient(
|
1315
|
+
self,
|
1316
|
+
trace_id,
|
1317
|
+
span_name,
|
1318
|
+
project_name=project,
|
1319
|
+
enable_monitoring=self.enable_monitoring,
|
1320
|
+
enable_evaluations=self.enable_evaluations,
|
1321
|
+
)
|
1322
|
+
|
1323
|
+
trace_token = self.set_current_trace(current_trace)
|
1324
|
+
return current_trace, trace_token
|
1325
|
+
|
1326
|
+
def _execute_with_auto_trace_creation(
|
1327
|
+
span_name, span_type, execution_func, args, kwargs
|
1328
|
+
):
|
1329
|
+
"""Helper function that handles automatic trace creation and span execution."""
|
1330
|
+
current_trace = self.get_current_trace()
|
1331
|
+
|
1332
|
+
if not current_trace:
|
1333
|
+
current_trace, trace_token = _create_new_trace(self, span_name)
|
1224
1334
|
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1335
|
+
try:
|
1336
|
+
result = _execute_in_span(
|
1337
|
+
current_trace,
|
1338
|
+
span_name,
|
1339
|
+
span_type,
|
1340
|
+
execution_func,
|
1341
|
+
args,
|
1342
|
+
kwargs,
|
1229
1343
|
)
|
1344
|
+
return result
|
1345
|
+
finally:
|
1346
|
+
# Cleanup the trace we created
|
1347
|
+
_cleanup_trace(current_trace, trace_token, "auto_trace")
|
1348
|
+
else:
|
1349
|
+
# Use existing trace
|
1350
|
+
return _execute_in_span(
|
1351
|
+
current_trace, span_name, span_type, execution_func, args, kwargs
|
1352
|
+
)
|
1230
1353
|
|
1231
|
-
|
1354
|
+
async def _execute_with_auto_trace_creation_async(
|
1355
|
+
span_name, span_type, async_execution_func, args, kwargs
|
1356
|
+
):
|
1357
|
+
"""Helper function that handles automatic trace creation and async span execution."""
|
1358
|
+
current_trace = self.get_current_trace()
|
1232
1359
|
|
1233
|
-
|
1234
|
-
|
1235
|
-
project = self.project_name
|
1360
|
+
if not current_trace:
|
1361
|
+
current_trace, trace_token = _create_new_trace(self, span_name)
|
1236
1362
|
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1363
|
+
try:
|
1364
|
+
result = await _execute_in_span_async(
|
1365
|
+
current_trace,
|
1240
1366
|
span_name,
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1367
|
+
span_type,
|
1368
|
+
async_execution_func,
|
1369
|
+
args,
|
1370
|
+
kwargs,
|
1244
1371
|
)
|
1372
|
+
return result
|
1373
|
+
finally:
|
1374
|
+
# Cleanup the trace we created
|
1375
|
+
_cleanup_trace(current_trace, trace_token, "async_auto_trace")
|
1376
|
+
else:
|
1377
|
+
# Use existing trace
|
1378
|
+
return await _execute_in_span_async(
|
1379
|
+
current_trace,
|
1380
|
+
span_name,
|
1381
|
+
span_type,
|
1382
|
+
async_execution_func,
|
1383
|
+
args,
|
1384
|
+
kwargs,
|
1385
|
+
)
|
1245
1386
|
|
1246
|
-
|
1387
|
+
# Check for generator functions first
|
1388
|
+
if inspect.isgeneratorfunction(func):
|
1247
1389
|
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1390
|
+
@functools.wraps(func)
|
1391
|
+
def generator_wrapper(*args, **kwargs):
|
1392
|
+
# Get the generator from the original function
|
1393
|
+
generator = func(*args, **kwargs)
|
1394
|
+
|
1395
|
+
# Create wrapper generator that creates spans for each yield
|
1396
|
+
def traced_generator():
|
1397
|
+
while True:
|
1398
|
+
try:
|
1399
|
+
# Handle automatic trace creation and span execution
|
1400
|
+
item = _execute_with_auto_trace_creation(
|
1401
|
+
original_span_name,
|
1402
|
+
span_type,
|
1403
|
+
lambda: next(generator),
|
1404
|
+
args,
|
1405
|
+
kwargs,
|
1257
1406
|
)
|
1407
|
+
yield item
|
1408
|
+
except StopIteration:
|
1409
|
+
break
|
1258
1410
|
|
1259
|
-
|
1260
|
-
if self.deep_tracing:
|
1261
|
-
with _DeepTracer(self):
|
1262
|
-
result = await func(*args, **kwargs)
|
1263
|
-
else:
|
1264
|
-
result = await func(*args, **kwargs)
|
1265
|
-
except Exception as e:
|
1266
|
-
_capture_exception_for_trace(
|
1267
|
-
current_trace, sys.exc_info()
|
1268
|
-
)
|
1269
|
-
raise e
|
1411
|
+
return traced_generator()
|
1270
1412
|
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1413
|
+
return generator_wrapper
|
1414
|
+
|
1415
|
+
# Check for async generator functions
|
1416
|
+
elif inspect.isasyncgenfunction(func):
|
1274
1417
|
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1418
|
+
@functools.wraps(func)
|
1419
|
+
def async_generator_wrapper(*args, **kwargs):
|
1420
|
+
# Get the async generator from the original function
|
1421
|
+
async_generator = func(*args, **kwargs)
|
1422
|
+
|
1423
|
+
# Create wrapper async generator that creates spans for each yield
|
1424
|
+
async def traced_async_generator():
|
1425
|
+
while True:
|
1278
1426
|
try:
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
"duration": current_trace.get_duration(),
|
1287
|
-
"trace_spans": [
|
1288
|
-
span.model_dump()
|
1289
|
-
for span in current_trace.trace_spans
|
1290
|
-
],
|
1291
|
-
"offline_mode": self.offline_mode,
|
1292
|
-
"parent_trace_id": current_trace.parent_trace_id,
|
1293
|
-
"parent_name": current_trace.parent_name,
|
1294
|
-
}
|
1295
|
-
|
1296
|
-
trace_id, server_response = current_trace.save(
|
1297
|
-
final_save=True
|
1427
|
+
# Handle automatic trace creation and span execution
|
1428
|
+
item = await _execute_with_auto_trace_creation_async(
|
1429
|
+
original_span_name,
|
1430
|
+
span_type,
|
1431
|
+
lambda: async_generator.__anext__(),
|
1432
|
+
args,
|
1433
|
+
kwargs,
|
1298
1434
|
)
|
1435
|
+
if inspect.iscoroutine(item):
|
1436
|
+
item = await item
|
1437
|
+
yield item
|
1438
|
+
except StopAsyncIteration:
|
1439
|
+
break
|
1299
1440
|
|
1300
|
-
|
1441
|
+
return traced_async_generator()
|
1301
1442
|
|
1302
|
-
|
1303
|
-
except Exception as e:
|
1304
|
-
judgeval_logger.warning(f"Issue with async_wrapper: {e}")
|
1305
|
-
pass
|
1306
|
-
else:
|
1307
|
-
with current_trace.span(span_name, span_type=span_type) as span:
|
1308
|
-
inputs = combine_args_kwargs(func, args, kwargs)
|
1309
|
-
span.record_input(inputs)
|
1310
|
-
if agent_name:
|
1311
|
-
span.record_agent_name(agent_name)
|
1312
|
-
|
1313
|
-
# Capture state before execution
|
1314
|
-
self._conditionally_capture_and_record_state(
|
1315
|
-
span, args, is_before=True
|
1316
|
-
)
|
1443
|
+
return async_generator_wrapper
|
1317
1444
|
|
1318
|
-
|
1319
|
-
if self.deep_tracing:
|
1320
|
-
with _DeepTracer(self):
|
1321
|
-
result = await func(*args, **kwargs)
|
1322
|
-
else:
|
1323
|
-
result = await func(*args, **kwargs)
|
1324
|
-
except Exception as e:
|
1325
|
-
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1326
|
-
raise e
|
1327
|
-
|
1328
|
-
# Capture state after execution
|
1329
|
-
self._conditionally_capture_and_record_state(
|
1330
|
-
span, args, is_before=False
|
1331
|
-
)
|
1445
|
+
elif asyncio.iscoroutinefunction(func):
|
1332
1446
|
|
1333
|
-
|
1334
|
-
|
1447
|
+
@functools.wraps(func)
|
1448
|
+
async def async_wrapper(*args, **kwargs):
|
1449
|
+
nonlocal original_span_name
|
1450
|
+
span_name = original_span_name
|
1451
|
+
|
1452
|
+
async def async_execution():
|
1453
|
+
if self.deep_tracing:
|
1454
|
+
with _DeepTracer(self):
|
1455
|
+
return await func(*args, **kwargs)
|
1456
|
+
else:
|
1457
|
+
return await func(*args, **kwargs)
|
1458
|
+
|
1459
|
+
result = await _execute_with_auto_trace_creation_async(
|
1460
|
+
span_name, span_type, async_execution, args, kwargs
|
1461
|
+
)
|
1462
|
+
|
1463
|
+
return result
|
1335
1464
|
|
1336
1465
|
return async_wrapper
|
1337
1466
|
else:
|
@@ -1339,122 +1468,18 @@ class Tracer:
|
|
1339
1468
|
@functools.wraps(func)
|
1340
1469
|
def wrapper(*args, **kwargs):
|
1341
1470
|
nonlocal original_span_name
|
1342
|
-
class_name = None
|
1343
1471
|
span_name = original_span_name
|
1344
|
-
agent_name = None
|
1345
|
-
if args and hasattr(args[0], "__class__"):
|
1346
|
-
class_name = args[0].__class__.__name__
|
1347
|
-
agent_name = get_instance_prefixed_name(
|
1348
|
-
args[0], class_name, self.class_identifiers
|
1349
|
-
)
|
1350
|
-
# Get current trace from context
|
1351
|
-
current_trace = self.get_current_trace()
|
1352
|
-
|
1353
|
-
# If there's no current trace, create a root trace
|
1354
|
-
if not current_trace:
|
1355
|
-
trace_id = str(uuid.uuid4())
|
1356
|
-
project = self.project_name
|
1357
|
-
|
1358
|
-
# Create a new trace client to serve as the root
|
1359
|
-
current_trace = TraceClient(
|
1360
|
-
self,
|
1361
|
-
trace_id,
|
1362
|
-
span_name,
|
1363
|
-
project_name=project,
|
1364
|
-
enable_monitoring=self.enable_monitoring,
|
1365
|
-
enable_evaluations=self.enable_evaluations,
|
1366
|
-
)
|
1367
|
-
|
1368
|
-
trace_token = self.set_current_trace(current_trace)
|
1369
|
-
|
1370
|
-
try:
|
1371
|
-
with current_trace.span(span_name, span_type=span_type) as span:
|
1372
|
-
# Record inputs
|
1373
|
-
inputs = combine_args_kwargs(func, args, kwargs)
|
1374
|
-
span.record_input(inputs)
|
1375
|
-
if agent_name:
|
1376
|
-
span.record_agent_name(agent_name)
|
1377
|
-
# Capture state before execution
|
1378
|
-
self._conditionally_capture_and_record_state(
|
1379
|
-
span, args, is_before=True
|
1380
|
-
)
|
1381
|
-
|
1382
|
-
try:
|
1383
|
-
if self.deep_tracing:
|
1384
|
-
with _DeepTracer(self):
|
1385
|
-
result = func(*args, **kwargs)
|
1386
|
-
else:
|
1387
|
-
result = func(*args, **kwargs)
|
1388
|
-
except Exception as e:
|
1389
|
-
_capture_exception_for_trace(
|
1390
|
-
current_trace, sys.exc_info()
|
1391
|
-
)
|
1392
|
-
raise e
|
1393
1472
|
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
)
|
1398
|
-
|
1399
|
-
|
1400
|
-
span.record_output(result)
|
1401
|
-
return result
|
1402
|
-
finally:
|
1403
|
-
try:
|
1404
|
-
trace_id, server_response = current_trace.save(
|
1405
|
-
final_save=True
|
1406
|
-
)
|
1473
|
+
def sync_execution():
|
1474
|
+
if self.deep_tracing:
|
1475
|
+
with _DeepTracer(self):
|
1476
|
+
return func(*args, **kwargs)
|
1477
|
+
else:
|
1478
|
+
return func(*args, **kwargs)
|
1407
1479
|
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
"created_at": datetime.fromtimestamp(
|
1412
|
-
current_trace.start_time or time.time(),
|
1413
|
-
timezone.utc,
|
1414
|
-
).isoformat(),
|
1415
|
-
"duration": current_trace.get_duration(),
|
1416
|
-
"trace_spans": [
|
1417
|
-
span.model_dump()
|
1418
|
-
for span in current_trace.trace_spans
|
1419
|
-
],
|
1420
|
-
"offline_mode": self.offline_mode,
|
1421
|
-
"parent_trace_id": current_trace.parent_trace_id,
|
1422
|
-
"parent_name": current_trace.parent_name,
|
1423
|
-
}
|
1424
|
-
self.traces.append(complete_trace_data)
|
1425
|
-
self.reset_current_trace(trace_token)
|
1426
|
-
except Exception as e:
|
1427
|
-
judgeval_logger.warning(f"Issue with save: {e}")
|
1428
|
-
pass
|
1429
|
-
else:
|
1430
|
-
with current_trace.span(span_name, span_type=span_type) as span:
|
1431
|
-
inputs = combine_args_kwargs(func, args, kwargs)
|
1432
|
-
span.record_input(inputs)
|
1433
|
-
if agent_name:
|
1434
|
-
span.record_agent_name(agent_name)
|
1435
|
-
|
1436
|
-
# Capture state before execution
|
1437
|
-
self._conditionally_capture_and_record_state(
|
1438
|
-
span, args, is_before=True
|
1439
|
-
)
|
1440
|
-
|
1441
|
-
try:
|
1442
|
-
if self.deep_tracing:
|
1443
|
-
with _DeepTracer(self):
|
1444
|
-
result = func(*args, **kwargs)
|
1445
|
-
else:
|
1446
|
-
result = func(*args, **kwargs)
|
1447
|
-
except Exception as e:
|
1448
|
-
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1449
|
-
raise e
|
1450
|
-
|
1451
|
-
# Capture state after execution
|
1452
|
-
self._conditionally_capture_and_record_state(
|
1453
|
-
span, args, is_before=False
|
1454
|
-
)
|
1455
|
-
|
1456
|
-
span.record_output(result)
|
1457
|
-
return result
|
1480
|
+
return _execute_with_auto_trace_creation(
|
1481
|
+
span_name, span_type, sync_execution, args, kwargs
|
1482
|
+
)
|
1458
1483
|
|
1459
1484
|
return wrapper
|
1460
1485
|
|
@@ -1709,6 +1734,93 @@ class Tracer:
|
|
1709
1734
|
f"Error during background service shutdown: {e}"
|
1710
1735
|
)
|
1711
1736
|
|
1737
|
+
def trace_to_message_history(
|
1738
|
+
self, trace: Union[Trace, TraceClient]
|
1739
|
+
) -> List[Dict[str, str]]:
|
1740
|
+
"""
|
1741
|
+
Extract message history from a trace for training purposes.
|
1742
|
+
|
1743
|
+
This method processes trace spans to reconstruct the conversation flow,
|
1744
|
+
extracting messages in chronological order from LLM, user, and tool spans.
|
1745
|
+
|
1746
|
+
Args:
|
1747
|
+
trace: Trace or TraceClient instance to extract messages from
|
1748
|
+
|
1749
|
+
Returns:
|
1750
|
+
List of message dictionaries with 'role' and 'content' keys
|
1751
|
+
|
1752
|
+
Raises:
|
1753
|
+
ValueError: If no trace is provided
|
1754
|
+
"""
|
1755
|
+
if not trace:
|
1756
|
+
raise ValueError("No trace provided")
|
1757
|
+
|
1758
|
+
# Handle both Trace and TraceClient objects
|
1759
|
+
if isinstance(trace, TraceClient):
|
1760
|
+
spans = trace.trace_spans
|
1761
|
+
else:
|
1762
|
+
spans = trace.trace_spans if hasattr(trace, "trace_spans") else []
|
1763
|
+
|
1764
|
+
messages = []
|
1765
|
+
first_found = False
|
1766
|
+
|
1767
|
+
# Process spans in chronological order
|
1768
|
+
for span in sorted(
|
1769
|
+
spans, key=lambda s: s.created_at if hasattr(s, "created_at") else 0
|
1770
|
+
):
|
1771
|
+
# Skip spans without output (except for first LLM span which may have input messages)
|
1772
|
+
if span.output is None and span.span_type != "llm":
|
1773
|
+
continue
|
1774
|
+
|
1775
|
+
if span.span_type == "llm":
|
1776
|
+
# For the first LLM span, extract input messages (system + user prompts)
|
1777
|
+
if not first_found and hasattr(span, "inputs") and span.inputs:
|
1778
|
+
input_messages = span.inputs.get("messages", [])
|
1779
|
+
if input_messages:
|
1780
|
+
first_found = True
|
1781
|
+
# Add input messages (typically system and user messages)
|
1782
|
+
for msg in input_messages:
|
1783
|
+
if (
|
1784
|
+
isinstance(msg, dict)
|
1785
|
+
and "role" in msg
|
1786
|
+
and "content" in msg
|
1787
|
+
):
|
1788
|
+
messages.append(
|
1789
|
+
{"role": msg["role"], "content": msg["content"]}
|
1790
|
+
)
|
1791
|
+
|
1792
|
+
# Add assistant response from span output
|
1793
|
+
if span.output is not None:
|
1794
|
+
messages.append({"role": "assistant", "content": str(span.output)})
|
1795
|
+
|
1796
|
+
elif span.span_type == "user":
|
1797
|
+
# Add user messages
|
1798
|
+
if span.output is not None:
|
1799
|
+
messages.append({"role": "user", "content": str(span.output)})
|
1800
|
+
|
1801
|
+
elif span.span_type == "tool":
|
1802
|
+
# Add tool responses as user messages (common pattern in training)
|
1803
|
+
if span.output is not None:
|
1804
|
+
messages.append({"role": "user", "content": str(span.output)})
|
1805
|
+
|
1806
|
+
return messages
|
1807
|
+
|
1808
|
+
def get_current_message_history(self) -> List[Dict[str, str]]:
|
1809
|
+
"""
|
1810
|
+
Get message history from the current trace.
|
1811
|
+
|
1812
|
+
Returns:
|
1813
|
+
List of message dictionaries from the current trace context
|
1814
|
+
|
1815
|
+
Raises:
|
1816
|
+
ValueError: If no current trace is found
|
1817
|
+
"""
|
1818
|
+
current_trace = self.get_current_trace()
|
1819
|
+
if not current_trace:
|
1820
|
+
raise ValueError("No current trace found")
|
1821
|
+
|
1822
|
+
return self.trace_to_message_history(current_trace)
|
1823
|
+
|
1712
1824
|
|
1713
1825
|
def _get_current_trace(
|
1714
1826
|
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
|
@@ -1724,7 +1836,7 @@ def wrap(
|
|
1724
1836
|
) -> Any:
|
1725
1837
|
"""
|
1726
1838
|
Wraps an API client to add tracing capabilities.
|
1727
|
-
Supports OpenAI, Together, Anthropic,
|
1839
|
+
Supports OpenAI, Together, Anthropic, Google GenAI clients, and TrainableModel.
|
1728
1840
|
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1729
1841
|
"""
|
1730
1842
|
(
|
@@ -1849,6 +1961,39 @@ def wrap(
|
|
1849
1961
|
setattr(client.chat.completions, "create", wrapped(original_create))
|
1850
1962
|
elif isinstance(client, (groq_AsyncGroq)):
|
1851
1963
|
setattr(client.chat.completions, "create", wrapped_async(original_create))
|
1964
|
+
|
1965
|
+
# Check for TrainableModel from judgeval.common.trainer
|
1966
|
+
try:
|
1967
|
+
from judgeval.common.trainer import TrainableModel
|
1968
|
+
|
1969
|
+
if isinstance(client, TrainableModel):
|
1970
|
+
# Define a wrapper function that can be reapplied to new model instances
|
1971
|
+
def wrap_model_instance(model_instance):
|
1972
|
+
"""Wrap a model instance with tracing functionality"""
|
1973
|
+
if hasattr(model_instance, "chat") and hasattr(
|
1974
|
+
model_instance.chat, "completions"
|
1975
|
+
):
|
1976
|
+
if hasattr(model_instance.chat.completions, "create"):
|
1977
|
+
setattr(
|
1978
|
+
model_instance.chat.completions,
|
1979
|
+
"create",
|
1980
|
+
wrapped(model_instance.chat.completions.create),
|
1981
|
+
)
|
1982
|
+
if hasattr(model_instance.chat.completions, "acreate"):
|
1983
|
+
setattr(
|
1984
|
+
model_instance.chat.completions,
|
1985
|
+
"acreate",
|
1986
|
+
wrapped_async(model_instance.chat.completions.acreate),
|
1987
|
+
)
|
1988
|
+
|
1989
|
+
# Register the wrapper function with the TrainableModel
|
1990
|
+
client._register_tracer_wrapper(wrap_model_instance)
|
1991
|
+
|
1992
|
+
# Apply wrapping to the current model
|
1993
|
+
wrap_model_instance(client._current_model)
|
1994
|
+
except ImportError:
|
1995
|
+
pass # TrainableModel not available
|
1996
|
+
|
1852
1997
|
return client
|
1853
1998
|
|
1854
1999
|
|
@@ -1955,6 +2100,22 @@ def _get_client_config(
|
|
1955
2100
|
return "GROQ_API_CALL", client.chat.completions.create, None, None, None
|
1956
2101
|
elif isinstance(client, (groq_AsyncGroq)):
|
1957
2102
|
return "GROQ_API_CALL", client.chat.completions.create, None, None, None
|
2103
|
+
|
2104
|
+
# Check for TrainableModel
|
2105
|
+
try:
|
2106
|
+
from judgeval.common.trainer import TrainableModel
|
2107
|
+
|
2108
|
+
if isinstance(client, TrainableModel):
|
2109
|
+
return (
|
2110
|
+
"FIREWORKS_TRAINABLE_MODEL_CALL",
|
2111
|
+
client._current_model.chat.completions.create,
|
2112
|
+
None,
|
2113
|
+
None,
|
2114
|
+
None,
|
2115
|
+
)
|
2116
|
+
except ImportError:
|
2117
|
+
pass # TrainableModel not available
|
2118
|
+
|
1958
2119
|
raise ValueError(f"Unsupported client type: {type(client)}")
|
1959
2120
|
|
1960
2121
|
|
@@ -2133,6 +2294,37 @@ def _format_output_data(
|
|
2133
2294
|
cache_creation_input_tokens,
|
2134
2295
|
)
|
2135
2296
|
|
2297
|
+
# Check for TrainableModel
|
2298
|
+
try:
|
2299
|
+
from judgeval.common.trainer import TrainableModel
|
2300
|
+
|
2301
|
+
if isinstance(client, TrainableModel):
|
2302
|
+
# TrainableModel uses Fireworks LLM internally, so response format should be similar to OpenAI
|
2303
|
+
if (
|
2304
|
+
hasattr(response, "model")
|
2305
|
+
and hasattr(response, "usage")
|
2306
|
+
and hasattr(response, "choices")
|
2307
|
+
):
|
2308
|
+
model_name = response.model
|
2309
|
+
prompt_tokens = response.usage.prompt_tokens if response.usage else 0
|
2310
|
+
completion_tokens = (
|
2311
|
+
response.usage.completion_tokens if response.usage else 0
|
2312
|
+
)
|
2313
|
+
message_content = response.choices[0].message.content
|
2314
|
+
|
2315
|
+
# Use LiteLLM cost calculation with fireworks_ai prefix
|
2316
|
+
# LiteLLM supports Fireworks AI models for cost calculation when prefixed with "fireworks_ai/"
|
2317
|
+
fireworks_model_name = f"fireworks_ai/{model_name}"
|
2318
|
+
return message_content, _create_usage(
|
2319
|
+
fireworks_model_name,
|
2320
|
+
prompt_tokens,
|
2321
|
+
completion_tokens,
|
2322
|
+
cache_read_input_tokens,
|
2323
|
+
cache_creation_input_tokens,
|
2324
|
+
)
|
2325
|
+
except ImportError:
|
2326
|
+
pass # TrainableModel not available
|
2327
|
+
|
2136
2328
|
judgeval_logger.warning(f"Unsupported client type: {type(client)}")
|
2137
2329
|
return None, None
|
2138
2330
|
|
@@ -2223,13 +2415,13 @@ def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
|
2223
2415
|
"""
|
2224
2416
|
if class_name in class_identifiers:
|
2225
2417
|
class_config = class_identifiers[class_name]
|
2226
|
-
attr = class_config
|
2227
|
-
|
2228
|
-
|
2229
|
-
|
2230
|
-
|
2231
|
-
|
2232
|
-
|
2233
|
-
|
2234
|
-
|
2235
|
-
|
2418
|
+
attr = class_config.get("identifier")
|
2419
|
+
if attr:
|
2420
|
+
if hasattr(instance, attr) and not callable(getattr(instance, attr)):
|
2421
|
+
instance_name = getattr(instance, attr)
|
2422
|
+
return instance_name
|
2423
|
+
else:
|
2424
|
+
raise Exception(
|
2425
|
+
f"Attribute {attr} does not exist for {class_name}. Check your agent() decorator."
|
2426
|
+
)
|
2427
|
+
return None
|