pydantic-ai 0.0.29__tar.gz → 0.0.30__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai might be problematic. Click here for more details.

Files changed (62) hide show
  1. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/PKG-INFO +3 -3
  2. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/pyproject.toml +3 -3
  3. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_instrumented.py +141 -3
  4. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_model_test.py +19 -0
  5. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_examples.py +53 -25
  6. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_streaming.py +112 -1
  7. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/.gitignore +0 -0
  8. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/LICENSE +0 -0
  9. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/Makefile +0 -0
  10. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/README.md +0 -0
  11. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/__init__.py +0 -0
  12. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/assets/kiwi.png +0 -0
  13. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/assets/marcelo.mp3 +0 -0
  14. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/conftest.py +0 -0
  15. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/example_modules/README.md +0 -0
  16. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/example_modules/bank_database.py +0 -0
  17. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/example_modules/fake_database.py +0 -0
  18. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/example_modules/weather_service.py +0 -0
  19. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/graph/__init__.py +0 -0
  20. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/graph/test_graph.py +0 -0
  21. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/graph/test_history.py +0 -0
  22. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/graph/test_mermaid.py +0 -0
  23. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/graph/test_state.py +0 -0
  24. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/graph/test_utils.py +0 -0
  25. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/import_examples.py +0 -0
  26. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/json_body_serializer.py +0 -0
  27. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/__init__.py +0 -0
  28. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_anthropic/test_image_url_input.yaml +0 -0
  29. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_anthropic/test_image_url_input_invalid_mime_type.yaml +0 -0
  30. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_anthropic/test_multiple_parallel_tool_calls.yaml +0 -0
  31. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_gemini/test_image_as_binary_content_input.yaml +0 -0
  32. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_gemini/test_image_url_input.yaml +0 -0
  33. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_groq/test_image_as_binary_content_input.yaml +0 -0
  34. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_groq/test_image_url_input.yaml +0 -0
  35. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_openai/test_audio_as_binary_content_input.yaml +0 -0
  36. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_openai/test_image_as_binary_content_input.yaml +0 -0
  37. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_openai/test_openai_o1_mini_system_role[developer].yaml +0 -0
  38. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/cassettes/test_openai/test_openai_o1_mini_system_role[system].yaml +0 -0
  39. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/mock_async_stream.py +0 -0
  40. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_anthropic.py +0 -0
  41. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_cohere.py +0 -0
  42. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_fallback.py +0 -0
  43. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_gemini.py +0 -0
  44. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_groq.py +0 -0
  45. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_mistral.py +0 -0
  46. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_model.py +0 -0
  47. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_model_function.py +0 -0
  48. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_model_names.py +0 -0
  49. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_openai.py +0 -0
  50. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/models/test_vertexai.py +0 -0
  51. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_agent.py +0 -0
  52. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_deps.py +0 -0
  53. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_format_as_xml.py +0 -0
  54. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_json_body_serializer.py +0 -0
  55. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_live.py +0 -0
  56. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_logfire.py +0 -0
  57. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_parts_manager.py +0 -0
  58. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_tools.py +0 -0
  59. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_usage_limits.py +0 -0
  60. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/test_utils.py +0 -0
  61. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/typed_agent.py +0 -0
  62. {pydantic_ai-0.0.29 → pydantic_ai-0.0.30}/tests/typed_graph.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai
3
- Version: 0.0.29
3
+ Version: 0.0.30
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs
5
5
  Project-URL: Homepage, https://ai.pydantic.dev
6
6
  Project-URL: Source, https://github.com/pydantic/pydantic-ai
@@ -28,9 +28,9 @@ Classifier: Topic :: Internet
28
28
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
29
29
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
30
30
  Requires-Python: >=3.9
31
- Requires-Dist: pydantic-ai-slim[anthropic,cohere,groq,mistral,openai,vertexai]==0.0.29
31
+ Requires-Dist: pydantic-ai-slim[anthropic,cohere,groq,mistral,openai,vertexai]==0.0.30
32
32
  Provides-Extra: examples
