braintrust 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. braintrust/__init__.py +4 -0
  2. braintrust/_generated_types.py +1200 -611
  3. braintrust/audit.py +2 -2
  4. braintrust/cli/eval.py +6 -7
  5. braintrust/cli/push.py +11 -11
  6. braintrust/conftest.py +1 -0
  7. braintrust/context.py +12 -17
  8. braintrust/contrib/temporal/__init__.py +16 -27
  9. braintrust/contrib/temporal/test_temporal.py +8 -3
  10. braintrust/devserver/auth.py +8 -8
  11. braintrust/devserver/cache.py +3 -4
  12. braintrust/devserver/cors.py +8 -7
  13. braintrust/devserver/dataset.py +3 -5
  14. braintrust/devserver/eval_hooks.py +7 -6
  15. braintrust/devserver/schemas.py +22 -19
  16. braintrust/devserver/server.py +19 -12
  17. braintrust/devserver/test_cached_login.py +4 -4
  18. braintrust/framework.py +128 -140
  19. braintrust/framework2.py +88 -87
  20. braintrust/functions/invoke.py +93 -53
  21. braintrust/functions/stream.py +3 -2
  22. braintrust/generated_types.py +17 -1
  23. braintrust/git_fields.py +11 -11
  24. braintrust/gitutil.py +2 -3
  25. braintrust/graph_util.py +10 -10
  26. braintrust/id_gen.py +2 -2
  27. braintrust/logger.py +346 -357
  28. braintrust/merge_row_batch.py +10 -9
  29. braintrust/oai.py +107 -24
  30. braintrust/otel/__init__.py +49 -49
  31. braintrust/otel/context.py +16 -30
  32. braintrust/otel/test_distributed_tracing.py +14 -11
  33. braintrust/otel/test_otel_bt_integration.py +32 -31
  34. braintrust/parameters.py +8 -8
  35. braintrust/prompt.py +14 -14
  36. braintrust/prompt_cache/disk_cache.py +5 -4
  37. braintrust/prompt_cache/lru_cache.py +3 -2
  38. braintrust/prompt_cache/prompt_cache.py +13 -14
  39. braintrust/queue.py +4 -4
  40. braintrust/score.py +4 -4
  41. braintrust/serializable_data_class.py +4 -4
  42. braintrust/span_identifier_v1.py +1 -2
  43. braintrust/span_identifier_v2.py +3 -4
  44. braintrust/span_identifier_v3.py +23 -20
  45. braintrust/span_identifier_v4.py +34 -25
  46. braintrust/test_framework.py +16 -6
  47. braintrust/test_helpers.py +5 -5
  48. braintrust/test_id_gen.py +2 -3
  49. braintrust/test_otel.py +61 -53
  50. braintrust/test_queue.py +0 -1
  51. braintrust/test_score.py +1 -3
  52. braintrust/test_span_components.py +29 -44
  53. braintrust/util.py +9 -8
  54. braintrust/version.py +2 -2
  55. braintrust/wrappers/_anthropic_utils.py +4 -4
  56. braintrust/wrappers/agno/__init__.py +3 -4
  57. braintrust/wrappers/agno/agent.py +1 -2
  58. braintrust/wrappers/agno/function_call.py +1 -2
  59. braintrust/wrappers/agno/model.py +1 -2
  60. braintrust/wrappers/agno/team.py +1 -2
  61. braintrust/wrappers/agno/utils.py +12 -12
  62. braintrust/wrappers/anthropic.py +7 -8
  63. braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
  64. braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
  65. braintrust/wrappers/dspy.py +15 -17
  66. braintrust/wrappers/google_genai/__init__.py +16 -16
  67. braintrust/wrappers/langchain.py +22 -24
  68. braintrust/wrappers/litellm.py +4 -3
  69. braintrust/wrappers/openai.py +15 -15
  70. braintrust/wrappers/pydantic_ai.py +1204 -0
  71. braintrust/wrappers/test_agno.py +0 -1
  72. braintrust/wrappers/test_dspy.py +0 -1
  73. braintrust/wrappers/test_google_genai.py +2 -3
  74. braintrust/wrappers/test_litellm.py +0 -1
  75. braintrust/wrappers/test_oai_attachments.py +322 -0
  76. braintrust/wrappers/test_pydantic_ai_integration.py +1788 -0
  77. braintrust/wrappers/{test_pydantic_ai.py → test_pydantic_ai_wrap_openai.py} +1 -2
  78. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/METADATA +3 -2
  79. braintrust-0.4.0.dist-info/RECORD +120 -0
  80. braintrust-0.3.14.dist-info/RECORD +0 -117
  81. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/WHEEL +0 -0
  82. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/entry_points.txt +0 -0
  83. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1788 @@
