judgeval 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -45,7 +45,7 @@ from judgeval.common.tracer.trace_manager import TraceManagerClient
45
45
 
46
46
  from judgeval.data import Example, Trace, TraceSpan, TraceUsage
47
47
  from judgeval.scorers import APIScorerConfig, BaseScorer
48
- from judgeval.evaluation_run import EvaluationRun
48
+ from judgeval.data.evaluation_run import EvaluationRun
49
49
  from judgeval.local_eval_queue import LocalEvaluationQueue
50
50
  from judgeval.common.api import JudgmentApiClient
51
51
  from judgeval.common.utils import OptExcInfo, validate_api_key
@@ -183,8 +183,10 @@ class TraceClient:
183
183
  eval_run_name = (
184
184
  f"{self.name.capitalize()}-{span_id}-{scorer.score_type.capitalize()}"
185
185
  )
186
-
187
- if isinstance(scorer, APIScorerConfig):
186
+ hosted_scoring = isinstance(scorer, APIScorerConfig) or (
187
+ isinstance(scorer, BaseScorer) and scorer.server_hosted
188
+ )
189
+ if hosted_scoring:
188
190
  eval_run = EvaluationRun(
189
191
  organization_id=self.tracer.organization_id,
190
192
  project_name=self.project_name,
@@ -203,7 +205,7 @@ class TraceClient:
203
205
  self.otel_span_processor.queue_evaluation_run(
204
206
  eval_run, span_id=span_id, span_data=current_span
205
207
  )
206
- elif isinstance(scorer, BaseScorer):
208
+ else:
207
209
  # Handle custom scorers using local evaluation queue
208
210
  eval_run = EvaluationRun(
209
211
  organization_id=self.tracer.organization_id,
@@ -212,9 +214,7 @@ class TraceClient:
212
214
  examples=[example],
213
215
  scorers=[scorer],
214
216
  model=model,
215
- judgment_api_key=self.tracer.api_key,
216
217
  trace_span_id=span_id,
217
- trace_id=self.trace_id,
218
218
  )
219
219
 
220
220
  self.add_eval_run(eval_run, start_time)
@@ -251,6 +251,14 @@ class TraceClient:
251
251
 
252
252
  self.otel_span_processor.queue_span_update(span, span_state="agent_name")
253
253
 
254
+ def record_class_name(self, class_name: str):
255
+ current_span_id = self.get_current_span()
256
+ if current_span_id:
257
+ span = self.span_id_to_span[current_span_id]
258
+ span.class_name = class_name
259
+
260
+ self.otel_span_processor.queue_span_update(span, span_state="class_name")
261
+
254
262
  def record_state_before(self, state: dict):
255
263
  """Records the agent's state before a tool execution on the current span.
256
264
 
@@ -277,35 +285,13 @@ class TraceClient:
277
285
 
278
286
  self.otel_span_processor.queue_span_update(span, span_state="state_after")
279
287
 
280
- async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
281
- """Helper method to update the output of a trace entry once the coroutine completes"""
282
- try:
283
- result = await coroutine
284
- setattr(span, field, result)
285
-
286
- if field == "output":
287
- self.otel_span_processor.queue_span_update(span, span_state="output")
288
-
289
- return result
290
- except Exception as e:
291
- setattr(span, field, f"Error: {str(e)}")
292
-
293
- if field == "output":
294
- self.otel_span_processor.queue_span_update(span, span_state="output")
295
-
296
- raise
297
-
298
288
  def record_output(self, output: Any):
299
289
  current_span_id = self.get_current_span()
300
290
  if current_span_id:
301
291
  span = self.span_id_to_span[current_span_id]
302
- span.output = "<pending>" if inspect.iscoroutine(output) else output
303
-
304
- if inspect.iscoroutine(output):
305
- asyncio.create_task(self._update_coroutine(span, output, "output"))
292
+ span.output = output
306
293
 
307
- if not inspect.iscoroutine(output):
308
- self.otel_span_processor.queue_span_update(span, span_state="output")
294
+ self.otel_span_processor.queue_span_update(span, span_state="output")
309
295
 
310
296
  return span
311
297
  return None
@@ -642,6 +628,7 @@ class _DeepTracer:
642
628
 
643
629
  qual_name = self._get_qual_name(frame)
644
630
  instance_name = None
631
+ class_name = None
645
632
  if "self" in frame.f_locals:
646
633
  instance = frame.f_locals["self"]
647
634
  class_name = instance.__class__.__name__
@@ -715,6 +702,7 @@ class _DeepTracer:
715
702
  parent_span_id=parent_span_id,
716
703
  function=qual_name,
717
704
  agent_name=instance_name,
705
+ class_name=class_name,
718
706
  )
719
707
  current_trace.add_span(span)
720
708
 
@@ -827,6 +815,8 @@ class Tracer:
827
815
  == "true",
828
816
  enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
829
817
  == "true",
818
+ show_trace_urls: bool = os.getenv("JUDGMENT_SHOW_TRACE_URLS", "true").lower()
819
+ == "true",
830
820
  # S3 configuration
831
821
  use_s3: bool = False,
832
822
  s3_bucket_name: Optional[str] = None,
@@ -871,6 +861,7 @@ class Tracer:
871
861
  self.traces: List[Trace] = []
872
862
  self.enable_monitoring: bool = enable_monitoring
873
863
  self.enable_evaluations: bool = enable_evaluations
864
+ self.show_trace_urls: bool = show_trace_urls
874
865
  self.class_identifiers: Dict[
875
866
  str, str
876
867
  ] = {} # Dictionary to store class identifiers
@@ -1063,10 +1054,10 @@ class Tracer:
1063
1054
  # Reset the context variable
1064
1055
  self.reset_current_trace(token)
1065
1056
 
1066
- def identify(
1057
+ def agent(
1067
1058
  self,
1068
- identifier: str,
1069
- track_state: bool = False,
1059
+ identifier: Optional[str] = None,
1060
+ track_state: Optional[bool] = False,
1070
1061
  track_attributes: Optional[List[str]] = None,
1071
1062
  field_mappings: Optional[Dict[str, str]] = None,
1072
1063
  ):
@@ -1104,11 +1095,18 @@ class Tracer:
1104
1095
  "track_state": track_state,
1105
1096
  "track_attributes": track_attributes,
1106
1097
  "field_mappings": field_mappings or {},
1098
+ "class_name": class_name,
1107
1099
  }
1108
1100
  return cls
1109
1101
 
1110
1102
  return decorator
1111
1103
 
1104
+ def identify(self, *args, **kwargs):
1105
+ judgeval_logger.warning(
1106
+ "identify() is deprecated and may not be supported in future versions of judgeval. Use the agent() decorator instead."
1107
+ )
1108
+ return self.agent(*args, **kwargs)
1109
+
1112
1110
  def _capture_instance_state(
1113
1111
  self, instance: Any, class_config: Dict[str, Any]
1114
1112
  ) -> Dict[str, Any]:
@@ -1213,125 +1211,256 @@ class Tracer:
1213
1211
  except Exception:
1214
1212
  return func
1215
1213
 
1216
- if asyncio.iscoroutinefunction(func):
1214
+ def _record_span_data(span, args, kwargs):
1215
+ """Helper function to record inputs, agent info, and state on a span."""
1216
+ # Get class and agent info
1217
+ class_name = None
1218
+ agent_name = None
1219
+ if args and hasattr(args[0], "__class__"):
1220
+ class_name = args[0].__class__.__name__
1221
+ agent_name = get_instance_prefixed_name(
1222
+ args[0], class_name, self.class_identifiers
1223
+ )
1217
1224
 
1218
- @functools.wraps(func)
1219
- async def async_wrapper(*args, **kwargs):
1220
- nonlocal original_span_name
1221
- class_name = None
1222
- span_name = original_span_name
1223
- agent_name = None
1225
+ # Record inputs, agent name, class name
1226
+ inputs = combine_args_kwargs(func, args, kwargs)
1227
+ span.record_input(inputs)
1228
+ if agent_name:
1229
+ span.record_agent_name(agent_name)
1230
+ if class_name and class_name in self.class_identifiers:
1231
+ span.record_class_name(class_name)
1232
+
1233
+ # Capture state before execution
1234
+ self._conditionally_capture_and_record_state(span, args, is_before=True)
1235
+
1236
+ return class_name, agent_name
1237
+
1238
+ def _finalize_span_data(span, result, args):
1239
+ """Helper function to record outputs and final state on a span."""
1240
+ # Record output
1241
+ span.record_output(result)
1242
+
1243
+ # Capture state after execution
1244
+ self._conditionally_capture_and_record_state(span, args, is_before=False)
1245
+
1246
+ def _cleanup_trace(current_trace, trace_token, wrapper_type="function"):
1247
+ """Helper function to handle trace cleanup in finally blocks."""
1248
+ try:
1249
+ trace_id, server_response = current_trace.save(final_save=True)
1250
+
1251
+ complete_trace_data = {
1252
+ "trace_id": current_trace.trace_id,
1253
+ "name": current_trace.name,
1254
+ "project_name": current_trace.project_name,
1255
+ "created_at": datetime.fromtimestamp(
1256
+ current_trace.start_time or time.time(),
1257
+ timezone.utc,
1258
+ ).isoformat(),
1259
+ "duration": current_trace.get_duration(),
1260
+ "trace_spans": [
1261
+ span.model_dump() for span in current_trace.trace_spans
1262
+ ],
1263
+ "evaluation_runs": [
1264
+ run.model_dump() for run in current_trace.evaluation_runs
1265
+ ],
1266
+ "offline_mode": self.offline_mode,
1267
+ "parent_trace_id": current_trace.parent_trace_id,
1268
+ "parent_name": current_trace.parent_name,
1269
+ "customer_id": current_trace.customer_id,
1270
+ "tags": current_trace.tags,
1271
+ "metadata": current_trace.metadata,
1272
+ "update_id": current_trace.update_id,
1273
+ }
1274
+ self.traces.append(complete_trace_data)
1275
+ self.reset_current_trace(trace_token)
1276
+ except Exception as e:
1277
+ judgeval_logger.warning(f"Issue with {wrapper_type} cleanup: {e}")
1278
+
1279
+ def _execute_in_span(
1280
+ current_trace, span_name, span_type, execution_func, args, kwargs
1281
+ ):
1282
+ """Helper function to execute code within a span context."""
1283
+ with current_trace.span(span_name, span_type=span_type) as span:
1284
+ _record_span_data(span, args, kwargs)
1285
+
1286
+ try:
1287
+ result = execution_func()
1288
+ _finalize_span_data(span, result, args)
1289
+ return result
1290
+ except Exception as e:
1291
+ _capture_exception_for_trace(current_trace, sys.exc_info())
1292
+ raise e
1293
+
1294
+ async def _execute_in_span_async(
1295
+ current_trace, span_name, span_type, async_execution_func, args, kwargs
1296
+ ):
1297
+ """Helper function to execute async code within a span context."""
1298
+ with current_trace.span(span_name, span_type=span_type) as span:
1299
+ _record_span_data(span, args, kwargs)
1300
+
1301
+ try:
1302
+ result = await async_execution_func()
1303
+ _finalize_span_data(span, result, args)
1304
+ return result
1305
+ except Exception as e:
1306
+ _capture_exception_for_trace(current_trace, sys.exc_info())
1307
+ raise e
1308
+
1309
+ def _create_new_trace(self, span_name):
1310
+ """Helper function to create a new trace and set it as current."""
1311
+ trace_id = str(uuid.uuid4())
1312
+ project = self.project_name
1313
+
1314
+ current_trace = TraceClient(
1315
+ self,
1316
+ trace_id,
1317
+ span_name,
1318
+ project_name=project,
1319
+ enable_monitoring=self.enable_monitoring,
1320
+ enable_evaluations=self.enable_evaluations,
1321
+ )
1322
+
1323
+ trace_token = self.set_current_trace(current_trace)
1324
+ return current_trace, trace_token
1325
+
1326
+ def _execute_with_auto_trace_creation(
1327
+ span_name, span_type, execution_func, args, kwargs
1328
+ ):
1329
+ """Helper function that handles automatic trace creation and span execution."""
1330
+ current_trace = self.get_current_trace()
1331
+
1332
+ if not current_trace:
1333
+ current_trace, trace_token = _create_new_trace(self, span_name)
1224
1334
 
1225
- if args and hasattr(args[0], "__class__"):
1226
- class_name = args[0].__class__.__name__
1227
- agent_name = get_instance_prefixed_name(
1228
- args[0], class_name, self.class_identifiers
1335
+ try:
1336
+ result = _execute_in_span(
1337
+ current_trace,
1338
+ span_name,
1339
+ span_type,
1340
+ execution_func,
1341
+ args,
1342
+ kwargs,
1229
1343
  )
1344
+ return result
1345
+ finally:
1346
+ # Cleanup the trace we created
1347
+ _cleanup_trace(current_trace, trace_token, "auto_trace")
1348
+ else:
1349
+ # Use existing trace
1350
+ return _execute_in_span(
1351
+ current_trace, span_name, span_type, execution_func, args, kwargs
1352
+ )
1230
1353
 
1231
- current_trace = self.get_current_trace()
1354
+ async def _execute_with_auto_trace_creation_async(
1355
+ span_name, span_type, async_execution_func, args, kwargs
1356
+ ):
1357
+ """Helper function that handles automatic trace creation and async span execution."""
1358
+ current_trace = self.get_current_trace()
1232
1359
 
1233
- if not current_trace:
1234
- trace_id = str(uuid.uuid4())
1235
- project = self.project_name
1360
+ if not current_trace:
1361
+ current_trace, trace_token = _create_new_trace(self, span_name)
1236
1362
 
1237
- current_trace = TraceClient(
1238
- self,
1239
- trace_id,
1363
+ try:
1364
+ result = await _execute_in_span_async(
1365
+ current_trace,
1240
1366
  span_name,
1241
- project_name=project,
1242
- enable_monitoring=self.enable_monitoring,
1243
- enable_evaluations=self.enable_evaluations,
1367
+ span_type,
1368
+ async_execution_func,
1369
+ args,
1370
+ kwargs,
1244
1371
  )
1372
+ return result
1373
+ finally:
1374
+ # Cleanup the trace we created
1375
+ _cleanup_trace(current_trace, trace_token, "async_auto_trace")
1376
+ else:
1377
+ # Use existing trace
1378
+ return await _execute_in_span_async(
1379
+ current_trace,
1380
+ span_name,
1381
+ span_type,
1382
+ async_execution_func,
1383
+ args,
1384
+ kwargs,
1385
+ )
1245
1386
 
1246
- trace_token = self.set_current_trace(current_trace)
1387
+ # Check for generator functions first
1388
+ if inspect.isgeneratorfunction(func):
1247
1389
 
1248
- try:
1249
- with current_trace.span(span_name, span_type=span_type) as span:
1250
- inputs = combine_args_kwargs(func, args, kwargs)
1251
- span.record_input(inputs)
1252
- if agent_name:
1253
- span.record_agent_name(agent_name)
1254
-
1255
- self._conditionally_capture_and_record_state(
1256
- span, args, is_before=True
1390
+ @functools.wraps(func)
1391
+ def generator_wrapper(*args, **kwargs):
1392
+ # Get the generator from the original function
1393
+ generator = func(*args, **kwargs)
1394
+
1395
+ # Create wrapper generator that creates spans for each yield
1396
+ def traced_generator():
1397
+ while True:
1398
+ try:
1399
+ # Handle automatic trace creation and span execution
1400
+ item = _execute_with_auto_trace_creation(
1401
+ original_span_name,
1402
+ span_type,
1403
+ lambda: next(generator),
1404
+ args,
1405
+ kwargs,
1257
1406
  )
1407
+ yield item
1408
+ except StopIteration:
1409
+ break
1258
1410
 
1259
- try:
1260
- if self.deep_tracing:
1261
- with _DeepTracer(self):
1262
- result = await func(*args, **kwargs)
1263
- else:
1264
- result = await func(*args, **kwargs)
1265
- except Exception as e:
1266
- _capture_exception_for_trace(
1267
- current_trace, sys.exc_info()
1268
- )
1269
- raise e
1411
+ return traced_generator()
1270
1412
 
1271
- self._conditionally_capture_and_record_state(
1272
- span, args, is_before=False
1273
- )
1413
+ return generator_wrapper
1414
+
1415
+ # Check for async generator functions
1416
+ elif inspect.isasyncgenfunction(func):
1274
1417
 
1275
- span.record_output(result)
1276
- return result
1277
- finally:
1418
+ @functools.wraps(func)
1419
+ def async_generator_wrapper(*args, **kwargs):
1420
+ # Get the async generator from the original function
1421
+ async_generator = func(*args, **kwargs)
1422
+
1423
+ # Create wrapper async generator that creates spans for each yield
1424
+ async def traced_async_generator():
1425
+ while True:
1278
1426
  try:
1279
- complete_trace_data = {
1280
- "trace_id": current_trace.trace_id,
1281
- "name": current_trace.name,
1282
- "created_at": datetime.fromtimestamp(
1283
- current_trace.start_time or time.time(),
1284
- timezone.utc,
1285
- ).isoformat(),
1286
- "duration": current_trace.get_duration(),
1287
- "trace_spans": [
1288
- span.model_dump()
1289
- for span in current_trace.trace_spans
1290
- ],
1291
- "offline_mode": self.offline_mode,
1292
- "parent_trace_id": current_trace.parent_trace_id,
1293
- "parent_name": current_trace.parent_name,
1294
- }
1295
-
1296
- trace_id, server_response = current_trace.save(
1297
- final_save=True
1427
+ # Handle automatic trace creation and span execution
1428
+ item = await _execute_with_auto_trace_creation_async(
1429
+ original_span_name,
1430
+ span_type,
1431
+ lambda: async_generator.__anext__(),
1432
+ args,
1433
+ kwargs,
1298
1434
  )
1435
+ if inspect.iscoroutine(item):
1436
+ item = await item
1437
+ yield item
1438
+ except StopAsyncIteration:
1439
+ break
1299
1440
 
1300
- self.traces.append(complete_trace_data)
1441
+ return traced_async_generator()
1301
1442
 
1302
- self.reset_current_trace(trace_token)
1303
- except Exception as e:
1304
- judgeval_logger.warning(f"Issue with async_wrapper: {e}")
1305
- pass
1306
- else:
1307
- with current_trace.span(span_name, span_type=span_type) as span:
1308
- inputs = combine_args_kwargs(func, args, kwargs)
1309
- span.record_input(inputs)
1310
- if agent_name:
1311
- span.record_agent_name(agent_name)
1312
-
1313
- # Capture state before execution
1314
- self._conditionally_capture_and_record_state(
1315
- span, args, is_before=True
1316
- )
1443
+ return async_generator_wrapper
1317
1444
 
1318
- try:
1319
- if self.deep_tracing:
1320
- with _DeepTracer(self):
1321
- result = await func(*args, **kwargs)
1322
- else:
1323
- result = await func(*args, **kwargs)
1324
- except Exception as e:
1325
- _capture_exception_for_trace(current_trace, sys.exc_info())
1326
- raise e
1327
-
1328
- # Capture state after execution
1329
- self._conditionally_capture_and_record_state(
1330
- span, args, is_before=False
1331
- )
1445
+ elif asyncio.iscoroutinefunction(func):
1332
1446
 
1333
- span.record_output(result)
1334
- return result
1447
+ @functools.wraps(func)
1448
+ async def async_wrapper(*args, **kwargs):
1449
+ nonlocal original_span_name
1450
+ span_name = original_span_name
1451
+
1452
+ async def async_execution():
1453
+ if self.deep_tracing:
1454
+ with _DeepTracer(self):
1455
+ return await func(*args, **kwargs)
1456
+ else:
1457
+ return await func(*args, **kwargs)
1458
+
1459
+ result = await _execute_with_auto_trace_creation_async(
1460
+ span_name, span_type, async_execution, args, kwargs
1461
+ )
1462
+
1463
+ return result
1335
1464
 
1336
1465
  return async_wrapper
1337
1466
  else:
@@ -1339,122 +1468,18 @@ class Tracer:
1339
1468
  @functools.wraps(func)
1340
1469
  def wrapper(*args, **kwargs):
1341
1470
  nonlocal original_span_name
1342
- class_name = None
1343
1471
  span_name = original_span_name
1344
- agent_name = None
1345
- if args and hasattr(args[0], "__class__"):
1346
- class_name = args[0].__class__.__name__
1347
- agent_name = get_instance_prefixed_name(
1348
- args[0], class_name, self.class_identifiers
1349
- )
1350
- # Get current trace from context
1351
- current_trace = self.get_current_trace()
1352
-
1353
- # If there's no current trace, create a root trace
1354
- if not current_trace:
1355
- trace_id = str(uuid.uuid4())
1356
- project = self.project_name
1357
-
1358
- # Create a new trace client to serve as the root
1359
- current_trace = TraceClient(
1360
- self,
1361
- trace_id,
1362
- span_name,
1363
- project_name=project,
1364
- enable_monitoring=self.enable_monitoring,
1365
- enable_evaluations=self.enable_evaluations,
1366
- )
1367
-
1368
- trace_token = self.set_current_trace(current_trace)
1369
-
1370
- try:
1371
- with current_trace.span(span_name, span_type=span_type) as span:
1372
- # Record inputs
1373
- inputs = combine_args_kwargs(func, args, kwargs)
1374
- span.record_input(inputs)
1375
- if agent_name:
1376
- span.record_agent_name(agent_name)
1377
- # Capture state before execution
1378
- self._conditionally_capture_and_record_state(
1379
- span, args, is_before=True
1380
- )
1381
-
1382
- try:
1383
- if self.deep_tracing:
1384
- with _DeepTracer(self):
1385
- result = func(*args, **kwargs)
1386
- else:
1387
- result = func(*args, **kwargs)
1388
- except Exception as e:
1389
- _capture_exception_for_trace(
1390
- current_trace, sys.exc_info()
1391
- )
1392
- raise e
1393
1472
 
1394
- # Capture state after execution
1395
- self._conditionally_capture_and_record_state(
1396
- span, args, is_before=False
1397
- )
1398
-
1399
- # Record output
1400
- span.record_output(result)
1401
- return result
1402
- finally:
1403
- try:
1404
- trace_id, server_response = current_trace.save(
1405
- final_save=True
1406
- )
1473
+ def sync_execution():
1474
+ if self.deep_tracing:
1475
+ with _DeepTracer(self):
1476
+ return func(*args, **kwargs)
1477
+ else:
1478
+ return func(*args, **kwargs)
1407
1479
 
1408
- complete_trace_data = {
1409
- "trace_id": current_trace.trace_id,
1410
- "name": current_trace.name,
1411
- "created_at": datetime.fromtimestamp(
1412
- current_trace.start_time or time.time(),
1413
- timezone.utc,
1414
- ).isoformat(),
1415
- "duration": current_trace.get_duration(),
1416
- "trace_spans": [
1417
- span.model_dump()
1418
- for span in current_trace.trace_spans
1419
- ],
1420
- "offline_mode": self.offline_mode,
1421
- "parent_trace_id": current_trace.parent_trace_id,
1422
- "parent_name": current_trace.parent_name,
1423
- }
1424
- self.traces.append(complete_trace_data)
1425
- self.reset_current_trace(trace_token)
1426
- except Exception as e:
1427
- judgeval_logger.warning(f"Issue with save: {e}")
1428
- pass
1429
- else:
1430
- with current_trace.span(span_name, span_type=span_type) as span:
1431
- inputs = combine_args_kwargs(func, args, kwargs)
1432
- span.record_input(inputs)
1433
- if agent_name:
1434
- span.record_agent_name(agent_name)
1435
-
1436
- # Capture state before execution
1437
- self._conditionally_capture_and_record_state(
1438
- span, args, is_before=True
1439
- )
1440
-
1441
- try:
1442
- if self.deep_tracing:
1443
- with _DeepTracer(self):
1444
- result = func(*args, **kwargs)
1445
- else:
1446
- result = func(*args, **kwargs)
1447
- except Exception as e:
1448
- _capture_exception_for_trace(current_trace, sys.exc_info())
1449
- raise e
1450
-
1451
- # Capture state after execution
1452
- self._conditionally_capture_and_record_state(
1453
- span, args, is_before=False
1454
- )
1455
-
1456
- span.record_output(result)
1457
- return result
1480
+ return _execute_with_auto_trace_creation(
1481
+ span_name, span_type, sync_execution, args, kwargs
1482
+ )
1458
1483
 
1459
1484
  return wrapper
1460
1485
 
@@ -1709,6 +1734,93 @@ class Tracer:
1709
1734
  f"Error during background service shutdown: {e}"
1710
1735
  )
1711
1736
 
1737
+ def trace_to_message_history(
1738
+ self, trace: Union[Trace, TraceClient]
1739
+ ) -> List[Dict[str, str]]:
1740
+ """
1741
+ Extract message history from a trace for training purposes.
1742
+
1743
+ This method processes trace spans to reconstruct the conversation flow,
1744
+ extracting messages in chronological order from LLM, user, and tool spans.
1745
+
1746
+ Args:
1747
+ trace: Trace or TraceClient instance to extract messages from
1748
+
1749
+ Returns:
1750
+ List of message dictionaries with 'role' and 'content' keys
1751
+
1752
+ Raises:
1753
+ ValueError: If no trace is provided
1754
+ """
1755
+ if not trace:
1756
+ raise ValueError("No trace provided")
1757
+
1758
+ # Handle both Trace and TraceClient objects
1759
+ if isinstance(trace, TraceClient):
1760
+ spans = trace.trace_spans
1761
+ else:
1762
+ spans = trace.trace_spans if hasattr(trace, "trace_spans") else []
1763
+
1764
+ messages = []
1765
+ first_found = False
1766
+
1767
+ # Process spans in chronological order
1768
+ for span in sorted(
1769
+ spans, key=lambda s: s.created_at if hasattr(s, "created_at") else 0
1770
+ ):
1771
+ # Skip spans without output (except for first LLM span which may have input messages)
1772
+ if span.output is None and span.span_type != "llm":
1773
+ continue
1774
+
1775
+ if span.span_type == "llm":
1776
+ # For the first LLM span, extract input messages (system + user prompts)
1777
+ if not first_found and hasattr(span, "inputs") and span.inputs:
1778
+ input_messages = span.inputs.get("messages", [])
1779
+ if input_messages:
1780
+ first_found = True
1781
+ # Add input messages (typically system and user messages)
1782
+ for msg in input_messages:
1783
+ if (
1784
+ isinstance(msg, dict)
1785
+ and "role" in msg
1786
+ and "content" in msg
1787
+ ):
1788
+ messages.append(
1789
+ {"role": msg["role"], "content": msg["content"]}
1790
+ )
1791
+
1792
+ # Add assistant response from span output
1793
+ if span.output is not None:
1794
+ messages.append({"role": "assistant", "content": str(span.output)})
1795
+
1796
+ elif span.span_type == "user":
1797
+ # Add user messages
1798
+ if span.output is not None:
1799
+ messages.append({"role": "user", "content": str(span.output)})
1800
+
1801
+ elif span.span_type == "tool":
1802
+ # Add tool responses as user messages (common pattern in training)
1803
+ if span.output is not None:
1804
+ messages.append({"role": "user", "content": str(span.output)})
1805
+
1806
+ return messages
1807
+
1808
+ def get_current_message_history(self) -> List[Dict[str, str]]:
1809
+ """
1810
+ Get message history from the current trace.
1811
+
1812
+ Returns:
1813
+ List of message dictionaries from the current trace context
1814
+
1815
+ Raises:
1816
+ ValueError: If no current trace is found
1817
+ """
1818
+ current_trace = self.get_current_trace()
1819
+ if not current_trace:
1820
+ raise ValueError("No current trace found")
1821
+
1822
+ return self.trace_to_message_history(current_trace)
1823
+
1712
1824
 
1713
1825
  def _get_current_trace(
1714
1826
  trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
@@ -1724,7 +1836,7 @@ def wrap(
1724
1836
  ) -> Any:
1725
1837
  """
1726
1838
  Wraps an API client to add tracing capabilities.
1727
- Supports OpenAI, Together, Anthropic, and Google GenAI clients.
1839
+ Supports OpenAI, Together, Anthropic, Google GenAI clients, and TrainableModel.
1728
1840
  Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
1729
1841
  """
1730
1842
  (
@@ -1849,6 +1961,39 @@ def wrap(
1849
1961
  setattr(client.chat.completions, "create", wrapped(original_create))
1850
1962
  elif isinstance(client, (groq_AsyncGroq)):
1851
1963
  setattr(client.chat.completions, "create", wrapped_async(original_create))
1964
+
1965
+ # Check for TrainableModel from judgeval.common.trainer
1966
+ try:
1967
+ from judgeval.common.trainer import TrainableModel
1968
+
1969
+ if isinstance(client, TrainableModel):
1970
+ # Define a wrapper function that can be reapplied to new model instances
1971
+ def wrap_model_instance(model_instance):
1972
+ """Wrap a model instance with tracing functionality"""
1973
+ if hasattr(model_instance, "chat") and hasattr(
1974
+ model_instance.chat, "completions"
1975
+ ):
1976
+ if hasattr(model_instance.chat.completions, "create"):
1977
+ setattr(
1978
+ model_instance.chat.completions,
1979
+ "create",
1980
+ wrapped(model_instance.chat.completions.create),
1981
+ )
1982
+ if hasattr(model_instance.chat.completions, "acreate"):
1983
+ setattr(
1984
+ model_instance.chat.completions,
1985
+ "acreate",
1986
+ wrapped_async(model_instance.chat.completions.acreate),
1987
+ )
1988
+
1989
+ # Register the wrapper function with the TrainableModel
1990
+ client._register_tracer_wrapper(wrap_model_instance)
1991
+
1992
+ # Apply wrapping to the current model
1993
+ wrap_model_instance(client._current_model)
1994
+ except ImportError:
1995
+ pass # TrainableModel not available
1996
+
1852
1997
  return client
1853
1998
 
1854
1999
 
@@ -1955,6 +2100,22 @@ def _get_client_config(
1955
2100
  return "GROQ_API_CALL", client.chat.completions.create, None, None, None
1956
2101
  elif isinstance(client, (groq_AsyncGroq)):
1957
2102
  return "GROQ_API_CALL", client.chat.completions.create, None, None, None
2103
+
2104
+ # Check for TrainableModel
2105
+ try:
2106
+ from judgeval.common.trainer import TrainableModel
2107
+
2108
+ if isinstance(client, TrainableModel):
2109
+ return (
2110
+ "FIREWORKS_TRAINABLE_MODEL_CALL",
2111
+ client._current_model.chat.completions.create,
2112
+ None,
2113
+ None,
2114
+ None,
2115
+ )
2116
+ except ImportError:
2117
+ pass # TrainableModel not available
2118
+
1958
2119
  raise ValueError(f"Unsupported client type: {type(client)}")
1959
2120
 
1960
2121
 
@@ -2133,6 +2294,37 @@ def _format_output_data(
2133
2294
  cache_creation_input_tokens,
2134
2295
  )
2135
2296
 
2297
+ # Check for TrainableModel
2298
+ try:
2299
+ from judgeval.common.trainer import TrainableModel
2300
+
2301
+ if isinstance(client, TrainableModel):
2302
+ # TrainableModel uses Fireworks LLM internally, so response format should be similar to OpenAI
2303
+ if (
2304
+ hasattr(response, "model")
2305
+ and hasattr(response, "usage")
2306
+ and hasattr(response, "choices")
2307
+ ):
2308
+ model_name = response.model
2309
+ prompt_tokens = response.usage.prompt_tokens if response.usage else 0
2310
+ completion_tokens = (
2311
+ response.usage.completion_tokens if response.usage else 0
2312
+ )
2313
+ message_content = response.choices[0].message.content
2314
+
2315
+ # Use LiteLLM cost calculation with fireworks_ai prefix
2316
+ # LiteLLM supports Fireworks AI models for cost calculation when prefixed with "fireworks_ai/"
2317
+ fireworks_model_name = f"fireworks_ai/{model_name}"
2318
+ return message_content, _create_usage(
2319
+ fireworks_model_name,
2320
+ prompt_tokens,
2321
+ completion_tokens,
2322
+ cache_read_input_tokens,
2323
+ cache_creation_input_tokens,
2324
+ )
2325
+ except ImportError:
2326
+ pass # TrainableModel not available
2327
+
2136
2328
  judgeval_logger.warning(f"Unsupported client type: {type(client)}")
2137
2329
  return None, None
2138
2330
 
@@ -2223,13 +2415,13 @@ def get_instance_prefixed_name(instance, class_name, class_identifiers):
2223
2415
  """
2224
2416
  if class_name in class_identifiers:
2225
2417
  class_config = class_identifiers[class_name]
2226
- attr = class_config["identifier"]
2227
-
2228
- if hasattr(instance, attr):
2229
- instance_name = getattr(instance, attr)
2230
- return instance_name
2231
- else:
2232
- raise Exception(
2233
- f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator."
2234
- )
2235
- return None
2418
+ attr = class_config.get("identifier")
2419
+ if attr:
2420
+ if hasattr(instance, attr) and not callable(getattr(instance, attr)):
2421
+ instance_name = getattr(instance, attr)
2422
+ return instance_name
2423
+ else:
2424
+ raise Exception(
2425
+ f"Attribute {attr} does not exist for {class_name}. Check your agent() decorator."
2426
+ )
2427
+ return None