33
- Requires-Dist: pydantic-ai-examples==0.0.29; extra == 'examples'
33
+ Requires-Dist: pydantic-ai-examples==0.0.30; extra == 'examples'
34
34
  Provides-Extra: logfire
35
35
  Requires-Dist: logfire>=2.3; extra == 'logfire'
36
36
  Description-Content-Type: text/markdown
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai"
7
- version = "0.0.29"
7
+ version = "0.0.30"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs"
9
9
  authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }]
10
10
  license = "MIT"
@@ -32,7 +32,7 @@ classifiers = [
32
32
  requires-python = ">=3.9"
33
33
 
34
34
  dependencies = [
35
- "pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere]==0.0.29",
35
+ "pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere]==0.0.30",
36
36
  ]
37
37
 
38
38
  [project.urls]
@@ -42,7 +42,7 @@ Documentation = "https://ai.pydantic.dev"
42
42
  Changelog = "https://github.com/pydantic/pydantic-ai/releases"
43
43
 
44
44
  [project.optional-dependencies]
45
- examples = ["pydantic-ai-examples==0.0.29"]
45
+ examples = ["pydantic-ai-examples==0.0.30"]
46
46
  logfire = ["logfire>=2.3"]
47
47
 
48
48
  [tool.uv.sources]
@@ -5,6 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from datetime import datetime
6
6
 
7
7
  import pytest
8
+ from dirty_equals import IsJson
8
9
  from inline_snapshot import snapshot
9
10
  from logfire_api import DEFAULT_LOGFIRE_INSTANCE
10
11
 
@@ -105,7 +106,7 @@ class MyResponseStream(StreamedResponse):
105
106
  @pytest.mark.anyio
106
107
  @requires_logfire_events
107
108
  async def test_instrumented_model(capfire: CaptureLogfire):
108
- model = InstrumentedModel.from_logfire(MyModel())
109
+ model = InstrumentedModel.from_logfire(MyModel(), event_mode='logs')
109
110
  assert model.system == 'my_system'
110
111
  assert model.model_name == 'my_model'
111
112
 
@@ -323,7 +324,7 @@ async def test_instrumented_model_not_recording():
323
324
  @pytest.mark.anyio
324
325
  @requires_logfire_events
325
326
  async def test_instrumented_model_stream(capfire: CaptureLogfire):
326
- model = InstrumentedModel.from_logfire(MyModel())
327
+ model = InstrumentedModel.from_logfire(MyModel(), event_mode='logs')
327
328
 
328
329
  messages: list[ModelMessage] = [
329
330
  ModelRequest(
@@ -405,7 +406,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire):
405
406
  @pytest.mark.anyio
406
407
  @requires_logfire_events
407
408
  async def test_instrumented_model_stream_break(capfire: CaptureLogfire):
408
- model = InstrumentedModel.from_logfire(MyModel())
409
+ model = InstrumentedModel.from_logfire(MyModel(), event_mode='logs')
409
410
 
410
411
  messages: list[ModelMessage] = [
411
412
  ModelRequest(
@@ -494,3 +495,140 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire):
494
495
  },
495
496
  ]
496
497
  )