1
+ # pyright: reportUntypedFunctionDecorator=false
2
+ # pyright: reportUnknownMemberType=false
3
+ # pyright: reportUnknownParameterType=false
4
+ # pyright: reportPrivateUsage=false
5
+ import time
6
+
7
+ import pytest
8
+ from braintrust import logger, setup_pydantic_ai, traced
9
+ from braintrust.span_types import SpanTypeAttribute
10
+ from braintrust.test_helpers import init_test_logger
11
+ from pydantic import BaseModel
12
+ from pydantic_ai import Agent, ModelSettings
13
+ from pydantic_ai.messages import ModelRequest, UserPromptPart
14
+
15
+ PROJECT_NAME = "test-pydantic-ai-integration"
16
+ MODEL = "openai:gpt-4o-mini" # Use cheaper model for tests
17
+ TEST_PROMPT = "What is 2+2? Answer with just the number."
18
+
19
+
20
+ @pytest.fixture(scope="module", autouse=True)
21
+ def setup_wrapper():
22
+ """Setup pydantic_ai wrapper before any tests run."""
23
+ setup_pydantic_ai(project_name=PROJECT_NAME)
24
+ yield
25
+
26
+
27
+ @pytest.fixture(scope="module")
28
+ def direct():
29
+ """Provide pydantic_ai.direct module after setup_wrapper has run."""
30
+ import pydantic_ai.direct as direct_module
31
+ return direct_module
32
+
33
+
34
+ @pytest.fixture(scope="module")
35
+ def vcr_config():
36
+ return {
37
+ "filter_headers": [
38
+ "authorization",
39
+ "openai-organization",
40
+ "x-api-key",
41
+ ]
42
+ }
43
+
44
+
45
+ @pytest.fixture
46
+ def memory_logger():
47
+ init_test_logger(PROJECT_NAME)
48
+ with logger._internal_with_memory_background_logger() as bgl:
49
+ yield bgl
50
+
51
+
52
+ def _assert_metrics_are_valid(metrics, start, end):
53
+ """Assert that metrics contain expected fields and values."""
54
+ assert "start" in metrics
55
+ assert "end" in metrics
56
+ assert "duration" in metrics
57
+ assert start <= metrics["start"] <= metrics["end"] <= end
58
+ assert metrics["duration"] > 0
59
+
60
+ # Token metrics (if present)
61
+ if "tokens" in metrics:
62
+ assert metrics["tokens"] > 0
63
+ if "prompt_tokens" in metrics:
64
+ assert metrics["prompt_tokens"] > 0
65
+ if "completion_tokens" in metrics:
66
+ assert metrics["completion_tokens"] > 0
67
+
68
+
69
+ @pytest.mark.vcr
70
+ @pytest.mark.asyncio
71
+ async def test_agent_run_async(memory_logger):
72
+ """Test Agent.run() async method."""
73
+ assert not memory_logger.pop()
74
+
75
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
76
+
77
+ start = time.time()
78
+ result = await agent.run(TEST_PROMPT)
79
+ end = time.time()
80
+
81
+ # Verify the result
82
+ assert result.output
83
+ assert "4" in str(result.output)
84
+
85
+ # Check spans - should now have parent agent_run + nested chat span
86
+ spans = memory_logger.pop()
87
+ assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
88
+
89
+ # Find agent_run and chat spans
90
+ agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"]), None)
91
+ chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
92
+
93
+ assert agent_span is not None, "agent_run span not found"
94
+ assert chat_span is not None, "chat span not found"
95
+
96
+ # Check agent span
97
+ assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
98
+ assert agent_span["metadata"]["model"] == "gpt-4o-mini"
99
+ assert agent_span["metadata"]["provider"] == "openai"
100
+ assert TEST_PROMPT in str(agent_span["input"])
101
+ assert "4" in str(agent_span["output"])
102
+ _assert_metrics_are_valid(agent_span["metrics"], start, end)
103
+
104
+ # Check chat span is nested under agent span (use span_id, not id which is the row ID)
105
+ assert chat_span["span_parents"] == [agent_span["span_id"]], "chat span should be nested under agent_run"
106
+ assert chat_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
107
+ assert "gpt-4o-mini" in chat_span["span_attributes"]["name"]
108
+ assert chat_span["metadata"]["model"] == "gpt-4o-mini"
109
+ assert chat_span["metadata"]["provider"] == "openai"
110
+ _assert_metrics_are_valid(chat_span["metrics"], start, end)
111
+
112
+ # Agent spans should have token metrics
113
+ assert "prompt_tokens" in agent_span["metrics"]
114
+ assert "completion_tokens" in agent_span["metrics"]
115
+ assert agent_span["metrics"]["prompt_tokens"] > 0
116
+ assert agent_span["metrics"]["completion_tokens"] > 0
117
+
118
+
119
+ @pytest.mark.vcr
120
+ def test_agent_run_sync(memory_logger):
121
+ """Test Agent.run_sync() synchronous method."""
122
+ assert not memory_logger.pop()
123
+
124
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
125
+
126
+ start = time.time()
127
+ result = agent.run_sync(TEST_PROMPT)
128
+ end = time.time()
129
+
130
+ # Verify the result
131
+ assert result.output
132
+ assert "4" in str(result.output)
133
+
134
+ # Check spans - should have parent agent_run_sync + nested spans
135
+ spans = memory_logger.pop()
136
+ assert len(spans) >= 2, f"Expected at least 2 spans (agent_run_sync + chat), got {len(spans)}"
137
+
138
+ # Find agent_run_sync and chat spans
139
+ agent_sync_span = next((s for s in spans if "agent_run_sync" in s["span_attributes"]["name"]), None)
140
+ chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
141
+
142
+ assert agent_sync_span is not None, "agent_run_sync span not found"
143
+ assert chat_span is not None, "chat span not found"
144
+
145
+ # Check agent span
146
+ assert agent_sync_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
147
+ assert agent_sync_span["metadata"]["model"] == "gpt-4o-mini"
148
+ assert agent_sync_span["metadata"]["provider"] == "openai"
149
+ assert TEST_PROMPT in str(agent_sync_span["input"])
150
+ assert "4" in str(agent_sync_span["output"])
151
+ _assert_metrics_are_valid(agent_sync_span["metrics"], start, end)
152
+
153
+ # Check chat span is a descendant of agent_run_sync span
154
+ # Build span tree to verify nesting
155
+ span_by_id = {s["span_id"]: s for s in spans}
156
+
157
+ def is_descendant(child_span, ancestor_id):
158
+ """Check if child_span is a descendant of ancestor_id."""
159
+ if not child_span.get("span_parents"):
160
+ return False
161
+ if ancestor_id in child_span["span_parents"]:
162
+ return True
163
+ # Check if any parent is a descendant
164
+ for parent_id in chat_span["span_parents"]:
165
+ if parent_id in span_by_id and is_descendant(span_by_id[parent_id], ancestor_id):
166
+ return True
167
+ return False
168
+
169
+
170
+ assert is_descendant(chat_span, agent_sync_span["span_id"]), "chat span should be nested under agent_run_sync"
171
+ assert chat_span["metadata"]["model"] == "gpt-4o-mini"
172
+ assert chat_span["metadata"]["provider"] == "openai"
173
+ _assert_metrics_are_valid(chat_span["metrics"], start, end)
174
+
175
+ # Agent spans should have token metrics
176
+ assert "prompt_tokens" in agent_sync_span["metrics"]
177
+ assert "completion_tokens" in agent_sync_span["metrics"]
178
+
179
+
180
+ @pytest.mark.vcr
181
+ @pytest.mark.asyncio
182
+ async def test_multiple_identical_sequential_streams(memory_logger):
183
+ """Test multiple identical sequential streaming calls to ensure offsets don't accumulate.
184
+
185
+ This test makes 3 identical streaming calls in sequence. If timing is captured correctly,
186
+ each chat span's offset relative to its parent agent span should be roughly the same
187
+ (typically < 100ms). If offsets are accumulating incorrectly, we'd see the second and
188
+ third chat spans having much larger offsets than the first.
189
+ """
190
+ assert not memory_logger.pop()
191
+
192
+ @traced
193
+ async def run_multiple_identical_streams():
194
+ # Make 3 identical streaming calls
195
+ for i in range(3):
196
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
197
+ async with agent.run_stream("Count from 1 to 3.") as result:
198
+ full_text = ""
199
+ async for text in result.stream_text(delta=True):
200
+ full_text += text
201
+ print(f"Completed stream {i+1}")
202
+
203
+ await run_multiple_identical_streams()
204
+
205
+ # Check spans
206
+ spans = memory_logger.pop()
207
+
208
+ # Find agent and chat spans
209
+ agent_spans = [s for s in spans if "agent_run" in s["span_attributes"]["name"]]
210
+ chat_spans = [s for s in spans if "chat" in s["span_attributes"]["name"]]
211
+
212
+ assert len(agent_spans) >= 3, f"Expected at least 3 agent spans, got {len(agent_spans)}"
213
+ assert len(chat_spans) >= 3, f"Expected at least 3 chat spans, got {len(chat_spans)}"
214
+
215
+ # Sort by creation time
216
+ agent_spans.sort(key=lambda s: s["created"])
217
+ chat_spans.sort(key=lambda s: s["created"])
218
+
219
+ # Calculate time-to-first-token for each pair
220
+ time_to_first_tokens = []
221
+ for i in range(3):
222
+ agent_start = agent_spans[i]["metrics"]["start"]
223
+ chat_start = chat_spans[i]["metrics"]["start"]
224
+ ttft = chat_start - agent_start
225
+ time_to_first_tokens.append(ttft)
226
+
227
+ print(f"\n=== STREAM {i+1} ===")
228
+ print(f"Agent span start: {agent_start}")
229
+ print(f"Chat span start: {chat_start}")
230
+ print(f"Time to first token: {ttft}s")
231
+ print(f"Agent span ID: {agent_spans[i]['span_id']}")
232
+ print(f"Chat span parents: {chat_spans[i]['span_parents']}")
233
+
234
+ # CRITICAL: All three time-to-first-token values should be similar (within 0.5s of each other)
235
+ # If they're accumulating, the second and third would be much larger
236
+ min_ttft = min(time_to_first_tokens)
237
+ max_ttft = max(time_to_first_tokens)
238
+ ttft_spread = max_ttft - min_ttft
239
+
240
+ print(f"\n=== TIME-TO-FIRST-TOKEN ANALYSIS ===")
241
+ print(f"TTFT 1: {time_to_first_tokens[0]:.4f}s")
242
+ print(f"TTFT 2: {time_to_first_tokens[1]:.4f}s")
243
+ print(f"TTFT 3: {time_to_first_tokens[2]:.4f}s")
244
+ print(f"Min: {min_ttft:.4f}s, Max: {max_ttft:.4f}s, Spread: {ttft_spread:.4f}s")
245
+
246
+ # All should be small (< 3s)
247
+ for i, ttft in enumerate(time_to_first_tokens):
248
+ assert ttft < 3.0, f"Stream {i+1} time to first token too large: {ttft}s"
249
+
250
+ # Spread should be small (< 0.5s) - this catches the accumulation bug
251
+ assert ttft_spread < 0.5, f"Time-to-first-token spread too large: {ttft_spread}s - suggests timing is accumulating from previous calls"
252
+
253
+
254
+ @pytest.mark.vcr
255
+ @pytest.mark.asyncio
256
+ async def test_multiple_sequential_streams(memory_logger):
257
+ """Test multiple sequential streaming calls to ensure offsets don't accumulate."""
258
+ assert not memory_logger.pop()
259
+
260
+ @traced
261
+ async def run_multiple_streams():
262
+ agent1 = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
263
+ agent2 = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
264
+
265
+ start = time.time()
266
+
267
+ # First stream
268
+ async with agent1.run_stream("Count from 1 to 3.") as result1:
269
+ full_text1 = ""
270
+ async for text in result1.stream_text(delta=True):
271
+ full_text1 += text
272
+
273
+ # Second stream
274
+ async with agent2.run_stream("Count from 1 to 3.") as result2:
275
+ full_text2 = ""
276
+ async for text in result2.stream_text(delta=True):
277
+ full_text2 += text
278
+
279
+ return start
280
+
281
+ start = await run_multiple_streams()
282
+ end = time.time()
283
+
284
+ # Check spans
285
+ spans = memory_logger.pop()
286
+
287
+ # Should have: 1 parent (run_multiple_streams) + 2 agent_run_stream spans + 2 chat spans = 5 total
288
+ assert len(spans) >= 5, f"Expected at least 5 spans (1 parent + 2 agent_run_stream + 2 chat), got {len(spans)}"
289
+
290
+ # Find agent and chat spans
291
+ agent_spans = [s for s in spans if "agent_run" in s["span_attributes"]["name"]]
292
+ chat_spans = [s for s in spans if "chat" in s["span_attributes"]["name"]]
293
+
294
+ assert len(agent_spans) >= 2, f"Expected at least 2 agent spans, got {len(agent_spans)}"
295
+ assert len(chat_spans) >= 2, f"Expected at least 2 chat spans, got {len(chat_spans)}"
296
+
297
+ # Sort by creation time
298
+ agent_spans.sort(key=lambda s: s["created"])
299
+ chat_spans.sort(key=lambda s: s["created"])
300
+
301
+ agent1_span = agent_spans[0]
302
+ agent2_span = agent_spans[1]
303
+ chat1_span = chat_spans[0]
304
+ chat2_span = chat_spans[1]
305
+
306
+ # Check timing for first pair
307
+ agent1_start = agent1_span["metrics"]["start"]
308
+ chat1_start = chat1_span["metrics"]["start"]
309
+ time_to_first_token_1 = chat1_start - agent1_start
310
+
311
+ # Check timing for second pair
312
+ agent2_start = agent2_span["metrics"]["start"]
313
+ chat2_start = chat2_span["metrics"]["start"]
314
+ time_to_first_token_2 = chat2_start - agent2_start
315
+
316
+ print(f"\n=== FIRST STREAM ===")
317
+ print(f"Agent1 start: {agent1_start}")
318
+ print(f"Chat1 start: {chat1_start}")
319
+ print(f"Time to first token 1: {time_to_first_token_1}s")
320
+
321
+ print(f"\n=== SECOND STREAM ===")
322
+ print(f"Agent2 start: {agent2_start}")
323
+ print(f"Chat2 start: {chat2_start}")
324
+ print(f"Time to first token 2: {time_to_first_token_2}s")
325
+
326
+ print(f"\n=== RELATIVE TIMING ===")
327
+ print(f"Agent2 start - Agent1 start: {agent2_start - agent1_start}s")
328
+ print(f"Chat2 start - Chat1 start: {chat2_start - chat1_start}s")
329
+
330
+ # CRITICAL: Both time-to-first-token values should be small and similar
331
+ assert time_to_first_token_1 < 3.0, f"First time to first token too large: {time_to_first_token_1}s"
332
+ assert time_to_first_token_2 < 3.0, f"Second time to first token too large: {time_to_first_token_2}s - suggests start_time is being reused from first call"
333
+
334
+ # Agent2 should start AFTER agent1 finishes (or near the end)
335
+ agent1_end = agent1_span["metrics"]["end"]
336
+ assert agent2_start >= agent1_end - 0.1, f"Agent2 started too early: {agent2_start} vs Agent1 end: {agent1_end}"
337
+
338
+
339
+ @pytest.mark.vcr
340
+ @pytest.mark.asyncio
341
+ async def test_agent_run_stream(memory_logger):
342
+ """Test Agent.run_stream() streaming method."""
343
+ assert not memory_logger.pop()
344
+
345
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
346
+
347
+ start = time.time()
348
+ full_text = ""
349
+ async with agent.run_stream("Count from 1 to 5") as result:
350
+ async for text in result.stream_text(delta=True):
351
+ full_text += text
352
+ end = time.time()
353
+
354
+ # Verify we got streaming content
355
+ assert full_text
356
+ assert any(str(i) in full_text for i in range(1, 6))
357
+
358
+ # Check spans - should now have parent agent_run_stream + nested chat span
359
+ spans = memory_logger.pop()
360
+ assert len(spans) == 2, f"Expected 2 spans (agent_run_stream + chat), got {len(spans)}"
361
+
362
+ # Find agent_run_stream and chat spans
363
+ agent_span = next((s for s in spans if "agent_run_stream" in s["span_attributes"]["name"]), None)
364
+ chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
365
+
366
+ assert agent_span is not None, "agent_run_stream span not found"
367
+ assert chat_span is not None, "chat span not found"
368
+
369
+ # Check agent span
370
+ assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
371
+ assert agent_span["metadata"]["model"] == "gpt-4o-mini"
372
+ assert "Count from 1 to 5" in str(agent_span["input"])
373
+ _assert_metrics_are_valid(agent_span["metrics"], start, end)
374
+
375
+ # Check chat span is nested under agent span
376
+ assert chat_span["span_parents"] == [agent_span["span_id"]], "chat span should be nested under agent_run_stream"
377
+ assert chat_span["metadata"]["model"] == "gpt-4o-mini"
378
+ assert chat_span["metadata"]["provider"] == "openai"
379
+ _assert_metrics_are_valid(chat_span["metrics"], start, end)
380
+
381
+ # CRITICAL: Check that the chat span's start time is close to agent span start
382
+ # The offset/time-to-first-token should be small (typically < 2 seconds)
383
+ agent_start = agent_span["metrics"]["start"]
384
+ chat_start = chat_span["metrics"]["start"]
385
+ time_to_first_token = chat_start - agent_start
386
+
387
+ # Debug: Print full span data
388
+ print(f"\n=== AGENT SPAN ===")
389
+ print(f"ID: {agent_span['id']}")
390
+ print(f"span_id: {agent_span['span_id']}")
391
+ print(f"metrics: {agent_span['metrics']}")
392
+ print(f"\n=== CHAT SPAN ===")
393
+ print(f"ID: {chat_span['id']}")
394
+ print(f"span_id: {chat_span['span_id']}")
395
+ print(f"span_parents: {chat_span['span_parents']}")
396
+ print(f"metrics: {chat_span['metrics']}")
397
+
398
+ # Time to first token should be reasonable (< 3 seconds for API call initiation)
399
+ assert time_to_first_token < 3.0, f"Time to first token too large: {time_to_first_token}s - suggests start_time is being reused incorrectly"
400
+
401
+ # Both spans should have started during our test timeframe
402
+ assert agent_start >= start, "Agent span started before test"
403
+ assert chat_start >= start, "Chat span started before test"
404
+
405
+ # Agent spans should have token metrics
406
+ assert "prompt_tokens" in agent_span["metrics"]
407
+ assert "completion_tokens" in agent_span["metrics"]
408
+
409
+
410
+ @pytest.mark.vcr
411
+ @pytest.mark.asyncio
412
+ async def test_agent_with_tools(memory_logger):
413
+ """Test Agent with tool calls."""
414
+ assert not memory_logger.pop()
415
+
416
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=200))
417
+
418
+ @agent.tool_plain
419
+ def get_weather(city: str) -> str:
420
+ """Get weather for a city.
421
+
422
+ Args:
423
+ city: The city name
424
+ """
425
+ return f"It's sunny in {city}"
426
+
427
+ start = time.time()
428
+ result = await agent.run("What's the weather in Paris?")
429
+ end = time.time()
430
+
431
+ # Verify tool was used
432
+ assert result.output
433
+ assert "Paris" in str(result.output) or "sunny" in str(result.output)
434
+
435
+ # Check spans
436
+ spans = memory_logger.pop()
437
+ assert len(spans) >= 1 # At least the agent span, possibly more
438
+
439
+ # Find the agent span
440
+ agent_span = next(s for s in spans if "agent_run" in s["span_attributes"]["name"])
441
+ assert agent_span
442
+ assert "weather" in str(agent_span["input"]).lower() or "paris" in str(agent_span["input"]).lower()
443
+ _assert_metrics_are_valid(agent_span["metrics"], start, end)
444
+
445
+
446
+ @pytest.mark.vcr
447
+ @pytest.mark.asyncio
448
+ async def test_direct_model_request(memory_logger, direct):
449
+ """Test direct API model_request()."""
450
+ assert not memory_logger.pop()
451
+
452
+ messages = [ModelRequest(parts=[UserPromptPart(content=TEST_PROMPT)])]
453
+
454
+ start = time.time()
455
+ response = await direct.model_request(model=MODEL, messages=messages)
456
+ end = time.time()
457
+
458
+ # Verify response
459
+ assert response.parts
460
+ assert "4" in str(response.parts[0].content)
461
+
462
+ # Check spans
463
+ spans = memory_logger.pop()
464
+ # Direct API calls may create 1 or 2 spans depending on model wrapping
465
+ assert len(spans) >= 1
466
+
467
+ # Find the direct API span
468
+ direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request"), None)
469
+ assert direct_span is not None
470
+
471
+ assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
472
+ assert direct_span["metadata"]["model"] == "gpt-4o-mini"
473
+ assert direct_span["metadata"]["provider"] == "openai"
474
+ assert TEST_PROMPT in str(direct_span["input"])
475
+ assert "4" in str(direct_span["output"])
476
+ _assert_metrics_are_valid(direct_span["metrics"], start, end)
477
+
478
+
479
+ @pytest.mark.vcr
480
+ def test_direct_model_request_sync(memory_logger, direct):
481
+ """Test direct API model_request_sync()."""
482
+ assert not memory_logger.pop()
483
+
484
+ messages = [ModelRequest(parts=[UserPromptPart(content=TEST_PROMPT)])]
485
+
486
+ start = time.time()
487
+ response = direct.model_request_sync(model=MODEL, messages=messages)
488
+ end = time.time()
489
+
490
+ # Verify response
491
+ assert response.parts
492
+ assert "4" in str(response.parts[0].content)
493
+
494
+ # Check spans - direct API may create 2-3 spans depending on wrapping layers
495
+ spans = memory_logger.pop()
496
+ assert len(spans) >= 2
497
+
498
+ # Find the model_request_sync span
499
+ span = next((s for s in spans if s["span_attributes"]["name"] == "model_request_sync"), None)
500
+ assert span is not None, "model_request_sync span not found"
501
+ assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
502
+ assert span["metadata"]["model"] == "gpt-4o-mini"
503
+ assert TEST_PROMPT in str(span["input"])
504
+ _assert_metrics_are_valid(span["metrics"], start, end)
505
+
506
+
507
+ @pytest.mark.vcr
508
+ @pytest.mark.asyncio
509
+ async def test_direct_model_request_with_settings(memory_logger, direct):
510
+ """Test that model_settings appears in input for direct API calls."""
511
+ assert not memory_logger.pop()
512
+
513
+ messages = [ModelRequest(parts=[UserPromptPart(content="Say hello")])]
514
+ custom_settings = ModelSettings(max_tokens=50, temperature=0.7)
515
+
516
+ start = time.time()
517
+ result = await direct.model_request(model=MODEL, messages=messages, model_settings=custom_settings)
518
+ end = time.time()
519
+
520
+ # Verify result
521
+ assert result.parts
522
+
523
+ # Check spans
524
+ spans = memory_logger.pop()
525
+ # Direct API calls may create 1 or 2 spans depending on model wrapping
526
+ assert len(spans) >= 1
527
+
528
+ # Find the direct API span
529
+ direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request"), None)
530
+ assert direct_span is not None
531
+
532
+ assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
533
+
534
+ # Verify model_settings is in input (NOT metadata)
535
+ assert "model_settings" in direct_span["input"], "model_settings should be in input"
536
+ settings = direct_span["input"]["model_settings"]
537
+ assert settings["max_tokens"] == 50
538
+ assert settings["temperature"] == 0.7
539
+
540
+ # Verify model_settings is NOT in metadata
541
+ assert "model_settings" not in direct_span["metadata"], "model_settings should NOT be in metadata"
542
+
543
+ # Verify metadata still has model and provider
544
+ assert direct_span["metadata"]["model"] == "gpt-4o-mini"
545
+ assert direct_span["metadata"]["provider"] == "openai"
546
+
547
+ _assert_metrics_are_valid(direct_span["metrics"], start, end)
548
+
549
+
550
+ @pytest.mark.vcr
551
+ @pytest.mark.asyncio
552
+ async def test_direct_model_request_stream(memory_logger, direct):
553
+ """Test direct API model_request_stream()."""
554
+ assert not memory_logger.pop()
555
+
556
+ messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 3")])]
557
+
558
+ start = time.time()
559
+ chunk_count = 0
560
+ async with direct.model_request_stream(model=MODEL, messages=messages) as stream:
561
+ async for chunk in stream:
562
+ chunk_count += 1
563
+ end = time.time()
564
+
565
+ # Verify we got chunks
566
+ assert chunk_count > 0
567
+
568
+ # Check spans
569
+ spans = memory_logger.pop()
570
+ # Direct API calls may create 1 or 2 spans depending on model wrapping
571
+ assert len(spans) >= 1
572
+
573
+ # Find the direct API span
574
+ direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request_stream"), None)
575
+ assert direct_span is not None
576
+
577
+ assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
578
+ assert direct_span["metadata"]["model"] == "gpt-4o-mini"
579
+ _assert_metrics_are_valid(direct_span["metrics"], start, end)
580
+
581
+
582
+ @pytest.mark.vcr
583
+ @pytest.mark.asyncio
584
+ async def test_direct_model_request_stream_complete_output(memory_logger, direct):
585
+ """Test that direct API streaming captures all text including first chunk from PartStartEvent."""
586
+ assert not memory_logger.pop()
587
+
588
+ messages = [ModelRequest(parts=[UserPromptPart(content="Say exactly: 1, 2, 3")])]
589
+
590
+ collected_text = ""
591
+ seen_delta = False
592
+ async with direct.model_request_stream(model=MODEL, messages=messages) as stream:
593
+ async for chunk in stream:
594
+ # Extract text, skipping final PartStartEvent after deltas
595
+ if hasattr(chunk, 'part') and hasattr(chunk.part, 'content') and not seen_delta:
596
+ # PartStartEvent has part.content with initial text
597
+ collected_text += str(chunk.part.content)
598
+ elif hasattr(chunk, 'delta') and chunk.delta:
599
+ seen_delta = True
600
+ # PartDeltaEvent has delta.content_delta
601
+ if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
602
+ collected_text += chunk.delta.content_delta
603
+
604
+ # Verify we got complete output including "1"
605
+ assert "1" in collected_text
606
+ assert "2" in collected_text
607
+ assert "3" in collected_text
608
+
609
+ # Check spans were created
610
+ spans = memory_logger.pop()
611
+ assert len(spans) >= 1
612
+
613
+
614
+ @pytest.mark.vcr
615
+ @pytest.mark.asyncio
616
+ async def test_direct_api_streaming_call_3(memory_logger, direct):
617
+ """Test direct API streaming (call 3) - should output complete '1, 2, 3, 4, 5'."""
618
+ assert not memory_logger.pop()
619
+
620
+ IDENTICAL_PROMPT = "Count from 1 to 5."
621
+ messages = [ModelRequest(parts=[UserPromptPart(content=IDENTICAL_PROMPT)])]
622
+
623
+ collected_text = ""
624
+ async with direct.model_request_stream(model="openai:gpt-4o", messages=messages, model_settings=ModelSettings(max_tokens=100)) as stream:
625
+ async for chunk in stream:
626
+ # FIX: Handle PartStartEvent which contains initial text
627
+ if hasattr(chunk, 'part') and hasattr(chunk.part, 'content'):
628
+ collected_text += str(chunk.part.content)
629
+ # Handle PartDeltaEvent with delta content
630
+ elif hasattr(chunk, 'delta') and chunk.delta:
631
+ if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
632
+ collected_text += chunk.delta.content_delta
633
+
634
+ # Now this should pass!
635
+ assert "1" in collected_text, f"Expected '1' in output but got: {collected_text}"
636
+ assert "2" in collected_text
637
+ assert "3" in collected_text
638
+ assert "4" in collected_text
639
+ assert "5" in collected_text
640
+
641
+
642
+ @pytest.mark.vcr
643
+ @pytest.mark.asyncio
644
+ async def test_direct_api_streaming_call_4(memory_logger, direct):
645
+ """Test direct API streaming (call 4) - identical to call 3."""
646
+ assert not memory_logger.pop()
647
+
648
+ IDENTICAL_PROMPT = "Count from 1 to 5."
649
+ messages = [ModelRequest(parts=[UserPromptPart(content=IDENTICAL_PROMPT)])]
650
+
651
+ collected_text = ""
652
+ async with direct.model_request_stream(model="openai:gpt-4o", messages=messages, model_settings=ModelSettings(max_tokens=100)) as stream:
653
+ async for chunk in stream:
654
+ # FIX: Handle PartStartEvent which contains initial text
655
+ if hasattr(chunk, 'part') and hasattr(chunk.part, 'content'):
656
+ collected_text += str(chunk.part.content)
657
+ # Handle PartDeltaEvent with delta content
658
+ elif hasattr(chunk, 'delta') and chunk.delta:
659
+ if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
660
+ collected_text += chunk.delta.content_delta
661
+
662
+ # Now this should pass!
663
+ assert "1" in collected_text, f"Expected '1' in output but got: {collected_text}"
664
+
665
+
666
+ @pytest.mark.vcr
667
+ @pytest.mark.asyncio
668
+ async def test_direct_api_streaming_early_break_call_5(memory_logger, direct):
669
+ """Test direct API streaming with early break (call 5) - should still get first few chars including '1'."""
670
+ assert not memory_logger.pop()
671
+
672
+ IDENTICAL_PROMPT = "Count from 1 to 5."
673
+ messages = [ModelRequest(parts=[UserPromptPart(content=IDENTICAL_PROMPT)])]
674
+
675
+ collected_text = ""
676
+ i = 0
677
+ async with direct.model_request_stream(model="openai:gpt-4o", messages=messages, model_settings=ModelSettings(max_tokens=100)) as stream:
678
+ async for chunk in stream:
679
+ # FIX: Handle PartStartEvent which contains initial text
680
+ if hasattr(chunk, 'part') and hasattr(chunk.part, 'content'):
681
+ collected_text += str(chunk.part.content)
682
+ # Handle PartDeltaEvent with delta content
683
+ elif hasattr(chunk, 'delta') and chunk.delta:
684
+ if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
685
+ collected_text += chunk.delta.content_delta
686
+
687
+ i += 1
688
+ if i >= 3:
689
+ break
690
+
691
+ # Even with early break after 3 chunks, we should capture text from PartStartEvent (chunk 1)
692
+ print(f"Collected text: '{collected_text}'")
693
+ assert len(collected_text) > 0, f"Expected some text even with early break but got empty string"
694
+ # Verify we're capturing PartStartEvent by checking we got text before breaking at chunk 3
695
+ assert collected_text, f"Should have captured text from PartStartEvent or first delta"
696
+
697
+
698
+ @pytest.mark.vcr
699
+ @pytest.mark.asyncio
700
+ async def test_direct_api_streaming_no_duplication(memory_logger, direct):
701
+ """Test that direct API streaming doesn't duplicate output and captures all text in span."""
702
+ assert not memory_logger.pop()
703
+
704
+ collected_text = ""
705
+ chunk_count = 0
706
+
707
+ # Use direct API streaming
708
+ messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 5, separated by commas.")])]
709
+ async with direct.model_request_stream(
710
+ messages=messages,
711
+ model_settings=ModelSettings(max_tokens=100),
712
+ model="openai:gpt-4o",
713
+ ) as response:
714
+ async for chunk in response:
715
+ chunk_count += 1
716
+ # Extract text from chunk
717
+ text = None
718
+ if hasattr(chunk, 'part') and hasattr(chunk.part, 'content'):
719
+ text = str(chunk.part.content)
720
+ elif hasattr(chunk, 'delta') and chunk.delta:
721
+ if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
722
+ text = chunk.delta.content_delta
723
+
724
+ if text:
725
+ collected_text += text
726
+
727
+ print(f"Collected text from stream: '{collected_text}'")
728
+ print(f"Total chunks: {chunk_count}")
729
+
730
+ # Verify we collected complete text
731
+ assert len(collected_text) > 0, "Should have collected text from stream"
732
+ assert "1" in collected_text, "Should have '1' in output"
733
+
734
+ # Check span captured the full output
735
+ spans = memory_logger.pop()
736
+ assert len(spans) >= 1, f"Expected at least 1 span, got {len(spans)}"
737
+
738
+ # Find the model_request_stream span
739
+ stream_span = next((s for s in spans if "model_request_stream" in s["span_attributes"]["name"]), None)
740
+ assert stream_span is not None, "model_request_stream span not found"
741
+
742
+ # Check that span output contains the full text, not just "1,"
743
+ span_output = stream_span.get("output", {})
744
+ print(f"Span output: {span_output}")
745
+
746
+ # The span should capture the full response
747
+ if "response" in span_output and "parts" in span_output["response"]:
748
+ parts = span_output["response"]["parts"]
749
+ span_text = "".join(str(p.get("content", "")) for p in parts if isinstance(p, dict))
750
+ print(f"Span captured text: '{span_text}'")
751
+ # Should have more than just "1,"
752
+ assert len(span_text) > 2, f"Span should capture more than just '1,', got: '{span_text}'"
753
+ assert "1" in span_text, "Span should contain '1'"
754
+
755
+
756
+ @pytest.mark.vcr
757
+ @pytest.mark.asyncio
758
+ async def test_direct_api_streaming_no_duplication_comprehensive(memory_logger, direct):
759
+ """Comprehensive test matching golden test setup to verify no duplication and full output capture."""
760
+ assert not memory_logger.pop()
761
+
762
+ # Match golden test exactly
763
+ IDENTICAL_PROMPT = "Count from 1 to 5."
764
+ IDENTICAL_SETTINGS = ModelSettings(max_tokens=100)
765
+
766
+ messages = [ModelRequest(parts=[UserPromptPart(content=IDENTICAL_PROMPT)])]
767
+
768
+ collected_text = ""
769
+ chunk_types = []
770
+ seen_delta = False
771
+
772
+ async with direct.model_request_stream(messages=messages, model_settings=IDENTICAL_SETTINGS, model="openai:gpt-4o") as stream:
773
+ async for chunk in stream:
774
+ # Track chunk types
775
+ if hasattr(chunk, 'part') and hasattr(chunk.part, 'content') and not seen_delta:
776
+ chunk_types.append(('PartStartEvent', str(chunk.part.content)))
777
+ text = str(chunk.part.content)
778
+ collected_text += text
779
+ elif hasattr(chunk, 'delta') and chunk.delta:
780
+ seen_delta = True
781
+ if hasattr(chunk.delta, 'content_delta') and chunk.delta.content_delta:
782
+ chunk_types.append(('PartDeltaEvent', chunk.delta.content_delta))
783
+ text = chunk.delta.content_delta
784
+ collected_text += text
785
+
786
+ print(f"\nCollected text: '{collected_text}'")
787
+ print(f"Total chunks received: {len(chunk_types)}")
788
+ print(f"All chunk types:")
789
+ for i, (chunk_type, content) in enumerate(chunk_types):
790
+ print(f" {i}: {chunk_type} = {content!r}")
791
+
792
+ # Verify no duplication in collected text
793
+ # Expected: "Sure! Here you go:\n\n1, 2, 3, 4, 5." or similar (length ~30)
794
+ # Should NOT be duplicated
795
+ assert len(collected_text) < 60, f"Text seems duplicated (too long): '{collected_text}' (len={len(collected_text)})"
796
+ assert collected_text.count("1, 2, 3") == 1, f"Text should appear once, not duplicated: '{collected_text}'"
797
+
798
+ # Check span
799
+ spans = memory_logger.pop()
800
+ print(f"Number of spans: {len(spans)}")
801
+ for i, s in enumerate(spans):
802
+ print(f"Span {i}: {s['span_attributes']['name']} (type: {s['span_attributes'].get('type', 'N/A')})")
803
+ if 'span_parents' in s and s['span_parents']:
804
+ print(f" Parents: {s['span_parents']}")
805
+
806
+ # Should have 1 or 2 spans (direct API wrapper + potentially model wrapper)
807
+ assert len(spans) >= 1, f"Expected at least 1 span, got {len(spans)}"
808
+
809
+ # Find the model_request_stream span
810
+ stream_span = next((s for s in spans if "model_request_stream" in s["span_attributes"]["name"]), None)
811
+ assert stream_span is not None, "model_request_stream span not found"
812
+
813
+ # Check that span output is not empty and captures reasonable amount of text
814
+ span_output = stream_span.get("output", {})
815
+ print(f"Span output keys: {span_output.keys() if span_output else 'None'}")
816
+
817
+ if "parts" in span_output:
818
+ parts = span_output.get("parts", [])
819
+ print(f"Span parts: {parts}")
820
+ if parts and len(parts) > 0:
821
+ first_part = parts[0]
822
+ print(f"First part type: {type(first_part)}")
823
+ print(f"First part: {first_part}")
824
+ if isinstance(first_part, dict):
825
+ part_content = first_part.get("content", "")
826
+ print(f"Part content: '{part_content}'")
827
+ print(f"Part content length: {len(part_content)}")
828
+ # The span should capture the FULL text, not just "1,"
829
+ assert len(part_content) > 5, f"Span should capture full text, got: '{part_content}'"
830
+
831
+
832
+ @pytest.mark.vcr
833
+ @pytest.mark.asyncio
834
+ async def test_async_generator_pattern_call_6(memory_logger):
835
+ """Test async generator pattern (call 6) - wrapping stream in async generator."""
836
+ assert not memory_logger.pop()
837
+
838
+ IDENTICAL_PROMPT = "Count from 1 to 5."
839
+
840
+ async def stream_with_async_generator(prompt: str):
841
+ """Wrap the stream in an async generator (customer pattern)."""
842
+ agent = Agent("openai:gpt-4o", model_settings=ModelSettings(max_tokens=100))
843
+ async for event in agent.run_stream_events(prompt):
844
+ yield event
845
+
846
+ collected_text = ""
847
+ i = 0
848
+ async for event in stream_with_async_generator(IDENTICAL_PROMPT):
849
+ # run_stream_events returns ResultEvent objects with different structure
850
+ # Try to extract text from whatever event type we get
851
+ if hasattr(event, 'content') and event.content:
852
+ collected_text += str(event.content)
853
+ elif hasattr(event, 'part') and hasattr(event.part, 'content'):
854
+ collected_text += str(event.part.content)
855
+ elif hasattr(event, 'delta') and event.delta:
856
+ if hasattr(event.delta, 'content_delta') and event.delta.content_delta:
857
+ collected_text += event.delta.content_delta
858
+
859
+ i += 1
860
+ if i >= 3:
861
+ break
862
+
863
+ # This should capture something
864
+ print(f"Collected text from generator: '{collected_text}'")
865
+ assert len(collected_text) > 0, f"Expected some text from async generator but got empty string"
866
+
867
+
868
+ @pytest.mark.vcr
869
+ @pytest.mark.asyncio
870
+ async def test_agent_structured_output(memory_logger):
871
+ """Test Agent with structured output (Pydantic model)."""
872
+ assert not memory_logger.pop()
873
+
874
+ class MathAnswer(BaseModel):
875
+ answer: int
876
+ explanation: str
877
+
878
+ agent = Agent(
879
+ MODEL,
880
+ output_type=MathAnswer,
881
+ model_settings=ModelSettings(max_tokens=200)
882
+ )
883
+
884
+ start = time.time()
885
+ result = await agent.run("What is 10 + 15?")
886
+ end = time.time()
887
+
888
+ # Verify structured output
889
+ assert isinstance(result.output, MathAnswer)
890
+ assert result.output.answer == 25
891
+ assert result.output.explanation
892
+
893
+ # Check spans - should have parent agent_run + nested spans
894
+ spans = memory_logger.pop()
895
+ assert len(spans) >= 2, f"Expected at least 2 spans (agent_run + chat), got {len(spans)}"
896
+
897
+ # Find agent_run and chat spans
898
+ agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
899
+ chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
900
+
901
+ assert agent_span is not None, "agent_run span not found"
902
+ assert chat_span is not None, "chat span not found"
903
+
904
+ # Check agent span
905
+ assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
906
+ assert agent_span["metadata"]["model"] == "gpt-4o-mini"
907
+ assert agent_span["metadata"]["provider"] == "openai"
908
+ assert "10 + 15" in str(agent_span["input"])
909
+ assert "25" in str(agent_span["output"])
910
+ _assert_metrics_are_valid(agent_span["metrics"], start, end)
911
+
912
+ # Check chat span is a descendant of agent_run
913
+ span_by_id = {s["span_id"]: s for s in spans}
914
+
915
+ def is_descendant(child_span, ancestor_id):
916
+ """Check if child_span is a descendant of ancestor_id."""
917
+ if not child_span.get("span_parents"):
918
+ return False
919
+ if ancestor_id in child_span["span_parents"]:
920
+ return True
921
+ for parent_id in child_span["span_parents"]:
922
+ if parent_id in span_by_id and is_descendant(span_by_id[parent_id], ancestor_id):
923
+ return True
924
+ return False
925
+
926
+ assert is_descendant(chat_span, agent_span["span_id"]), "chat span should be nested under agent_run"
927
+ assert chat_span["metadata"]["model"] == "gpt-4o-mini"
928
+ assert chat_span["metadata"]["provider"] == "openai"
929
+ _assert_metrics_are_valid(chat_span["metrics"], start, end)
930
+
931
+ # Agent spans should have token metrics
932
+ assert "prompt_tokens" in agent_span["metrics"]
933
+ assert "completion_tokens" in agent_span["metrics"]
934
+
935
+
936
+ @pytest.mark.vcr
937
+ @pytest.mark.asyncio
938
+ async def test_agent_with_model_settings_in_metadata(memory_logger):
939
+ """Test that model_settings from agent config appears in metadata, not input."""
940
+ assert not memory_logger.pop()
941
+
942
+ custom_settings = ModelSettings(max_tokens=100, temperature=0.5)
943
+ agent = Agent(MODEL, model_settings=custom_settings)
944
+
945
+ start = time.time()
946
+ result = await agent.run("Say hello")
947
+ end = time.time()
948
+
949
+ assert result.output
950
+
951
+ # Check spans
952
+ spans = memory_logger.pop()
953
+ assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
954
+
955
+ # Find agent_run and chat spans
956
+ agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
957
+ chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
958
+
959
+ assert agent_span is not None, "agent_run span not found"
960
+ assert chat_span is not None, "chat span not found"
961
+
962
+ # Verify model_settings is in agent METADATA (not input, since it's agent config)
963
+ assert "model_settings" in agent_span["metadata"], "model_settings should be in agent_run metadata"
964
+ agent_settings = agent_span["metadata"]["model_settings"]
965
+ assert agent_settings["max_tokens"] == 100
966
+ assert agent_settings["temperature"] == 0.5
967
+
968
+ # Verify model_settings is NOT in agent input (it wasn't passed to run())
969
+ assert "model_settings" not in agent_span["input"], "model_settings should NOT be in agent_run input when not passed to run()"
970
+
971
+ # Verify model_settings is in chat input (passed to the model)
972
+ assert "model_settings" in chat_span["input"], "model_settings should be in chat span input"
973
+ chat_settings = chat_span["input"]["model_settings"]
974
+ assert chat_settings["max_tokens"] == 100
975
+ assert chat_settings["temperature"] == 0.5
976
+
977
+ # Verify model_settings is NOT in chat metadata (it's in input)
978
+ assert "model_settings" not in chat_span["metadata"], "model_settings should NOT be in chat span metadata"
979
+
980
+ # Verify other metadata is present
981
+ assert chat_span["metadata"]["model"] == "gpt-4o-mini"
982
+ assert chat_span["metadata"]["provider"] == "openai"
983
+
984
+
985
+ @pytest.mark.vcr
986
+ @pytest.mark.asyncio
987
+ async def test_agent_with_model_settings_override_in_input(memory_logger):
988
+ """Test that model_settings passed to run() appears in input, not metadata."""
989
+ assert not memory_logger.pop()
990
+
991
+ # Agent has default settings
992
+ default_settings = ModelSettings(max_tokens=50)
993
+ agent = Agent(MODEL, model_settings=default_settings)
994
+
995
+ # Override with different settings in run() call
996
+ override_settings = ModelSettings(max_tokens=200, temperature=0.9)
997
+
998
+ start = time.time()
999
+ result = await agent.run("Tell me a story", model_settings=override_settings)
1000
+ end = time.time()
1001
+
1002
+ assert result.output
1003
+
1004
+ # Check spans
1005
+ spans = memory_logger.pop()
1006
+ assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
1007
+
1008
+ # Find agent_run span
1009
+ agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
1010
+ assert agent_span is not None, "agent_run span not found"
1011
+
1012
+ # Verify override settings are in agent INPUT (because they were passed to run())
1013
+ assert "model_settings" in agent_span["input"], "model_settings should be in agent_run input when passed to run()"
1014
+ input_settings = agent_span["input"]["model_settings"]
1015
+ assert input_settings["max_tokens"] == 200, "Should use override settings from run() call"
1016
+ assert input_settings["temperature"] == 0.9
1017
+
1018
+ # Verify agent default settings are NOT in metadata (when overridden in input, we don't duplicate in metadata)
1019
+ assert "model_settings" not in agent_span["metadata"], "model_settings should NOT be in metadata when explicitly passed to run()"
1020
+
1021
+
1022
+ @pytest.mark.vcr
1023
+ @pytest.mark.asyncio
1024
+ async def test_agent_with_system_prompt_in_metadata(memory_logger):
1025
+ """Test that system_prompt from agent config appears in input (it's semantically part of LLM input)."""
1026
+ assert not memory_logger.pop()
1027
+
1028
+ system_prompt = "You are a helpful AI assistant who speaks like a pirate."
1029
+ agent = Agent(MODEL, system_prompt=system_prompt, model_settings=ModelSettings(max_tokens=100))
1030
+
1031
+ start = time.time()
1032
+ result = await agent.run("What is the weather?")
1033
+ end = time.time()
1034
+
1035
+ assert result.output
1036
+
1037
+ # Check spans
1038
+ spans = memory_logger.pop()
1039
+ assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
1040
+
1041
+ # Find agent_run span
1042
+ agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
1043
+ assert agent_span is not None, "agent_run span not found"
1044
+
1045
+ # Verify system_prompt is in input (because it's semantically part of the LLM input)
1046
+ assert "system_prompt" in agent_span["input"], "system_prompt should be in agent_run input"
1047
+ assert agent_span["input"]["system_prompt"] == system_prompt, "system_prompt should be the actual string, not a method reference"
1048
+
1049
+ # Verify system_prompt is NOT in metadata
1050
+ assert "system_prompt" not in agent_span["metadata"], "system_prompt should NOT be in agent_run metadata"
1051
+
1052
+ # Verify other metadata is present
1053
+ assert agent_span["metadata"]["model"] == "gpt-4o-mini"
1054
+ assert agent_span["metadata"]["provider"] == "openai"
1055
+
1056
+
1057
+ @pytest.mark.vcr
1058
+ @pytest.mark.asyncio
1059
+ async def test_agent_with_message_history(memory_logger):
1060
+ """Test Agent with conversation history."""
1061
+ assert not memory_logger.pop()
1062
+
1063
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
1064
+
1065
+ # First message
1066
+ result1 = await agent.run("My name is Alice")
1067
+ assert result1.output
1068
+ memory_logger.pop() # Clear first span
1069
+
1070
+ # Second message with history
1071
+ start = time.time()
1072
+ result2 = await agent.run(
1073
+ "What is my name?",
1074
+ message_history=result1.all_messages()
1075
+ )
1076
+ end = time.time()
1077
+
1078
+ # Verify it remembers
1079
+ assert "Alice" in str(result2.output)
1080
+
1081
+ # Check spans - should now have parent agent_run + nested chat span
1082
+ spans = memory_logger.pop()
1083
+ assert len(spans) == 2, f"Expected 2 spans (agent_run + chat), got {len(spans)}"
1084
+
1085
+ # Find agent_run and chat spans
1086
+ agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
1087
+
1088
+ assert agent_span is not None, "agent_run span not found"
1089
+ assert "message_history" in str(agent_span["input"])
1090
+ assert "Alice" in str(agent_span["output"])
1091
+ _assert_metrics_are_valid(agent_span["metrics"], start, end)
1092
+
1093
+
1094
+ @pytest.mark.vcr
1095
+ @pytest.mark.asyncio
1096
+ async def test_agent_with_custom_settings(memory_logger):
1097
+ """Test Agent with custom model settings."""
1098
+ assert not memory_logger.pop()
1099
+
1100
+ agent = Agent(MODEL)
1101
+
1102
+ start = time.time()
1103
+ result = await agent.run(
1104
+ "Say hello",
1105
+ model_settings=ModelSettings(
1106
+ max_tokens=20,
1107
+ temperature=0.5,
1108
+ top_p=0.9
1109
+ )
1110
+ )
1111
+ end = time.time()
1112
+
1113
+ assert result.output
1114
+
1115
+ # Check spans - should now have parent agent_run + nested chat span
1116
+ spans = memory_logger.pop()
1117
+ assert len(spans) >= 2, f"Expected at least 2 spans (agent_run + chat), got {len(spans)}"
1118
+
1119
+ # Find agent_run span
1120
+ agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
1121
+ assert agent_span is not None, "agent_run span not found"
1122
+
1123
+ # Model settings passed to run() should be in input (not metadata)
1124
+ assert "model_settings" in agent_span["input"]
1125
+ settings = agent_span["input"]["model_settings"]
1126
+ assert settings["max_tokens"] == 20
1127
+ assert settings["temperature"] == 0.5
1128
+ assert settings["top_p"] == 0.9
1129
+ _assert_metrics_are_valid(agent_span["metrics"], start, end)
1130
+
1131
+
1132
+ @pytest.mark.vcr
1133
+ def test_agent_run_stream_sync(memory_logger):
1134
+ """Test Agent.run_stream_sync() synchronous streaming method."""
1135
+ assert not memory_logger.pop()
1136
+
1137
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
1138
+
1139
+ start = time.time()
1140
+ full_text = ""
1141
+ result = agent.run_stream_sync("Count from 1 to 3")
1142
+ for text in result.stream_text(delta=True):
1143
+ full_text += text
1144
+ end = time.time()
1145
+
1146
+ # Verify we got streaming content
1147
+ assert full_text
1148
+ assert any(str(i) in full_text for i in range(1, 4))
1149
+
1150
+ # Check spans - should have parent agent_run_stream_sync + nested spans
1151
+ spans = memory_logger.pop()
1152
+ assert len(spans) >= 2, f"Expected at least 2 spans (agent_run_stream_sync + chat), got {len(spans)}"
1153
+
1154
+ # Find agent_run_stream_sync and chat spans
1155
+ agent_span = next((s for s in spans if "agent_run_stream_sync" in s["span_attributes"]["name"]), None)
1156
+ chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
1157
+
1158
+ assert agent_span is not None, "agent_run_stream_sync span not found"
1159
+ assert chat_span is not None, "chat span not found"
1160
+
1161
+ # Check agent span
1162
+ assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1163
+ assert agent_span["metadata"]["model"] == "gpt-4o-mini"
1164
+ assert "Count from 1 to 3" in str(agent_span["input"])
1165
+ _assert_metrics_are_valid(agent_span["metrics"], start, end)
1166
+
1167
+ # Check chat span is a descendant of agent_run_stream_sync
1168
+ span_by_id = {s["span_id"]: s for s in spans}
1169
+
1170
+ def is_descendant(child_span, ancestor_id):
1171
+ """Check if child_span is a descendant of ancestor_id."""
1172
+ if not child_span.get("span_parents"):
1173
+ return False
1174
+ if ancestor_id in child_span["span_parents"]:
1175
+ return True
1176
+ for parent_id in child_span["span_parents"]:
1177
+ if parent_id in span_by_id and is_descendant(span_by_id[parent_id], ancestor_id):
1178
+ return True
1179
+ return False
1180
+
1181
+ assert is_descendant(chat_span, agent_span["span_id"]), "chat span should be nested under agent_run_stream_sync"
1182
+ assert chat_span["metadata"]["model"] == "gpt-4o-mini"
1183
+ assert chat_span["metadata"]["provider"] == "openai"
1184
+ # Chat span may not have complete metrics since it's an intermediate span
1185
+ assert "start" in chat_span["metrics"]
1186
+
1187
+ # Agent spans should have token metrics
1188
+ assert "prompt_tokens" in agent_span["metrics"]
1189
+ assert "completion_tokens" in agent_span["metrics"]
1190
+
1191
+
1192
+ @pytest.mark.vcr
1193
+ @pytest.mark.asyncio
1194
+ async def test_agent_run_stream_events(memory_logger):
1195
+ """Test Agent.run_stream_events() event streaming method."""
1196
+ assert not memory_logger.pop()
1197
+
1198
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
1199
+
1200
+ start = time.time()
1201
+ event_count = 0
1202
+ events = []
1203
+ # Consume all events
1204
+ async for event in agent.run_stream_events("What is 5+5?"):
1205
+ event_count += 1
1206
+ events.append(event)
1207
+ end = time.time()
1208
+
1209
+ # Verify we got events
1210
+ assert event_count > 0, "Should receive at least one event"
1211
+
1212
+ # Check spans - should have agent_run_stream_events span
1213
+ spans = memory_logger.pop()
1214
+ assert len(spans) >= 1, f"Expected at least 1 span, got {len(spans)}"
1215
+
1216
+ # Find agent_run_stream_events span
1217
+ agent_span = next((s for s in spans if "agent_run_stream_events" in s["span_attributes"]["name"]), None)
1218
+ assert agent_span is not None, "agent_run_stream_events span not found"
1219
+
1220
+ # Check agent span has basic structure
1221
+ assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1222
+ assert agent_span["metadata"]["model"] == "gpt-4o-mini"
1223
+ assert "5+5" in str(agent_span["input"]) or "What" in str(agent_span["input"])
1224
+ assert agent_span["metrics"]["event_count"] == event_count
1225
+ _assert_metrics_are_valid(agent_span["metrics"], start, end)
1226
+
1227
+
1228
+ @pytest.mark.vcr
1229
+ def test_direct_model_request_stream_sync(memory_logger, direct):
1230
+ """Test direct API model_request_stream_sync()."""
1231
+ assert not memory_logger.pop()
1232
+
1233
+ messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 3")])]
1234
+
1235
+ start = time.time()
1236
+ chunk_count = 0
1237
+ with direct.model_request_stream_sync(model=MODEL, messages=messages) as stream:
1238
+ for chunk in stream:
1239
+ chunk_count += 1
1240
+ end = time.time()
1241
+
1242
+ # Verify we got chunks
1243
+ assert chunk_count > 0
1244
+
1245
+ # Check spans
1246
+ spans = memory_logger.pop()
1247
+ assert len(spans) == 1
1248
+
1249
+ span = spans[0]
1250
+ assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1251
+ assert span["span_attributes"]["name"] == "model_request_stream_sync"
1252
+ assert span["metadata"]["model"] == "gpt-4o-mini"
1253
+ _assert_metrics_are_valid(span["metrics"], start, end)
1254
+
1255
+
1256
+ @pytest.mark.vcr
1257
+ @pytest.mark.asyncio
1258
+ async def test_stream_early_break_async_generator(memory_logger, direct):
1259
+ """Test breaking early from an async generator wrapper around a stream.
1260
+
1261
+ This reproduces the 'Token was created in a different Context' error that occurs
1262
+ when breaking early from async generators. The cleanup happens in a different
1263
+ async context, causing ContextVar token errors.
1264
+
1265
+ Our fix: Clear the context token before cleanup to make unset_current() use
1266
+ the safe set(None) path instead of reset(token).
1267
+ """
1268
+ assert not memory_logger.pop()
1269
+
1270
+ messages = [ModelRequest(parts=[UserPromptPart(content="Count from 1 to 5")])]
1271
+
1272
+ async def stream_wrapper():
1273
+ """Wrap the stream in an async generator (common customer pattern)."""
1274
+ async with direct.model_request_stream(model=MODEL, messages=messages) as stream:
1275
+ count = 0
1276
+ async for chunk in stream:
1277
+ yield chunk
1278
+ count += 1
1279
+ if count >= 3:
1280
+ # Break early - this triggers cleanup in different context
1281
+ break
1282
+
1283
+ start = time.time()
1284
+ chunk_count = 0
1285
+
1286
+ # This should NOT raise ValueError about "different Context"
1287
+ async for chunk in stream_wrapper():
1288
+ chunk_count += 1
1289
+
1290
+ end = time.time()
1291
+
1292
+ # Should not raise ValueError about context token
1293
+ assert chunk_count == 3
1294
+
1295
+ # Check spans - should have created a span despite early break
1296
+ spans = memory_logger.pop()
1297
+ assert len(spans) >= 1, "Should have at least one span even with early break"
1298
+
1299
+ span = spans[0]
1300
+ assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1301
+ assert span["span_attributes"]["name"] == "model_request_stream"
1302
+
1303
+
1304
+ @pytest.mark.vcr
1305
+ @pytest.mark.asyncio
1306
+ async def test_agent_stream_early_break(memory_logger):
1307
+ """Test breaking early from agent.run_stream() context manager.
1308
+
1309
+ Verifies that breaking early from the stream doesn't cause context token errors
1310
+ and that spans are still properly logged.
1311
+ """
1312
+ assert not memory_logger.pop()
1313
+
1314
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
1315
+
1316
+ start = time.time()
1317
+ text_count = 0
1318
+
1319
+ # Break early from stream - should not raise context token error
1320
+ async with agent.run_stream("Count from 1 to 10") as result:
1321
+ async for text in result.stream_text(delta=True):
1322
+ text_count += 1
1323
+ if text_count >= 2: # Lower threshold - streaming may not produce many chunks
1324
+ break # Early break
1325
+
1326
+ end = time.time()
1327
+
1328
+ assert text_count >= 1 # At least one chunk received
1329
+
1330
+ # Check spans - may have incomplete spans due to early break
1331
+ spans = memory_logger.pop()
1332
+ assert len(spans) >= 1, f"Expected at least 1 span, got {len(spans)}"
1333
+
1334
+ # Find agent_run_stream span (if created)
1335
+ agent_span = next((s for s in spans if "agent_run_stream" in s["span_attributes"]["name"]), None)
1336
+
1337
+ # Verify at least agent_run_stream span exists and has basic structure
1338
+ if agent_span:
1339
+ assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1340
+ # Metrics may be incomplete due to early break
1341
+ assert "start" in agent_span["metrics"]
1342
+
1343
+
1344
+ @pytest.mark.vcr
1345
+ @pytest.mark.asyncio
1346
+ async def test_agent_with_binary_content(memory_logger):
1347
+ """Test that agents with binary content (images) work correctly.
1348
+
1349
+ Note: Full binary content serialization with attachment references is a complex feature.
1350
+ This test verifies basic functionality - that binary content doesn't break tracing.
1351
+ """
1352
+ from pydantic_ai.models.function import BinaryContent
1353
+
1354
+ assert not memory_logger.pop()
1355
+
1356
+ # Use a small test image
1357
+ image_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\nIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82'
1358
+
1359
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
1360
+
1361
+ start = time.time()
1362
+ result = await agent.run(
1363
+ [
1364
+ BinaryContent(data=image_data, media_type="image/png"),
1365
+ "What color is this image?",
1366
+ ]
1367
+ )
1368
+ end = time.time()
1369
+
1370
+ assert result.output
1371
+ assert isinstance(result.output, str)
1372
+
1373
+ # Check spans - verify basic tracing works
1374
+ spans = memory_logger.pop()
1375
+ assert len(spans) >= 1, f"Expected at least 1 span, got {len(spans)}"
1376
+
1377
+ # Find agent_run span
1378
+ agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"] and "chat" not in s["span_attributes"]["name"]), None)
1379
+ assert agent_span is not None, "agent_run span not found"
1380
+
1381
+ # Verify basic span structure
1382
+ assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1383
+ assert agent_span["metadata"]["model"] == "gpt-4o-mini"
1384
+ _assert_metrics_are_valid(agent_span["metrics"], start, end)
1385
+
1386
+ # TODO: Future enhancement - add full binary content serialization with attachment references
1387
+
1388
+ @pytest.mark.vcr
1389
+ @pytest.mark.asyncio
1390
+ async def test_agent_with_tool_execution(memory_logger):
1391
+ """Test that tool execution creates proper span hierarchy.
1392
+
1393
+ Verifies that:
1394
+ 1. Agent creates proper spans for tool calls
1395
+ 2. Tool execution is captured in spans (ideally with "running tools" parent)
1396
+ 3. Individual tool calls create child spans
1397
+ """
1398
+ assert not memory_logger.pop()
1399
+
1400
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=200))
1401
+
1402
+ @agent.tool_plain
1403
+ def calculate(operation: str, a: float, b: float) -> str:
1404
+ """Perform a mathematical calculation.
1405
+
1406
+ Args:
1407
+ operation: The mathematical operation (add, subtract, multiply, divide)
1408
+ a: First number
1409
+ b: Second number
1410
+ """
1411
+ ops = {
1412
+ "add": a + b,
1413
+ "subtract": a - b,
1414
+ "multiply": a * b,
1415
+ "divide": a / b if b != 0 else "Error: Division by zero",
1416
+ }
1417
+ return str(ops.get(operation, "Invalid operation"))
1418
+
1419
+ start = time.time()
1420
+ result = await agent.run("What is 127 multiplied by 49?")
1421
+ end = time.time()
1422
+
1423
+ assert result.output
1424
+ assert "6" in str(result.output) and "223" in str(result.output) # Result contains 6223 (possibly formatted)
1425
+
1426
+ # Check spans
1427
+ spans = memory_logger.pop()
1428
+
1429
+ # We should have at least agent_run and chat spans
1430
+ # TODO: Add "running tools" parent span and "running tool: calculate" child span
1431
+ assert len(spans) >= 2, f"Expected at least 2 spans, got {len(spans)}"
1432
+
1433
+ # Find agent_run span
1434
+ agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"]), None)
1435
+ assert agent_span is not None, "agent_run span not found"
1436
+
1437
+ # Verify that toolsets are captured in input with correct tool names
1438
+ assert "toolsets" in agent_span["input"], "toolsets should be in input (not metadata)"
1439
+ toolsets = agent_span["input"]["toolsets"]
1440
+ assert len(toolsets) > 0, "At least one toolset should be present"
1441
+
1442
+ # Find the agent toolset
1443
+ agent_toolset = None
1444
+ for ts in toolsets:
1445
+ if ts.get("id") == "<agent>":
1446
+ agent_toolset = ts
1447
+ break
1448
+
1449
+ assert agent_toolset is not None, "Agent toolset not found"
1450
+ assert "tools" in agent_toolset, "tools should be in agent toolset"
1451
+
1452
+ # Verify calculate tool is present (tools are now dicts with full schemas in input)
1453
+ tools = agent_toolset["tools"]
1454
+ assert isinstance(tools, list), "tools should be a list"
1455
+ tool_names = [t["name"] for t in tools if isinstance(t, dict)]
1456
+ assert "calculate" in tool_names, f"calculate tool should be in tools list, got: {tool_names}"
1457
+
1458
+ # Verify toolsets are NOT in metadata (following the principle: agent.run() accepts it)
1459
+ assert "toolsets" not in agent_span["metadata"], "toolsets should NOT be in metadata"
1460
+
1461
+
1462
+ def test_tool_execution_creates_spans(memory_logger):
1463
+ """Test that executing tools with agents works and creates traced spans.
1464
+
1465
+ Note: Tool-level span creation is not yet implemented in the wrapper.
1466
+ This test verifies that agents with tools work correctly and produce agent/chat spans.
1467
+
1468
+ Future enhancement: Add automatic span creation for tool executions as children of
1469
+ the chat span that requested them.
1470
+ """
1471
+ assert not memory_logger.pop()
1472
+
1473
+ start = time.time()
1474
+
1475
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=500))
1476
+
1477
+ @agent.tool_plain
1478
+ def calculate(operation: str, a: float, b: float) -> float:
1479
+ """Perform a mathematical calculation."""
1480
+ if operation == "multiply":
1481
+ return a * b
1482
+ elif operation == "add":
1483
+ return a + b
1484
+ else:
1485
+ return 0.0
1486
+
1487
+ # Run the agent with a query that will use the tool
1488
+ result = agent.run_sync("What is 127 multiplied by 49?")
1489
+ end = time.time()
1490
+
1491
+ # Verify the tool was actually called and result is correct
1492
+ assert result.output
1493
+ assert "6223" in str(result.output) or "6,223" in str(result.output), f"Expected calculation result in output: {result.output}"
1494
+
1495
+ # Get logged spans
1496
+ spans = memory_logger.pop()
1497
+
1498
+ # Find spans by type
1499
+ agent_span = next((s for s in spans if "agent_run" in s["span_attributes"]["name"]), None)
1500
+ chat_spans = [s for s in spans if "chat" in s["span_attributes"]["name"]]
1501
+
1502
+ # Assertions - verify basic tracing works with tools
1503
+ assert agent_span is not None, "agent_run span should exist"
1504
+ assert len(chat_spans) >= 1, f"Expected at least 1 chat span, got {len(chat_spans)}"
1505
+
1506
+ # Verify agent span has tool information in input
1507
+ assert "toolsets" in agent_span["input"], "Tool information should be captured in agent input"
1508
+ toolsets = agent_span["input"]["toolsets"]
1509
+ agent_toolset = next((ts for ts in toolsets if ts.get("id") == "<agent>"), None)
1510
+ assert agent_toolset is not None, "Agent toolset should be in input"
1511
+
1512
+ # Verify calculate tool is in the toolset
1513
+ tools = agent_toolset.get("tools", [])
1514
+ tool_names = [t["name"] for t in tools if isinstance(t, dict)]
1515
+ assert "calculate" in tool_names, f"calculate tool should be in toolset, got: {tool_names}"
1516
+
1517
+ # TODO: Future enhancement - verify tool execution spans are created
1518
+ # tool_spans = [s for s in spans if "calculate" in s["span_attributes"].get("name", "")]
1519
+ # assert len(tool_spans) > 0, "Tool execution should create spans"
1520
+
1521
+
1522
+ def test_agent_tool_metadata_extraction(memory_logger):
1523
+ """Test that agent tools are properly extracted with full schemas in INPUT (not metadata).
1524
+
1525
+ Principle: If agent.run() accepts it, it goes in input only.
1526
+ """
1527
+ from braintrust.wrappers.pydantic_ai import _build_agent_input_and_metadata
1528
+
1529
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=100))
1530
+
1531
+ # Add multiple tools with different signatures
1532
+ @agent.tool_plain
1533
+ def calculate(operation: str, a: float, b: float) -> str:
1534
+ """Perform a mathematical calculation."""
1535
+ return str(a + b)
1536
+
1537
+ @agent.tool_plain
1538
+ def get_weather(location: str) -> str:
1539
+ """Get weather for a location."""
1540
+ return f"Weather in {location}"
1541
+
1542
+ @agent.tool_plain
1543
+ def search_database(query: str, limit: int = 10) -> str:
1544
+ """Search the database."""
1545
+ return "Results"
1546
+
1547
+ # Extract metadata using the actual function signature
1548
+ args = ("Test prompt",)
1549
+ kwargs = {}
1550
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, agent)
1551
+
1552
+ # Verify toolsets are in INPUT (since agent.run() accepts toolsets parameter)
1553
+ assert "toolsets" in input_data, "toolsets should be in input (can be passed to agent.run())"
1554
+ toolsets = input_data["toolsets"]
1555
+ assert len(toolsets) > 0, "At least one toolset should be present"
1556
+
1557
+ # Verify toolsets are NOT in metadata (following the principle)
1558
+ assert "toolsets" not in metadata, "toolsets should NOT be in metadata (it's a run() parameter)"
1559
+
1560
+ # Find the agent toolset
1561
+ agent_toolset = None
1562
+ for ts in toolsets:
1563
+ if ts.get("id") == "<agent>":
1564
+ agent_toolset = ts
1565
+ break
1566
+
1567
+ assert agent_toolset is not None, "Agent toolset not found in input"
1568
+ assert agent_toolset.get("label") == "the agent", "Agent toolset should have correct label"
1569
+ assert "tools" in agent_toolset, "tools should be in agent toolset"
1570
+
1571
+ # Verify all tools are present with FULL SCHEMAS
1572
+ tools = agent_toolset["tools"]
1573
+ assert isinstance(tools, list), "tools should be a list"
1574
+ assert len(tools) == 3, f"Should have exactly 3 tools, got {len(tools)}"
1575
+
1576
+ # Check each tool has full schema information
1577
+ tool_names = [t["name"] for t in tools]
1578
+ assert "calculate" in tool_names, f"calculate tool should be present, got: {tool_names}"
1579
+ assert "get_weather" in tool_names, f"get_weather tool should be present, got: {tool_names}"
1580
+ assert "search_database" in tool_names, f"search_database tool should be present, got: {tool_names}"
1581
+
1582
+ # Verify calculate tool has full schema
1583
+ calculate_tool = next(t for t in tools if t["name"] == "calculate")
1584
+ assert "description" in calculate_tool, "Tool should have description"
1585
+ assert "Perform a mathematical calculation" in calculate_tool["description"]
1586
+ assert "parameters" in calculate_tool, "Tool should have parameters schema"
1587
+ params = calculate_tool["parameters"]
1588
+ assert "properties" in params, "Parameters should have properties"
1589
+ assert "operation" in params["properties"], "Should have 'operation' parameter"
1590
+ assert "a" in params["properties"], "Should have 'a' parameter"
1591
+ assert "b" in params["properties"], "Should have 'b' parameter"
1592
+ assert params["properties"]["operation"]["type"] == "string"
1593
+ assert params["properties"]["a"]["type"] == "number"
1594
+ assert params["properties"]["b"]["type"] == "number"
1595
+
1596
+ # Verify search_database has optional parameter
1597
+ search_tool = next(t for t in tools if t["name"] == "search_database")
1598
+ assert "parameters" in search_tool
1599
+ search_params = search_tool["parameters"]
1600
+ assert "query" in search_params["properties"]
1601
+ assert "limit" in search_params["properties"]
1602
+ # 'query' should be required, 'limit' should be optional (has default)
1603
+ assert "query" in search_params.get("required", [])
1604
+
1605
+
1606
+ def test_agent_without_tools_metadata():
1607
+ """Test metadata extraction for agent without tools."""
1608
+ from braintrust.wrappers.pydantic_ai import _build_agent_input_and_metadata
1609
+
1610
+ # Agent with no tools
1611
+ agent = Agent(MODEL, model_settings=ModelSettings(max_tokens=50))
1612
+
1613
+ args = ("Test prompt",)
1614
+ kwargs = {}
1615
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, agent)
1616
+
1617
+ # Should have toolsets in input (even if empty)
1618
+ # Note: Pydantic AI agents always have some toolsets (e.g., for output parsing)
1619
+ # so we just verify the structure exists
1620
+ assert isinstance(input_data.get("toolsets"), (list, type(None))), "toolsets should be list or None in input"
1621
+
1622
+
1623
+ def test_agent_tool_with_custom_name():
1624
+ """Test that tools with custom names are properly extracted with schemas in input."""
1625
+ from braintrust.wrappers.pydantic_ai import _build_agent_input_and_metadata
1626
+
1627
+ agent = Agent(MODEL)
1628
+
1629
+ # Add tool with custom name
1630
+ @agent.tool_plain(name="custom_calculator")
1631
+ def calc(a: int, b: int) -> int:
1632
+ """Add two numbers."""
1633
+ return a + b
1634
+
1635
+ args = ("Test",)
1636
+ kwargs = {}
1637
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, agent)
1638
+
1639
+ # Verify custom name is used in input (not metadata)
1640
+ assert "toolsets" in input_data
1641
+ assert "toolsets" not in metadata, "toolsets should not be in metadata"
1642
+
1643
+ agent_toolset = next((ts for ts in input_data["toolsets"] if ts.get("id") == "<agent>"), None)
1644
+ assert agent_toolset is not None
1645
+ tools = agent_toolset.get("tools", [])
1646
+
1647
+ # The tool should be a dict with schema info
1648
+ assert len(tools) == 1, f"Should have 1 tool, got {len(tools)}"
1649
+ tool = tools[0]
1650
+ assert isinstance(tool, dict), "Tool should be a dict with schema"
1651
+ assert tool["name"] == "custom_calculator", f"Should use custom name, got: {tool.get('name')}"
1652
+ assert "description" in tool, "Tool should have description"
1653
+ assert "parameters" in tool, "Tool should have parameters schema"
1654
+ assert "a" in tool["parameters"]["properties"]
1655
+ assert "b" in tool["parameters"]["properties"]
1656
+
1657
+
1658
+ def test_explicit_toolsets_kwarg_in_input():
1659
+ """Test that explicitly passed toolsets kwarg goes to input (not just metadata)."""
1660
+ from braintrust.wrappers.pydantic_ai import _build_agent_input_and_metadata
1661
+
1662
+ agent = Agent(MODEL)
1663
+
1664
+ # Add a tool to the agent
1665
+ @agent.tool_plain
1666
+ def helper_tool() -> str:
1667
+ """A helper tool."""
1668
+ return "help"
1669
+
1670
+ # Simulate passing toolsets as explicit kwarg (would be a different toolset in practice)
1671
+ # For testing, we'll just pass the string "custom" to see it in input
1672
+ args = ("Test",)
1673
+ kwargs = {"toolsets": "custom_toolset_marker"} # Simplified for testing
1674
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, agent)
1675
+
1676
+ # Toolsets passed as kwargs should be in input
1677
+ assert "toolsets" in input_data, "explicitly passed toolsets should be in input"
1678
+
1679
+
1680
+ @pytest.mark.vcr
1681
+ def test_reasoning_tokens_extraction(memory_logger):
1682
+ """Test that reasoning tokens are extracted from model responses.
1683
+
1684
+ For reasoning models like o1/o3, usage.details.reasoning_tokens should be
1685
+ captured in the metrics field.
1686
+ """
1687
+ assert not memory_logger.pop()
1688
+
1689
+ # Mock a response that has reasoning tokens
1690
+ from unittest.mock import MagicMock
1691
+
1692
+ # Create a mock response with reasoning tokens
1693
+ mock_response = MagicMock()
1694
+ mock_response.parts = [
1695
+ MagicMock(
1696
+ part_kind="thinking",
1697
+ content="Let me think about this...",
1698
+ ),
1699
+ MagicMock(
1700
+ part_kind="text",
1701
+ content="The answer is 42",
1702
+ ),
1703
+ ]
1704
+ mock_response.usage = MagicMock()
1705
+ mock_response.usage.input_tokens = 10
1706
+ mock_response.usage.output_tokens = 20
1707
+ mock_response.usage.total_tokens = 30
1708
+ mock_response.usage.cache_read_tokens = 0
1709
+ mock_response.usage.cache_write_tokens = 0
1710
+ mock_response.usage.details = MagicMock()
1711
+ mock_response.usage.details.reasoning_tokens = 128
1712
+
1713
+ # Test the metric extraction function directly
1714
+ from braintrust.wrappers.pydantic_ai import _extract_response_metrics
1715
+
1716
+ start_time = time.time()
1717
+ end_time = start_time + 5.0
1718
+
1719
+ metrics = _extract_response_metrics(mock_response, start_time, end_time)
1720
+
1721
+ # Verify all metrics are present
1722
+ assert metrics is not None, "Should extract metrics"
1723
+ # pylint: disable=unsupported-membership-test,unsubscriptable-object
1724
+ assert "prompt_tokens" in metrics, "Should have prompt_tokens"
1725
+ assert metrics["prompt_tokens"] == 10.0
1726
+ assert "completion_tokens" in metrics, "Should have completion_tokens"
1727
+ assert metrics["completion_tokens"] == 20.0
1728
+ assert "tokens" in metrics, "Should have total tokens"
1729
+ assert metrics["tokens"] == 30.0
1730
+ assert "completion_reasoning_tokens" in metrics, "Should have completion_reasoning_tokens"
1731
+ assert metrics["completion_reasoning_tokens"] == 128.0, f"Expected 128.0, got {metrics['completion_reasoning_tokens']}"
1732
+ assert "duration" in metrics
1733
+ assert "start" in metrics
1734
+ assert "end" in metrics
1735
+ # pylint: enable=unsupported-membership-test,unsubscriptable-object
1736
+
1737
+
1738
+ @pytest.mark.vcr
1739
+ @pytest.mark.asyncio
1740
+ async def test_agent_run_stream_structured_output(memory_logger):
1741
+ """Test Agent.run_stream() with structured output (Pydantic model).
1742
+
1743
+ Verifies that streaming structured output creates proper spans and
1744
+ that the result can be accessed via get_output() method.
1745
+ """
1746
+ assert not memory_logger.pop()
1747
+
1748
+ class Product(BaseModel):
1749
+ name: str
1750
+ price: float
1751
+
1752
+ agent = Agent(
1753
+ MODEL,
1754
+ output_type=Product,
1755
+ model_settings=ModelSettings(max_tokens=200)
1756
+ )
1757
+
1758
+ start = time.time()
1759
+ async with agent.run_stream("Create a product: wireless mouse for $29.99") as result:
1760
+ # For structured output, use get_output() instead of streaming text
1761
+ product = await result.get_output()
1762
+ end = time.time()
1763
+
1764
+ # Verify structured output
1765
+ assert isinstance(product, Product)
1766
+ assert product.name
1767
+ assert product.price > 0
1768
+
1769
+ # Check spans
1770
+ spans = memory_logger.pop()
1771
+ assert len(spans) >= 2, f"Expected at least 2 spans (agent_run_stream + chat), got {len(spans)}"
1772
+
1773
+ # Find agent_run_stream and chat spans
1774
+ agent_span = next((s for s in spans if "agent_run_stream" in s["span_attributes"]["name"]), None)
1775
+ chat_span = next((s for s in spans if "chat" in s["span_attributes"]["name"]), None)
1776
+
1777
+ assert agent_span is not None, "agent_run_stream span not found"
1778
+ assert chat_span is not None, "chat span not found"
1779
+
1780
+ # Check agent span
1781
+ assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
1782
+ assert agent_span["metadata"]["model"] == "gpt-4o-mini"
1783
+ _assert_metrics_are_valid(agent_span["metrics"], start, end)
1784
+
1785
+ # Check chat span is nested
1786
+ assert chat_span["span_parents"] == [agent_span["span_id"]], "chat span should be nested under agent_run_stream"
1787
+ assert chat_span["metadata"]["model"] == "gpt-4o-mini"
1788
+ _assert_metrics_are_valid(chat_span["metrics"], start, end)