braintrust 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/bt_json.py +178 -19
- braintrust/db_fields.py +1 -0
- braintrust/framework.py +13 -4
- braintrust/logger.py +30 -120
- braintrust/otel/__init__.py +24 -15
- braintrust/test_bt_json.py +644 -0
- braintrust/test_framework.py +81 -0
- braintrust/test_logger.py +245 -107
- braintrust/test_otel.py +118 -26
- braintrust/test_util.py +51 -1
- braintrust/util.py +24 -3
- braintrust/version.py +2 -2
- braintrust/wrappers/google_genai/__init__.py +2 -15
- braintrust/wrappers/litellm.py +43 -0
- braintrust/wrappers/pydantic_ai.py +209 -95
- braintrust/wrappers/test_google_genai.py +62 -1
- braintrust/wrappers/test_litellm.py +73 -0
- braintrust/wrappers/test_pydantic_ai_integration.py +819 -22
- {braintrust-0.4.0.dist-info → braintrust-0.4.2.dist-info}/METADATA +1 -1
- {braintrust-0.4.0.dist-info → braintrust-0.4.2.dist-info}/RECORD +23 -22
- {braintrust-0.4.0.dist-info → braintrust-0.4.2.dist-info}/WHEEL +0 -0
- {braintrust-0.4.0.dist-info → braintrust-0.4.2.dist-info}/entry_points.txt +0 -0
- {braintrust-0.4.0.dist-info → braintrust-0.4.2.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# pyright: reportUnknownMemberType=false
|
|
3
3
|
# pyright: reportUnknownParameterType=false
|
|
4
4
|
# pyright: reportPrivateUsage=false
|
|
5
|
+
import asyncio
|
|
5
6
|
import time
|
|
6
7
|
|
|
7
8
|
import pytest
|
|
@@ -378,30 +379,28 @@ async def test_agent_run_stream(memory_logger):
|
|
|
378
379
|
assert chat_span["metadata"]["provider"] == "openai"
|
|
379
380
|
_assert_metrics_are_valid(chat_span["metrics"], start, end)
|
|
380
381
|
|
|
381
|
-
# CRITICAL: Check that
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
382
|
+
# CRITICAL: Check that time_to_first_token is captured
|
|
383
|
+
assert "time_to_first_token" in agent_span["metrics"], "agent_run_stream span should have time_to_first_token metric"
|
|
384
|
+
ttft = agent_span["metrics"]["time_to_first_token"]
|
|
385
|
+
duration = agent_span["metrics"]["duration"]
|
|
386
|
+
|
|
387
|
+
# time_to_first_token should be reasonable: > 0 and < duration
|
|
388
|
+
assert ttft > 0, f"time_to_first_token should be > 0, got {ttft}"
|
|
389
|
+
assert ttft <= duration, f"time_to_first_token ({ttft}s) should be <= duration ({duration}s)"
|
|
390
|
+
assert ttft < 3.0, f"time_to_first_token should be < 3s for API call, got {ttft}s"
|
|
386
391
|
|
|
387
392
|
# Debug: Print full span data
|
|
388
393
|
print(f"\n=== AGENT SPAN ===")
|
|
389
394
|
print(f"ID: {agent_span['id']}")
|
|
390
395
|
print(f"span_id: {agent_span['span_id']}")
|
|
391
396
|
print(f"metrics: {agent_span['metrics']}")
|
|
397
|
+
print(f"time_to_first_token: {ttft}s")
|
|
392
398
|
print(f"\n=== CHAT SPAN ===")
|
|
393
399
|
print(f"ID: {chat_span['id']}")
|
|
394
400
|
print(f"span_id: {chat_span['span_id']}")
|
|
395
401
|
print(f"span_parents: {chat_span['span_parents']}")
|
|
396
402
|
print(f"metrics: {chat_span['metrics']}")
|
|
397
403
|
|
|
398
|
-
# Time to first token should be reasonable (< 3 seconds for API call initiation)
|
|
399
|
-
assert time_to_first_token < 3.0, f"Time to first token too large: {time_to_first_token}s - suggests start_time is being reused incorrectly"
|
|
400
|
-
|
|
401
|
-
# Both spans should have started during our test timeframe
|
|
402
|
-
assert agent_start >= start, "Agent span started before test"
|
|
403
|
-
assert chat_start >= start, "Chat span started before test"
|
|
404
|
-
|
|
405
404
|
# Agent spans should have token metrics
|
|
406
405
|
assert "prompt_tokens" in agent_span["metrics"]
|
|
407
406
|
assert "completion_tokens" in agent_span["metrics"]
|
|
@@ -550,7 +549,7 @@ async def test_direct_model_request_with_settings(memory_logger, direct):
|
|
|
550
549
|
@pytest.mark.vcr
|
|
551
550
|
@pytest.mark.asyncio
|
|
552
551
|
async def test_direct_model_request_stream(memory_logger, direct):
|
|
553
|
-
"""Test direct API model_request_stream()."""
|
|
552
|
+
"""Test direct API model_request_stream() - verifies time_to_first_token is captured."""
|
|
554
553
|
assert not memory_logger.pop()
|
|
555
554
|
|
|
556
555
|
messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 3")])]
|
|
@@ -578,6 +577,18 @@ async def test_direct_model_request_stream(memory_logger, direct):
|
|
|
578
577
|
assert direct_span["metadata"]["model"] == "gpt-4o-mini"
|
|
579
578
|
_assert_metrics_are_valid(direct_span["metrics"], start, end)
|
|
580
579
|
|
|
580
|
+
# CRITICAL: Verify time_to_first_token is captured in direct streaming
|
|
581
|
+
assert "time_to_first_token" in direct_span["metrics"], "model_request_stream span should have time_to_first_token metric"
|
|
582
|
+
ttft = direct_span["metrics"]["time_to_first_token"]
|
|
583
|
+
duration = direct_span["metrics"]["duration"]
|
|
584
|
+
|
|
585
|
+
# time_to_first_token should be reasonable: > 0 and < duration
|
|
586
|
+
assert ttft > 0, f"time_to_first_token should be > 0, got {ttft}"
|
|
587
|
+
assert ttft <= duration, f"time_to_first_token ({ttft}s) should be <= duration ({duration}s)"
|
|
588
|
+
assert ttft < 3.0, f"time_to_first_token should be < 3s for API call, got {ttft}s"
|
|
589
|
+
|
|
590
|
+
print(f"✓ Direct stream time_to_first_token: {ttft}s (duration: {duration}s)")
|
|
591
|
+
|
|
581
592
|
|
|
582
593
|
@pytest.mark.vcr
|
|
583
594
|
@pytest.mark.asyncio
|
|
@@ -1131,7 +1142,7 @@ async def test_agent_with_custom_settings(memory_logger):
|
|
|
1131
1142
|
|
|
1132
1143
|
@pytest.mark.vcr
|
|
1133
1144
|
def test_agent_run_stream_sync(memory_logger):
|
|
1134
|
-
"""Test Agent.run_stream_sync() synchronous streaming method."""
|
|
1145
|
+
"""Test Agent.run_stream_sync() synchronous streaming method - verifies time_to_first_token."""
|
|
1135
1146
|
assert not memory_logger.pop()
|
|
1136
1147
|
|
|
1137
1148
|
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
|
|
@@ -1164,6 +1175,18 @@ def test_agent_run_stream_sync(memory_logger):
|
|
|
1164
1175
|
assert "Count from 1 to 3" in str(agent_span["input"])
|
|
1165
1176
|
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
1166
1177
|
|
|
1178
|
+
# CRITICAL: Verify time_to_first_token is captured in sync streaming
|
|
1179
|
+
assert "time_to_first_token" in agent_span["metrics"], "agent_run_stream_sync span should have time_to_first_token metric"
|
|
1180
|
+
ttft = agent_span["metrics"]["time_to_first_token"]
|
|
1181
|
+
duration = agent_span["metrics"]["duration"]
|
|
1182
|
+
|
|
1183
|
+
# time_to_first_token should be reasonable: > 0 and < duration
|
|
1184
|
+
assert ttft > 0, f"time_to_first_token should be > 0, got {ttft}"
|
|
1185
|
+
assert ttft <= duration, f"time_to_first_token ({ttft}s) should be <= duration ({duration}s)"
|
|
1186
|
+
assert ttft < 3.0, f"time_to_first_token should be < 3s for API call, got {ttft}s"
|
|
1187
|
+
|
|
1188
|
+
print(f"✓ Sync stream time_to_first_token: {ttft}s (duration: {duration}s)")
|
|
1189
|
+
|
|
1167
1190
|
# Check chat span is a descendant of agent_run_stream_sync
|
|
1168
1191
|
span_by_id = {s["span_id"]: s for s in spans}
|
|
1169
1192
|
|
|
@@ -1227,7 +1250,7 @@ async def test_agent_run_stream_events(memory_logger):
|
|
|
1227
1250
|
|
|
1228
1251
|
@pytest.mark.vcr
|
|
1229
1252
|
def test_direct_model_request_stream_sync(memory_logger, direct):
|
|
1230
|
-
"""Test direct API model_request_stream_sync()."""
|
|
1253
|
+
"""Test direct API model_request_stream_sync() - verifies time_to_first_token."""
|
|
1231
1254
|
assert not memory_logger.pop()
|
|
1232
1255
|
|
|
1233
1256
|
messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 3")])]
|
|
@@ -1252,6 +1275,18 @@ def test_direct_model_request_stream_sync(memory_logger, direct):
|
|
|
1252
1275
|
assert span["metadata"]["model"] == "gpt-4o-mini"
|
|
1253
1276
|
_assert_metrics_are_valid(span["metrics"], start, end)
|
|
1254
1277
|
|
|
1278
|
+
# CRITICAL: Verify time_to_first_token is captured in sync direct streaming
|
|
1279
|
+
assert "time_to_first_token" in span["metrics"], "model_request_stream_sync span should have time_to_first_token metric"
|
|
1280
|
+
ttft = span["metrics"]["time_to_first_token"]
|
|
1281
|
+
duration = span["metrics"]["duration"]
|
|
1282
|
+
|
|
1283
|
+
# time_to_first_token should be reasonable: > 0 and < duration
|
|
1284
|
+
assert ttft > 0, f"time_to_first_token should be > 0, got {ttft}"
|
|
1285
|
+
assert ttft <= duration, f"time_to_first_token ({ttft}s) should be <= duration ({duration}s)"
|
|
1286
|
+
assert ttft < 3.0, f"time_to_first_token should be < 3s for API call, got {ttft}s"
|
|
1287
|
+
|
|
1288
|
+
print(f"✓ Direct sync stream time_to_first_token: {ttft}s (duration: {duration}s)")
|
|
1289
|
+
|
|
1255
1290
|
|
|
1256
1291
|
@pytest.mark.vcr
|
|
1257
1292
|
@pytest.mark.asyncio
|
|
@@ -1341,19 +1376,168 @@ async def test_agent_stream_early_break(memory_logger):
|
|
|
1341
1376
|
assert "start" in agent_span["metrics"]
|
|
1342
1377
|
|
|
1343
1378
|
|
|
1379
|
+
@pytest.mark.vcr
|
|
1380
|
+
@pytest.mark.asyncio
|
|
1381
|
+
async def test_stream_buffer_pattern_early_return(memory_logger, direct):
|
|
1382
|
+
"""Test the _stream_single/_buffer_stream pattern with early return.
|
|
1383
|
+
|
|
1384
|
+
This tests a common customer pattern where:
|
|
1385
|
+
1. An async generator wraps a stream and yields chunks + final response
|
|
1386
|
+
2. A consumer function returns early when it sees the final ModelResponse
|
|
1387
|
+
3. The generator cleanup happens in a different async context
|
|
1388
|
+
|
|
1389
|
+
This pattern would trigger 'Token was created in a different Context' errors
|
|
1390
|
+
before the task-tracking fix, because the consumer's early return causes
|
|
1391
|
+
the generator to be cleaned up in a different task context.
|
|
1392
|
+
"""
|
|
1393
|
+
from collections.abc import AsyncIterator
|
|
1394
|
+
|
|
1395
|
+
from pydantic_ai.messages import ModelResponse
|
|
1396
|
+
|
|
1397
|
+
assert not memory_logger.pop()
|
|
1398
|
+
|
|
1399
|
+
messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 5")])]
|
|
1400
|
+
|
|
1401
|
+
class LLMStreamResponse:
|
|
1402
|
+
"""Wrapper for streaming responses."""
|
|
1403
|
+
|
|
1404
|
+
def __init__(self, llm_response, is_final=False):
|
|
1405
|
+
self.llm_response = llm_response
|
|
1406
|
+
self.is_final = is_final
|
|
1407
|
+
|
|
1408
|
+
async def _stream_single() -> AsyncIterator[LLMStreamResponse]:
|
|
1409
|
+
"""Async generator that yields streaming chunks and final response."""
|
|
1410
|
+
async with direct.model_request_stream(model=MODEL, messages=messages) as stream:
|
|
1411
|
+
async for chunk in stream:
|
|
1412
|
+
yield LLMStreamResponse(llm_response=chunk, is_final=False)
|
|
1413
|
+
|
|
1414
|
+
# Yield the final response after streaming completes
|
|
1415
|
+
response = stream.get()
|
|
1416
|
+
yield LLMStreamResponse(llm_response=response, is_final=True)
|
|
1417
|
+
|
|
1418
|
+
async def _buffer_stream() -> LLMStreamResponse:
|
|
1419
|
+
"""Consumer that returns early when it gets a ModelResponse.
|
|
1420
|
+
|
|
1421
|
+
This early return causes the generator to be cleaned up in a different
|
|
1422
|
+
async context than where it was created, triggering the context issue.
|
|
1423
|
+
"""
|
|
1424
|
+
async for event in _stream_single():
|
|
1425
|
+
if isinstance(event.llm_response, ModelResponse):
|
|
1426
|
+
# Early return - generator cleanup happens in different context
|
|
1427
|
+
return event
|
|
1428
|
+
raise RuntimeError("No ModelResponse received")
|
|
1429
|
+
|
|
1430
|
+
start = time.time()
|
|
1431
|
+
|
|
1432
|
+
# This should NOT raise ValueError about "different Context"
|
|
1433
|
+
result = await _buffer_stream()
|
|
1434
|
+
end = time.time()
|
|
1435
|
+
|
|
1436
|
+
# Verify we got the final response
|
|
1437
|
+
assert isinstance(result.llm_response, ModelResponse)
|
|
1438
|
+
assert result.is_final
|
|
1439
|
+
|
|
1440
|
+
# Check spans - should have created a span despite early generator cleanup
|
|
1441
|
+
spans = memory_logger.pop()
|
|
1442
|
+
assert len(spans) >= 1, "Should have at least one span even with early return"
|
|
1443
|
+
|
|
1444
|
+
span = spans[0]
|
|
1445
|
+
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
1446
|
+
assert span["span_attributes"]["name"] == "model_request_stream"
|
|
1447
|
+
assert "start" in span["metrics"]
|
|
1448
|
+
assert span["metrics"]["start"] >= start
|
|
1449
|
+
# "end" may not be present if span was terminated early, but if present it should be valid
|
|
1450
|
+
if "end" in span["metrics"]:
|
|
1451
|
+
assert span["metrics"]["end"] <= end
|
|
1452
|
+
|
|
1453
|
+
|
|
1454
|
+
@pytest.mark.vcr
|
|
1455
|
+
@pytest.mark.asyncio
|
|
1456
|
+
async def test_agent_stream_buffer_pattern_early_return(memory_logger):
|
|
1457
|
+
"""Test the _stream_single/_buffer_stream pattern with agent.run_stream().
|
|
1458
|
+
|
|
1459
|
+
This tests the same buffer/stream pattern but with the high-level Agent API
|
|
1460
|
+
to ensure _AgentStreamWrapper also handles context cleanup correctly.
|
|
1461
|
+
|
|
1462
|
+
Pattern:
|
|
1463
|
+
1. An async generator wraps agent.run_stream() and yields events + final result
|
|
1464
|
+
2. A consumer returns early when it sees the final result
|
|
1465
|
+
3. Generator cleanup happens in a different context
|
|
1466
|
+
|
|
1467
|
+
This verifies both _AgentStreamWrapper and _DirectStreamWrapper handle
|
|
1468
|
+
task context changes correctly.
|
|
1469
|
+
"""
|
|
1470
|
+
from collections.abc import AsyncIterator
|
|
1471
|
+
|
|
1472
|
+
assert not memory_logger.pop()
|
|
1473
|
+
|
|
1474
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
|
|
1475
|
+
|
|
1476
|
+
class StreamEvent:
|
|
1477
|
+
"""Wrapper for stream events."""
|
|
1478
|
+
|
|
1479
|
+
def __init__(self, data, is_final=False):
|
|
1480
|
+
self.data = data
|
|
1481
|
+
self.is_final = is_final
|
|
1482
|
+
|
|
1483
|
+
async def _agent_stream_wrapper() -> AsyncIterator[StreamEvent]:
|
|
1484
|
+
"""Async generator that wraps agent streaming."""
|
|
1485
|
+
async with agent.run_stream("Count from 1 to 5") as result:
|
|
1486
|
+
# Yield text chunks
|
|
1487
|
+
async for text in result.stream_text(delta=True):
|
|
1488
|
+
yield StreamEvent(data=text, is_final=False)
|
|
1489
|
+
|
|
1490
|
+
# Yield final result after streaming
|
|
1491
|
+
# Note: We can't call result.output here as it's consumed during streaming,
|
|
1492
|
+
# so we yield a marker for the final event
|
|
1493
|
+
yield StreamEvent(data="FINAL", is_final=True)
|
|
1494
|
+
|
|
1495
|
+
async def _consume_until_final() -> StreamEvent:
|
|
1496
|
+
"""Consumer that returns early when it sees final event.
|
|
1497
|
+
|
|
1498
|
+
This early return causes generator cleanup in different context.
|
|
1499
|
+
"""
|
|
1500
|
+
async for event in _agent_stream_wrapper():
|
|
1501
|
+
if event.is_final:
|
|
1502
|
+
# Early return - generator cleanup in different context
|
|
1503
|
+
return event
|
|
1504
|
+
raise RuntimeError("No final event received")
|
|
1505
|
+
|
|
1506
|
+
start = time.time()
|
|
1507
|
+
|
|
1508
|
+
# This should NOT raise ValueError about "different Context"
|
|
1509
|
+
result = await _consume_until_final()
|
|
1510
|
+
end = time.time()
|
|
1511
|
+
|
|
1512
|
+
# Verify we got the final event
|
|
1513
|
+
assert result.is_final
|
|
1514
|
+
assert result.data == "FINAL"
|
|
1515
|
+
|
|
1516
|
+
# Check spans - should have created spans despite early generator cleanup
|
|
1517
|
+
spans = memory_logger.pop()
|
|
1518
|
+
assert len(spans) >= 1, "Should have at least one span"
|
|
1519
|
+
|
|
1520
|
+
# Find agent_run_stream span
|
|
1521
|
+
agent_span = next((s for s in spans if "agent_run_stream" in s["span_attributes"]["name"]), None)
|
|
1522
|
+
assert agent_span is not None, "agent_run_stream span should exist"
|
|
1523
|
+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
1524
|
+
assert "start" in agent_span["metrics"]
|
|
1525
|
+
|
|
1526
|
+
|
|
1344
1527
|
@pytest.mark.vcr
|
|
1345
1528
|
@pytest.mark.asyncio
|
|
1346
1529
|
async def test_agent_with_binary_content(memory_logger):
|
|
1347
1530
|
"""Test that agents with binary content (images) work correctly.
|
|
1348
1531
|
|
|
1349
|
-
|
|
1350
|
-
|
|
1532
|
+
Verifies that BinaryContent is properly converted to Braintrust attachments
|
|
1533
|
+
in both the agent_run span (parent) and chat span (child).
|
|
1351
1534
|
"""
|
|
1535
|
+
from braintrust.logger import Attachment
|
|
1352
1536
|
from pydantic_ai.models.function import BinaryContent
|
|
1353
1537
|
|
|
1354
1538
|
assert not memory_logger.pop()
|
|
1355
1539
|
|
|
1356
|
-
# Use a small test image
|
|
1540
|
+
# Use a small test image (1x1 PNG)
|
|
1357
1541
|
image_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\nIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82'
|
|
1358
1542
|
|
|
1359
1543
|
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
|
|
@@ -1370,20 +1554,184 @@ async def test_agent_with_binary_content(memory_logger):
|
|
|
1370
1554
|
assert result.output
|
|
1371
1555
|
assert isinstance(result.output, str)
|
|
1372
1556
|
|
|
1373
|
-
# Check spans -
|
|
1557
|
+
# Check spans - should have both agent_run and chat spans
|
|
1374
1558
|
spans = memory_logger.pop()
|
|
1375
|
-
assert len(spans) >=
|
|
1559
|
+
assert len(spans) >= 2, f"Expected at least 2 spans (agent_run + chat), got {len(spans)}"
|
|
1376
1560
|
|
|
1377
|
-
# Find agent_run span
|
|
1561
|
+
# Find agent_run span (parent)
|
|
1378
1562
|
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
|
|
1379
1563
|
assert agent_span is not None, "agent_run span not found"
|
|
1380
1564
|
|
|
1565
|
+
# Find chat span (child)
|
|
1566
|
+
chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
|
|
1567
|
+
assert chat_span is not None, "chat span not found"
|
|
1568
|
+
|
|
1381
1569
|
# Verify basic span structure
|
|
1382
1570
|
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
1383
1571
|
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
|
|
1384
1572
|
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
1385
1573
|
|
|
1386
|
-
#
|
|
1574
|
+
# CRITICAL: Verify that BOTH spans properly serialize BinaryContent to attachments
|
|
1575
|
+
def has_attachment_in_input(span_input):
|
|
1576
|
+
"""Check if span input contains a Braintrust Attachment object."""
|
|
1577
|
+
if not span_input:
|
|
1578
|
+
return False
|
|
1579
|
+
|
|
1580
|
+
def check_item(item):
|
|
1581
|
+
"""Recursively check an item for attachments."""
|
|
1582
|
+
if isinstance(item, dict):
|
|
1583
|
+
if item.get("type") == "binary" and isinstance(item.get("attachment"), Attachment):
|
|
1584
|
+
return True
|
|
1585
|
+
# Check nested content field (for UserPromptPart-like structures)
|
|
1586
|
+
if "content" in item:
|
|
1587
|
+
content = item["content"]
|
|
1588
|
+
if isinstance(content, list):
|
|
1589
|
+
for sub_item in content:
|
|
1590
|
+
if check_item(sub_item):
|
|
1591
|
+
return True
|
|
1592
|
+
return False
|
|
1593
|
+
|
|
1594
|
+
# Check user_prompt (agent_run span)
|
|
1595
|
+
if "user_prompt" in span_input:
|
|
1596
|
+
user_prompt = span_input["user_prompt"]
|
|
1597
|
+
if isinstance(user_prompt, list):
|
|
1598
|
+
for item in user_prompt:
|
|
1599
|
+
if check_item(item):
|
|
1600
|
+
return True
|
|
1601
|
+
|
|
1602
|
+
# Check messages (chat span)
|
|
1603
|
+
if "messages" in span_input:
|
|
1604
|
+
messages = span_input["messages"]
|
|
1605
|
+
if isinstance(messages, list):
|
|
1606
|
+
for msg in messages:
|
|
1607
|
+
if isinstance(msg, dict) and "parts" in msg:
|
|
1608
|
+
parts = msg["parts"]
|
|
1609
|
+
if isinstance(parts, list):
|
|
1610
|
+
for part in parts:
|
|
1611
|
+
if check_item(part):
|
|
1612
|
+
return True
|
|
1613
|
+
|
|
1614
|
+
return False
|
|
1615
|
+
|
|
1616
|
+
# Verify agent_run span has attachment
|
|
1617
|
+
agent_has_attachment = has_attachment_in_input(agent_span.get("input", {}))
|
|
1618
|
+
assert agent_has_attachment, (
|
|
1619
|
+
"agent_run span should have BinaryContent converted to Braintrust Attachment. "
|
|
1620
|
+
f"Input: {agent_span.get('input', {})}"
|
|
1621
|
+
)
|
|
1622
|
+
|
|
1623
|
+
# Verify chat span has attachment (this is the key test for the bug)
|
|
1624
|
+
chat_has_attachment = has_attachment_in_input(chat_span.get("input", {}))
|
|
1625
|
+
assert chat_has_attachment, (
|
|
1626
|
+
"chat span should have BinaryContent converted to Braintrust Attachment. "
|
|
1627
|
+
"The child span should process attachments the same way as the parent. "
|
|
1628
|
+
f"Input: {chat_span.get('input', {})}"
|
|
1629
|
+
)
|
|
1630
|
+
|
|
1631
|
+
|
|
1632
|
+
@pytest.mark.vcr
|
|
1633
|
+
@pytest.mark.asyncio
|
|
1634
|
+
async def test_agent_with_document_input(memory_logger):
|
|
1635
|
+
"""Test that agents with document input (PDF) properly serialize attachments.
|
|
1636
|
+
|
|
1637
|
+
This specifically tests the scenario from test_document_input in the golden tests,
|
|
1638
|
+
verifying that both agent_run and chat spans convert BinaryContent to Braintrust
|
|
1639
|
+
attachments for document files like PDFs.
|
|
1640
|
+
"""
|
|
1641
|
+
from braintrust.logger import Attachment
|
|
1642
|
+
from pydantic_ai.models.function import BinaryContent
|
|
1643
|
+
|
|
1644
|
+
assert not memory_logger.pop()
|
|
1645
|
+
|
|
1646
|
+
# Create a minimal PDF (this is a valid but minimal PDF structure)
|
|
1647
|
+
pdf_data = b'%PDF-1.4\n1 0 obj<</Type/Catalog/Pages 2 0 R>>endobj 2 0 obj<</Type/Pages/Kids[3 0 R]/Count 1>>endobj 3 0 obj<</Type/Page/Parent 2 0 R/MediaBox[0 0 612 792]/Contents 4 0 R>>endobj 4 0 obj<</Length 44>>stream\nBT /F1 12 Tf 100 700 Td (Test Document) Tj ET\nendstream\nendobj\nxref\n0 5\n0000000000 65535 f\n0000000009 00000 n\n0000000058 00000 n\n0000000115 00000 n\n0000000214 00000 n\ntrailer<</Size 5/Root 1 0 R>>\nstartxref\n307\n%%EOF'
|
|
1648
|
+
|
|
1649
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=150))
|
|
1650
|
+
|
|
1651
|
+
start = time.time()
|
|
1652
|
+
result = await agent.run(
|
|
1653
|
+
[
|
|
1654
|
+
BinaryContent(data=pdf_data, media_type="application/pdf"),
|
|
1655
|
+
"What is in this document?",
|
|
1656
|
+
]
|
|
1657
|
+
)
|
|
1658
|
+
end = time.time()
|
|
1659
|
+
|
|
1660
|
+
assert result.output
|
|
1661
|
+
assert isinstance(result.output, str)
|
|
1662
|
+
|
|
1663
|
+
# Check spans
|
|
1664
|
+
spans = memory_logger.pop()
|
|
1665
|
+
assert len(spans) >= 2, f"Expected at least 2 spans (agent_run + chat), got {len(spans)}"
|
|
1666
|
+
|
|
1667
|
+
# Find spans
|
|
1668
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
|
|
1669
|
+
chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
|
|
1670
|
+
|
|
1671
|
+
assert agent_span is not None, "agent_run span not found"
|
|
1672
|
+
assert chat_span is not None, "chat span not found"
|
|
1673
|
+
|
|
1674
|
+
# Helper to check for PDF attachment
|
|
1675
|
+
def has_pdf_attachment(span_input):
|
|
1676
|
+
"""Check if span input contains a PDF Braintrust Attachment."""
|
|
1677
|
+
if not span_input:
|
|
1678
|
+
return False
|
|
1679
|
+
|
|
1680
|
+
def check_item(item):
|
|
1681
|
+
"""Recursively check an item for PDF attachments."""
|
|
1682
|
+
if isinstance(item, dict):
|
|
1683
|
+
if item.get("type") == "binary" and item.get("media_type") == "application/pdf":
|
|
1684
|
+
attachment = item.get("attachment")
|
|
1685
|
+
if isinstance(attachment, Attachment):
|
|
1686
|
+
if attachment._reference.get("content_type") == "application/pdf":
|
|
1687
|
+
return True
|
|
1688
|
+
# Check nested content field (for UserPromptPart-like structures)
|
|
1689
|
+
if "content" in item:
|
|
1690
|
+
content = item["content"]
|
|
1691
|
+
if isinstance(content, list):
|
|
1692
|
+
for sub_item in content:
|
|
1693
|
+
if check_item(sub_item):
|
|
1694
|
+
return True
|
|
1695
|
+
return False
|
|
1696
|
+
|
|
1697
|
+
# Check user_prompt (agent_run span)
|
|
1698
|
+
if "user_prompt" in span_input:
|
|
1699
|
+
user_prompt = span_input["user_prompt"]
|
|
1700
|
+
if isinstance(user_prompt, list):
|
|
1701
|
+
for item in user_prompt:
|
|
1702
|
+
if check_item(item):
|
|
1703
|
+
return True
|
|
1704
|
+
|
|
1705
|
+
# Check messages (chat span)
|
|
1706
|
+
if "messages" in span_input:
|
|
1707
|
+
messages = span_input["messages"]
|
|
1708
|
+
if isinstance(messages, list):
|
|
1709
|
+
for msg in messages:
|
|
1710
|
+
if isinstance(msg, dict) and "parts" in msg:
|
|
1711
|
+
parts = msg["parts"]
|
|
1712
|
+
if isinstance(parts, list):
|
|
1713
|
+
for part in parts:
|
|
1714
|
+
if check_item(part):
|
|
1715
|
+
return True
|
|
1716
|
+
|
|
1717
|
+
return False
|
|
1718
|
+
|
|
1719
|
+
# Verify agent_run span has PDF attachment
|
|
1720
|
+
assert has_pdf_attachment(agent_span.get("input", {})), (
|
|
1721
|
+
"agent_run span should have PDF BinaryContent converted to Braintrust Attachment"
|
|
1722
|
+
)
|
|
1723
|
+
|
|
1724
|
+
# Verify chat span has PDF attachment (critical for document input)
|
|
1725
|
+
assert has_pdf_attachment(chat_span.get("input", {})), (
|
|
1726
|
+
"chat span should have PDF BinaryContent converted to Braintrust Attachment. "
|
|
1727
|
+
"This ensures documents are properly traced in the low-level model call. "
|
|
1728
|
+
f"Chat span input: {chat_span.get('input', {})}"
|
|
1729
|
+
)
|
|
1730
|
+
|
|
1731
|
+
# Verify metrics
|
|
1732
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
1733
|
+
_assert_metrics_are_valid(chat_span["metrics"], start, end)
|
|
1734
|
+
|
|
1387
1735
|
|
|
1388
1736
|
@pytest.mark.vcr
|
|
1389
1737
|
@pytest.mark.asyncio
|
|
@@ -1459,6 +1807,7 @@ async def test_agent_with_tool_execution(memory_logger):
|
|
|
1459
1807
|
assert "toolsets" not in agent_span["metadata"], "toolsets should NOT be in metadata"
|
|
1460
1808
|
|
|
1461
1809
|
|
|
1810
|
+
@pytest.mark.vcr
|
|
1462
1811
|
def test_tool_execution_creates_spans(memory_logger):
|
|
1463
1812
|
"""Test that executing tools with agents works and creates traced spans.
|
|
1464
1813
|
|
|
@@ -1786,3 +2135,451 @@ async def test_agent_run_stream_structured_output(memory_logger):
|
|
|
1786
2135
|
assert chat_span["span_parents"] == [agent_span["span_id"]], "chat span should be nested under agent_run_stream"
|
|
1787
2136
|
assert chat_span["metadata"]["model"] == "gpt-4o-mini"
|
|
1788
2137
|
_assert_metrics_are_valid(chat_span["metrics"], start, end)
|
|
2138
|
+
|
|
2139
|
+
|
|
2140
|
+
@pytest.mark.vcr
|
|
2141
|
+
@pytest.mark.asyncio
|
|
2142
|
+
async def test_model_class_span_names(memory_logger):
|
|
2143
|
+
"""Test that model class spans have proper names.
|
|
2144
|
+
|
|
2145
|
+
Verifies that the nested chat span from the model class wrapper has a
|
|
2146
|
+
meaningful name (either the model name or class name), not a misleading
|
|
2147
|
+
string like 'log'.
|
|
2148
|
+
|
|
2149
|
+
This test ensures that when model_name is None, we fall back to the
|
|
2150
|
+
class name (e.g., 'OpenAIChatModel') rather than str(instance) which
|
|
2151
|
+
could return unexpected values.
|
|
2152
|
+
"""
|
|
2153
|
+
assert not memory_logger.pop()
|
|
2154
|
+
|
|
2155
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
|
|
2156
|
+
|
|
2157
|
+
start = time.time()
|
|
2158
|
+
result = await agent.run("What is 2+2?")
|
|
2159
|
+
end = time.time()
|
|
2160
|
+
|
|
2161
|
+
assert result.output
|
|
2162
|
+
|
|
2163
|
+
# Check spans
|
|
2164
|
+
spans = memory_logger.pop()
|
|
2165
|
+
assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
|
|
2166
|
+
|
|
2167
|
+
# Find chat span (the nested model class span)
|
|
2168
|
+
chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
|
|
2169
|
+
assert chat_span is not None, "chat span not found"
|
|
2170
|
+
|
|
2171
|
+
span_name = chat_span["span_attributes"]["name"]
|
|
2172
|
+
|
|
2173
|
+
# Verify the span name is meaningful
|
|
2174
|
+
# It should be either "chat <model_name>" or "chat <ClassName>"
|
|
2175
|
+
# but NOT "chat log" or other misleading names
|
|
2176
|
+
assert span_name.startswith("chat "), f"Chat span should start with 'chat ', got: {span_name}"
|
|
2177
|
+
|
|
2178
|
+
# Extract the model/class identifier part after "chat "
|
|
2179
|
+
identifier = span_name[5:] # Skip "chat "
|
|
2180
|
+
|
|
2181
|
+
# Should not be empty or misleading values
|
|
2182
|
+
assert identifier, "Chat span should have a model name or class name after 'chat '"
|
|
2183
|
+
assert identifier != "log", "Chat span should not be named 'log' - should use model name or class name"
|
|
2184
|
+
assert len(identifier) > 2, f"Chat span identifier seems too short: {identifier}"
|
|
2185
|
+
|
|
2186
|
+
# Common valid patterns:
|
|
2187
|
+
# - "chat gpt-4o-mini" (model name extracted)
|
|
2188
|
+
# - "chat OpenAIChatModel" (class name fallback)
|
|
2189
|
+
# - "chat gpt-4o" (model name)
|
|
2190
|
+
valid_patterns = [
|
|
2191
|
+
"gpt-" in identifier, # OpenAI model names
|
|
2192
|
+
"claude" in identifier.lower(), # Anthropic models
|
|
2193
|
+
"Model" in identifier, # Class name fallback (e.g., OpenAIChatModel)
|
|
2194
|
+
"-" in identifier, # Model names typically have hyphens
|
|
2195
|
+
]
|
|
2196
|
+
|
|
2197
|
+
assert any(valid_patterns), (
|
|
2198
|
+
f"Chat span name '{span_name}' doesn't match expected patterns. "
|
|
2199
|
+
f"Should contain model name (e.g., 'gpt-4o-mini') or class name (e.g., 'OpenAIChatModel')"
|
|
2200
|
+
)
|
|
2201
|
+
|
|
2202
|
+
# Verify span has proper structure
|
|
2203
|
+
assert chat_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
2204
|
+
assert chat_span["metadata"]["model"] == "gpt-4o-mini"
|
|
2205
|
+
_assert_metrics_are_valid(chat_span["metrics"], start, end)
|
|
2206
|
+
|
|
2207
|
+
|
|
2208
|
+
def test_serialize_content_part_with_binary_content():
|
|
2209
|
+
"""Unit test to verify _serialize_content_part handles BinaryContent correctly.
|
|
2210
|
+
|
|
2211
|
+
This tests the direct serialization of BinaryContent objects and verifies
|
|
2212
|
+
they are converted to Braintrust Attachment objects.
|
|
2213
|
+
"""
|
|
2214
|
+
from braintrust.logger import Attachment
|
|
2215
|
+
from braintrust.wrappers.pydantic_ai import _serialize_content_part
|
|
2216
|
+
from pydantic_ai.models.function import BinaryContent
|
|
2217
|
+
|
|
2218
|
+
# Test 1: Direct BinaryContent serialization
|
|
2219
|
+
binary = BinaryContent(data=b"test pdf data", media_type="application/pdf")
|
|
2220
|
+
result = _serialize_content_part(binary)
|
|
2221
|
+
|
|
2222
|
+
assert result is not None, "Should serialize BinaryContent"
|
|
2223
|
+
assert result["type"] == "binary", "Should have type 'binary'"
|
|
2224
|
+
assert result["media_type"] == "application/pdf", "Should preserve media_type"
|
|
2225
|
+
assert isinstance(result["attachment"], Attachment), "Should convert to Braintrust Attachment"
|
|
2226
|
+
|
|
2227
|
+
# Verify attachment has correct content_type
|
|
2228
|
+
assert result["attachment"]._reference["content_type"] == "application/pdf"
|
|
2229
|
+
|
|
2230
|
+
|
|
2231
|
+
def test_serialize_content_part_with_user_prompt_part():
|
|
2232
|
+
"""Unit test to verify _serialize_content_part handles UserPromptPart with nested BinaryContent.
|
|
2233
|
+
|
|
2234
|
+
This is the critical test for the bug: when a UserPromptPart has a content list
|
|
2235
|
+
containing BinaryContent, we need to recursively serialize the content items
|
|
2236
|
+
so that BinaryContent is converted to Braintrust Attachment.
|
|
2237
|
+
"""
|
|
2238
|
+
from braintrust.logger import Attachment
|
|
2239
|
+
from braintrust.wrappers.pydantic_ai import _serialize_content_part
|
|
2240
|
+
from pydantic_ai.messages import UserPromptPart
|
|
2241
|
+
from pydantic_ai.models.function import BinaryContent
|
|
2242
|
+
|
|
2243
|
+
# Create a UserPromptPart with mixed content (BinaryContent + string)
|
|
2244
|
+
pdf_data = b"%PDF-1.4 test document content"
|
|
2245
|
+
binary = BinaryContent(data=pdf_data, media_type="application/pdf")
|
|
2246
|
+
user_prompt_part = UserPromptPart(content=[binary, "What is in this document?"])
|
|
2247
|
+
|
|
2248
|
+
# Serialize the UserPromptPart
|
|
2249
|
+
result = _serialize_content_part(user_prompt_part)
|
|
2250
|
+
|
|
2251
|
+
# Verify the result is a dict with serialized content
|
|
2252
|
+
assert isinstance(result, dict), f"Should return dict, got {type(result)}"
|
|
2253
|
+
assert "content" in result, f"Should have 'content' key. Keys: {result.keys()}"
|
|
2254
|
+
|
|
2255
|
+
content = result["content"]
|
|
2256
|
+
assert isinstance(content, list), f"Content should be a list, got {type(content)}"
|
|
2257
|
+
assert len(content) == 2, f"Should have 2 content items, got {len(content)}"
|
|
2258
|
+
|
|
2259
|
+
# CRITICAL: First item should be serialized BinaryContent with Attachment
|
|
2260
|
+
binary_item = content[0]
|
|
2261
|
+
assert isinstance(binary_item, dict), f"Binary item should be dict, got {type(binary_item)}"
|
|
2262
|
+
assert binary_item.get("type") == "binary", (
|
|
2263
|
+
f"Binary item should have type='binary'. Got: {binary_item}"
|
|
2264
|
+
)
|
|
2265
|
+
assert "attachment" in binary_item, (
|
|
2266
|
+
f"Binary item should have 'attachment' key. Keys: {binary_item.keys()}"
|
|
2267
|
+
)
|
|
2268
|
+
assert isinstance(binary_item["attachment"], Attachment), (
|
|
2269
|
+
f"Should be Braintrust Attachment, got {type(binary_item.get('attachment'))}"
|
|
2270
|
+
)
|
|
2271
|
+
assert binary_item["media_type"] == "application/pdf"
|
|
2272
|
+
|
|
2273
|
+
# Second item should be the string
|
|
2274
|
+
assert content[1] == "What is in this document?"
|
|
2275
|
+
|
|
2276
|
+
|
|
2277
|
+
def test_serialize_messages_with_binary_content():
|
|
2278
|
+
"""Unit test to verify _serialize_messages handles ModelRequest with BinaryContent in parts.
|
|
2279
|
+
|
|
2280
|
+
This tests the full message serialization path that's used for the chat span,
|
|
2281
|
+
ensuring that nested BinaryContent in UserPromptPart is properly converted.
|
|
2282
|
+
"""
|
|
2283
|
+
from braintrust.logger import Attachment
|
|
2284
|
+
from braintrust.wrappers.pydantic_ai import _serialize_messages
|
|
2285
|
+
from pydantic_ai.messages import ModelRequest, UserPromptPart
|
|
2286
|
+
from pydantic_ai.models.function import BinaryContent
|
|
2287
|
+
|
|
2288
|
+
# Create a ModelRequest with UserPromptPart containing BinaryContent
|
|
2289
|
+
pdf_data = b"%PDF-1.4 test document content"
|
|
2290
|
+
binary = BinaryContent(data=pdf_data, media_type="application/pdf")
|
|
2291
|
+
user_prompt_part = UserPromptPart(content=[binary, "What is in this document?"])
|
|
2292
|
+
model_request = ModelRequest(parts=[user_prompt_part])
|
|
2293
|
+
|
|
2294
|
+
# Serialize the messages
|
|
2295
|
+
messages = [model_request]
|
|
2296
|
+
result = _serialize_messages(messages)
|
|
2297
|
+
|
|
2298
|
+
# Verify structure
|
|
2299
|
+
assert len(result) == 1, f"Should have 1 message, got {len(result)}"
|
|
2300
|
+
msg = result[0]
|
|
2301
|
+
assert "parts" in msg, f"Message should have 'parts'. Keys: {msg.keys()}"
|
|
2302
|
+
|
|
2303
|
+
parts = msg["parts"]
|
|
2304
|
+
assert len(parts) == 1, f"Should have 1 part, got {len(parts)}"
|
|
2305
|
+
|
|
2306
|
+
part = parts[0]
|
|
2307
|
+
assert isinstance(part, dict), f"Part should be dict, got {type(part)}"
|
|
2308
|
+
assert "content" in part, f"Part should have 'content'. Keys: {part.keys()}"
|
|
2309
|
+
|
|
2310
|
+
content = part["content"]
|
|
2311
|
+
assert isinstance(content, list), f"Content should be list, got {type(content)}"
|
|
2312
|
+
assert len(content) == 2, f"Should have 2 content items, got {len(content)}"
|
|
2313
|
+
|
|
2314
|
+
# CRITICAL: First content item should be serialized BinaryContent with Attachment
|
|
2315
|
+
binary_item = content[0]
|
|
2316
|
+
assert isinstance(binary_item, dict), f"Binary item should be dict, got {type(binary_item)}"
|
|
2317
|
+
assert binary_item.get("type") == "binary", (
|
|
2318
|
+
f"Binary item should have type='binary'. Got: {binary_item}"
|
|
2319
|
+
)
|
|
2320
|
+
assert "attachment" in binary_item, (
|
|
2321
|
+
f"Binary item should have 'attachment'. Keys: {binary_item.keys()}"
|
|
2322
|
+
)
|
|
2323
|
+
assert isinstance(binary_item["attachment"], Attachment), (
|
|
2324
|
+
f"Should be Braintrust Attachment, got {type(binary_item.get('attachment'))}"
|
|
2325
|
+
)
|
|
2326
|
+
assert binary_item["media_type"] == "application/pdf"
|
|
2327
|
+
|
|
2328
|
+
# Second content item should be the string
|
|
2329
|
+
assert content[1] == "What is in this document?"
|
|
2330
|
+
|
|
2331
|
+
|
|
2332
|
+
@pytest.mark.asyncio
|
|
2333
|
+
async def test_streaming_wrappers_capture_time_to_first_token():
|
|
2334
|
+
"""Unit test verifying all streaming wrappers capture time_to_first_token.
|
|
2335
|
+
|
|
2336
|
+
This test uses mocks to verify the internal wrapper logic without requiring
|
|
2337
|
+
API calls. It ensures that _first_token_time is tracked correctly in:
|
|
2338
|
+
- _AgentStreamWrapper (async agent streaming)
|
|
2339
|
+
- _DirectStreamWrapper (async direct API streaming)
|
|
2340
|
+
- _AgentStreamResultSyncProxy (sync agent streaming)
|
|
2341
|
+
- _DirectStreamWrapperSync (sync direct API streaming)
|
|
2342
|
+
"""
|
|
2343
|
+
from unittest.mock import AsyncMock, MagicMock, Mock
|
|
2344
|
+
|
|
2345
|
+
from braintrust.wrappers.pydantic_ai import (
|
|
2346
|
+
_AgentStreamResultSyncProxy,
|
|
2347
|
+
_AgentStreamWrapper,
|
|
2348
|
+
_DirectStreamIteratorProxy,
|
|
2349
|
+
_DirectStreamIteratorSyncProxy,
|
|
2350
|
+
_DirectStreamWrapper,
|
|
2351
|
+
_DirectStreamWrapperSync,
|
|
2352
|
+
_StreamResultProxy,
|
|
2353
|
+
)
|
|
2354
|
+
|
|
2355
|
+
# Test 1: _AgentStreamWrapper captures first token time
|
|
2356
|
+
print("\n--- Testing _AgentStreamWrapper ---")
|
|
2357
|
+
|
|
2358
|
+
class MockStreamResult:
|
|
2359
|
+
async def stream_text(self, delta=True):
|
|
2360
|
+
for i in range(3):
|
|
2361
|
+
await asyncio.sleep(0.001)
|
|
2362
|
+
yield f"token{i} "
|
|
2363
|
+
|
|
2364
|
+
def usage(self):
|
|
2365
|
+
usage_mock = Mock(input_tokens=50, output_tokens=20, total_tokens=70)
|
|
2366
|
+
usage_mock.cache_read_tokens = None
|
|
2367
|
+
usage_mock.cache_write_tokens = None
|
|
2368
|
+
return usage_mock
|
|
2369
|
+
|
|
2370
|
+
mock_stream_result = MockStreamResult()
|
|
2371
|
+
wrapper = _AgentStreamWrapper(
|
|
2372
|
+
stream_cm=AsyncMock(),
|
|
2373
|
+
span_name="test_stream",
|
|
2374
|
+
input_data={"prompt": "test"},
|
|
2375
|
+
metadata={"model": "gpt-4o"},
|
|
2376
|
+
)
|
|
2377
|
+
|
|
2378
|
+
wrapper.span_cm = MagicMock()
|
|
2379
|
+
wrapper.span_cm.__enter__ = MagicMock()
|
|
2380
|
+
wrapper.start_time = time.time()
|
|
2381
|
+
wrapper.stream_result = mock_stream_result
|
|
2382
|
+
|
|
2383
|
+
proxy = _StreamResultProxy(mock_stream_result, wrapper)
|
|
2384
|
+
|
|
2385
|
+
assert wrapper._first_token_time is None
|
|
2386
|
+
|
|
2387
|
+
chunk_count = 0
|
|
2388
|
+
async for text in proxy.stream_text(delta=True):
|
|
2389
|
+
chunk_count += 1
|
|
2390
|
+
if chunk_count == 1:
|
|
2391
|
+
assert wrapper._first_token_time is not None
|
|
2392
|
+
assert wrapper._first_token_time > wrapper.start_time
|
|
2393
|
+
|
|
2394
|
+
assert chunk_count == 3
|
|
2395
|
+
assert wrapper._first_token_time is not None
|
|
2396
|
+
print("✓ _AgentStreamWrapper captures first token time")
|
|
2397
|
+
|
|
2398
|
+
# Test 2: _DirectStreamWrapper captures first token time
|
|
2399
|
+
print("\n--- Testing _DirectStreamWrapper ---")
|
|
2400
|
+
|
|
2401
|
+
class MockStream:
|
|
2402
|
+
def __init__(self):
|
|
2403
|
+
self.chunks = []
|
|
2404
|
+
|
|
2405
|
+
async def __anext__(self):
|
|
2406
|
+
if len(self.chunks) < 3:
|
|
2407
|
+
await asyncio.sleep(0.001)
|
|
2408
|
+
chunk = Mock(delta=Mock(content_delta=f"chunk{len(self.chunks)}"))
|
|
2409
|
+
self.chunks.append(chunk)
|
|
2410
|
+
return chunk
|
|
2411
|
+
raise StopAsyncIteration
|
|
2412
|
+
|
|
2413
|
+
def __aiter__(self):
|
|
2414
|
+
return self
|
|
2415
|
+
|
|
2416
|
+
def get(self):
|
|
2417
|
+
usage_mock = Mock(input_tokens=50, output_tokens=20, total_tokens=70)
|
|
2418
|
+
usage_mock.cache_read_tokens = None
|
|
2419
|
+
usage_mock.cache_write_tokens = None
|
|
2420
|
+
return Mock(usage=usage_mock)
|
|
2421
|
+
|
|
2422
|
+
mock_stream = MockStream()
|
|
2423
|
+
direct_wrapper = _DirectStreamWrapper(
|
|
2424
|
+
stream_cm=AsyncMock(),
|
|
2425
|
+
span_name="test_direct_stream",
|
|
2426
|
+
input_data={"messages": []},
|
|
2427
|
+
metadata={"model": "gpt-4o"},
|
|
2428
|
+
)
|
|
2429
|
+
|
|
2430
|
+
direct_wrapper.span_cm = MagicMock()
|
|
2431
|
+
direct_wrapper.start_time = time.time()
|
|
2432
|
+
direct_wrapper.stream = mock_stream
|
|
2433
|
+
|
|
2434
|
+
proxy = _DirectStreamIteratorProxy(mock_stream, direct_wrapper)
|
|
2435
|
+
|
|
2436
|
+
assert direct_wrapper._first_token_time is None
|
|
2437
|
+
|
|
2438
|
+
chunk_count = 0
|
|
2439
|
+
async for chunk in proxy:
|
|
2440
|
+
chunk_count += 1
|
|
2441
|
+
if chunk_count == 1:
|
|
2442
|
+
assert direct_wrapper._first_token_time is not None
|
|
2443
|
+
assert direct_wrapper._first_token_time > direct_wrapper.start_time
|
|
2444
|
+
|
|
2445
|
+
assert chunk_count == 3
|
|
2446
|
+
assert direct_wrapper._first_token_time is not None
|
|
2447
|
+
print("✓ _DirectStreamWrapper captures first token time")
|
|
2448
|
+
|
|
2449
|
+
# Test 3: _AgentStreamResultSyncProxy captures first token time
|
|
2450
|
+
print("\n--- Testing _AgentStreamResultSyncProxy ---")
|
|
2451
|
+
|
|
2452
|
+
class MockSyncStreamResult:
|
|
2453
|
+
def stream_text(self, delta=True):
|
|
2454
|
+
for i in range(3):
|
|
2455
|
+
time.sleep(0.001)
|
|
2456
|
+
yield f"token{i} "
|
|
2457
|
+
|
|
2458
|
+
def usage(self):
|
|
2459
|
+
usage_mock = Mock(input_tokens=50, output_tokens=20, total_tokens=70)
|
|
2460
|
+
usage_mock.cache_read_tokens = None
|
|
2461
|
+
usage_mock.cache_write_tokens = None
|
|
2462
|
+
return usage_mock
|
|
2463
|
+
|
|
2464
|
+
mock_sync_result = MockSyncStreamResult()
|
|
2465
|
+
sync_proxy = _AgentStreamResultSyncProxy(
|
|
2466
|
+
stream_result=mock_sync_result,
|
|
2467
|
+
span=MagicMock(),
|
|
2468
|
+
span_cm=MagicMock(),
|
|
2469
|
+
start_time=time.time(),
|
|
2470
|
+
)
|
|
2471
|
+
|
|
2472
|
+
assert sync_proxy._first_token_time is None
|
|
2473
|
+
|
|
2474
|
+
chunk_count = 0
|
|
2475
|
+
for text in sync_proxy.stream_text(delta=True):
|
|
2476
|
+
chunk_count += 1
|
|
2477
|
+
if chunk_count == 1:
|
|
2478
|
+
assert sync_proxy._first_token_time is not None
|
|
2479
|
+
|
|
2480
|
+
assert chunk_count == 3
|
|
2481
|
+
assert sync_proxy._first_token_time is not None
|
|
2482
|
+
print("✓ _AgentStreamResultSyncProxy captures first token time")
|
|
2483
|
+
|
|
2484
|
+
# Test 4: _DirectStreamWrapperSync captures first token time
|
|
2485
|
+
print("\n--- Testing _DirectStreamWrapperSync ---")
|
|
2486
|
+
|
|
2487
|
+
class MockSyncStream:
|
|
2488
|
+
def __init__(self):
|
|
2489
|
+
self.chunks = []
|
|
2490
|
+
|
|
2491
|
+
def __iter__(self):
|
|
2492
|
+
return self
|
|
2493
|
+
|
|
2494
|
+
def __next__(self):
|
|
2495
|
+
if len(self.chunks) < 3:
|
|
2496
|
+
time.sleep(0.001)
|
|
2497
|
+
chunk = Mock(delta=Mock(content_delta=f"chunk{len(self.chunks)}"))
|
|
2498
|
+
self.chunks.append(chunk)
|
|
2499
|
+
return chunk
|
|
2500
|
+
raise StopIteration
|
|
2501
|
+
|
|
2502
|
+
def get(self):
|
|
2503
|
+
usage_mock = Mock(input_tokens=50, output_tokens=20, total_tokens=70)
|
|
2504
|
+
usage_mock.cache_read_tokens = None
|
|
2505
|
+
usage_mock.cache_write_tokens = None
|
|
2506
|
+
return Mock(usage=usage_mock)
|
|
2507
|
+
|
|
2508
|
+
mock_sync_stream = MockSyncStream()
|
|
2509
|
+
sync_wrapper = _DirectStreamWrapperSync(
|
|
2510
|
+
stream_cm=MagicMock(),
|
|
2511
|
+
span_name="test_sync_stream",
|
|
2512
|
+
input_data={"messages": []},
|
|
2513
|
+
metadata={"model": "gpt-4o"},
|
|
2514
|
+
)
|
|
2515
|
+
|
|
2516
|
+
sync_wrapper.start_time = time.time()
|
|
2517
|
+
sync_wrapper.stream = mock_sync_stream
|
|
2518
|
+
|
|
2519
|
+
sync_proxy = _DirectStreamIteratorSyncProxy(mock_sync_stream, sync_wrapper)
|
|
2520
|
+
|
|
2521
|
+
assert sync_wrapper._first_token_time is None
|
|
2522
|
+
|
|
2523
|
+
chunk_count = 0
|
|
2524
|
+
for chunk in sync_proxy:
|
|
2525
|
+
chunk_count += 1
|
|
2526
|
+
if chunk_count == 1:
|
|
2527
|
+
assert sync_wrapper._first_token_time is not None
|
|
2528
|
+
assert sync_wrapper._first_token_time > sync_wrapper.start_time
|
|
2529
|
+
|
|
2530
|
+
assert chunk_count == 3
|
|
2531
|
+
assert sync_wrapper._first_token_time is not None
|
|
2532
|
+
print("✓ _DirectStreamWrapperSync captures first token time")
|
|
2533
|
+
|
|
2534
|
+
print("\n✅ All streaming wrapper unit tests passed!")
|
|
2535
|
+
|
|
2536
|
+
|
|
2537
|
+
@pytest.mark.asyncio
|
|
2538
|
+
async def test_attachment_preserved_in_model_settings(memory_logger):
|
|
2539
|
+
"""Test that attachments in model_settings are preserved through serialization."""
|
|
2540
|
+
from braintrust.bt_json import bt_safe_deep_copy
|
|
2541
|
+
from braintrust.logger import Attachment
|
|
2542
|
+
|
|
2543
|
+
attachment = Attachment(data=b"config data", filename="config.txt", content_type="text/plain")
|
|
2544
|
+
|
|
2545
|
+
# Simulate model_settings with attachment
|
|
2546
|
+
settings = {"temperature": 0.7, "context_file": attachment}
|
|
2547
|
+
|
|
2548
|
+
# Test bt_safe_deep_copy preserves attachment
|
|
2549
|
+
copied = bt_safe_deep_copy(settings)
|
|
2550
|
+
assert copied["context_file"] is attachment
|
|
2551
|
+
assert copied["temperature"] == 0.7
|
|
2552
|
+
|
|
2553
|
+
|
|
2554
|
+
@pytest.mark.asyncio
|
|
2555
|
+
async def test_attachment_in_message_part(memory_logger):
|
|
2556
|
+
"""Test that attachment in custom message part is preserved."""
|
|
2557
|
+
from braintrust.bt_json import bt_safe_deep_copy
|
|
2558
|
+
from braintrust.logger import Attachment
|
|
2559
|
+
|
|
2560
|
+
attachment = Attachment(data=b"message data", filename="msg.txt", content_type="text/plain")
|
|
2561
|
+
|
|
2562
|
+
# Simulate message part with attachment
|
|
2563
|
+
message_part = {"type": "file", "content": attachment, "metadata": {"source": "upload"}}
|
|
2564
|
+
|
|
2565
|
+
copied = bt_safe_deep_copy(message_part)
|
|
2566
|
+
assert copied["content"] is attachment
|
|
2567
|
+
assert copied["type"] == "file"
|
|
2568
|
+
|
|
2569
|
+
|
|
2570
|
+
@pytest.mark.asyncio
|
|
2571
|
+
async def test_attachment_in_result_data(memory_logger):
|
|
2572
|
+
"""Test that attachment in custom result data is preserved."""
|
|
2573
|
+
from braintrust.bt_json import bt_safe_deep_copy
|
|
2574
|
+
from braintrust.logger import ExternalAttachment
|
|
2575
|
+
|
|
2576
|
+
ext_attachment = ExternalAttachment(
|
|
2577
|
+
url="s3://bucket/result.pdf", filename="result.pdf", content_type="application/pdf"
|
|
2578
|
+
)
|
|
2579
|
+
|
|
2580
|
+
# Simulate agent result with attachment
|
|
2581
|
+
result_data = {"success": True, "output_file": ext_attachment, "metadata": {"processed": True}}
|
|
2582
|
+
|
|
2583
|
+
copied = bt_safe_deep_copy(result_data)
|
|
2584
|
+
assert copied["output_file"] is ext_attachment
|
|
2585
|
+
assert copied["success"] is True
|