498
+
499
+
500
+ @pytest.mark.anyio
501
+ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire):
502
+ model = InstrumentedModel(MyModel(), event_mode='attributes')
503
+ assert model.system == 'my_system'
504
+ assert model.model_name == 'my_model'
505
+
506
+ messages = [
507
+ ModelRequest(
508
+ parts=[
509
+ SystemPromptPart('system_prompt'),
510
+ UserPromptPart('user_prompt'),
511
+ ToolReturnPart('tool3', 'tool_return_content', 'tool_call_3'),
512
+ RetryPromptPart('retry_prompt1', tool_name='tool4', tool_call_id='tool_call_4'),
513
+ RetryPromptPart('retry_prompt2'),
514
+ {}, # test unexpected parts # type: ignore
515
+ ]
516
+ ),
517
+ ModelResponse(
518
+ parts=[
519
+ TextPart('text3'),
520
+ ]
521
+ ),
522
+ ]
523
+ await model.request(
524
+ messages,
525
+ model_settings=ModelSettings(temperature=1),
526
+ model_request_parameters=ModelRequestParameters(
527
+ function_tools=[],
528
+ allow_text_result=True,
529
+ result_tools=[],
530
+ ),
531
+ )
532
+
533
+ assert capfire.exporter.exported_spans_as_dict() == snapshot(
534
+ [
535
+ {
536
+ 'name': 'chat my_model',
537
+ 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
538
+ 'parent': None,
539
+ 'start_time': 1000000000,
540
+ 'end_time': 2000000000,
541
+ 'attributes': {
542
+ 'gen_ai.operation.name': 'chat',
543
+ 'gen_ai.system': 'my_system',
544
+ 'gen_ai.request.model': 'my_model',
545
+ 'gen_ai.request.temperature': 1,
546
+ 'logfire.msg': 'chat my_model',
547
+ 'logfire.span_type': 'span',
548
+ 'gen_ai.response.model': 'my_model_123',
549
+ 'gen_ai.usage.input_tokens': 100,
550
+ 'gen_ai.usage.output_tokens': 200,
551
+ 'events': IsJson(
552
+ snapshot(
553
+ [
554
+ {
555
+ 'event.name': 'gen_ai.system.message',
556
+ 'content': 'system_prompt',
557
+ 'role': 'system',
558
+ 'gen_ai.system': 'my_system',
559
+ },
560
+ {
561
+ 'event.name': 'gen_ai.user.message',
562
+ 'content': 'user_prompt',
563
+ 'role': 'user',
564
+ 'gen_ai.system': 'my_system',
565
+ },
566
+ {
567
+ 'event.name': 'gen_ai.tool.message',
568
+ 'content': 'tool_return_content',
569
+ 'role': 'tool',
570
+ 'id': 'tool_call_3',
571
+ 'gen_ai.system': 'my_system',
572
+ },
573
+ {
574
+ 'event.name': 'gen_ai.tool.message',
575
+ 'content': """\
576
+ retry_prompt1
577
+
578
+ Fix the errors and try again.\
579
+ """,
580
+ 'role': 'tool',
581
+ 'id': 'tool_call_4',
582
+ 'gen_ai.system': 'my_system',
583
+ },
584
+ {
585
+ 'event.name': 'gen_ai.user.message',
586
+ 'content': """\
587
+ retry_prompt2
588
+
589
+ Fix the errors and try again.\
590
+ """,
591
+ 'role': 'user',
592
+ 'gen_ai.system': 'my_system',
593
+ },
594
+ {
595
+ 'event.name': 'gen_ai.assistant.message',
596
+ 'role': 'assistant',
597
+ 'content': 'text3',
598
+ 'gen_ai.system': 'my_system',
599
+ },
600
+ {
601
+ 'event.name': 'gen_ai.choice',
602
+ 'index': 0,
603
+ 'message': {
604
+ 'role': 'assistant',
605
+ 'content': 'text1',
606
+ 'tool_calls': [
607
+ {
608
+ 'id': 'tool_call_1',
609
+ 'type': 'function',
610
+ 'function': {'name': 'tool1', 'arguments': 'args1'},
611
+ },
612
+ {
613
+ 'id': 'tool_call_2',
614
+ 'type': 'function',
615
+ 'function': {'name': 'tool2', 'arguments': {'args2': 3}},
616
+ },
617
+ ],
618
+ },
619
+ 'gen_ai.system': 'my_system',
620
+ },
621
+ {
622
+ 'event.name': 'gen_ai.choice',
623
+ 'index': 0,
624
+ 'message': {'role': 'assistant', 'content': 'text2'},
625
+ 'gen_ai.system': 'my_system',
626
+ },
627
+ ]
628
+ )
629
+ ),
630
+ 'logfire.json_schema': '{"type": "object", "properties": {"events": {"type": "array"}}}',
631
+ },
632
+ },
633
+ ]
634
+ )
@@ -13,6 +13,9 @@ from pydantic import BaseModel, Field
13
13
  from pydantic_ai import Agent, ModelRetry, RunContext
