braintrust 0.3.14__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/__init__.py +4 -0
- braintrust/_generated_types.py +596 -72
- braintrust/conftest.py +1 -0
- braintrust/functions/invoke.py +35 -2
- braintrust/generated_types.py +15 -1
- braintrust/oai.py +88 -6
- braintrust/version.py +2 -2
- braintrust/wrappers/pydantic_ai.py +1203 -0
- braintrust/wrappers/test_oai_attachments.py +322 -0
- braintrust/wrappers/test_pydantic_ai_integration.py +1788 -0
- braintrust/wrappers/{test_pydantic_ai.py → test_pydantic_ai_wrap_openai.py} +1 -2
- {braintrust-0.3.14.dist-info → braintrust-0.3.15.dist-info}/METADATA +1 -1
- {braintrust-0.3.14.dist-info → braintrust-0.3.15.dist-info}/RECORD +16 -13
- {braintrust-0.3.14.dist-info → braintrust-0.3.15.dist-info}/WHEEL +0 -0
- {braintrust-0.3.14.dist-info → braintrust-0.3.15.dist-info}/entry_points.txt +0 -0
- {braintrust-0.3.14.dist-info → braintrust-0.3.15.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1788 @@
|
|
|
1
|
+
# pyright: reportUntypedFunctionDecorator=false
|
|
2
|
+
# pyright: reportUnknownMemberType=false
|
|
3
|
+
# pyright: reportUnknownParameterType=false
|
|
4
|
+
# pyright: reportPrivateUsage=false
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from braintrust import logger, setup_pydantic_ai, traced
|
|
9
|
+
from braintrust.span_types import SpanTypeAttribute
|
|
10
|
+
from braintrust.test_helpers import init_test_logger
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
from pydantic_ai import Agent, ModelSettings
|
|
13
|
+
from pydantic_ai.messages import ModelRequest, UserPromptPart
|
|
14
|
+
|
|
15
|
+
PROJECT_NAME = "test-pydantic-ai-integration"
|
|
16
|
+
MODEL = "openai:gpt-4o-mini" # Use cheaper model for tests
|
|
17
|
+
TEST_PROMPT = "What is 2+2? Answer with just the number."
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.fixture(scope="module", autouse=True)
|
|
21
|
+
def setup_wrapper():
|
|
22
|
+
"""Setup pydantic_ai wrapper before any tests run."""
|
|
23
|
+
setup_pydantic_ai(project_name=PROJECT_NAME)
|
|
24
|
+
yield
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.fixture(scope="module")
|
|
28
|
+
def direct():
|
|
29
|
+
"""Provide pydantic_ai.direct module after setup_wrapper has run."""
|
|
30
|
+
import pydantic_ai.direct as direct_module
|
|
31
|
+
return direct_module
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.fixture(scope="module")
|
|
35
|
+
def vcr_config():
|
|
36
|
+
return {
|
|
37
|
+
"filter_headers": [
|
|
38
|
+
"authorization",
|
|
39
|
+
"openai-organization",
|
|
40
|
+
"x-api-key",
|
|
41
|
+
]
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@pytest.fixture
|
|
46
|
+
def memory_logger():
|
|
47
|
+
init_test_logger(PROJECT_NAME)
|
|
48
|
+
with logger._internal_with_memory_background_logger() as bgl:
|
|
49
|
+
yield bgl
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _assert_metrics_are_valid(metrics, start, end):
|
|
53
|
+
"""Assert that metrics contain expected fields and values."""
|
|
54
|
+
assert "start" in metrics
|
|
55
|
+
assert "end" in metrics
|
|
56
|
+
assert "duration" in metrics
|
|
57
|
+
assert start <= metrics["start"] <= metrics["end"] <= end
|
|
58
|
+
assert metrics["duration"] > 0
|
|
59
|
+
|
|
60
|
+
# Token metrics (if present)
|
|
61
|
+
if "tokens" in metrics:
|
|
62
|
+
assert metrics["tokens"] > 0
|
|
63
|
+
if "prompt_tokens" in metrics:
|
|
64
|
+
assert metrics["prompt_tokens"] > 0
|
|
65
|
+
if "completion_tokens" in metrics:
|
|
66
|
+
assert metrics["completion_tokens"] > 0
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.mark.vcr
|
|
70
|
+
@pytest.mark.asyncio
|
|
71
|
+
async def test_agent_run_async(memory_logger):
|
|
72
|
+
"""Test Agent.run() async method."""
|
|
73
|
+
assert not memory_logger.pop()
|
|
74
|
+
|
|
75
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
|
|
76
|
+
|
|
77
|
+
start = time.time()
|
|
78
|
+
result = await agent.run(TEST_PROMPT)
|
|
79
|
+
end = time.time()
|
|
80
|
+
|
|
81
|
+
# Verify the result
|
|
82
|
+
assert result.output
|
|
83
|
+
assert "4" in str(result.output)
|
|
84
|
+
|
|
85
|
+
# Check spans - should now have parent agent_run + nested chat span
|
|
86
|
+
spans = memory_logger.pop()
|
|
87
|
+
assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
|
|
88
|
+
|
|
89
|
+
# Find agent_run and chat spans
|
|
90
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"]), None)
|
|
91
|
+
chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
|
|
92
|
+
|
|
93
|
+
assert agent_span is not None, "agent_run span not found"
|
|
94
|
+
assert chat_span is not None, "chat span not found"
|
|
95
|
+
|
|
96
|
+
# Check agent span
|
|
97
|
+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
98
|
+
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
|
|
99
|
+
assert agent_span["metadata"]["provider"] == "openai"
|
|
100
|
+
assert TEST_PROMPT in str(agent_span["input"])
|
|
101
|
+
assert "4" in str(agent_span["output"])
|
|
102
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
103
|
+
|
|
104
|
+
# Check chat span is nested under agent span (use span_id, not id which is the row ID)
|
|
105
|
+
assert chat_span["span_parents"] == [agent_span["span_id"]], "chat span should be nested under agent_run"
|
|
106
|
+
assert chat_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
107
|
+
assert "gpt-4o-mini" in chat_span["span_attributes"]["name"]
|
|
108
|
+
assert chat_span["metadata"]["model"] == "gpt-4o-mini"
|
|
109
|
+
assert chat_span["metadata"]["provider"] == "openai"
|
|
110
|
+
_assert_metrics_are_valid(chat_span["metrics"], start, end)
|
|
111
|
+
|
|
112
|
+
# Agent spans should have token metrics
|
|
113
|
+
assert "prompt_tokens" in agent_span["metrics"]
|
|
114
|
+
assert "completion_tokens" in agent_span["metrics"]
|
|
115
|
+
assert agent_span["metrics"]["prompt_tokens"] > 0
|
|
116
|
+
assert agent_span["metrics"]["completion_tokens"] > 0
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@pytest.mark.vcr
|
|
120
|
+
def test_agent_run_sync(memory_logger):
|
|
121
|
+
"""Test Agent.run_sync() synchronous method."""
|
|
122
|
+
assert not memory_logger.pop()
|
|
123
|
+
|
|
124
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
|
|
125
|
+
|
|
126
|
+
start = time.time()
|
|
127
|
+
result = agent.run_sync(TEST_PROMPT)
|
|
128
|
+
end = time.time()
|
|
129
|
+
|
|
130
|
+
# Verify the result
|
|
131
|
+
assert result.output
|
|
132
|
+
assert "4" in str(result.output)
|
|
133
|
+
|
|
134
|
+
# Check spans - should have parent agent_run_sync + nested spans
|
|
135
|
+
spans = memory_logger.pop()
|
|
136
|
+
assert len(spans) >= 2, f"Expected at least 2 spans (agent_run_sync + chat), got {len(spans)}"
|
|
137
|
+
|
|
138
|
+
# Find agent_run_sync and chat spans
|
|
139
|
+
agent_sync_span = next((s for s in spans if "agent_run_sync" in s["span_attributes"]["name"]), None)
|
|
140
|
+
chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
|
|
141
|
+
|
|
142
|
+
assert agent_sync_span is not None, "agent_run_sync span not found"
|
|
143
|
+
assert chat_span is not None, "chat span not found"
|
|
144
|
+
|
|
145
|
+
# Check agent span
|
|
146
|
+
assert agent_sync_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
147
|
+
assert agent_sync_span["metadata"]["model"] == "gpt-4o-mini"
|
|
148
|
+
assert agent_sync_span["metadata"]["provider"] == "openai"
|
|
149
|
+
assert TEST_PROMPT in str(agent_sync_span["input"])
|
|
150
|
+
assert "4" in str(agent_sync_span["output"])
|
|
151
|
+
_assert_metrics_are_valid(agent_sync_span["metrics"], start, end)
|
|
152
|
+
|
|
153
|
+
# Check chat span is a descendant of agent_run_sync span
|
|
154
|
+
# Build span tree to verify nesting
|
|
155
|
+
span_by_id = {s["span_id"]: s for s in spans}
|
|
156
|
+
|
|
157
|
+
def is_descendant(child_span, ancestor_id):
|
|
158
|
+
"""Check if child_span is a descendant of ancestor_id."""
|
|
159
|
+
if not child_span.get("span_parents"):
|
|
160
|
+
return False
|
|
161
|
+
if ancestor_id in child_span["span_parents"]:
|
|
162
|
+
return True
|
|
163
|
+
# Check if any parent is a descendant
|
|
164
|
+
for parent_id in chat_span["span_parents"]:
|
|
165
|
+
if parent_id in span_by_id and is_descendant(span_by_id[parent_id], ancestor_id):
|
|
166
|
+
return True
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
assert is_descendant(chat_span, agent_sync_span["span_id"]), "chat span should be nested under agent_run_sync"
|
|
171
|
+
assert chat_span["metadata"]["model"] == "gpt-4o-mini"
|
|
172
|
+
assert chat_span["metadata"]["provider"] == "openai"
|
|
173
|
+
_assert_metrics_are_valid(chat_span["metrics"], start, end)
|
|
174
|
+
|
|
175
|
+
# Agent spans should have token metrics
|
|
176
|
+
assert "prompt_tokens" in agent_sync_span["metrics"]
|
|
177
|
+
assert "completion_tokens" in agent_sync_span["metrics"]
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@pytest.mark.vcr
|
|
181
|
+
@pytest.mark.asyncio
|
|
182
|
+
async def test_multiple_identical_sequential_streams(memory_logger):
|
|
183
|
+
"""Test multiple identical sequential streaming calls to ensure offsets don't accumulate.
|
|
184
|
+
|
|
185
|
+
This test makes 3 identical streaming calls in sequence. If timing is captured correctly,
|
|
186
|
+
each chat span's offset relative to its parent agent span should be roughly the same
|
|
187
|
+
(typically < 100ms). If offsets are accumulating incorrectly, we'd see the second and
|
|
188
|
+
third chat spans having much larger offsets than the first.
|
|
189
|
+
"""
|
|
190
|
+
assert not memory_logger.pop()
|
|
191
|
+
|
|
192
|
+
@traced
|
|
193
|
+
async def run_multiple_identical_streams():
|
|
194
|
+
# Make 3 identical streaming calls
|
|
195
|
+
for i in range(3):
|
|
196
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
|
|
197
|
+
async with agent.run_stream("Count from 1 to 3.") as result:
|
|
198
|
+
full_text = ""
|
|
199
|
+
async for text in result.stream_text(delta=True):
|
|
200
|
+
full_text += text
|
|
201
|
+
print(f"Completed stream {i+1}")
|
|
202
|
+
|
|
203
|
+
await run_multiple_identical_streams()
|
|
204
|
+
|
|
205
|
+
# Check spans
|
|
206
|
+
spans = memory_logger.pop()
|
|
207
|
+
|
|
208
|
+
# Find agent and chat spans
|
|
209
|
+
agent_spans = [s for s in spans if "agent_run" in s["span_attributes"]["name"]]
|
|
210
|
+
chat_spans = [s for s in spans if "chat" in s["span_attributes"]["name"]]
|
|
211
|
+
|
|
212
|
+
assert len(agent_spans) >= 3, f"Expected at least 3 agent spans, got {len(agent_spans)}"
|
|
213
|
+
assert len(chat_spans) >= 3, f"Expected at least 3 chat spans, got {len(chat_spans)}"
|
|
214
|
+
|
|
215
|
+
# Sort by creation time
|
|
216
|
+
agent_spans.sort(key=lambda s: s["created"])
|
|
217
|
+
chat_spans.sort(key=lambda s: s["created"])
|
|
218
|
+
|
|
219
|
+
# Calculate time-to-first-token for each pair
|
|
220
|
+
time_to_first_tokens = []
|
|
221
|
+
for i in range(3):
|
|
222
|
+
agent_start = agent_spans[i]["metrics"]["start"]
|
|
223
|
+
chat_start = chat_spans[i]["metrics"]["start"]
|
|
224
|
+
ttft = chat_start - agent_start
|
|
225
|
+
time_to_first_tokens.append(ttft)
|
|
226
|
+
|
|
227
|
+
print(f"\n=== STREAM {i+1} ===")
|
|
228
|
+
print(f"Agent span start: {agent_start}")
|
|
229
|
+
print(f"Chat span start: {chat_start}")
|
|
230
|
+
print(f"Time to first token: {ttft}s")
|
|
231
|
+
print(f"Agent span ID: {agent_spans[i]['span_id']}")
|
|
232
|
+
print(f"Chat span parents: {chat_spans[i]['span_parents']}")
|
|
233
|
+
|
|
234
|
+
# CRITICAL: All three time-to-first-token values should be similar (within 0.5s of each other)
|
|
235
|
+
# If they're accumulating, the second and third would be much larger
|
|
236
|
+
min_ttft = min(time_to_first_tokens)
|
|
237
|
+
max_ttft = max(time_to_first_tokens)
|
|
238
|
+
ttft_spread = max_ttft - min_ttft
|
|
239
|
+
|
|
240
|
+
print(f"\n=== TIME-TO-FIRST-TOKEN ANALYSIS ===")
|
|
241
|
+
print(f"TTFT 1: {time_to_first_tokens[0]:.4f}s")
|
|
242
|
+
print(f"TTFT 2: {time_to_first_tokens[1]:.4f}s")
|
|
243
|
+
print(f"TTFT 3: {time_to_first_tokens[2]:.4f}s")
|
|
244
|
+
print(f"Min: {min_ttft:.4f}s, Max: {max_ttft:.4f}s, Spread: {ttft_spread:.4f}s")
|
|
245
|
+
|
|
246
|
+
# All should be small (< 3s)
|
|
247
|
+
for i, ttft in enumerate(time_to_first_tokens):
|
|
248
|
+
assert ttft < 3.0, f"Stream {i+1} time to first token too large: {ttft}s"
|
|
249
|
+
|
|
250
|
+
# Spread should be small (< 0.5s) - this catches the accumulation bug
|
|
251
|
+
assert ttft_spread < 0.5, f"Time-to-first-token spread too large: {ttft_spread}s - suggests timing is accumulating from previous calls"
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@pytest.mark.vcr
|
|
255
|
+
@pytest.mark.asyncio
|
|
256
|
+
async def test_multiple_sequential_streams(memory_logger):
|
|
257
|
+
"""Test multiple sequential streaming calls to ensure offsets don't accumulate."""
|
|
258
|
+
assert not memory_logger.pop()
|
|
259
|
+
|
|
260
|
+
@traced
|
|
261
|
+
async def run_multiple_streams():
|
|
262
|
+
agent1 = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
|
|
263
|
+
agent2 = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
|
|
264
|
+
|
|
265
|
+
start = time.time()
|
|
266
|
+
|
|
267
|
+
# First stream
|
|
268
|
+
async with agent1.run_stream("Count from 1 to 3.") as result1:
|
|
269
|
+
full_text1 = ""
|
|
270
|
+
async for text in result1.stream_text(delta=True):
|
|
271
|
+
full_text1 += text
|
|
272
|
+
|
|
273
|
+
# Second stream
|
|
274
|
+
async with agent2.run_stream("Count from 1 to 3.") as result2:
|
|
275
|
+
full_text2 = ""
|
|
276
|
+
async for text in result2.stream_text(delta=True):
|
|
277
|
+
full_text2 += text
|
|
278
|
+
|
|
279
|
+
return start
|
|
280
|
+
|
|
281
|
+
start = await run_multiple_streams()
|
|
282
|
+
end = time.time()
|
|
283
|
+
|
|
284
|
+
# Check spans
|
|
285
|
+
spans = memory_logger.pop()
|
|
286
|
+
|
|
287
|
+
# Should have: 1 parent (run_multiple_streams) + 2 agent_run_stream spans + 2 chat spans = 5 total
|
|
288
|
+
assert len(spans) >= 5, f"Expected at least 5 spans (1 parent + 2 agent_run_stream + 2 chat), got {len(spans)}"
|
|
289
|
+
|
|
290
|
+
# Find agent and chat spans
|
|
291
|
+
agent_spans = [s for s in spans if "agent_run" in s["span_attributes"]["name"]]
|
|
292
|
+
chat_spans = [s for s in spans if "chat" in s["span_attributes"]["name"]]
|
|
293
|
+
|
|
294
|
+
assert len(agent_spans) >= 2, f"Expected at least 2 agent spans, got {len(agent_spans)}"
|
|
295
|
+
assert len(chat_spans) >= 2, f"Expected at least 2 chat spans, got {len(chat_spans)}"
|
|
296
|
+
|
|
297
|
+
# Sort by creation time
|
|
298
|
+
agent_spans.sort(key=lambda s: s["created"])
|
|
299
|
+
chat_spans.sort(key=lambda s: s["created"])
|
|
300
|
+
|
|
301
|
+
agent1_span = agent_spans[0]
|
|
302
|
+
agent2_span = agent_spans[1]
|
|
303
|
+
chat1_span = chat_spans[0]
|
|
304
|
+
chat2_span = chat_spans[1]
|
|
305
|
+
|
|
306
|
+
# Check timing for first pair
|
|
307
|
+
agent1_start = agent1_span["metrics"]["start"]
|
|
308
|
+
chat1_start = chat1_span["metrics"]["start"]
|
|
309
|
+
time_to_first_token_1 = chat1_start - agent1_start
|
|
310
|
+
|
|
311
|
+
# Check timing for second pair
|
|
312
|
+
agent2_start = agent2_span["metrics"]["start"]
|
|
313
|
+
chat2_start = chat2_span["metrics"]["start"]
|
|
314
|
+
time_to_first_token_2 = chat2_start - agent2_start
|
|
315
|
+
|
|
316
|
+
print(f"\n=== FIRST STREAM ===")
|
|
317
|
+
print(f"Agent1 start: {agent1_start}")
|
|
318
|
+
print(f"Chat1 start: {chat1_start}")
|
|
319
|
+
print(f"Time to first token 1: {time_to_first_token_1}s")
|
|
320
|
+
|
|
321
|
+
print(f"\n=== SECOND STREAM ===")
|
|
322
|
+
print(f"Agent2 start: {agent2_start}")
|
|
323
|
+
print(f"Chat2 start: {chat2_start}")
|
|
324
|
+
print(f"Time to first token 2: {time_to_first_token_2}s")
|
|
325
|
+
|
|
326
|
+
print(f"\n=== RELATIVE TIMING ===")
|
|
327
|
+
print(f"Agent2 start - Agent1 start: {agent2_start - agent1_start}s")
|
|
328
|
+
print(f"Chat2 start - Chat1 start: {chat2_start - chat1_start}s")
|
|
329
|
+
|
|
330
|
+
# CRITICAL: Both time-to-first-token values should be small and similar
|
|
331
|
+
assert time_to_first_token_1 < 3.0, f"First time to first token too large: {time_to_first_token_1}s"
|
|
332
|
+
assert time_to_first_token_2 < 3.0, f"Second time to first token too large: {time_to_first_token_2}s - suggests start_time is being reused from first call"
|
|
333
|
+
|
|
334
|
+
# Agent2 should start AFTER agent1 finishes (or near the end)
|
|
335
|
+
agent1_end = agent1_span["metrics"]["end"]
|
|
336
|
+
assert agent2_start >= agent1_end - 0.1, f"Agent2 started too early: {agent2_start} vs Agent1 end: {agent1_end}"
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@pytest.mark.vcr
|
|
340
|
+
@pytest.mark.asyncio
|
|
341
|
+
async def test_agent_run_stream(memory_logger):
|
|
342
|
+
"""Test Agent.run_stream() streaming method."""
|
|
343
|
+
assert not memory_logger.pop()
|
|
344
|
+
|
|
345
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
|
|
346
|
+
|
|
347
|
+
start = time.time()
|
|
348
|
+
full_text = ""
|
|
349
|
+
async with agent.run_stream("Count from 1 to 5") as result:
|
|
350
|
+
async for text in result.stream_text(delta=True):
|
|
351
|
+
full_text += text
|
|
352
|
+
end = time.time()
|
|
353
|
+
|
|
354
|
+
# Verify we got streaming content
|
|
355
|
+
assert full_text
|
|
356
|
+
assert any(str(i) in full_text for i in range(1, 6))
|
|
357
|
+
|
|
358
|
+
# Check spans - should now have parent agent_run_stream + nested chat span
|
|
359
|
+
spans = memory_logger.pop()
|
|
360
|
+
assert len(spans) == 2, f"Expected 2 spans (agent_run_stream + chat), got {len(spans)}"
|
|
361
|
+
|
|
362
|
+
# Find agent_run_stream and chat spans
|
|
363
|
+
agent_span = next((s for s in spans if "agent_run_stream" in s["span_attributes"]["name"]), None)
|
|
364
|
+
chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
|
|
365
|
+
|
|
366
|
+
assert agent_span is not None, "agent_run_stream span not found"
|
|
367
|
+
assert chat_span is not None, "chat span not found"
|
|
368
|
+
|
|
369
|
+
# Check agent span
|
|
370
|
+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
371
|
+
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
|
|
372
|
+
assert "Count from 1 to 5" in str(agent_span["input"])
|
|
373
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
374
|
+
|
|
375
|
+
# Check chat span is nested under agent span
|
|
376
|
+
assert chat_span["span_parents"] == [agent_span["span_id"]], "chat span should be nested under agent_run_stream"
|
|
377
|
+
assert chat_span["metadata"]["model"] == "gpt-4o-mini"
|
|
378
|
+
assert chat_span["metadata"]["provider"] == "openai"
|
|
379
|
+
_assert_metrics_are_valid(chat_span["metrics"], start, end)
|
|
380
|
+
|
|
381
|
+
# CRITICAL: Check that the chat span's start time is close to agent span start
|
|
382
|
+
# The offset/time-to-first-token should be small (typically < 2 seconds)
|
|
383
|
+
agent_start = agent_span["metrics"]["start"]
|
|
384
|
+
chat_start = chat_span["metrics"]["start"]
|
|
385
|
+
time_to_first_token = chat_start - agent_start
|
|
386
|
+
|
|
387
|
+
# Debug: Print full span data
|
|
388
|
+
print(f"\n=== AGENT SPAN ===")
|
|
389
|
+
print(f"ID: {agent_span['id']}")
|
|
390
|
+
print(f"span_id: {agent_span['span_id']}")
|
|
391
|
+
print(f"metrics: {agent_span['metrics']}")
|
|
392
|
+
print(f"\n=== CHAT SPAN ===")
|
|
393
|
+
print(f"ID: {chat_span['id']}")
|
|
394
|
+
print(f"span_id: {chat_span['span_id']}")
|
|
395
|
+
print(f"span_parents: {chat_span['span_parents']}")
|
|
396
|
+
print(f"metrics: {chat_span['metrics']}")
|
|
397
|
+
|
|
398
|
+
# Time to first token should be reasonable (< 3 seconds for API call initiation)
|
|
399
|
+
assert time_to_first_token < 3.0, f"Time to first token too large: {time_to_first_token}s - suggests start_time is being reused incorrectly"
|
|
400
|
+
|
|
401
|
+
# Both spans should have started during our test timeframe
|
|
402
|
+
assert agent_start >= start, "Agent span started before test"
|
|
403
|
+
assert chat_start >= start, "Chat span started before test"
|
|
404
|
+
|
|
405
|
+
# Agent spans should have token metrics
|
|
406
|
+
assert "prompt_tokens" in agent_span["metrics"]
|
|
407
|
+
assert "completion_tokens" in agent_span["metrics"]
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
@pytest.mark.vcr
|
|
411
|
+
@pytest.mark.asyncio
|
|
412
|
+
async def test_agent_with_tools(memory_logger):
|
|
413
|
+
"""Test Agent with tool calls."""
|
|
414
|
+
assert not memory_logger.pop()
|
|
415
|
+
|
|
416
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=200))
|
|
417
|
+
|
|
418
|
+
@agent.tool_plain
|
|
419
|
+
def get_weather(city: str) -> str:
|
|
420
|
+
"""Get weather for a city.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
city: The city name
|
|
424
|
+
"""
|
|
425
|
+
return f"It's sunny in {city}"
|
|
426
|
+
|
|
427
|
+
start = time.time()
|
|
428
|
+
result = await agent.run("What's the weather in Paris?")
|
|
429
|
+
end = time.time()
|
|
430
|
+
|
|
431
|
+
# Verify tool was used
|
|
432
|
+
assert result.output
|
|
433
|
+
assert "Paris" in str(result.output) or "sunny" in str(result.output)
|
|
434
|
+
|
|
435
|
+
# Check spans
|
|
436
|
+
spans = memory_logger.pop()
|
|
437
|
+
assert len(spans) >= 1 # At least the agent span, possibly more
|
|
438
|
+
|
|
439
|
+
# Find the agent span
|
|
440
|
+
agent_span = next(s for s in spans if "agent_run" in s["span_attributes"]["name"])
|
|
441
|
+
assert agent_span
|
|
442
|
+
assert "weather" in str(agent_span["input"]).lower() or "paris" in str(agent_span["input"]).lower()
|
|
443
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@pytest.mark.vcr
|
|
447
|
+
@pytest.mark.asyncio
|
|
448
|
+
async def test_direct_model_request(memory_logger, direct):
|
|
449
|
+
"""Test direct API model_request()."""
|
|
450
|
+
assert not memory_logger.pop()
|
|
451
|
+
|
|
452
|
+
messages = [ModelRequest(parts=[UserPromptPart(content=TEST_PROMPT)])]
|
|
453
|
+
|
|
454
|
+
start = time.time()
|
|
455
|
+
response = await direct.model_request(model=MODEL, messages=messages)
|
|
456
|
+
end = time.time()
|
|
457
|
+
|
|
458
|
+
# Verify response
|
|
459
|
+
assert response.parts
|
|
460
|
+
assert "4" in str(response.parts[0].content)
|
|
461
|
+
|
|
462
|
+
# Check spans
|
|
463
|
+
spans = memory_logger.pop()
|
|
464
|
+
# Direct API calls may create 1 or 2 spans depending on model wrapping
|
|
465
|
+
assert len(spans) >= 1
|
|
466
|
+
|
|
467
|
+
# Find the direct API span
|
|
468
|
+
direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request"), None)
|
|
469
|
+
assert direct_span is not None
|
|
470
|
+
|
|
471
|
+
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
472
|
+
assert direct_span["metadata"]["model"] == "gpt-4o-mini"
|
|
473
|
+
assert direct_span["metadata"]["provider"] == "openai"
|
|
474
|
+
assert TEST_PROMPT in str(direct_span["input"])
|
|
475
|
+
assert "4" in str(direct_span["output"])
|
|
476
|
+
_assert_metrics_are_valid(direct_span["metrics"], start, end)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
@pytest.mark.vcr
|
|
480
|
+
def test_direct_model_request_sync(memory_logger, direct):
|
|
481
|
+
"""Test direct API model_request_sync()."""
|
|
482
|
+
assert not memory_logger.pop()
|
|
483
|
+
|
|
484
|
+
messages = [ModelRequest(parts=[UserPromptPart(content=TEST_PROMPT)])]
|
|
485
|
+
|
|
486
|
+
start = time.time()
|
|
487
|
+
response = direct.model_request_sync(model=MODEL, messages=messages)
|
|
488
|
+
end = time.time()
|
|
489
|
+
|
|
490
|
+
# Verify response
|
|
491
|
+
assert response.parts
|
|
492
|
+
assert "4" in str(response.parts[0].content)
|
|
493
|
+
|
|
494
|
+
# Check spans - direct API may create 2-3 spans depending on wrapping layers
|
|
495
|
+
spans = memory_logger.pop()
|
|
496
|
+
assert len(spans) >= 2
|
|
497
|
+
|
|
498
|
+
# Find the model_request_sync span
|
|
499
|
+
span = next((s for s in spans if s["span_attributes"]["name"] == "model_request_sync"), None)
|
|
500
|
+
assert span is not None, "model_request_sync span not found"
|
|
501
|
+
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
502
|
+
assert span["metadata"]["model"] == "gpt-4o-mini"
|
|
503
|
+
assert TEST_PROMPT in str(span["input"])
|
|
504
|
+
_assert_metrics_are_valid(span["metrics"], start, end)
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
@pytest.mark.vcr
|
|
508
|
+
@pytest.mark.asyncio
|
|
509
|
+
async def test_direct_model_request_with_settings(memory_logger, direct):
|
|
510
|
+
"""Test that model_settings appears in input for direct API calls."""
|
|
511
|
+
assert not memory_logger.pop()
|
|
512
|
+
|
|
513
|
+
messages = [ModelRequest(parts=[UserPromptPart(content="Say hello")])]
|
|
514
|
+
custom_settings = ModelSettings(max_tokens=50, temperature=0.7)
|
|
515
|
+
|
|
516
|
+
start = time.time()
|
|
517
|
+
result = await direct.model_request(model=MODEL, messages=messages, model_settings=custom_settings)
|
|
518
|
+
end = time.time()
|
|
519
|
+
|
|
520
|
+
# Verify result
|
|
521
|
+
assert result.parts
|
|
522
|
+
|
|
523
|
+
# Check spans
|
|
524
|
+
spans = memory_logger.pop()
|
|
525
|
+
# Direct API calls may create 1 or 2 spans depending on model wrapping
|
|
526
|
+
assert len(spans) >= 1
|
|
527
|
+
|
|
528
|
+
# Find the direct API span
|
|
529
|
+
direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request"), None)
|
|
530
|
+
assert direct_span is not None
|
|
531
|
+
|
|
532
|
+
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
533
|
+
|
|
534
|
+
# Verify model_settings is in input (NOT metadata)
|
|
535
|
+
assert "model_settings" in direct_span["input"], "model_settings should be in input"
|
|
536
|
+
settings = direct_span["input"]["model_settings"]
|
|
537
|
+
assert settings["max_tokens"] == 50
|
|
538
|
+
assert settings["temperature"] == 0.7
|
|
539
|
+
|
|
540
|
+
# Verify model_settings is NOT in metadata
|
|
541
|
+
assert "model_settings" not in direct_span["metadata"], "model_settings should NOT be in metadata"
|
|
542
|
+
|
|
543
|
+
# Verify metadata still has model and provider
|
|
544
|
+
assert direct_span["metadata"]["model"] == "gpt-4o-mini"
|
|
545
|
+
assert direct_span["metadata"]["provider"] == "openai"
|
|
546
|
+
|
|
547
|
+
_assert_metrics_are_valid(direct_span["metrics"], start, end)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
@pytest.mark.vcr
|
|
551
|
+
@pytest.mark.asyncio
|
|
552
|
+
async def test_direct_model_request_stream(memory_logger, direct):
|
|
553
|
+
"""Test direct API model_request_stream()."""
|
|
554
|
+
assert not memory_logger.pop()
|
|
555
|
+
|
|
556
|
+
messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 3")])]
|
|
557
|
+
|
|
558
|
+
start = time.time()
|
|
559
|
+
chunk_count = 0
|
|
560
|
+
async with direct.model_request_stream(model=MODEL, messages=messages) as stream:
|
|
561
|
+
async for chunk in stream:
|
|
562
|
+
chunk_count += 1
|
|
563
|
+
end = time.time()
|
|
564
|
+
|
|
565
|
+
# Verify we got chunks
|
|
566
|
+
assert chunk_count > 0
|
|
567
|
+
|
|
568
|
+
# Check spans
|
|
569
|
+
spans = memory_logger.pop()
|
|
570
|
+
# Direct API calls may create 1 or 2 spans depending on model wrapping
|
|
571
|
+
assert len(spans) >= 1
|
|
572
|
+
|
|
573
|
+
# Find the direct API span
|
|
574
|
+
direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request_stream"), None)
|
|
575
|
+
assert direct_span is not None
|
|
576
|
+
|
|
577
|
+
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
578
|
+
assert direct_span["metadata"]["model"] == "gpt-4o-mini"
|
|
579
|
+
_assert_metrics_are_valid(direct_span["metrics"], start, end)
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
@pytest.mark.vcr
|
|
583
|
+
@pytest.mark.asyncio
|
|
584
|
+
async def test_direct_model_request_stream_complete_output(memory_logger, direct):
|
|
585
|
+
"""Test that direct API streaming captures all text including first chunk from PartStartEvent."""
|
|
586
|
+
assert not memory_logger.pop()
|
|
587
|
+
|
|
588
|
+
messages = [ModelRequest(parts=[UserPromptPart(content="Say exactly: 1, 2, 3")])]
|
|
589
|
+
|
|
590
|
+
collected_text = ""
|
|
591
|
+
seen_delta = False
|
|
592
|
+
async with direct.model_request_stream(model=MODEL, messages=messages) as stream:
|
|
593
|
+
async for chunk in stream:
|
|
594
|
+
# Extract text, skipping final PartStartEvent after deltas
|
|
595
|
+
if hasattr(chunk, 'part') and hasattr(chunk.part, 'content') and not seen_delta:
|
|
596
|
+
# PartStartEvent has part.content with initial text
|
|
597
|
+
collected_text += str(chunk.part.content)
|
|
598
|
+
elif hasattr(chunk, 'delta') and chunk.delta:
|
|
599
|
+
seen_delta = True
|
|
600
|
+
# PartDeltaEvent has delta.content_delta
|
|
601
|
+
if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
|
|
602
|
+
collected_text += chunk.delta.content_delta
|
|
603
|
+
|
|
604
|
+
# Verify we got complete output including "1"
|
|
605
|
+
assert "1" in collected_text
|
|
606
|
+
assert "2" in collected_text
|
|
607
|
+
assert "3" in collected_text
|
|
608
|
+
|
|
609
|
+
# Check spans were created
|
|
610
|
+
spans = memory_logger.pop()
|
|
611
|
+
assert len(spans) >= 1
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
@pytest.mark.vcr
|
|
615
|
+
@pytest.mark.asyncio
|
|
616
|
+
async def test_direct_api_streaming_call_3(memory_logger, direct):
|
|
617
|
+
"""Test direct API streaming (call 3) - should output complete '1, 2, 3, 4, 5'."""
|
|
618
|
+
assert not memory_logger.pop()
|
|
619
|
+
|
|
620
|
+
IDENTICAL_PROMPT = "Count from 1 to 5."
|
|
621
|
+
messages = [ModelRequest(parts=[UserPromptPart(content=IDENTICAL_PROMPT)])]
|
|
622
|
+
|
|
623
|
+
collected_text = ""
|
|
624
|
+
async with direct.model_request_stream(model="openai:gpt-4o", messages=messages, model_settings=ModelSettings(max_tokens=100)) as stream:
|
|
625
|
+
async for chunk in stream:
|
|
626
|
+
# FIX: Handle PartStartEvent which contains initial text
|
|
627
|
+
if hasattr(chunk, 'part') and hasattr(chunk.part, 'content'):
|
|
628
|
+
collected_text += str(chunk.part.content)
|
|
629
|
+
# Handle PartDeltaEvent with delta content
|
|
630
|
+
elif hasattr(chunk, 'delta') and chunk.delta:
|
|
631
|
+
if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
|
|
632
|
+
collected_text += chunk.delta.content_delta
|
|
633
|
+
|
|
634
|
+
# Now this should pass!
|
|
635
|
+
assert "1" in collected_text, f"Expected '1' in output but got: {collected_text}"
|
|
636
|
+
assert "2" in collected_text
|
|
637
|
+
assert "3" in collected_text
|
|
638
|
+
assert "4" in collected_text
|
|
639
|
+
assert "5" in collected_text
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
@pytest.mark.vcr
|
|
643
|
+
@pytest.mark.asyncio
|
|
644
|
+
async def test_direct_api_streaming_call_4(memory_logger, direct):
|
|
645
|
+
"""Test direct API streaming (call 4) - identical to call 3."""
|
|
646
|
+
assert not memory_logger.pop()
|
|
647
|
+
|
|
648
|
+
IDENTICAL_PROMPT = "Count from 1 to 5."
|
|
649
|
+
messages = [ModelRequest(parts=[UserPromptPart(content=IDENTICAL_PROMPT)])]
|
|
650
|
+
|
|
651
|
+
collected_text = ""
|
|
652
|
+
async with direct.model_request_stream(model="openai:gpt-4o", messages=messages, model_settings=ModelSettings(max_tokens=100)) as stream:
|
|
653
|
+
async for chunk in stream:
|
|
654
|
+
# FIX: Handle PartStartEvent which contains initial text
|
|
655
|
+
if hasattr(chunk, 'part') and hasattr(chunk.part, 'content'):
|
|
656
|
+
collected_text += str(chunk.part.content)
|
|
657
|
+
# Handle PartDeltaEvent with delta content
|
|
658
|
+
elif hasattr(chunk, 'delta') and chunk.delta:
|
|
659
|
+
if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
|
|
660
|
+
collected_text += chunk.delta.content_delta
|
|
661
|
+
|
|
662
|
+
# Now this should pass!
|
|
663
|
+
assert "1" in collected_text, f"Expected '1' in output but got: {collected_text}"
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
@pytest.mark.vcr
|
|
667
|
+
@pytest.mark.asyncio
|
|
668
|
+
async def test_direct_api_streaming_early_break_call_5(memory_logger, direct):
|
|
669
|
+
"""Test direct API streaming with early break (call 5) - should still get first few chars including '1'."""
|
|
670
|
+
assert not memory_logger.pop()
|
|
671
|
+
|
|
672
|
+
IDENTICAL_PROMPT = "Count from 1 to 5."
|
|
673
|
+
messages = [ModelRequest(parts=[UserPromptPart(content=IDENTICAL_PROMPT)])]
|
|
674
|
+
|
|
675
|
+
collected_text = ""
|
|
676
|
+
i = 0
|
|
677
|
+
async with direct.model_request_stream(model="openai:gpt-4o", messages=messages, model_settings=ModelSettings(max_tokens=100)) as stream:
|
|
678
|
+
async for chunk in stream:
|
|
679
|
+
# FIX: Handle PartStartEvent which contains initial text
|
|
680
|
+
if hasattr(chunk, 'part') and hasattr(chunk.part, 'content'):
|
|
681
|
+
collected_text += str(chunk.part.content)
|
|
682
|
+
# Handle PartDeltaEvent with delta content
|
|
683
|
+
elif hasattr(chunk, 'delta') and chunk.delta:
|
|
684
|
+
if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
|
|
685
|
+
collected_text += chunk.delta.content_delta
|
|
686
|
+
|
|
687
|
+
i += 1
|
|
688
|
+
if i >= 3:
|
|
689
|
+
break
|
|
690
|
+
|
|
691
|
+
# Even with early break after 3 chunks, we should capture text from PartStartEvent (chunk 1)
|
|
692
|
+
print(f"Collected text: '{collected_text}'")
|
|
693
|
+
assert len(collected_text) > 0, f"Expected some text even with early break but got empty string"
|
|
694
|
+
# Verify we're capturing PartStartEvent by checking we got text before breaking at chunk 3
|
|
695
|
+
assert collected_text, f"Should have captured text from PartStartEvent or first delta"
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
@pytest.mark.vcr
|
|
699
|
+
@pytest.mark.asyncio
|
|
700
|
+
async def test_direct_api_streaming_no_duplication(memory_logger, direct):
|
|
701
|
+
"""Test that direct API streaming doesn't duplicate output and captures all text in span."""
|
|
702
|
+
assert not memory_logger.pop()
|
|
703
|
+
|
|
704
|
+
collected_text = ""
|
|
705
|
+
chunk_count = 0
|
|
706
|
+
|
|
707
|
+
# Use direct API streaming
|
|
708
|
+
messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 5, separated by commas.")])]
|
|
709
|
+
async with direct.model_request_stream(
|
|
710
|
+
messages=messages,
|
|
711
|
+
model_settings=ModelSettings(max_tokens=100),
|
|
712
|
+
model="openai:gpt-4o",
|
|
713
|
+
) as response:
|
|
714
|
+
async for chunk in response:
|
|
715
|
+
chunk_count += 1
|
|
716
|
+
# Extract text from chunk
|
|
717
|
+
text = None
|
|
718
|
+
if hasattr(chunk, 'part') and hasattr(chunk.part, 'content'):
|
|
719
|
+
text = str(chunk.part.content)
|
|
720
|
+
elif hasattr(chunk, 'delta') and chunk.delta:
|
|
721
|
+
if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
|
|
722
|
+
text = chunk.delta.content_delta
|
|
723
|
+
|
|
724
|
+
if text:
|
|
725
|
+
collected_text += text
|
|
726
|
+
|
|
727
|
+
print(f"Collected text from stream: '{collected_text}'")
|
|
728
|
+
print(f"Total chunks: {chunk_count}")
|
|
729
|
+
|
|
730
|
+
# Verify we collected complete text
|
|
731
|
+
assert len(collected_text) > 0, "Should have collected text from stream"
|
|
732
|
+
assert "1" in collected_text, "Should have '1' in output"
|
|
733
|
+
|
|
734
|
+
# Check span captured the full output
|
|
735
|
+
spans = memory_logger.pop()
|
|
736
|
+
assert len(spans) >= 1, f"Expected at least 1 span, got {len(spans)}"
|
|
737
|
+
|
|
738
|
+
# Find the model_request_stream span
|
|
739
|
+
stream_span = next((s for s in spans if "model_request_stream" in s["span_attributes"]["name"]), None)
|
|
740
|
+
assert stream_span is not None, "model_request_stream span not found"
|
|
741
|
+
|
|
742
|
+
# Check that span output contains the full text, not just "1,"
|
|
743
|
+
span_output = stream_span.get("output", {})
|
|
744
|
+
print(f"Span output: {span_output}")
|
|
745
|
+
|
|
746
|
+
# The span should capture the full response
|
|
747
|
+
if "response" in span_output and "parts" in span_output["response"]:
|
|
748
|
+
parts = span_output["response"]["parts"]
|
|
749
|
+
span_text = "".join(str(p.get("content", "")) for p in parts if isinstance(p, dict))
|
|
750
|
+
print(f"Span captured text: '{span_text}'")
|
|
751
|
+
# Should have more than just "1,"
|
|
752
|
+
assert len(span_text) > 2, f"Span should capture more than just '1,', got: '{span_text}'"
|
|
753
|
+
assert "1" in span_text, "Span should contain '1'"
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
@pytest.mark.vcr
|
|
757
|
+
@pytest.mark.asyncio
|
|
758
|
+
async def test_direct_api_streaming_no_duplication_comprehensive(memory_logger, direct):
|
|
759
|
+
"""Comprehensive test matching golden test setup to verify no duplication and full output capture."""
|
|
760
|
+
assert not memory_logger.pop()
|
|
761
|
+
|
|
762
|
+
# Match golden test exactly
|
|
763
|
+
IDENTICAL_PROMPT = "Count from 1 to 5."
|
|
764
|
+
IDENTICAL_SETTINGS = ModelSettings(max_tokens=100)
|
|
765
|
+
|
|
766
|
+
messages = [ModelRequest(parts=[UserPromptPart(content=IDENTICAL_PROMPT)])]
|
|
767
|
+
|
|
768
|
+
collected_text = ""
|
|
769
|
+
chunk_types = []
|
|
770
|
+
seen_delta = False
|
|
771
|
+
|
|
772
|
+
async with direct.model_request_stream(messages=messages, model_settings=IDENTICAL_SETTINGS, model="openai:gpt-4o") as stream:
|
|
773
|
+
async for chunk in stream:
|
|
774
|
+
# Track chunk types
|
|
775
|
+
if hasattr(chunk, 'part') and hasattr(chunk.part, 'content') and not seen_delta:
|
|
776
|
+
chunk_types.append(('PartStartEvent', str(chunk.part.content)))
|
|
777
|
+
text = str(chunk.part.content)
|
|
778
|
+
collected_text += text
|
|
779
|
+
elif hasattr(chunk, 'delta') and chunk.delta:
|
|
780
|
+
seen_delta = True
|
|
781
|
+
if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
|
|
782
|
+
chunk_types.append(('PartDeltaEvent', chunk.delta.content_delta))
|
|
783
|
+
text = chunk.delta.content_delta
|
|
784
|
+
collected_text += text
|
|
785
|
+
|
|
786
|
+
print(f"\nCollected text: '{collected_text}'")
|
|
787
|
+
print(f"Total chunks received: {len(chunk_types)}")
|
|
788
|
+
print(f"All chunk types:")
|
|
789
|
+
for i, (chunk_type, content) in enumerate(chunk_types):
|
|
790
|
+
print(f" {i}: {chunk_type} = {content!r}")
|
|
791
|
+
|
|
792
|
+
# Verify no duplication in collected text
|
|
793
|
+
# Expected: "Sure! Here you go:\n\n1, 2, 3, 4, 5." or similar (length ~30)
|
|
794
|
+
# Should NOT be duplicated
|
|
795
|
+
assert len(collected_text) < 60, f"Text seems duplicated (too long): '{collected_text}' (len={len(collected_text)})"
|
|
796
|
+
assert collected_text.count("1, 2, 3") == 1, f"Text should appear once, not duplicated: '{collected_text}'"
|
|
797
|
+
|
|
798
|
+
# Check span
|
|
799
|
+
spans = memory_logger.pop()
|
|
800
|
+
print(f"Number of spans: {len(spans)}")
|
|
801
|
+
for i, s in enumerate(spans):
|
|
802
|
+
print(f"Span {i}: {s['span_attributes']['name']} (type: {s['span_attributes'].get('type', 'N/A')})")
|
|
803
|
+
if 'span_parents' in s and s['span_parents']:
|
|
804
|
+
print(f" Parents: {s['span_parents']}")
|
|
805
|
+
|
|
806
|
+
# Should have 1 or 2 spans (direct API wrapper + potentially model wrapper)
|
|
807
|
+
assert len(spans) >= 1, f"Expected at least 1 span, got {len(spans)}"
|
|
808
|
+
|
|
809
|
+
# Find the model_request_stream span
|
|
810
|
+
stream_span = next((s for s in spans if "model_request_stream" in s["span_attributes"]["name"]), None)
|
|
811
|
+
assert stream_span is not None, "model_request_stream span not found"
|
|
812
|
+
|
|
813
|
+
# Check that span output is not empty and captures reasonable amount of text
|
|
814
|
+
span_output = stream_span.get("output", {})
|
|
815
|
+
print(f"Span output keys: {span_output.keys() if span_output else 'None'}")
|
|
816
|
+
|
|
817
|
+
if "parts" in span_output:
|
|
818
|
+
parts = span_output.get("parts", [])
|
|
819
|
+
print(f"Span parts: {parts}")
|
|
820
|
+
if parts and len(parts) > 0:
|
|
821
|
+
first_part = parts[0]
|
|
822
|
+
print(f"First part type: {type(first_part)}")
|
|
823
|
+
print(f"First part: {first_part}")
|
|
824
|
+
if isinstance(first_part, dict):
|
|
825
|
+
part_content = first_part.get("content", "")
|
|
826
|
+
print(f"Part content: '{part_content}'")
|
|
827
|
+
print(f"Part content length: {len(part_content)}")
|
|
828
|
+
# The span should capture the FULL text, not just "1,"
|
|
829
|
+
assert len(part_content) > 5, f"Span should capture full text, got: '{part_content}'"
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
@pytest.mark.vcr
|
|
833
|
+
@pytest.mark.asyncio
|
|
834
|
+
async def test_async_generator_pattern_call_6(memory_logger):
|
|
835
|
+
"""Test async generator pattern (call 6) - wrapping stream in async generator."""
|
|
836
|
+
assert not memory_logger.pop()
|
|
837
|
+
|
|
838
|
+
IDENTICAL_PROMPT = "Count from 1 to 5."
|
|
839
|
+
|
|
840
|
+
async def stream_with_async_generator(prompt: str):
|
|
841
|
+
"""Wrap the stream in an async generator (customer pattern)."""
|
|
842
|
+
agent = Agent("openai:gpt-4o", model_settings=ModelSettings(max_tokens=100))
|
|
843
|
+
async for event in agent.run_stream_events(prompt):
|
|
844
|
+
yield event
|
|
845
|
+
|
|
846
|
+
collected_text = ""
|
|
847
|
+
i = 0
|
|
848
|
+
async for event in stream_with_async_generator(IDENTICAL_PROMPT):
|
|
849
|
+
# run_stream_events returns ResultEvent objects with different structure
|
|
850
|
+
# Try to extract text from whatever event type we get
|
|
851
|
+
if hasattr(event, 'content') and event.content:
|
|
852
|
+
collected_text += str(event.content)
|
|
853
|
+
elif hasattr(event, 'part') and hasattr(event.part, 'content'):
|
|
854
|
+
collected_text += str(event.part.content)
|
|
855
|
+
elif hasattr(event, 'delta') and event.delta:
|
|
856
|
+
if hasattr(event.delta, 'content_delta') and event.delta.content_delta:
|
|
857
|
+
collected_text += event.delta.content_delta
|
|
858
|
+
|
|
859
|
+
i += 1
|
|
860
|
+
if i >= 3:
|
|
861
|
+
break
|
|
862
|
+
|
|
863
|
+
# This should capture something
|
|
864
|
+
print(f"Collected text from generator: '{collected_text}'")
|
|
865
|
+
assert len(collected_text) > 0, f"Expected some text from async generator but got empty string"
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
@pytest.mark.vcr
|
|
869
|
+
@pytest.mark.asyncio
|
|
870
|
+
async def test_agent_structured_output(memory_logger):
|
|
871
|
+
"""Test Agent with structured output (Pydantic model)."""
|
|
872
|
+
assert not memory_logger.pop()
|
|
873
|
+
|
|
874
|
+
class MathAnswer(BaseModel):
|
|
875
|
+
answer: int
|
|
876
|
+
explanation: str
|
|
877
|
+
|
|
878
|
+
agent = Agent(
|
|
879
|
+
MODEL,
|
|
880
|
+
output_type=MathAnswer,
|
|
881
|
+
model_settings=ModelSettings(max_tokens=200)
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
start = time.time()
|
|
885
|
+
result = await agent.run("What is 10 + 15?")
|
|
886
|
+
end = time.time()
|
|
887
|
+
|
|
888
|
+
# Verify structured output
|
|
889
|
+
assert isinstance(result.output, MathAnswer)
|
|
890
|
+
assert result.output.answer == 25
|
|
891
|
+
assert result.output.explanation
|
|
892
|
+
|
|
893
|
+
# Check spans - should have parent agent_run + nested spans
|
|
894
|
+
spans = memory_logger.pop()
|
|
895
|
+
assert len(spans) >= 2, f"Expected at least 2 spans (agent_run + chat), got {len(spans)}"
|
|
896
|
+
|
|
897
|
+
# Find agent_run and chat spans
|
|
898
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
|
|
899
|
+
chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
|
|
900
|
+
|
|
901
|
+
assert agent_span is not None, "agent_run span not found"
|
|
902
|
+
assert chat_span is not None, "chat span not found"
|
|
903
|
+
|
|
904
|
+
# Check agent span
|
|
905
|
+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
906
|
+
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
|
|
907
|
+
assert agent_span["metadata"]["provider"] == "openai"
|
|
908
|
+
assert "10 + 15" in str(agent_span["input"])
|
|
909
|
+
assert "25" in str(agent_span["output"])
|
|
910
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
911
|
+
|
|
912
|
+
# Check chat span is a descendant of agent_run
|
|
913
|
+
span_by_id = {s["span_id"]: s for s in spans}
|
|
914
|
+
|
|
915
|
+
def is_descendant(child_span, ancestor_id):
|
|
916
|
+
"""Check if child_span is a descendant of ancestor_id."""
|
|
917
|
+
if not child_span.get("span_parents"):
|
|
918
|
+
return False
|
|
919
|
+
if ancestor_id in child_span["span_parents"]:
|
|
920
|
+
return True
|
|
921
|
+
for parent_id in child_span["span_parents"]:
|
|
922
|
+
if parent_id in span_by_id and is_descendant(span_by_id[parent_id], ancestor_id):
|
|
923
|
+
return True
|
|
924
|
+
return False
|
|
925
|
+
|
|
926
|
+
assert is_descendant(chat_span, agent_span["span_id"]), "chat span should be nested under agent_run"
|
|
927
|
+
assert chat_span["metadata"]["model"] == "gpt-4o-mini"
|
|
928
|
+
assert chat_span["metadata"]["provider"] == "openai"
|
|
929
|
+
_assert_metrics_are_valid(chat_span["metrics"], start, end)
|
|
930
|
+
|
|
931
|
+
# Agent spans should have token metrics
|
|
932
|
+
assert "prompt_tokens" in agent_span["metrics"]
|
|
933
|
+
assert "completion_tokens" in agent_span["metrics"]
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
@pytest.mark.vcr
|
|
937
|
+
@pytest.mark.asyncio
|
|
938
|
+
async def test_agent_with_model_settings_in_metadata(memory_logger):
|
|
939
|
+
"""Test that model_settings from agent config appears in metadata, not input."""
|
|
940
|
+
assert not memory_logger.pop()
|
|
941
|
+
|
|
942
|
+
custom_settings = ModelSettings(max_tokens=100, temperature=0.5)
|
|
943
|
+
agent = Agent(MODEL, model_settings=custom_settings)
|
|
944
|
+
|
|
945
|
+
start = time.time()
|
|
946
|
+
result = await agent.run("Say hello")
|
|
947
|
+
end = time.time()
|
|
948
|
+
|
|
949
|
+
assert result.output
|
|
950
|
+
|
|
951
|
+
# Check spans
|
|
952
|
+
spans = memory_logger.pop()
|
|
953
|
+
assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
|
|
954
|
+
|
|
955
|
+
# Find agent_run and chat spans
|
|
956
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
|
|
957
|
+
chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
|
|
958
|
+
|
|
959
|
+
assert agent_span is not None, "agent_run span not found"
|
|
960
|
+
assert chat_span is not None, "chat span not found"
|
|
961
|
+
|
|
962
|
+
# Verify model_settings is in agent METADATA (not input, since it's agent config)
|
|
963
|
+
assert "model_settings" in agent_span["metadata"], "model_settings should be in agent_run metadata"
|
|
964
|
+
agent_settings = agent_span["metadata"]["model_settings"]
|
|
965
|
+
assert agent_settings["max_tokens"] == 100
|
|
966
|
+
assert agent_settings["temperature"] == 0.5
|
|
967
|
+
|
|
968
|
+
# Verify model_settings is NOT in agent input (it wasn't passed to run())
|
|
969
|
+
assert "model_settings" not in agent_span["input"], "model_settings should NOT be in agent_run input when not passed to run()"
|
|
970
|
+
|
|
971
|
+
# Verify model_settings is in chat input (passed to the model)
|
|
972
|
+
assert "model_settings" in chat_span["input"], "model_settings should be in chat span input"
|
|
973
|
+
chat_settings = chat_span["input"]["model_settings"]
|
|
974
|
+
assert chat_settings["max_tokens"] == 100
|
|
975
|
+
assert chat_settings["temperature"] == 0.5
|
|
976
|
+
|
|
977
|
+
# Verify model_settings is NOT in chat metadata (it's in input)
|
|
978
|
+
assert "model_settings" not in chat_span["metadata"], "model_settings should NOT be in chat span metadata"
|
|
979
|
+
|
|
980
|
+
# Verify other metadata is present
|
|
981
|
+
assert chat_span["metadata"]["model"] == "gpt-4o-mini"
|
|
982
|
+
assert chat_span["metadata"]["provider"] == "openai"
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
@pytest.mark.vcr
|
|
986
|
+
@pytest.mark.asyncio
|
|
987
|
+
async def test_agent_with_model_settings_override_in_input(memory_logger):
|
|
988
|
+
"""Test that model_settings passed to run() appears in input, not metadata."""
|
|
989
|
+
assert not memory_logger.pop()
|
|
990
|
+
|
|
991
|
+
# Agent has default settings
|
|
992
|
+
default_settings = ModelSettings(max_tokens=50)
|
|
993
|
+
agent = Agent(MODEL, model_settings=default_settings)
|
|
994
|
+
|
|
995
|
+
# Override with different settings in run() call
|
|
996
|
+
override_settings = ModelSettings(max_tokens=200, temperature=0.9)
|
|
997
|
+
|
|
998
|
+
start = time.time()
|
|
999
|
+
result = await agent.run("Tell me a story", model_settings=override_settings)
|
|
1000
|
+
end = time.time()
|
|
1001
|
+
|
|
1002
|
+
assert result.output
|
|
1003
|
+
|
|
1004
|
+
# Check spans
|
|
1005
|
+
spans = memory_logger.pop()
|
|
1006
|
+
assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
|
|
1007
|
+
|
|
1008
|
+
# Find agent_run span
|
|
1009
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
|
|
1010
|
+
assert agent_span is not None, "agent_run span not found"
|
|
1011
|
+
|
|
1012
|
+
# Verify override settings are in agent INPUT (because they were passed to run())
|
|
1013
|
+
assert "model_settings" in agent_span["input"], "model_settings should be in agent_run input when passed to run()"
|
|
1014
|
+
input_settings = agent_span["input"]["model_settings"]
|
|
1015
|
+
assert input_settings["max_tokens"] == 200, "Should use override settings from run() call"
|
|
1016
|
+
assert input_settings["temperature"] == 0.9
|
|
1017
|
+
|
|
1018
|
+
# Verify agent default settings are NOT in metadata (when overridden in input, we don't duplicate in metadata)
|
|
1019
|
+
assert "model_settings" not in agent_span["metadata"], "model_settings should NOT be in metadata when explicitly passed to run()"
|
|
1020
|
+
|
|
1021
|
+
|
|
1022
|
+
@pytest.mark.vcr
|
|
1023
|
+
@pytest.mark.asyncio
|
|
1024
|
+
async def test_agent_with_system_prompt_in_metadata(memory_logger):
|
|
1025
|
+
"""Test that system_prompt from agent config appears in input (it's semantically part of LLM input)."""
|
|
1026
|
+
assert not memory_logger.pop()
|
|
1027
|
+
|
|
1028
|
+
system_prompt = "You are a helpful AI assistant who speaks like a pirate."
|
|
1029
|
+
agent = Agent(MODEL, system_prompt=system_prompt, model_settings=ModelSettings(max_tokens=100))
|
|
1030
|
+
|
|
1031
|
+
start = time.time()
|
|
1032
|
+
result = await agent.run("What is the weather?")
|
|
1033
|
+
end = time.time()
|
|
1034
|
+
|
|
1035
|
+
assert result.output
|
|
1036
|
+
|
|
1037
|
+
# Check spans
|
|
1038
|
+
spans = memory_logger.pop()
|
|
1039
|
+
assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
|
|
1040
|
+
|
|
1041
|
+
# Find agent_run span
|
|
1042
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
|
|
1043
|
+
assert agent_span is not None, "agent_run span not found"
|
|
1044
|
+
|
|
1045
|
+
# Verify system_prompt is in input (because it's semantically part of the LLM input)
|
|
1046
|
+
assert "system_prompt" in agent_span["input"], "system_prompt should be in agent_run input"
|
|
1047
|
+
assert agent_span["input"]["system_prompt"] == system_prompt, "system_prompt should be the actual string, not a method reference"
|
|
1048
|
+
|
|
1049
|
+
# Verify system_prompt is NOT in metadata
|
|
1050
|
+
assert "system_prompt" not in agent_span["metadata"], "system_prompt should NOT be in agent_run metadata"
|
|
1051
|
+
|
|
1052
|
+
# Verify other metadata is present
|
|
1053
|
+
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
|
|
1054
|
+
assert agent_span["metadata"]["provider"] == "openai"
|
|
1055
|
+
|
|
1056
|
+
|
|
1057
|
+
@pytest.mark.vcr
|
|
1058
|
+
@pytest.mark.asyncio
|
|
1059
|
+
async def test_agent_with_message_history(memory_logger):
|
|
1060
|
+
"""Test Agent with conversation history."""
|
|
1061
|
+
assert not memory_logger.pop()
|
|
1062
|
+
|
|
1063
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
|
|
1064
|
+
|
|
1065
|
+
# First message
|
|
1066
|
+
result1 = await agent.run("My name is Alice")
|
|
1067
|
+
assert result1.output
|
|
1068
|
+
memory_logger.pop() # Clear first span
|
|
1069
|
+
|
|
1070
|
+
# Second message with history
|
|
1071
|
+
start = time.time()
|
|
1072
|
+
result2 = await agent.run(
|
|
1073
|
+
"What is my name?",
|
|
1074
|
+
message_history=result1.all_messages()
|
|
1075
|
+
)
|
|
1076
|
+
end = time.time()
|
|
1077
|
+
|
|
1078
|
+
# Verify it remembers
|
|
1079
|
+
assert "Alice" in str(result2.output)
|
|
1080
|
+
|
|
1081
|
+
# Check spans - should now have parent agent_run + nested chat span
|
|
1082
|
+
spans = memory_logger.pop()
|
|
1083
|
+
assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
|
|
1084
|
+
|
|
1085
|
+
# Find agent_run and chat spans
|
|
1086
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
|
|
1087
|
+
|
|
1088
|
+
assert agent_span is not None, "agent_run span not found"
|
|
1089
|
+
assert "message_history" in str(agent_span["input"])
|
|
1090
|
+
assert "Alice" in str(agent_span["output"])
|
|
1091
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
1092
|
+
|
|
1093
|
+
|
|
1094
|
+
@pytest.mark.vcr
|
|
1095
|
+
@pytest.mark.asyncio
|
|
1096
|
+
async def test_agent_with_custom_settings(memory_logger):
|
|
1097
|
+
"""Test Agent with custom model settings."""
|
|
1098
|
+
assert not memory_logger.pop()
|
|
1099
|
+
|
|
1100
|
+
agent = Agent(MODEL)
|
|
1101
|
+
|
|
1102
|
+
start = time.time()
|
|
1103
|
+
result = await agent.run(
|
|
1104
|
+
"Say hello",
|
|
1105
|
+
model_settings=ModelSettings(
|
|
1106
|
+
max_tokens=20,
|
|
1107
|
+
temperature=0.5,
|
|
1108
|
+
top_p=0.9
|
|
1109
|
+
)
|
|
1110
|
+
)
|
|
1111
|
+
end = time.time()
|
|
1112
|
+
|
|
1113
|
+
assert result.output
|
|
1114
|
+
|
|
1115
|
+
# Check spans - should now have parent agent_run + nested chat span
|
|
1116
|
+
spans = memory_logger.pop()
|
|
1117
|
+
assert len(spans) >= 2, f"Expected at least 2 spans (agent_run + chat), got {len(spans)}"
|
|
1118
|
+
|
|
1119
|
+
# Find agent_run span
|
|
1120
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
|
|
1121
|
+
assert agent_span is not None, "agent_run span not found"
|
|
1122
|
+
|
|
1123
|
+
# Model settings passed to run() should be in input (not metadata)
|
|
1124
|
+
assert "model_settings" in agent_span["input"]
|
|
1125
|
+
settings = agent_span["input"]["model_settings"]
|
|
1126
|
+
assert settings["max_tokens"] == 20
|
|
1127
|
+
assert settings["temperature"] == 0.5
|
|
1128
|
+
assert settings["top_p"] == 0.9
|
|
1129
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
1130
|
+
|
|
1131
|
+
|
|
1132
|
+
@pytest.mark.vcr
|
|
1133
|
+
def test_agent_run_stream_sync(memory_logger):
|
|
1134
|
+
"""Test Agent.run_stream_sync() synchronous streaming method."""
|
|
1135
|
+
assert not memory_logger.pop()
|
|
1136
|
+
|
|
1137
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
|
|
1138
|
+
|
|
1139
|
+
start = time.time()
|
|
1140
|
+
full_text = ""
|
|
1141
|
+
result = agent.run_stream_sync("Count from 1 to 3")
|
|
1142
|
+
for text in result.stream_text(delta=True):
|
|
1143
|
+
full_text += text
|
|
1144
|
+
end = time.time()
|
|
1145
|
+
|
|
1146
|
+
# Verify we got streaming content
|
|
1147
|
+
assert full_text
|
|
1148
|
+
assert any(str(i) in full_text for i in range(1, 4))
|
|
1149
|
+
|
|
1150
|
+
# Check spans - should have parent agent_run_stream_sync + nested spans
|
|
1151
|
+
spans = memory_logger.pop()
|
|
1152
|
+
assert len(spans) >= 2, f"Expected at least 2 spans (agent_run_stream_sync + chat), got {len(spans)}"
|
|
1153
|
+
|
|
1154
|
+
# Find agent_run_stream_sync and chat spans
|
|
1155
|
+
agent_span = next((s for s in spans if "agent_run_stream_sync" in s["span_attributes"]["name"]), None)
|
|
1156
|
+
chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
|
|
1157
|
+
|
|
1158
|
+
assert agent_span is not None, "agent_run_stream_sync span not found"
|
|
1159
|
+
assert chat_span is not None, "chat span not found"
|
|
1160
|
+
|
|
1161
|
+
# Check agent span
|
|
1162
|
+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
1163
|
+
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
|
|
1164
|
+
assert "Count from 1 to 3" in str(agent_span["input"])
|
|
1165
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
1166
|
+
|
|
1167
|
+
# Check chat span is a descendant of agent_run_stream_sync
|
|
1168
|
+
span_by_id = {s["span_id"]: s for s in spans}
|
|
1169
|
+
|
|
1170
|
+
def is_descendant(child_span, ancestor_id):
|
|
1171
|
+
"""Check if child_span is a descendant of ancestor_id."""
|
|
1172
|
+
if not child_span.get("span_parents"):
|
|
1173
|
+
return False
|
|
1174
|
+
if ancestor_id in child_span["span_parents"]:
|
|
1175
|
+
return True
|
|
1176
|
+
for parent_id in child_span["span_parents"]:
|
|
1177
|
+
if parent_id in span_by_id and is_descendant(span_by_id[parent_id], ancestor_id):
|
|
1178
|
+
return True
|
|
1179
|
+
return False
|
|
1180
|
+
|
|
1181
|
+
assert is_descendant(chat_span, agent_span["span_id"]), "chat span should be nested under agent_run_stream_sync"
|
|
1182
|
+
assert chat_span["metadata"]["model"] == "gpt-4o-mini"
|
|
1183
|
+
assert chat_span["metadata"]["provider"] == "openai"
|
|
1184
|
+
# Chat span may not have complete metrics since it's an intermediate span
|
|
1185
|
+
assert "start" in chat_span["metrics"]
|
|
1186
|
+
|
|
1187
|
+
# Agent spans should have token metrics
|
|
1188
|
+
assert "prompt_tokens" in agent_span["metrics"]
|
|
1189
|
+
assert "completion_tokens" in agent_span["metrics"]
|
|
1190
|
+
|
|
1191
|
+
|
|
1192
|
+
@pytest.mark.vcr
|
|
1193
|
+
@pytest.mark.asyncio
|
|
1194
|
+
async def test_agent_run_stream_events(memory_logger):
|
|
1195
|
+
"""Test Agent.run_stream_events() event streaming method."""
|
|
1196
|
+
assert not memory_logger.pop()
|
|
1197
|
+
|
|
1198
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
|
|
1199
|
+
|
|
1200
|
+
start = time.time()
|
|
1201
|
+
event_count = 0
|
|
1202
|
+
events = []
|
|
1203
|
+
# Consume all events
|
|
1204
|
+
async for event in agent.run_stream_events("What is 5+5?"):
|
|
1205
|
+
event_count += 1
|
|
1206
|
+
events.append(event)
|
|
1207
|
+
end = time.time()
|
|
1208
|
+
|
|
1209
|
+
# Verify we got events
|
|
1210
|
+
assert event_count > 0, "Should receive at least one event"
|
|
1211
|
+
|
|
1212
|
+
# Check spans - should have agent_run_stream_events span
|
|
1213
|
+
spans = memory_logger.pop()
|
|
1214
|
+
assert len(spans) >= 1, f"Expected at least 1 span, got {len(spans)}"
|
|
1215
|
+
|
|
1216
|
+
# Find agent_run_stream_events span
|
|
1217
|
+
agent_span = next((s for s in spans if "agent_run_stream_events" in s["span_attributes"]["name"]), None)
|
|
1218
|
+
assert agent_span is not None, "agent_run_stream_events span not found"
|
|
1219
|
+
|
|
1220
|
+
# Check agent span has basic structure
|
|
1221
|
+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
1222
|
+
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
|
|
1223
|
+
assert "5+5" in str(agent_span["input"]) or "What" in str(agent_span["input"])
|
|
1224
|
+
assert agent_span["metrics"]["event_count"] == event_count
|
|
1225
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
1226
|
+
|
|
1227
|
+
|
|
1228
|
+
@pytest.mark.vcr
|
|
1229
|
+
def test_direct_model_request_stream_sync(memory_logger, direct):
|
|
1230
|
+
"""Test direct API model_request_stream_sync()."""
|
|
1231
|
+
assert not memory_logger.pop()
|
|
1232
|
+
|
|
1233
|
+
messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 3")])]
|
|
1234
|
+
|
|
1235
|
+
start = time.time()
|
|
1236
|
+
chunk_count = 0
|
|
1237
|
+
with direct.model_request_stream_sync(model=MODEL, messages=messages) as stream:
|
|
1238
|
+
for chunk in stream:
|
|
1239
|
+
chunk_count += 1
|
|
1240
|
+
end = time.time()
|
|
1241
|
+
|
|
1242
|
+
# Verify we got chunks
|
|
1243
|
+
assert chunk_count > 0
|
|
1244
|
+
|
|
1245
|
+
# Check spans
|
|
1246
|
+
spans = memory_logger.pop()
|
|
1247
|
+
assert len(spans) == 1
|
|
1248
|
+
|
|
1249
|
+
span = spans[0]
|
|
1250
|
+
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
1251
|
+
assert span["span_attributes"]["name"] == "model_request_stream_sync"
|
|
1252
|
+
assert span["metadata"]["model"] == "gpt-4o-mini"
|
|
1253
|
+
_assert_metrics_are_valid(span["metrics"], start, end)
|
|
1254
|
+
|
|
1255
|
+
|
|
1256
|
+
@pytest.mark.vcr
|
|
1257
|
+
@pytest.mark.asyncio
|
|
1258
|
+
async def test_stream_early_break_async_generator(memory_logger, direct):
|
|
1259
|
+
"""Test breaking early from an async generator wrapper around a stream.
|
|
1260
|
+
|
|
1261
|
+
This reproduces the 'Token was created in a different Context' error that occurs
|
|
1262
|
+
when breaking early from async generators. The cleanup happens in a different
|
|
1263
|
+
async context, causing ContextVar token errors.
|
|
1264
|
+
|
|
1265
|
+
Our fix: Clear the context token before cleanup to make unset_current() use
|
|
1266
|
+
the safe set(None) path instead of reset(token).
|
|
1267
|
+
"""
|
|
1268
|
+
assert not memory_logger.pop()
|
|
1269
|
+
|
|
1270
|
+
messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 5")])]
|
|
1271
|
+
|
|
1272
|
+
async def stream_wrapper():
|
|
1273
|
+
"""Wrap the stream in an async generator (common customer pattern)."""
|
|
1274
|
+
async with direct.model_request_stream(model=MODEL, messages=messages) as stream:
|
|
1275
|
+
count = 0
|
|
1276
|
+
async for chunk in stream:
|
|
1277
|
+
yield chunk
|
|
1278
|
+
count += 1
|
|
1279
|
+
if count >= 3:
|
|
1280
|
+
# Break early - this triggers cleanup in different context
|
|
1281
|
+
break
|
|
1282
|
+
|
|
1283
|
+
start = time.time()
|
|
1284
|
+
chunk_count = 0
|
|
1285
|
+
|
|
1286
|
+
# This should NOT raise ValueError about "different Context"
|
|
1287
|
+
async for chunk in stream_wrapper():
|
|
1288
|
+
chunk_count += 1
|
|
1289
|
+
|
|
1290
|
+
end = time.time()
|
|
1291
|
+
|
|
1292
|
+
# Should not raise ValueError about context token
|
|
1293
|
+
assert chunk_count == 3
|
|
1294
|
+
|
|
1295
|
+
# Check spans - should have created a span despite early break
|
|
1296
|
+
spans = memory_logger.pop()
|
|
1297
|
+
assert len(spans) >= 1, "Should have at least one span even with early break"
|
|
1298
|
+
|
|
1299
|
+
span = spans[0]
|
|
1300
|
+
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
1301
|
+
assert span["span_attributes"]["name"] == "model_request_stream"
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
@pytest.mark.vcr
|
|
1305
|
+
@pytest.mark.asyncio
|
|
1306
|
+
async def test_agent_stream_early_break(memory_logger):
|
|
1307
|
+
"""Test breaking early from agent.run_stream() context manager.
|
|
1308
|
+
|
|
1309
|
+
Verifies that breaking early from the stream doesn't cause context token errors
|
|
1310
|
+
and that spans are still properly logged.
|
|
1311
|
+
"""
|
|
1312
|
+
assert not memory_logger.pop()
|
|
1313
|
+
|
|
1314
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
|
|
1315
|
+
|
|
1316
|
+
start = time.time()
|
|
1317
|
+
text_count = 0
|
|
1318
|
+
|
|
1319
|
+
# Break early from stream - should not raise context token error
|
|
1320
|
+
async with agent.run_stream("Count from 1 to 10") as result:
|
|
1321
|
+
async for text in result.stream_text(delta=True):
|
|
1322
|
+
text_count += 1
|
|
1323
|
+
if text_count >= 2: # Lower threshold - streaming may not produce many chunks
|
|
1324
|
+
break # Early break
|
|
1325
|
+
|
|
1326
|
+
end = time.time()
|
|
1327
|
+
|
|
1328
|
+
assert text_count >= 1 # At least one chunk received
|
|
1329
|
+
|
|
1330
|
+
# Check spans - may have incomplete spans due to early break
|
|
1331
|
+
spans = memory_logger.pop()
|
|
1332
|
+
assert len(spans) >= 1, f"Expected at least 1 span, got {len(spans)}"
|
|
1333
|
+
|
|
1334
|
+
# Find agent_run_stream span (if created)
|
|
1335
|
+
agent_span = next((s for s in spans if "agent_run_stream" in s["span_attributes"]["name"]), None)
|
|
1336
|
+
|
|
1337
|
+
# Verify at least agent_run_stream span exists and has basic structure
|
|
1338
|
+
if agent_span:
|
|
1339
|
+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
1340
|
+
# Metrics may be incomplete due to early break
|
|
1341
|
+
assert "start" in agent_span["metrics"]
|
|
1342
|
+
|
|
1343
|
+
|
|
1344
|
+
@pytest.mark.vcr
|
|
1345
|
+
@pytest.mark.asyncio
|
|
1346
|
+
async def test_agent_with_binary_content(memory_logger):
|
|
1347
|
+
"""Test that agents with binary content (images) work correctly.
|
|
1348
|
+
|
|
1349
|
+
Note: Full binary content serialization with attachment references is a complex feature.
|
|
1350
|
+
This test verifies basic functionality - that binary content doesn't break tracing.
|
|
1351
|
+
"""
|
|
1352
|
+
from pydantic_ai.models.function import BinaryContent
|
|
1353
|
+
|
|
1354
|
+
assert not memory_logger.pop()
|
|
1355
|
+
|
|
1356
|
+
# Use a small test image
|
|
1357
|
+
image_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\nIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82'
|
|
1358
|
+
|
|
1359
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
|
|
1360
|
+
|
|
1361
|
+
start = time.time()
|
|
1362
|
+
result = await agent.run(
|
|
1363
|
+
[
|
|
1364
|
+
BinaryContent(data=image_data, media_type="image/png"),
|
|
1365
|
+
"What color is this image?",
|
|
1366
|
+
]
|
|
1367
|
+
)
|
|
1368
|
+
end = time.time()
|
|
1369
|
+
|
|
1370
|
+
assert result.output
|
|
1371
|
+
assert isinstance(result.output, str)
|
|
1372
|
+
|
|
1373
|
+
# Check spans - verify basic tracing works
|
|
1374
|
+
spans = memory_logger.pop()
|
|
1375
|
+
assert len(spans) >= 1, f"Expected at least 1 span, got {len(spans)}"
|
|
1376
|
+
|
|
1377
|
+
# Find agent_run span
|
|
1378
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
|
|
1379
|
+
assert agent_span is not None, "agent_run span not found"
|
|
1380
|
+
|
|
1381
|
+
# Verify basic span structure
|
|
1382
|
+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
1383
|
+
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
|
|
1384
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
1385
|
+
|
|
1386
|
+
# TODO: Future enhancement - add full binary content serialization with attachment references
|
|
1387
|
+
|
|
1388
|
+
@pytest.mark.vcr
|
|
1389
|
+
@pytest.mark.asyncio
|
|
1390
|
+
async def test_agent_with_tool_execution(memory_logger):
|
|
1391
|
+
"""Test that tool execution creates proper span hierarchy.
|
|
1392
|
+
|
|
1393
|
+
Verifies that:
|
|
1394
|
+
1. Agent creates proper spans for tool calls
|
|
1395
|
+
2. Tool execution is captured in spans (ideally with "running tools" parent)
|
|
1396
|
+
3. Individual tool calls create child spans
|
|
1397
|
+
"""
|
|
1398
|
+
assert not memory_logger.pop()
|
|
1399
|
+
|
|
1400
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=200))
|
|
1401
|
+
|
|
1402
|
+
@agent.tool_plain
|
|
1403
|
+
def calculate(operation: str, a: float, b: float) -> str:
|
|
1404
|
+
"""Perform a mathematical calculation.
|
|
1405
|
+
|
|
1406
|
+
Args:
|
|
1407
|
+
operation: The mathematical operation (add, subtract, multiply, divide)
|
|
1408
|
+
a: First number
|
|
1409
|
+
b: Second number
|
|
1410
|
+
"""
|
|
1411
|
+
ops = {
|
|
1412
|
+
"add": a + b,
|
|
1413
|
+
"subtract": a - b,
|
|
1414
|
+
"multiply": a * b,
|
|
1415
|
+
"divide": a / b if b != 0 else "Error: Division by zero",
|
|
1416
|
+
}
|
|
1417
|
+
return str(ops.get(operation, "Invalid operation"))
|
|
1418
|
+
|
|
1419
|
+
start = time.time()
|
|
1420
|
+
result = await agent.run("What is 127 multiplied by 49?")
|
|
1421
|
+
end = time.time()
|
|
1422
|
+
|
|
1423
|
+
assert result.output
|
|
1424
|
+
assert "6" in str(result.output) and "223" in str(result.output) # Result contains 6223 (possibly formatted)
|
|
1425
|
+
|
|
1426
|
+
# Check spans
|
|
1427
|
+
spans = memory_logger.pop()
|
|
1428
|
+
|
|
1429
|
+
# We should have at least agent_run and chat spans
|
|
1430
|
+
# TODO: Add "running tools" parent span and "running tool: calculate" child span
|
|
1431
|
+
assert len(spans) >= 2, f"Expected at least 2 spans, got {len(spans)}"
|
|
1432
|
+
|
|
1433
|
+
# Find agent_run span
|
|
1434
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"]), None)
|
|
1435
|
+
assert agent_span is not None, "agent_run span not found"
|
|
1436
|
+
|
|
1437
|
+
# Verify that toolsets are captured in input with correct tool names
|
|
1438
|
+
assert "toolsets" in agent_span["input"], "toolsets should be in input (not metadata)"
|
|
1439
|
+
toolsets = agent_span["input"]["toolsets"]
|
|
1440
|
+
assert len(toolsets) > 0, "At least one toolset should be present"
|
|
1441
|
+
|
|
1442
|
+
# Find the agent toolset
|
|
1443
|
+
agent_toolset = None
|
|
1444
|
+
for ts in toolsets:
|
|
1445
|
+
if ts.get("id") == "<agent>":
|
|
1446
|
+
agent_toolset = ts
|
|
1447
|
+
break
|
|
1448
|
+
|
|
1449
|
+
assert agent_toolset is not None, "Agent toolset not found"
|
|
1450
|
+
assert "tools" in agent_toolset, "tools should be in agent toolset"
|
|
1451
|
+
|
|
1452
|
+
# Verify calculate tool is present (tools are now dicts with full schemas in input)
|
|
1453
|
+
tools = agent_toolset["tools"]
|
|
1454
|
+
assert isinstance(tools, list), "tools should be a list"
|
|
1455
|
+
tool_names = [t["name"] for t in tools if isinstance(t, dict)]
|
|
1456
|
+
assert "calculate" in tool_names, f"calculate tool should be in tools list, got: {tool_names}"
|
|
1457
|
+
|
|
1458
|
+
# Verify toolsets are NOT in metadata (following the principle: agent.run() accepts it)
|
|
1459
|
+
assert "toolsets" not in agent_span["metadata"], "toolsets should NOT be in metadata"
|
|
1460
|
+
|
|
1461
|
+
|
|
1462
|
+
def test_tool_execution_creates_spans(memory_logger):
|
|
1463
|
+
"""Test that executing tools with agents works and creates traced spans.
|
|
1464
|
+
|
|
1465
|
+
Note: Tool-level span creation is not yet implemented in the wrapper.
|
|
1466
|
+
This test verifies that agents with tools work correctly and produce agent/chat spans.
|
|
1467
|
+
|
|
1468
|
+
Future enhancement: Add automatic span creation for tool executions as children of
|
|
1469
|
+
the chat span that requested them.
|
|
1470
|
+
"""
|
|
1471
|
+
assert not memory_logger.pop()
|
|
1472
|
+
|
|
1473
|
+
start = time.time()
|
|
1474
|
+
|
|
1475
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=500))
|
|
1476
|
+
|
|
1477
|
+
@agent.tool_plain
|
|
1478
|
+
def calculate(operation: str, a: float, b: float) -> float:
|
|
1479
|
+
"""Perform a mathematical calculation."""
|
|
1480
|
+
if operation == "multiply":
|
|
1481
|
+
return a * b
|
|
1482
|
+
elif operation == "add":
|
|
1483
|
+
return a + b
|
|
1484
|
+
else:
|
|
1485
|
+
return 0.0
|
|
1486
|
+
|
|
1487
|
+
# Run the agent with a query that will use the tool
|
|
1488
|
+
result = agent.run_sync("What is 127 multiplied by 49?")
|
|
1489
|
+
end = time.time()
|
|
1490
|
+
|
|
1491
|
+
# Verify the tool was actually called and result is correct
|
|
1492
|
+
assert result.output
|
|
1493
|
+
assert "6223" in str(result.output) or "6,223" in str(result.output), f"Expected calculation result in output: {result.output}"
|
|
1494
|
+
|
|
1495
|
+
# Get logged spans
|
|
1496
|
+
spans = memory_logger.pop()
|
|
1497
|
+
|
|
1498
|
+
# Find spans by type
|
|
1499
|
+
agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"]), None)
|
|
1500
|
+
chat_spans = [s for s in spans if "chat" in s["span_attributes"]["name"]]
|
|
1501
|
+
|
|
1502
|
+
# Assertions - verify basic tracing works with tools
|
|
1503
|
+
assert agent_span is not None, "agent_run span should exist"
|
|
1504
|
+
assert len(chat_spans) >= 1, f"Expected at least 1 chat span, got {len(chat_spans)}"
|
|
1505
|
+
|
|
1506
|
+
# Verify agent span has tool information in input
|
|
1507
|
+
assert "toolsets" in agent_span["input"], "Tool information should be captured in agent input"
|
|
1508
|
+
toolsets = agent_span["input"]["toolsets"]
|
|
1509
|
+
agent_toolset = next((ts for ts in toolsets if ts.get("id") == "<agent>"), None)
|
|
1510
|
+
assert agent_toolset is not None, "Agent toolset should be in input"
|
|
1511
|
+
|
|
1512
|
+
# Verify calculate tool is in the toolset
|
|
1513
|
+
tools = agent_toolset.get("tools", [])
|
|
1514
|
+
tool_names = [t["name"] for t in tools if isinstance(t, dict)]
|
|
1515
|
+
assert "calculate" in tool_names, f"calculate tool should be in toolset, got: {tool_names}"
|
|
1516
|
+
|
|
1517
|
+
# TODO: Future enhancement - verify tool execution spans are created
|
|
1518
|
+
# tool_spans = [s for s in spans if "calculate" in s["span_attributes"].get("name", "")]
|
|
1519
|
+
# assert len(tool_spans) > 0, "Tool execution should create spans"
|
|
1520
|
+
|
|
1521
|
+
|
|
1522
|
+
def test_agent_tool_metadata_extraction(memory_logger):
|
|
1523
|
+
"""Test that agent tools are properly extracted with full schemas in INPUT (not metadata).
|
|
1524
|
+
|
|
1525
|
+
Principle: If agent.run() accepts it, it goes in input only.
|
|
1526
|
+
"""
|
|
1527
|
+
from braintrust.wrappers.pydantic_ai import _build_agent_input_and_metadata
|
|
1528
|
+
|
|
1529
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
|
|
1530
|
+
|
|
1531
|
+
# Add multiple tools with different signatures
|
|
1532
|
+
@agent.tool_plain
|
|
1533
|
+
def calculate(operation: str, a: float, b: float) -> str:
|
|
1534
|
+
"""Perform a mathematical calculation."""
|
|
1535
|
+
return str(a + b)
|
|
1536
|
+
|
|
1537
|
+
@agent.tool_plain
|
|
1538
|
+
def get_weather(location: str) -> str:
|
|
1539
|
+
"""Get weather for a location."""
|
|
1540
|
+
return f"Weather in {location}"
|
|
1541
|
+
|
|
1542
|
+
@agent.tool_plain
|
|
1543
|
+
def search_database(query: str, limit: int = 10) -> str:
|
|
1544
|
+
"""Search the database."""
|
|
1545
|
+
return "Results"
|
|
1546
|
+
|
|
1547
|
+
# Extract metadata using the actual function signature
|
|
1548
|
+
args = ("Test prompt",)
|
|
1549
|
+
kwargs = {}
|
|
1550
|
+
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, agent)
|
|
1551
|
+
|
|
1552
|
+
# Verify toolsets are in INPUT (since agent.run() accepts toolsets parameter)
|
|
1553
|
+
assert "toolsets" in input_data, "toolsets should be in input (can be passed to agent.run())"
|
|
1554
|
+
toolsets = input_data["toolsets"]
|
|
1555
|
+
assert len(toolsets) > 0, "At least one toolset should be present"
|
|
1556
|
+
|
|
1557
|
+
# Verify toolsets are NOT in metadata (following the principle)
|
|
1558
|
+
assert "toolsets" not in metadata, "toolsets should NOT be in metadata (it's a run() parameter)"
|
|
1559
|
+
|
|
1560
|
+
# Find the agent toolset
|
|
1561
|
+
agent_toolset = None
|
|
1562
|
+
for ts in toolsets:
|
|
1563
|
+
if ts.get("id") == "<agent>":
|
|
1564
|
+
agent_toolset = ts
|
|
1565
|
+
break
|
|
1566
|
+
|
|
1567
|
+
assert agent_toolset is not None, "Agent toolset not found in input"
|
|
1568
|
+
assert agent_toolset.get("label") == "the agent", "Agent toolset should have correct label"
|
|
1569
|
+
assert "tools" in agent_toolset, "tools should be in agent toolset"
|
|
1570
|
+
|
|
1571
|
+
# Verify all tools are present with FULL SCHEMAS
|
|
1572
|
+
tools = agent_toolset["tools"]
|
|
1573
|
+
assert isinstance(tools, list), "tools should be a list"
|
|
1574
|
+
assert len(tools) == 3, f"Should have exactly 3 tools, got {len(tools)}"
|
|
1575
|
+
|
|
1576
|
+
# Check each tool has full schema information
|
|
1577
|
+
tool_names = [t["name"] for t in tools]
|
|
1578
|
+
assert "calculate" in tool_names, f"calculate tool should be present, got: {tool_names}"
|
|
1579
|
+
assert "get_weather" in tool_names, f"get_weather tool should be present, got: {tool_names}"
|
|
1580
|
+
assert "search_database" in tool_names, f"search_database tool should be present, got: {tool_names}"
|
|
1581
|
+
|
|
1582
|
+
# Verify calculate tool has full schema
|
|
1583
|
+
calculate_tool = next(t for t in tools if t["name"] == "calculate")
|
|
1584
|
+
assert "description" in calculate_tool, "Tool should have description"
|
|
1585
|
+
assert "Perform a mathematical calculation" in calculate_tool["description"]
|
|
1586
|
+
assert "parameters" in calculate_tool, "Tool should have parameters schema"
|
|
1587
|
+
params = calculate_tool["parameters"]
|
|
1588
|
+
assert "properties" in params, "Parameters should have properties"
|
|
1589
|
+
assert "operation" in params["properties"], "Should have 'operation' parameter"
|
|
1590
|
+
assert "a" in params["properties"], "Should have 'a' parameter"
|
|
1591
|
+
assert "b" in params["properties"], "Should have 'b' parameter"
|
|
1592
|
+
assert params["properties"]["operation"]["type"] == "string"
|
|
1593
|
+
assert params["properties"]["a"]["type"] == "number"
|
|
1594
|
+
assert params["properties"]["b"]["type"] == "number"
|
|
1595
|
+
|
|
1596
|
+
# Verify search_database has optional parameter
|
|
1597
|
+
search_tool = next(t for t in tools if t["name"] == "search_database")
|
|
1598
|
+
assert "parameters" in search_tool
|
|
1599
|
+
search_params = search_tool["parameters"]
|
|
1600
|
+
assert "query" in search_params["properties"]
|
|
1601
|
+
assert "limit" in search_params["properties"]
|
|
1602
|
+
# 'query' should be required, 'limit' should be optional (has default)
|
|
1603
|
+
assert "query" in search_params.get("required", [])
|
|
1604
|
+
|
|
1605
|
+
|
|
1606
|
+
def test_agent_without_tools_metadata():
|
|
1607
|
+
"""Test metadata extraction for agent without tools."""
|
|
1608
|
+
from braintrust.wrappers.pydantic_ai import _build_agent_input_and_metadata
|
|
1609
|
+
|
|
1610
|
+
# Agent with no tools
|
|
1611
|
+
agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
|
|
1612
|
+
|
|
1613
|
+
args = ("Test prompt",)
|
|
1614
|
+
kwargs = {}
|
|
1615
|
+
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, agent)
|
|
1616
|
+
|
|
1617
|
+
# Should have toolsets in input (even if empty)
|
|
1618
|
+
# Note: Pydantic AI agents always have some toolsets (e.g., for output parsing)
|
|
1619
|
+
# so we just verify the structure exists
|
|
1620
|
+
assert isinstance(input_data.get("toolsets"), (list, type(None))), "toolsets should be list or None in input"
|
|
1621
|
+
|
|
1622
|
+
|
|
1623
|
+
def test_agent_tool_with_custom_name():
|
|
1624
|
+
"""Test that tools with custom names are properly extracted with schemas in input."""
|
|
1625
|
+
from braintrust.wrappers.pydantic_ai import _build_agent_input_and_metadata
|
|
1626
|
+
|
|
1627
|
+
agent = Agent(MODEL)
|
|
1628
|
+
|
|
1629
|
+
# Add tool with custom name
|
|
1630
|
+
@agent.tool_plain(name="custom_calculator")
|
|
1631
|
+
def calc(a: int, b: int) -> int:
|
|
1632
|
+
"""Add two numbers."""
|
|
1633
|
+
return a + b
|
|
1634
|
+
|
|
1635
|
+
args = ("Test",)
|
|
1636
|
+
kwargs = {}
|
|
1637
|
+
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, agent)
|
|
1638
|
+
|
|
1639
|
+
# Verify custom name is used in input (not metadata)
|
|
1640
|
+
assert "toolsets" in input_data
|
|
1641
|
+
assert "toolsets" not in metadata, "toolsets should not be in metadata"
|
|
1642
|
+
|
|
1643
|
+
agent_toolset = next((ts for ts in input_data["toolsets"] if ts.get("id") == "<agent>"), None)
|
|
1644
|
+
assert agent_toolset is not None
|
|
1645
|
+
tools = agent_toolset.get("tools", [])
|
|
1646
|
+
|
|
1647
|
+
# The tool should be a dict with schema info
|
|
1648
|
+
assert len(tools) == 1, f"Should have 1 tool, got {len(tools)}"
|
|
1649
|
+
tool = tools[0]
|
|
1650
|
+
assert isinstance(tool, dict), "Tool should be a dict with schema"
|
|
1651
|
+
assert tool["name"] == "custom_calculator", f"Should use custom name, got: {tool.get('name')}"
|
|
1652
|
+
assert "description" in tool, "Tool should have description"
|
|
1653
|
+
assert "parameters" in tool, "Tool should have parameters schema"
|
|
1654
|
+
assert "a" in tool["parameters"]["properties"]
|
|
1655
|
+
assert "b" in tool["parameters"]["properties"]
|
|
1656
|
+
|
|
1657
|
+
|
|
1658
|
+
def test_explicit_toolsets_kwarg_in_input():
|
|
1659
|
+
"""Test that explicitly passed toolsets kwarg goes to input (not just metadata)."""
|
|
1660
|
+
from braintrust.wrappers.pydantic_ai import _build_agent_input_and_metadata
|
|
1661
|
+
|
|
1662
|
+
agent = Agent(MODEL)
|
|
1663
|
+
|
|
1664
|
+
# Add a tool to the agent
|
|
1665
|
+
@agent.tool_plain
|
|
1666
|
+
def helper_tool() -> str:
|
|
1667
|
+
"""A helper tool."""
|
|
1668
|
+
return "help"
|
|
1669
|
+
|
|
1670
|
+
# Simulate passing toolsets as explicit kwarg (would be a different toolset in practice)
|
|
1671
|
+
# For testing, we'll just pass the string "custom" to see it in input
|
|
1672
|
+
args = ("Test",)
|
|
1673
|
+
kwargs = {"toolsets": "custom_toolset_marker"} # Simplified for testing
|
|
1674
|
+
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, agent)
|
|
1675
|
+
|
|
1676
|
+
# Toolsets passed as kwargs should be in input
|
|
1677
|
+
assert "toolsets" in input_data, "explicitly passed toolsets should be in input"
|
|
1678
|
+
|
|
1679
|
+
|
|
1680
|
+
@pytest.mark.vcr
|
|
1681
|
+
def test_reasoning_tokens_extraction(memory_logger):
|
|
1682
|
+
"""Test that reasoning tokens are extracted from model responses.
|
|
1683
|
+
|
|
1684
|
+
For reasoning models like o1/o3, usage.details.reasoning_tokens should be
|
|
1685
|
+
captured in the metrics field.
|
|
1686
|
+
"""
|
|
1687
|
+
assert not memory_logger.pop()
|
|
1688
|
+
|
|
1689
|
+
# Mock a response that has reasoning tokens
|
|
1690
|
+
from unittest.mock import MagicMock
|
|
1691
|
+
|
|
1692
|
+
# Create a mock response with reasoning tokens
|
|
1693
|
+
mock_response = MagicMock()
|
|
1694
|
+
mock_response.parts = [
|
|
1695
|
+
MagicMock(
|
|
1696
|
+
part_kind="thinking",
|
|
1697
|
+
content="Let me think about this...",
|
|
1698
|
+
),
|
|
1699
|
+
MagicMock(
|
|
1700
|
+
part_kind="text",
|
|
1701
|
+
content="The answer is 42",
|
|
1702
|
+
),
|
|
1703
|
+
]
|
|
1704
|
+
mock_response.usage = MagicMock()
|
|
1705
|
+
mock_response.usage.input_tokens = 10
|
|
1706
|
+
mock_response.usage.output_tokens = 20
|
|
1707
|
+
mock_response.usage.total_tokens = 30
|
|
1708
|
+
mock_response.usage.cache_read_tokens = 0
|
|
1709
|
+
mock_response.usage.cache_write_tokens = 0
|
|
1710
|
+
mock_response.usage.details = MagicMock()
|
|
1711
|
+
mock_response.usage.details.reasoning_tokens = 128
|
|
1712
|
+
|
|
1713
|
+
# Test the metric extraction function directly
|
|
1714
|
+
from braintrust.wrappers.pydantic_ai import _extract_response_metrics
|
|
1715
|
+
|
|
1716
|
+
start_time = time.time()
|
|
1717
|
+
end_time = start_time + 5.0
|
|
1718
|
+
|
|
1719
|
+
metrics = _extract_response_metrics(mock_response, start_time, end_time)
|
|
1720
|
+
|
|
1721
|
+
# Verify all metrics are present
|
|
1722
|
+
assert metrics is not None, "Should extract metrics"
|
|
1723
|
+
# pylint: disable=unsupported-membership-test,unsubscriptable-object
|
|
1724
|
+
assert "prompt_tokens" in metrics, "Should have prompt_tokens"
|
|
1725
|
+
assert metrics["prompt_tokens"] == 10.0
|
|
1726
|
+
assert "completion_tokens" in metrics, "Should have completion_tokens"
|
|
1727
|
+
assert metrics["completion_tokens"] == 20.0
|
|
1728
|
+
assert "tokens" in metrics, "Should have total tokens"
|
|
1729
|
+
assert metrics["tokens"] == 30.0
|
|
1730
|
+
assert "completion_reasoning_tokens" in metrics, "Should have completion_reasoning_tokens"
|
|
1731
|
+
assert metrics["completion_reasoning_tokens"] == 128.0, f"Expected 128.0, got {metrics['completion_reasoning_tokens']}"
|
|
1732
|
+
assert "duration" in metrics
|
|
1733
|
+
assert "start" in metrics
|
|
1734
|
+
assert "end" in metrics
|
|
1735
|
+
# pylint: enable=unsupported-membership-test,unsubscriptable-object
|
|
1736
|
+
|
|
1737
|
+
|
|
1738
|
+
@pytest.mark.vcr
|
|
1739
|
+
@pytest.mark.asyncio
|
|
1740
|
+
async def test_agent_run_stream_structured_output(memory_logger):
|
|
1741
|
+
"""Test Agent.run_stream() with structured output (Pydantic model).
|
|
1742
|
+
|
|
1743
|
+
Verifies that streaming structured output creates proper spans and
|
|
1744
|
+
that the result can be accessed via get_output() method.
|
|
1745
|
+
"""
|
|
1746
|
+
assert not memory_logger.pop()
|
|
1747
|
+
|
|
1748
|
+
class Product(BaseModel):
|
|
1749
|
+
name: str
|
|
1750
|
+
price: float
|
|
1751
|
+
|
|
1752
|
+
agent = Agent(
|
|
1753
|
+
MODEL,
|
|
1754
|
+
output_type=Product,
|
|
1755
|
+
model_settings=ModelSettings(max_tokens=200)
|
|
1756
|
+
)
|
|
1757
|
+
|
|
1758
|
+
start = time.time()
|
|
1759
|
+
async with agent.run_stream("Create a product: wireless mouse for $29.99") as result:
|
|
1760
|
+
# For structured output, use get_output() instead of streaming text
|
|
1761
|
+
product = await result.get_output()
|
|
1762
|
+
end = time.time()
|
|
1763
|
+
|
|
1764
|
+
# Verify structured output
|
|
1765
|
+
assert isinstance(product, Product)
|
|
1766
|
+
assert product.name
|
|
1767
|
+
assert product.price > 0
|
|
1768
|
+
|
|
1769
|
+
# Check spans
|
|
1770
|
+
spans = memory_logger.pop()
|
|
1771
|
+
assert len(spans) >= 2, f"Expected at least 2 spans (agent_run_stream + chat), got {len(spans)}"
|
|
1772
|
+
|
|
1773
|
+
# Find agent_run_stream and chat spans
|
|
1774
|
+
agent_span = next((s for s in spans if "agent_run_stream" in s["span_attributes"]["name"]), None)
|
|
1775
|
+
chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
|
|
1776
|
+
|
|
1777
|
+
assert agent_span is not None, "agent_run_stream span not found"
|
|
1778
|
+
assert chat_span is not None, "chat span not found"
|
|
1779
|
+
|
|
1780
|
+
# Check agent span
|
|
1781
|
+
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
|
|
1782
|
+
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
|
|
1783
|
+
_assert_metrics_are_valid(agent_span["metrics"], start, end)
|
|
1784
|
+
|
|
1785
|
+
# Check chat span is nested
|
|
1786
|
+
assert chat_span["span_parents"] == [agent_span["span_id"]], "chat span should be nested under agent_run_stream"
|
|
1787
|
+
assert chat_span["metadata"]["model"] == "gpt-4o-mini"
|
|
1788
|
+
_assert_metrics_are_valid(chat_span["metrics"], start, end)
|