pydantic-ai-slim 0.0.55__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (53) hide show
  1. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/PKG-INFO +5 -5
  2. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/__init__.py +10 -3
  3. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_agent_graph.py +67 -55
  4. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_cli.py +1 -2
  5. pydantic_ai_slim-0.0.55/pydantic_ai/_result.py → pydantic_ai_slim-0.1.0/pydantic_ai/_output.py +69 -47
  6. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_utils.py +20 -0
  7. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/agent.py +501 -161
  8. pydantic_ai_slim-0.1.0/pydantic_ai/format_as_xml.py +9 -0
  9. pydantic_ai_slim-0.0.55/pydantic_ai/format_as_xml.py → pydantic_ai_slim-0.1.0/pydantic_ai/format_prompt.py +1 -1
  10. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/messages.py +104 -21
  11. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/__init__.py +24 -4
  12. pydantic_ai_slim-0.1.0/pydantic_ai/models/_json_schema.py +156 -0
  13. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/anthropic.py +5 -3
  14. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/bedrock.py +100 -22
  15. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/cohere.py +48 -44
  16. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/fallback.py +2 -1
  17. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/function.py +8 -8
  18. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/gemini.py +65 -75
  19. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/groq.py +32 -28
  20. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/instrumented.py +4 -4
  21. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/mistral.py +62 -58
  22. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/openai.py +110 -158
  23. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/test.py +45 -46
  24. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/result.py +203 -90
  25. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/tools.py +3 -3
  26. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pyproject.toml +2 -2
  27. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/.gitignore +0 -0
  28. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/README.md +0 -0
  29. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/__main__.py +0 -0
  30. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_griffe.py +0 -0
  31. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_parts_manager.py +0 -0
  32. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_pydantic.py +0 -0
  33. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_system_prompt.py +0 -0
  34. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/common_tools/__init__.py +0 -0
  35. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  36. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/common_tools/tavily.py +0 -0
  37. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/exceptions.py +0 -0
  38. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/mcp.py +0 -0
  39. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/wrapper.py +0 -0
  40. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/__init__.py +0 -0
  41. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/anthropic.py +0 -0
  42. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/azure.py +0 -0
  43. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/bedrock.py +0 -0
  44. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/cohere.py +0 -0
  45. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/deepseek.py +0 -0
  46. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/google_gla.py +0 -0
  47. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/google_vertex.py +0 -0
  48. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/groq.py +0 -0
  49. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/mistral.py +0 -0
  50. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/openai.py +0 -0
  51. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/py.typed +0 -0
  52. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/settings.py +0 -0
  53. {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.55
3
+ Version: 0.1.0
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,13 +29,13 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.55
32
+ Requires-Dist: pydantic-graph==0.1.0
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
36
36
  Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
37
37
  Provides-Extra: bedrock
38
- Requires-Dist: boto3>=1.34.116; extra == 'bedrock'
38
+ Requires-Dist: boto3>=1.35.74; extra == 'bedrock'
39
39
  Provides-Extra: cli
40
40
  Requires-Dist: argcomplete>=3.5.0; extra == 'cli'
41
41
  Requires-Dist: prompt-toolkit>=3; extra == 'cli'
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
45
45
  Provides-Extra: duckduckgo
46
46
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
47
47
  Provides-Extra: evals
48
- Requires-Dist: pydantic-evals==0.0.55; extra == 'evals'
48
+ Requires-Dist: pydantic-evals==0.1.0; extra == 'evals'
49
49
  Provides-Extra: groq
50
50
  Requires-Dist: groq>=0.15.0; extra == 'groq'
51
51
  Provides-Extra: logfire
@@ -55,7 +55,7 @@ Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
55
55
  Provides-Extra: mistral
56
56
  Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
57
57
  Provides-Extra: openai
58
- Requires-Dist: openai>=1.67.0; extra == 'openai'
58
+ Requires-Dist: openai>=1.74.0; extra == 'openai'
59
59
  Provides-Extra: tavily
60
60
  Requires-Dist: tavily-python>=0.5.0; extra == 'tavily'
61
61
  Provides-Extra: vertexai
@@ -1,4 +1,4 @@
1
- from importlib.metadata import version
1
+ from importlib.metadata import version as _metadata_version
2
2
 
3
3
  from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
4
4
  from .exceptions import (
@@ -10,7 +10,9 @@ from .exceptions import (
10
10
  UsageLimitExceeded,
11
11
  UserError,
12
12
  )
13
- from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl
13
+ from .format_prompt import format_as_xml
14
+ from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
15
+ from .result import ToolOutput
14
16
  from .tools import RunContext, Tool
15
17
 
16
18
  __all__ = (
@@ -33,10 +35,15 @@ __all__ = (
33
35
  # messages
34
36
  'ImageUrl',
35
37
  'AudioUrl',
38
+ 'VideoUrl',
36
39
  'DocumentUrl',
37
40
  'BinaryContent',
38
41
  # tools
39
42
  'Tool',
40
43
  'RunContext',
44
+ # result
45
+ 'ToolOutput',
46
+ # format_prompt
47
+ 'format_as_xml',
41
48
  )
42
- __version__ = version('pydantic_ai_slim')
49
+ __version__ = _metadata_version('pydantic_ai_slim')
@@ -16,7 +16,7 @@ from pydantic_graph import BaseNode, Graph, GraphRunContext
16
16
  from pydantic_graph.nodes import End, NodeRunEndT
17
17
 
18
18
  from . import (
19
- _result,
19
+ _output,
20
20
  _system_prompt,
21
21
  exceptions,
22
22
  messages as _messages,
@@ -25,7 +25,7 @@ from . import (
25
25
  usage as _usage,
26
26
  )
27
27
  from .models.instrumented import InstrumentedModel
28
- from .result import ResultDataT
28
+ from .result import OutputDataT, ToolOutput
29
29
  from .settings import ModelSettings, merge_model_settings
30
30
  from .tools import RunContext, Tool, ToolDefinition
31
31
 
@@ -53,7 +53,7 @@ EndStrategy = Literal['early', 'exhaustive']
53
53
  - `'exhaustive'`: Process all tool calls even after finding a final result
54
54
  """
55
55
  DepsT = TypeVar('DepsT')
56
- ResultT = TypeVar('ResultT')
56
+ OutputT = TypeVar('OutputT')
57
57
 
58
58
 
59
59
  @dataclasses.dataclass
@@ -74,7 +74,7 @@ class GraphAgentState:
74
74
 
75
75
 
76
76
  @dataclasses.dataclass
77
- class GraphAgentDeps(Generic[DepsT, ResultDataT]):
77
+ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
78
78
  """Dependencies/config passed to the agent graph."""
79
79
 
80
80
  user_deps: DepsT
@@ -88,9 +88,8 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
88
88
  max_result_retries: int
89
89
  end_strategy: EndStrategy
90
90
 
91
- result_schema: _result.ResultSchema[ResultDataT] | None
92
- result_tools: list[ToolDefinition]
93
- result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
91
+ output_schema: _output.OutputSchema[OutputDataT] | None
92
+ output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
94
93
 
95
94
  function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
96
95
  mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
@@ -126,6 +125,9 @@ def is_agent_node(
126
125
  class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
127
126
  user_prompt: str | Sequence[_messages.UserContent] | None
128
127
 
128
+ instructions: str | None
129
+ instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
130
+
129
131
  system_prompts: tuple[str, ...]
130
132
  system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
131
133
  system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
@@ -167,6 +169,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
167
169
  ctx_messages.used = True
168
170
 
169
171
  parts: list[_messages.ModelRequestPart] = []
172
+ instructions = await self._instructions(run_context)
170
173
  if message_history:
171
174
  # Shallow copy messages
172
175
  messages.extend(message_history)
@@ -177,7 +180,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
177
180
 
178
181
  if user_prompt is not None:
179
182
  parts.append(_messages.UserPromptPart(user_prompt))
180
- return messages, _messages.ModelRequest(parts)
183
+ return messages, _messages.ModelRequest(parts, instructions=instructions)
181
184
 
182
185
  async def _reevaluate_dynamic_prompts(
183
186
  self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
@@ -207,6 +210,15 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
207
210
  messages.append(_messages.SystemPromptPart(prompt))
208
211
  return messages
209
212
 
213
+ async def _instructions(self, run_context: RunContext[DepsT]) -> str | None:
214
+ if self.instructions is None and not self.instructions_functions:
215
+ return None
216
+
217
+ instructions = self.instructions or ''
218
+ for instructions_runner in self.instructions_functions:
219
+ instructions += await instructions_runner.run(run_context)
220
+ return instructions
221
+
210
222
 
211
223
  async def _prepare_request_parameters(
212
224
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
@@ -233,11 +245,11 @@ async def _prepare_request_parameters(
233
245
  *map(add_mcp_server_tools, ctx.deps.mcp_servers),
234
246
  )
235
247
 
236
- result_schema = ctx.deps.result_schema
248
+ output_schema = ctx.deps.output_schema
237
249
  return models.ModelRequestParameters(
238
250
  function_tools=function_tool_defs,
239
- allow_text_result=allow_text_result(result_schema),
240
- result_tools=result_schema.tool_defs() if result_schema is not None else [],
251
+ allow_text_output=allow_text_output(output_schema),
252
+ output_tools=output_schema.tool_defs() if output_schema is not None else [],
241
253
  )
242
254
 
243
255
 
@@ -271,8 +283,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
271
283
  async with self._stream(ctx) as streamed_response:
272
284
  agent_stream = result.AgentStream[DepsT, T](
273
285
  streamed_response,
274
- ctx.deps.result_schema,
275
- ctx.deps.result_validators,
286
+ ctx.deps.output_schema,
287
+ ctx.deps.output_validators,
276
288
  build_run_context(ctx),
277
289
  ctx.deps.usage_limits,
278
290
  )
@@ -290,6 +302,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
290
302
  assert not self._did_stream, 'stream() should only be called once per node'
291
303
 
292
304
  model_settings, model_request_parameters = await self._prepare_request(ctx)
305
+ model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
293
306
  async with ctx.deps.model.request_stream(
294
307
  ctx.state.message_history, model_settings, model_request_parameters
295
308
  ) as streamed_response:
@@ -431,17 +444,17 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
431
444
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
432
445
  tool_calls: list[_messages.ToolCallPart],
433
446
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
434
- result_schema = ctx.deps.result_schema
447
+ output_schema = ctx.deps.output_schema
435
448
 
436
- # first look for the result tool call
449
+ # first, look for the output tool call
437
450
  final_result: result.FinalResult[NodeRunEndT] | None = None
438
451
  parts: list[_messages.ModelRequestPart] = []
439
- if result_schema is not None:
440
- for call, result_tool in result_schema.find_tool(tool_calls):
452
+ if output_schema is not None:
453
+ for call, output_tool in output_schema.find_tool(tool_calls):
441
454
  try:
442
- result_data = result_tool.validate(call)
443
- result_data = await _validate_result(result_data, ctx, call)
444
- except _result.ToolRetryError as e:
455
+ result_data = output_tool.validate(call)
456
+ result_data = await _validate_output(result_data, ctx, call)
457
+ except _output.ToolRetryError as e:
445
458
  # TODO: Should only increment retry stuff once per node execution, not for each tool call
446
459
  # Also, should increment the tool-specific retry count rather than the run retry count
447
460
  ctx.state.increment_retries(ctx.deps.max_result_retries)
@@ -488,9 +501,9 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
488
501
  'all_messages_events': json.dumps(
489
502
  [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
490
503
  ),
491
- 'final_result': final_result.data
492
- if isinstance(final_result.data, str)
493
- else json.dumps(InstrumentedModel.serialize_any(final_result.data)),
504
+ 'final_result': final_result.output
505
+ if isinstance(final_result.output, str)
506
+ else json.dumps(InstrumentedModel.serialize_any(final_result.output)),
494
507
  }
495
508
  )
496
509
  run_span.set_attributes(
@@ -507,7 +520,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
507
520
  }
508
521
  )
509
522
 
510
- # End the run with self.data
511
523
  return End(final_result)
512
524
 
513
525
  async def _handle_text_response(
@@ -515,14 +527,14 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
515
527
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
516
528
  texts: list[str],
517
529
  ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
518
- result_schema = ctx.deps.result_schema
530
+ output_schema = ctx.deps.output_schema
519
531
 
520
532
  text = '\n\n'.join(texts)
521
- if allow_text_result(result_schema):
533
+ if allow_text_output(output_schema):
522
534
  result_data_input = cast(NodeRunEndT, text)
523
535
  try:
524
- result_data = await _validate_result(result_data_input, ctx, None)
525
- except _result.ToolRetryError as e:
536
+ result_data = await _validate_output(result_data_input, ctx, None)
537
+ except _output.ToolRetryError as e:
526
538
  ctx.state.increment_retries(ctx.deps.max_result_retries)
527
539
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
528
540
  else:
@@ -534,7 +546,7 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
534
546
  _messages.ModelRequest(
535
547
  parts=[
536
548
  _messages.RetryPromptPart(
537
- content='Plain text responses are not permitted, please call one of the functions instead.',
549
+ content='Plain text responses are not permitted, please include your response in a tool call',
538
550
  )
539
551
  ]
540
552
  )
@@ -555,8 +567,8 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
555
567
 
556
568
  async def process_function_tools(
557
569
  tool_calls: list[_messages.ToolCallPart],
558
- result_tool_name: str | None,
559
- result_tool_call_id: str | None,
570
+ output_tool_name: str | None,
571
+ output_tool_call_id: str | None,
560
572
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
561
573
  output_parts: list[_messages.ModelRequestPart],
562
574
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
@@ -566,22 +578,22 @@ async def process_function_tools(
566
578
 
567
579
  Because async iterators can't have return values, we use `output_parts` as an output argument.
568
580
  """
569
- stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
570
- result_schema = ctx.deps.result_schema
581
+ stub_function_tools = bool(output_tool_name) and ctx.deps.end_strategy == 'early'
582
+ output_schema = ctx.deps.output_schema
571
583
 
572
- # we rely on the fact that if we found a result, it's the first result tool in the last
573
- found_used_result_tool = False
584
+ # we rely on the fact that if we found a result, it's the first output tool in the last
585
+ found_used_output_tool = False
574
586
  run_context = build_run_context(ctx)
575
587
 
576
588
  calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
577
589
  call_index_to_event_id: dict[int, str] = {}
578
590
  for call in tool_calls:
579
591
  if (
580
- call.tool_name == result_tool_name
581
- and call.tool_call_id == result_tool_call_id
582
- and not found_used_result_tool
592
+ call.tool_name == output_tool_name
593
+ and call.tool_call_id == output_tool_call_id
594
+ and not found_used_output_tool
583
595
  ):
584
- found_used_result_tool = True
596
+ found_used_output_tool = True
585
597
  output_parts.append(
586
598
  _messages.ToolReturnPart(
587
599
  tool_name=call.tool_name,
@@ -618,15 +630,15 @@ async def process_function_tools(
618
630
  yield event
619
631
  call_index_to_event_id[len(calls_to_run)] = event.call_id
620
632
  calls_to_run.append((mcp_tool, call))
621
- elif result_schema is not None and call.tool_name in result_schema.tools:
622
- # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
633
+ elif output_schema is not None and call.tool_name in output_schema.tools:
634
+ # if tool_name is in output_schema, it means we found a output tool but an error occurred in
623
635
  # validation, we don't add another part here
624
- if result_tool_name is not None:
625
- if found_used_result_tool:
626
- content = 'Result tool not used - a final result was already processed.'
636
+ if output_tool_name is not None:
637
+ if found_used_output_tool:
638
+ content = 'Output tool not used - a final result was already processed.'
627
639
  else:
628
640
  # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
629
- content = 'Result tool not used - result failed validation.'
641
+ content = 'Output tool not used - result failed validation.'
630
642
  part = _messages.ToolReturnPart(
631
643
  tool_name=call.tool_name,
632
644
  content=content,
@@ -706,8 +718,8 @@ def _unknown_tool(
706
718
  ) -> _messages.RetryPromptPart:
707
719
  ctx.state.increment_retries(ctx.deps.max_result_retries)
708
720
  tool_names = list(ctx.deps.function_tools.keys())
709
- if result_schema := ctx.deps.result_schema:
710
- tool_names.extend(result_schema.tool_names())
721
+ if output_schema := ctx.deps.output_schema:
722
+ tool_names.extend(output_schema.tool_names())
711
723
 
712
724
  if tool_names:
713
725
  msg = f'Available tools: {", ".join(tool_names)}'
@@ -721,20 +733,20 @@ def _unknown_tool(
721
733
  )
722
734
 
723
735
 
724
- async def _validate_result(
736
+ async def _validate_output(
725
737
  result_data: T,
726
738
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
727
739
  tool_call: _messages.ToolCallPart | None,
728
740
  ) -> T:
729
- for validator in ctx.deps.result_validators:
741
+ for validator in ctx.deps.output_validators:
730
742
  run_context = build_run_context(ctx)
731
743
  result_data = await validator.validate(result_data, tool_call, run_context)
732
744
  return result_data
733
745
 
734
746
 
735
- def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
747
+ def allow_text_output(output_schema: _output.OutputSchema[Any] | None) -> bool:
736
748
  """Check if the result schema allows text results."""
737
- return result_schema is None or result_schema.allow_text_result
749
+ return output_schema is None or output_schema.allow_text_output
738
750
 
739
751
 
740
752
  @dataclasses.dataclass
@@ -786,19 +798,19 @@ def get_captured_run_messages() -> _RunMessages:
786
798
 
787
799
 
788
800
  def build_agent_graph(
789
- name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
790
- ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]:
801
+ name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT]
802
+ ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
791
803
  """Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
792
804
  nodes = (
793
805
  UserPromptNode[DepsT],
794
806
  ModelRequestNode[DepsT],
795
807
  CallToolsNode[DepsT],
796
808
  )
797
- graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
809
+ graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]](
798
810
  nodes=nodes,
799
811
  name=name or 'Agent',
800
812
  state_type=GraphAgentState,
801
- run_end_type=result.FinalResult[result_type],
813
+ run_end_type=result.FinalResult[OutputT],
802
814
  auto_instrument=False,
803
815
  )
804
816
  return graph
@@ -208,14 +208,13 @@ async def ask_agent(
208
208
  if not stream:
209
209
  with status:
210
210
  result = await agent.run(prompt, message_history=messages)
211
- content = result.data
211
+ content = result.output
212
212
  console.print(Markdown(content, code_theme=code_theme))
213
213
  return result.all_messages()
214
214
 
215
215
  with status, ExitStack() as stack:
216
216
  async with agent.iter(prompt, message_history=messages) as agent_run:
217
217
  live = Live('', refresh_per_second=15, console=console, vertical_overflow='visible')
218
- content: str = ''
219
218
  async for node in agent_run:
220
219
  if Agent.is_model_request_node(node):
221
220
  async with node.stream(agent_run.ctx) as handle_stream:
@@ -12,7 +12,7 @@ from typing_inspection.introspection import is_union_origin
12
12
 
13
13
  from . import _utils, messages as _messages
14
14
  from .exceptions import ModelRetry
15
- from .result import ResultDataT, ResultDataT_inv, ResultValidatorFunc
15
+ from .result import DEFAULT_OUTPUT_TOOL_NAME, OutputDataT, OutputDataT_inv, OutputValidatorFunc, ToolOutput
16
16
  from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition
17
17
 
18
18
  T = TypeVar('T')
@@ -20,8 +20,8 @@ T = TypeVar('T')
20
20
 
21
21
 
22
22
  @dataclass
23
- class ResultValidator(Generic[AgentDepsT, ResultDataT_inv]):
24
- function: ResultValidatorFunc[AgentDepsT, ResultDataT_inv]
23
+ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
24
+ function: OutputValidatorFunc[AgentDepsT, OutputDataT_inv]
25
25
  _takes_ctx: bool = field(init=False)
26
26
  _is_async: bool = field(init=False)
27
27
 
@@ -77,47 +77,68 @@ class ToolRetryError(Exception):
77
77
 
78
78
 
79
79
  @dataclass
80
- class ResultSchema(Generic[ResultDataT]):
80
+ class OutputSchema(Generic[OutputDataT]):
81
81
  """Model the final response from an agent run.
82
82
 
83
- Similar to `Tool` but for the final result of running an agent.
83
+ Similar to `Tool` but for the final output of running an agent.
84
84
  """
85
85
 
86
- tools: dict[str, ResultTool[ResultDataT]]
87
- allow_text_result: bool
86
+ tools: dict[str, OutputSchemaTool[OutputDataT]]
87
+ allow_text_output: bool
88
88
 
89
89
  @classmethod
90
90
  def build(
91
- cls: type[ResultSchema[T]], response_type: type[T], name: str, description: str | None
92
- ) -> ResultSchema[T] | None:
93
- """Build a ResultSchema dataclass from a response type."""
94
- if response_type is str:
91
+ cls: type[OutputSchema[T]],
92
+ output_type: type[T] | ToolOutput[T],
93
+ name: str | None = None,
94
+ description: str | None = None,
95
+ strict: bool | None = None,
96
+ ) -> OutputSchema[T] | None:
97
+ """Build an OutputSchema dataclass from a response type."""
98
+ if output_type is str:
95
99
  return None
96
100
 
97
- if response_type_option := extract_str_from_union(response_type):
98
- response_type = response_type_option.value
99
- allow_text_result = True
101
+ if isinstance(output_type, ToolOutput):
102
+ # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
103
+ name = output_type.name
104
+ description = output_type.description
105
+ output_type_ = output_type.output_type
106
+ strict = output_type.strict
100
107
  else:
101
- allow_text_result = False
108
+ output_type_ = output_type
102
109
 
103
- def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[T]:
104
- return cast(ResultTool[T], ResultTool(a, tool_name_, description, multiple))
110
+ if output_type_option := extract_str_from_union(output_type):
111
+ output_type_ = output_type_option.value
112
+ allow_text_output = True
113
+ else:
114
+ allow_text_output = False
105
115
 
106
- tools: dict[str, ResultTool[T]] = {}
107
- if args := get_union_args(response_type):
116
+ tools: dict[str, OutputSchemaTool[T]] = {}
117
+ if args := get_union_args(output_type_):
108
118
  for i, arg in enumerate(args, start=1):
109
- tool_name = union_tool_name(name, arg)
119
+ tool_name = raw_tool_name = union_tool_name(name, arg)
110
120
  while tool_name in tools:
111
- tool_name = f'{tool_name}_{i}'
112
- tools[tool_name] = _build_tool(arg, tool_name, True)
121
+ tool_name = f'{raw_tool_name}_{i}'
122
+ tools[tool_name] = cast(
123
+ OutputSchemaTool[T],
124
+ OutputSchemaTool(
125
+ output_type=arg, name=tool_name, description=description, multiple=True, strict=strict
126
+ ),
127
+ )
113
128
  else:
114
- tools[name] = _build_tool(response_type, name, False)
129
+ name = name or DEFAULT_OUTPUT_TOOL_NAME
130
+ tools[name] = cast(
131
+ OutputSchemaTool[T],
132
+ OutputSchemaTool(
133
+ output_type=output_type_, name=name, description=description, multiple=False, strict=strict
134
+ ),
135
+ )
115
136
 
116
- return cls(tools=tools, allow_text_result=allow_text_result)
137
+ return cls(tools=tools, allow_text_output=allow_text_output)
117
138
 
118
139
  def find_named_tool(
119
140
  self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
120
- ) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
141
+ ) -> tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]] | None:
121
142
  """Find a tool that matches one of the calls, with a specific name."""
122
143
  for part in parts:
123
144
  if isinstance(part, _messages.ToolCallPart):
@@ -127,7 +148,7 @@ class ResultSchema(Generic[ResultDataT]):
127
148
  def find_tool(
128
149
  self,
129
150
  parts: Iterable[_messages.ModelResponsePart],
130
- ) -> Iterator[tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]]:
151
+ ) -> Iterator[tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]]]:
131
152
  """Find a tool that matches one of the calls."""
132
153
  for part in parts:
133
154
  if isinstance(part, _messages.ToolCallPart):
@@ -147,16 +168,16 @@ DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
147
168
 
148
169
 
149
170
  @dataclass(init=False)
150
- class ResultTool(Generic[ResultDataT]):
171
+ class OutputSchemaTool(Generic[OutputDataT]):
151
172
  tool_def: ToolDefinition
152
173
  type_adapter: TypeAdapter[Any]
153
174
 
154
- def __init__(self, response_type: type[ResultDataT], name: str, description: str | None, multiple: bool):
155
- """Build a ResultTool dataclass from a response type."""
156
- assert response_type is not str, 'ResultTool does not support str as a response type'
157
-
158
- if _utils.is_model_like(response_type):
159
- self.type_adapter = TypeAdapter(response_type)
175
+ def __init__(
176
+ self, *, output_type: type[OutputDataT], name: str, description: str | None, multiple: bool, strict: bool | None
177
+ ):
178
+ """Build a OutputSchemaTool from a response type."""
179
+ if _utils.is_model_like(output_type):
180
+ self.type_adapter = TypeAdapter(output_type)
160
181
  outer_typed_dict_key: str | None = None
161
182
  # noinspection PyArgumentList
162
183
  parameters_json_schema = _utils.check_object_json_schema(
@@ -165,7 +186,7 @@ class ResultTool(Generic[ResultDataT]):
165
186
  else:
166
187
  response_data_typed_dict = TypedDict( # noqa: UP013
167
188
  'response_data_typed_dict',
168
- {'response': response_type}, # pyright: ignore[reportInvalidTypeForm]
189
+ {'response': output_type}, # pyright: ignore[reportInvalidTypeForm]
169
190
  )
170
191
  self.type_adapter = TypeAdapter(response_data_typed_dict)
171
192
  outer_typed_dict_key = 'response'
@@ -184,19 +205,20 @@ class ResultTool(Generic[ResultDataT]):
184
205
  else:
185
206
  tool_description = description or DEFAULT_DESCRIPTION
186
207
  if multiple:
187
- tool_description = f'{union_arg_name(response_type)}: {tool_description}'
208
+ tool_description = f'{union_arg_name(output_type)}: {tool_description}'
188
209
 
189
210
  self.tool_def = ToolDefinition(
190
211
  name=name,
191
212
  description=tool_description,
192
213
  parameters_json_schema=parameters_json_schema,
193
214
  outer_typed_dict_key=outer_typed_dict_key,
215
+ strict=strict,
194
216
  )
195
217
 
196
218
  def validate(
197
219
  self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
198
- ) -> ResultDataT:
199
- """Validate a result message.
220
+ ) -> OutputDataT:
221
+ """Validate an output message.
200
222
 
201
223
  Args:
202
224
  tool_call: The tool call from the LLM to validate.
@@ -204,14 +226,14 @@ class ResultTool(Generic[ResultDataT]):
204
226
  wrap_validation_errors: If true, wrap the validation errors in a retry message.
205
227
 
206
228
  Returns:
207
- Either the validated result data (left) or a retry message (right).
229
+ Either the validated output data (left) or a retry message (right).
208
230
  """
209
231
  try:
210
232
  pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
211
233
  if isinstance(tool_call.args, str):
212
- result = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial)
234
+ output = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial)
213
235
  else:
214
- result = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial)
236
+ output = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial)
215
237
  except ValidationError as e:
216
238
  if wrap_validation_errors:
217
239
  m = _messages.RetryPromptPart(
@@ -224,21 +246,21 @@ class ResultTool(Generic[ResultDataT]):
224
246
  raise
225
247
  else:
226
248
  if k := self.tool_def.outer_typed_dict_key:
227
- result = result[k]
228
- return result
249
+ output = output[k]
250
+ return output
229
251
 
230
252
 
231
- def union_tool_name(base_name: str, union_arg: Any) -> str:
232
- return f'{base_name}_{union_arg_name(union_arg)}'
253
+ def union_tool_name(base_name: str | None, union_arg: Any) -> str:
254
+ return f'{base_name or DEFAULT_OUTPUT_TOOL_NAME}_{union_arg_name(union_arg)}'
233
255
 
234
256
 
235
257
  def union_arg_name(union_arg: Any) -> str:
236
258
  return union_arg.__name__
237
259
 
238
260
 
239
- def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
261
+ def extract_str_from_union(output_type: Any) -> _utils.Option[Any]:
240
262
  """Extract the string type from a Union, return the remaining union or remaining type."""
241
- union_args = get_union_args(response_type)
263
+ union_args = get_union_args(output_type)
242
264
  if any(t is str for t in union_args):
243
265
  remain_args: list[Any] = []
244
266
  includes_str = False
@@ -255,7 +277,7 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
255
277
 
256
278
 
257
279
  def get_union_args(tp: Any) -> tuple[Any, ...]:
258
- """Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty tuple."""
280
+ """Extract the arguments of a Union type if `output_type` is a union, otherwise return an empty tuple."""
259
281
  if typing_objects.is_typealiastype(tp):
260
282
  tp = tp.__value__
261
283