14
14
  from pydantic_ai.exceptions import UnexpectedModelBehavior
15
15
  from pydantic_ai.messages import (
16
+ AudioUrl,
17
+ BinaryContent,
18
+ ImageUrl,
16
19
  ModelRequest,
17
20
  ModelResponse,
18
21
  RetryPromptPart,
@@ -22,6 +25,7 @@ from pydantic_ai.messages import (
22
25
  UserPromptPart,
23
26
  )
24
27
  from pydantic_ai.models.test import TestModel, _chars, _JsonSchemaTestData # pyright: ignore[reportPrivateUsage]
28
+ from pydantic_ai.usage import Usage
25
29
 
26
30
  from ..conftest import IsNow
27
31
 
@@ -271,3 +275,18 @@ def test_max_items():
271
275
  }
272
276
  data = _JsonSchemaTestData(json_schema).generate()
273
277
  assert data == snapshot([])
278
+
279
+
280
+ @pytest.mark.parametrize(
281
+ 'content',
282
+ [
283
+ AudioUrl(url='https://example.com'),
284
+ ImageUrl(url='https://example.com'),
285
+ BinaryContent(data=b'', media_type='image/png'),
286
+ ],
287
+ )
288
+ def test_different_content_input(content: AudioUrl | ImageUrl | BinaryContent):
289
+ agent = Agent()
290
+ result = agent.run_sync('x', model=TestModel(custom_result_text='custom'))
291
+ assert result.data == snapshot('custom')
292
+ assert result.usage() == snapshot(Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52))
@@ -182,6 +182,9 @@ text_responses: dict[str, str | ToolCallPart] = {
182
182
  'What is the weather like in West London and in Wiltshire?': (
183
183
  'The weather in West London is raining, while in Wiltshire it is sunny.'
184
184
  ),
185
+ 'What will the weather be like in Paris on Tuesday?': ToolCallPart(
186
+ tool_name='weather_forecast', args={'location': 'Paris', 'forecast_date': '2030-01-01'}, tool_call_id='0001'
187
+ ),
185
188
  'Tell me a joke.': 'Did you hear about the toothpaste scandal? They called it Colgate.',
186
189
  'Explain?': 'This is an excellent joke invented by Samuel Colvin, it needs no explanation.',
187
190
  'What is the capital of France?': 'Paris',
@@ -270,6 +273,13 @@ text_responses: dict[str, str | ToolCallPart] = {
270
273
  ),
271
274
  }
272
275
 
276
+ tool_responses: dict[tuple[str, str], str] = {
277
+ (
278
+ 'weather_forecast',
279
+ 'The forecast in Paris on 2030-01-01 is 24°C and sunny.',
280
+ ): 'It will be warm and sunny in Paris on Tuesday.',
281
+ }
282
+
273
283
 
274
284
  async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover # noqa: C901
275
285
  m = messages[-1].parts[-1]
@@ -348,35 +358,53 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes
348
358
  raise RuntimeError(f'Unexpected message: {m}')
349
359
 
350
360
 
351
- async def stream_model_logic(
361
+ async def stream_model_logic( # noqa C901
352
362
  messages: list[ModelMessage], info: AgentInfo
353
363
  ) -> AsyncIterator[str | DeltaToolCalls]: # pragma: no cover
354
- m = messages[-1].parts[-1]
355
- if isinstance(m, UserPromptPart):
356
- assert isinstance(m.content, str)
357
- if response := text_responses.get(m.content):
358
- if isinstance(response, str):
359
- words = response.split(' ')
360
- chunk: list[str] = []
361
- for work in words:
362
- chunk.append(work)
363
- if len(chunk) == 3:
364
- yield ' '.join(chunk) + ' '
365
- chunk.clear()
366
- if chunk:
367
- yield ' '.join(chunk)
368
- return
369
- else:
370
- json_text = response.args_as_json_str()
371
-
372
- yield {1: DeltaToolCall(name=response.tool_name)}
373
- for chunk_index in range(0, len(json_text), 15):
374
- text_chunk = json_text[chunk_index : chunk_index + 15]
375
- yield {1: DeltaToolCall(json_args=text_chunk)}
376
- return
364
+ async def stream_text_response(r: str) -> AsyncIterator[str]:
365
+ if isinstance(r, str):
366
+ words = r.split(' ')
367
+ chunk: list[str] = []
368
+ for word in words:
369
+ chunk.append(word)
370
+ if len(chunk) == 3:
371
+ yield ' '.join(chunk) + ' '
372
+ chunk.clear()
373
+ if chunk:
374
+ yield ' '.join(chunk)
375
+
376
+ async def stream_tool_call_response(r: ToolCallPart) -> AsyncIterator[DeltaToolCalls]:
377
+ json_text = r.args_as_json_str()
378
+
379
+ yield {1: DeltaToolCall(name=r.tool_name, tool_call_id=r.tool_call_id)}
380
+ for chunk_index in range(0, len(json_text), 15):
381
+ text_chunk = json_text[chunk_index : chunk_index + 15]
382
+ yield {1: DeltaToolCall(json_args=text_chunk)}
383
+
384
+ async def stream_part_response(r: str | ToolCallPart) -> AsyncIterator[str | DeltaToolCalls]:
385
+ if isinstance(r, str):
386
+ async for chunk in stream_text_response(r):
387
+ yield chunk
388
+ else:
389
+ async for chunk in stream_tool_call_response(r):
390
+ yield chunk
391
+
392
+ last_part = messages[-1].parts[-1]
393
+ if isinstance(last_part, UserPromptPart):
394
+ assert isinstance(last_part.content, str)
395
+ if response := text_responses.get(last_part.content):
396
+ async for chunk in stream_part_response(response):
397
+ yield chunk
398
+ return
399
+ elif isinstance(last_part, ToolReturnPart):
400
+ assert isinstance(last_part.content, str)
401
+ if response := tool_responses.get((last_part.tool_name, last_part.content)):
402
+ async for chunk in stream_part_response(response):
403
+ yield chunk
404
+ return
377
405
 
378
406
  sys.stdout.write(str(debug.format(messages, info)))
379
- raise RuntimeError(f'Unexpected message: {m}')
407
+ raise RuntimeError(f'Unexpected message: {last_part}')
380
408
 
381
409
 
382
410
  def mock_infer_model(model: Model | KnownModelName) -> Model:
@@ -2,14 +2,18 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import datetime
4
4
  import json
5
+ import re
5
6
  from collections.abc import AsyncIterator
7
+ from copy import deepcopy
6
8
  from datetime import timezone
9
+ from typing import Union
7
10
 
8
11
  import pytest
9
12
  from inline_snapshot import snapshot
10
13
  from pydantic import BaseModel
11
14
 
12
15
  from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages
16
+ from pydantic_ai.agent import AgentRun
13
17
  from pydantic_ai.messages import (
14
18
  ModelMessage,
15
19
  ModelRequest,
@@ -22,7 +26,8 @@ from pydantic_ai.messages import (
22
26
  )
23
27
  from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
24
28
  from pydantic_ai.models.test import TestModel
25
- from pydantic_ai.result import Usage
29
+ from pydantic_ai.result import AgentStream, FinalResult, Usage
30
+ from pydantic_graph import End
26
31
 
27
32
  from .conftest import IsNow
28
33
 
@@ -739,3 +744,109 @@ async def test_custom_result_type_default_structured() -> None:
739
744
  async with agent.run_stream('test', result_type=str) as result:
740
745
  response = await result.get_data()
741
746
  assert response == snapshot('success (no tool calls)')
747
+
748
+
749
+ async def test_iter_stream_output():
750
+ m = TestModel(custom_result_text='The cat sat on the mat.')
751
+
752
+ agent = Agent(m)
753
+
754
+ @agent.result_validator
755
+ def result_validator_simple(data: str) -> str:
756
+ # Make a substitution in the validated results
757
+ return re.sub('cat sat', 'bat sat', data)
758
+
759
+ run: AgentRun
760
+ stream: AgentStream
761
+ messages: list[str] = []
762
+
763
+ stream_usage: Usage | None = None
764
+ with agent.iter('Hello') as run:
765
+ async for node in run:
766
+ if agent.is_model_request_node(node):
767
+ async with node.stream(run.ctx) as stream:
768
+ async for chunk in stream.stream_output(debounce_by=None):
769
+ messages.append(chunk)
770
+ stream_usage = deepcopy(stream.usage())
771
+ assert run.next_node == End(data=FinalResult(data='The bat sat on the mat.', tool_name=None))
772
+ assert (
773
+ run.usage()
774
+ == stream_usage
775
+ == Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58, details=None)
776
+ )
777
+
778
+ assert messages == [
779
+ '',
780
+ 'The ',
781
+ 'The cat ',
782
+ 'The bat sat ',
783
+ 'The bat sat on ',
784
+ 'The bat sat on the ',
785
+ 'The bat sat on the mat.',
786
+ 'The bat sat on the mat.',
787
+ ]
788
+
789
+
790
+ async def test_iter_stream_responses():
791
+ m = TestModel(custom_result_text='The cat sat on the mat.')
792
+
793
+ agent = Agent(m)
794
+
795
+ @agent.result_validator
796
+ def result_validator_simple(data: str) -> str:
797
+ # Make a substitution in the validated results
798
+ return re.sub('cat sat', 'bat sat', data)
799
+
800
+ run: AgentRun
801
+ stream: AgentStream
802
+ messages: list[ModelResponse] = []
803
+ with agent.iter('Hello') as run:
804
+ async for node in run:
805
+ if agent.is_model_request_node(node):
806
+ async with node.stream(run.ctx) as stream:
807
+ async for chunk in stream.stream_responses(debounce_by=None):
808
+ messages.append(chunk)
809
+
810
+ assert messages == [
811
+ ModelResponse(
812
+ parts=[TextPart(content=text, part_kind='text')],
813
+ model_name='test',
814
+ timestamp=IsNow(tz=timezone.utc),
815
+ kind='response',
816
+ )
817
+ for text in [
818
+ '',
819
+ '',
820
+ 'The ',
821
+ 'The cat ',
822
+ 'The cat sat ',
823
+ 'The cat sat on ',
824
+ 'The cat sat on the ',
825
+ 'The cat sat on the mat.',
826
+ ]
827
+ ]
828
+
829
+ # Note: as you can see above, the result validator is not applied to the streamed responses, just the final result:
830
+ assert run.result is not None
831
+ assert run.result.data == 'The bat sat on the mat.'
832
+
833
+
834
+ async def test_stream_iter_structured_validator() -> None:
835
+ class NotResultType(BaseModel):
836
+ not_value: str
837
+
838
+ agent = Agent[None, Union[ResultType, NotResultType]]('test', result_type=Union[ResultType, NotResultType]) # pyright: ignore[reportArgumentType]
839
+
840
+ @agent.result_validator
841
+ def result_validator(data: ResultType | NotResultType) -> ResultType | NotResultType:
842
+ assert isinstance(data, ResultType)
843
+ return ResultType(value=data.value + ' (validated)')
844
+
845
+ outputs: list[ResultType] = []
846
+ with agent.iter('test') as run:
847
+ async for node in run:
848
+ if agent.is_model_request_node(node):
849
+ async with node.stream(run.ctx) as stream:
850
+ async for output in stream.stream_output(debounce_by=None):
851
+ outputs.append(output)
852
+ assert outputs == [ResultType(value='a (validated)'), ResultType(value='a (validated)')]
File without changes
File without changes
File without changes
File without changes