pydantic-ai-slim 0.0.46__tar.gz → 0.0.48__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (51) hide show
  1. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/PKG-INFO +6 -4
  2. pydantic_ai_slim-0.0.48/pydantic_ai/__main__.py +6 -0
  3. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/_agent_graph.py +19 -13
  4. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/_cli.py +120 -77
  5. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/_result.py +11 -4
  6. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/_utils.py +1 -1
  7. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/agent.py +30 -30
  8. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/messages.py +1 -1
  9. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/__init__.py +206 -193
  10. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/anthropic.py +4 -1
  11. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/bedrock.py +7 -0
  12. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/cohere.py +4 -1
  13. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/gemini.py +4 -1
  14. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/groq.py +32 -15
  15. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/instrumented.py +6 -1
  16. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/mistral.py +6 -1
  17. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/openai.py +415 -11
  18. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/bedrock.py +11 -0
  19. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/tools.py +34 -3
  20. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pyproject.toml +18 -6
  21. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/.gitignore +0 -0
  22. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/README.md +0 -0
  23. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/__init__.py +0 -0
  24. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/_griffe.py +0 -0
  25. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/_parts_manager.py +0 -0
  26. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/_pydantic.py +0 -0
  27. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/_system_prompt.py +0 -0
  28. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/common_tools/__init__.py +0 -0
  29. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  30. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/common_tools/tavily.py +0 -0
  31. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/exceptions.py +0 -0
  32. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/format_as_xml.py +0 -0
  33. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/mcp.py +0 -0
  34. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/fallback.py +0 -0
  35. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/function.py +0 -0
  36. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/test.py +0 -0
  37. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/models/wrapper.py +0 -0
  38. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/__init__.py +0 -0
  39. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/anthropic.py +0 -0
  40. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/azure.py +0 -0
  41. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/cohere.py +0 -0
  42. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/deepseek.py +0 -0
  43. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/google_gla.py +0 -0
  44. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/google_vertex.py +0 -0
  45. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/groq.py +0 -0
  46. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/mistral.py +0 -0
  47. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/providers/openai.py +0 -0
  48. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/py.typed +0 -0
  49. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/result.py +0 -0
  50. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/settings.py +0 -0
  51. {pydantic_ai_slim-0.0.46 → pydantic_ai_slim-0.0.48}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.46
3
+ Version: 0.0.48
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.46
32
+ Requires-Dist: pydantic-graph==0.0.48
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
@@ -41,13 +41,15 @@ Requires-Dist: argcomplete>=3.5.0; extra == 'cli'
41
41
  Requires-Dist: prompt-toolkit>=3; extra == 'cli'
42
42
  Requires-Dist: rich>=13; extra == 'cli'
43
43
  Provides-Extra: cohere
44
- Requires-Dist: cohere>=5.13.11; extra == 'cohere'
44
+ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == 'cohere'
45
45
  Provides-Extra: duckduckgo
46
46
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
47
+ Provides-Extra: evals
48
+ Requires-Dist: pydantic-evals==0.0.48; extra == 'evals'
47
49
  Provides-Extra: groq
48
50
  Requires-Dist: groq>=0.15.0; extra == 'groq'
49
51
  Provides-Extra: logfire
50
- Requires-Dist: logfire>=2.3; extra == 'logfire'
52
+ Requires-Dist: logfire>=3.11.0; extra == 'logfire'
51
53
  Provides-Extra: mcp
52
54
  Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
53
55
  Provides-Extra: mistral
@@ -0,0 +1,6 @@
1
+ """This means `python -m pydantic_ai` should run the CLI."""
2
+
3
+ from ._cli import app
4
+
5
+ if __name__ == '__main__':
6
+ app()
@@ -331,10 +331,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
331
331
  ctx.state.run_step += 1
332
332
 
333
333
  model_settings = merge_model_settings(ctx.deps.model_settings, None)
334
- with ctx.deps.tracer.start_as_current_span(
335
- 'preparing model request params', attributes=dict(run_step=ctx.state.run_step)
336
- ):
337
- model_request_parameters = await _prepare_request_parameters(ctx)
334
+ model_request_parameters = await _prepare_request_parameters(ctx)
338
335
  return model_settings, model_request_parameters
339
336
 
340
337
  def _finish_handling(
@@ -374,9 +371,8 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
374
371
  ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007
375
372
  async with self.stream(ctx):
376
373
  pass
377
-
378
- assert (next_node := self._next_node) is not None, 'the stream should set `self._next_node` before it ends'
379
- return next_node
374
+ assert self._next_node is not None, 'the stream should set `self._next_node` before it ends'
375
+ return self._next_node
380
376
 
381
377
  @asynccontextmanager
382
378
  async def stream(
@@ -635,19 +631,24 @@ async def process_function_tools(
635
631
  )
636
632
  output_parts.append(part)
637
633
  else:
638
- output_parts.append(_unknown_tool(call.tool_name, ctx))
634
+ output_parts.append(_unknown_tool(call.tool_name, call.tool_call_id, ctx))
639
635
 
640
636
  if not calls_to_run:
641
637
  return
642
638
 
643
639
  # Run all tool tasks in parallel
644
640
  results_by_index: dict[int, _messages.ModelRequestPart] = {}
645
- tool_names = [call.tool_name for _, call in calls_to_run]
646
641
  with ctx.deps.tracer.start_as_current_span(
647
- 'running tools', attributes={'tools': tool_names, 'logfire.msg': f'running tools: {", ".join(tool_names)}'}
642
+ 'running tools',
643
+ attributes={
644
+ 'tools': [call.tool_name for _, call in calls_to_run],
645
+ 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
646
+ },
648
647
  ):
649
- # TODO: Should we wrap each individual tool call in a dedicated span?
650
- tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
648
+ tasks = [
649
+ asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer), name=call.tool_name)
650
+ for tool, call in calls_to_run
651
+ ]
651
652
  pending = tasks
652
653
  while pending:
653
654
  done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
@@ -697,6 +698,7 @@ async def _tool_from_mcp_server(
697
698
 
698
699
  def _unknown_tool(
699
700
  tool_name: str,
701
+ tool_call_id: str,
700
702
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
701
703
  ) -> _messages.RetryPromptPart:
702
704
  ctx.state.increment_retries(ctx.deps.max_result_retries)
@@ -709,7 +711,11 @@ def _unknown_tool(
709
711
  else:
710
712
  msg = 'No tools available.'
711
713
 
712
- return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
714
+ return _messages.RetryPromptPart(
715
+ tool_name=tool_name,
716
+ tool_call_id=tool_call_id,
717
+ content=f'Unknown tool name: {tool_name!r}. {msg}',
718
+ )
713
719
 
714
720
 
715
721
  async def _validate_result(
@@ -4,6 +4,7 @@ import argparse
4
4
  import asyncio
5
5
  import sys
6
6
  from collections.abc import Sequence
7
+ from contextlib import ExitStack
7
8
  from datetime import datetime, timezone
8
9
  from importlib.metadata import version
9
10
  from pathlib import Path
@@ -11,9 +12,10 @@ from typing import cast
11
12
 
12
13
  from typing_inspection.introspection import get_literal_values
13
14
 
15
+ from pydantic_ai.agent import Agent
14
16
  from pydantic_ai.exceptions import UserError
17
+ from pydantic_ai.messages import ModelMessage, PartDeltaEvent, TextPartDelta
15
18
  from pydantic_ai.models import KnownModelName
16
- from pydantic_graph.nodes import End
17
19
 
18
20
  try:
19
21
  import argcomplete
@@ -24,8 +26,9 @@ try:
24
26
  from prompt_toolkit.history import FileHistory
25
27
  from rich.console import Console, ConsoleOptions, RenderResult
26
28
  from rich.live import Live
27
- from rich.markdown import CodeBlock, Markdown
29
+ from rich.markdown import CodeBlock, Heading, Markdown
28
30
  from rich.status import Status
31
+ from rich.style import Style
29
32
  from rich.syntax import Syntax
30
33
  from rich.text import Text
31
34
  except ImportError as _import_error:
@@ -34,13 +37,16 @@ except ImportError as _import_error:
34
37
  'you can use the `cli` optional group — `pip install "pydantic-ai-slim[cli]"`'
35
38
  ) from _import_error
36
39
 
37
- from pydantic_ai.agent import Agent
38
- from pydantic_ai.messages import ModelMessage, PartDeltaEvent, TextPartDelta
39
40
 
40
41
  __version__ = version('pydantic-ai-slim')
41
42
 
42
43
 
43
44
  class SimpleCodeBlock(CodeBlock):
45
+ """Customised code blocks in markdown.
46
+
47
+ This avoids a background color which messes up copy-pasting and sets the language name as dim prefix and suffix.
48
+ """
49
+
44
50
  def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: # pragma: no cover
45
51
  code = str(self.text).rstrip()
46
52
  yield Text(self.lexer_name, style='dim')
@@ -48,7 +54,18 @@ class SimpleCodeBlock(CodeBlock):
48
54
  yield Text(f'/{self.lexer_name}', style='dim')
49
55
 
50
56
 
51
- Markdown.elements['fence'] = SimpleCodeBlock
57
+ class LeftHeading(Heading):
58
+ """Customised headings in markdown to stop centering and prepend markdown style hashes."""
59
+
60
+ def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: # pragma: no cover
61
+ # note we use `Style(bold=True)` not `self.style_name` here to disable underlining which is ugly IMHO
62
+ yield Text(f'{"#" * int(self.tag[1:])} {self.text.plain}', style=Style(bold=True))
63
+
64
+
65
+ Markdown.elements.update(
66
+ fence=SimpleCodeBlock,
67
+ heading_open=LeftHeading,
68
+ )
52
69
 
53
70
 
54
71
  def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma: no cover
@@ -65,28 +82,53 @@ Special prompt:
65
82
  formatter_class=argparse.RawTextHelpFormatter,
66
83
  )
67
84
  parser.add_argument('prompt', nargs='?', help='AI Prompt, if omitted fall into interactive mode')
68
- parser.add_argument(
85
+ arg = parser.add_argument(
86
+ '-m',
69
87
  '--model',
70
88
  nargs='?',
71
- help='Model to use, it should be "<provider>:<model>" e.g. "openai:gpt-4o". If omitted it will default to "openai:gpt-4o"',
89
+ help='Model to use, in format "<provider>:<model>" e.g. "openai:gpt-4o". Defaults to "openai:gpt-4o".',
72
90
  default='openai:gpt-4o',
73
- ).completer = argcomplete.ChoicesCompleter(list(get_literal_values(KnownModelName))) # type: ignore[reportPrivateUsage]
74
- parser.add_argument('--no-stream', action='store_true', help='Whether to stream responses from OpenAI')
91
+ )
92
+ # we don't want to autocomplete or list models that don't include the provider,
93
+ # e.g. we want to show `openai:gpt-4o` but not `gpt-4o`
94
+ qualified_model_names = [n for n in get_literal_values(KnownModelName.__value__) if ':' in n]
95
+ arg.completer = argcomplete.ChoicesCompleter(qualified_model_names) # type: ignore[reportPrivateUsage]
96
+ parser.add_argument(
97
+ '-l',
98
+ '--list-models',
99
+ action='store_true',
100
+ help='List all available models and exit',
101
+ )
102
+ parser.add_argument(
103
+ '-t',
104
+ '--code-theme',
105
+ nargs='?',
106
+ help='Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "monokai".',
107
+ default='monokai',
108
+ )
109
+ parser.add_argument('--no-stream', action='store_true', help='Whether to stream responses from the model')
75
110
  parser.add_argument('--version', action='store_true', help='Show version and exit')
76
111
 
77
112
  argcomplete.autocomplete(parser)
78
113
  args = parser.parse_args(args_list)
79
114
 
80
115
  console = Console()
81
- console.print(f'pai - PydanticAI CLI v{__version__}', style='green bold', highlight=False)
116
+ console.print(
117
+ f'[green]pai - PydanticAI CLI v{__version__} using[/green] [magenta]{args.model}[/magenta]', highlight=False
118
+ )
82
119
  if args.version:
83
120
  return 0
121
+ if args.list_models:
122
+ console.print('Available models:', style='green bold')
123
+ for model in qualified_model_names:
124
+ console.print(f' {model}', highlight=False)
125
+ return 0
84
126
 
85
127
  now_utc = datetime.now(timezone.utc)
86
128
  tzname = now_utc.astimezone().tzinfo.tzname(now_utc) # type: ignore
87
129
  try:
88
130
  agent = Agent(
89
- model=args.model or 'openai:gpt-4o',
131
+ model=args.model,
90
132
  system_prompt=f"""\
91
133
  Help the user by responding to their request, the output should be concise and always written in markdown.
92
134
  The current date and time is {datetime.now()} {tzname}.
@@ -97,10 +139,16 @@ Special prompt:
97
139
  return 1
98
140
 
99
141
  stream = not args.no_stream
142
+ if args.code_theme == 'light':
143
+ code_theme = 'default'
144
+ elif args.code_theme == 'dark':
145
+ code_theme = 'monokai'
146
+ else:
147
+ code_theme = args.code_theme
100
148
 
101
149
  if prompt := cast(str, args.prompt):
102
150
  try:
103
- asyncio.run(ask_agent(agent, prompt, stream, console))
151
+ asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
104
152
  except KeyboardInterrupt:
105
153
  pass
106
154
  return 0
@@ -121,37 +169,46 @@ Special prompt:
121
169
  continue
122
170
 
123
171
  ident_prompt = text.lower().strip(' ').replace(' ', '-').lstrip(' ')
124
- if ident_prompt == '/markdown':
125
- try:
126
- parts = messages[-1].parts
127
- except IndexError:
128
- console.print('[dim]No markdown output available.[/dim]')
129
- continue
130
- for part in parts:
131
- if part.part_kind == 'text':
132
- last_content = part.content
133
- console.print('[dim]Last markdown output of last question:[/dim]\n')
134
- console.print(Syntax(last_content, lexer='markdown', background_color='default'))
135
-
136
- continue
137
- if ident_prompt == '/multiline':
138
- multiline = not multiline
139
- if multiline:
140
- console.print(
141
- 'Enabling multiline mode. '
142
- '[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
143
- )
172
+ if ident_prompt.startswith('/'):
173
+ if ident_prompt == '/markdown':
174
+ try:
175
+ parts = messages[-1].parts
176
+ except IndexError:
177
+ console.print('[dim]No markdown output available.[/dim]')
178
+ continue
179
+ console.print('[dim]Markdown output of last question:[/dim]\n')
180
+ for part in parts:
181
+ if part.part_kind == 'text':
182
+ console.print(
183
+ Syntax(
184
+ part.content,
185
+ lexer='markdown',
186
+ theme=code_theme,
187
+ word_wrap=True,
188
+ background_color='default',
189
+ )
190
+ )
191
+
192
+ elif ident_prompt == '/multiline':
193
+ multiline = not multiline
194
+ if multiline:
195
+ console.print(
196
+ 'Enabling multiline mode. '
197
+ '[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
198
+ )
199
+ else:
200
+ console.print('Disabling multiline mode.')
201
+ elif ident_prompt == '/exit':
202
+ console.print('[dim]Exiting…[/dim]')
203
+ return 0
144
204
  else:
145
- console.print('Disabling multiline mode.')
146
- continue
147
- if ident_prompt == '/exit':
148
- console.print('[dim]Exiting…[/dim]')
149
- return 0
150
-
151
- try:
152
- messages = asyncio.run(ask_agent(agent, text, stream, console, messages))
153
- except KeyboardInterrupt:
154
- return 0
205
+ console.print(f'[red]Unknown command[/red] [magenta]`{ident_prompt}`[/magenta]')
206
+ else:
207
+ try:
208
+ messages = asyncio.run(ask_agent(agent, text, stream, console, code_theme, messages))
209
+ except KeyboardInterrupt:
210
+ console.print('[dim]Interrupted[/dim]')
211
+ messages = []
155
212
 
156
213
 
157
214
  async def ask_agent(
@@ -159,48 +216,34 @@ async def ask_agent(
159
216
  prompt: str,
160
217
  stream: bool,
161
218
  console: Console,
219
+ code_theme: str,
162
220
  messages: list[ModelMessage] | None = None,
163
221
  ) -> list[ModelMessage]: # pragma: no cover
164
- status: None | Status = Status('[dim]Working on it…[/dim]', console=console)
165
- live = Live('', refresh_per_second=15, console=console)
166
- status.start()
167
-
168
- async with agent.iter(prompt, message_history=messages) as agent_run:
169
- console.print('\nResponse:', style='green')
170
-
171
- content: str = ''
172
- interrupted = False
173
- try:
174
- node = agent_run.next_node
175
- while not isinstance(node, End):
176
- node = await agent_run.next(node)
222
+ status = Status('[dim]Working on it…[/dim]', console=console)
223
+
224
+ if not stream:
225
+ with status:
226
+ result = await agent.run(prompt, message_history=messages)
227
+ content = result.data
228
+ console.print(Markdown(content, code_theme=code_theme))
229
+ return result.all_messages()
230
+
231
+ with status, ExitStack() as stack:
232
+ async with agent.iter(prompt, message_history=messages) as agent_run:
233
+ live = Live('', refresh_per_second=15, console=console, vertical_overflow='visible')
234
+ content: str = ''
235
+ async for node in agent_run:
177
236
  if Agent.is_model_request_node(node):
178
237
  async with node.stream(agent_run.ctx) as handle_stream:
179
- # NOTE(Marcelo): It took me a lot of time to figure out how to stop `status` and start `live`
180
- # in a context manager, so I had to do it manually with `stop` and `start` methods.
181
- # PR welcome to simplify this code.
182
- if status is not None:
183
- status.stop()
184
- status = None
185
- if not live.is_started:
186
- live.start()
238
+ status.stop() # stopping multiple times is idempotent
239
+ stack.enter_context(live) # entering multiple times is idempotent
240
+
187
241
  async for event in handle_stream:
188
242
  if isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta):
189
- if stream:
190
- content += event.delta.content_delta
191
- live.update(Markdown(content))
192
- except KeyboardInterrupt:
193
- interrupted = True
194
- finally:
195
- live.stop()
196
-
197
- if interrupted:
198
- console.print('[dim]Interrupted[/dim]')
243
+ content += event.delta.content_delta
244
+ live.update(Markdown(content, code_theme=code_theme))
199
245
 
200
- assert agent_run.result
201
- if not stream:
202
- content = agent_run.result.data
203
- console.print(Markdown(content))
246
+ assert agent_run.result is not None
204
247
  return agent_run.result.all_messages()
205
248
 
206
249
 
@@ -13,7 +13,7 @@ from typing_inspection.introspection import is_union_origin
13
13
  from . import _utils, messages as _messages
14
14
  from .exceptions import ModelRetry
15
15
  from .result import ResultDataT, ResultDataT_inv, ResultValidatorFunc
16
- from .tools import AgentDepsT, RunContext, ToolDefinition
16
+ from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition
17
17
 
18
18
  T = TypeVar('T')
19
19
  """An invariant TypeVar."""
@@ -159,13 +159,20 @@ class ResultTool(Generic[ResultDataT]):
159
159
  self.type_adapter = TypeAdapter(response_type)
160
160
  outer_typed_dict_key: str | None = None
161
161
  # noinspection PyArgumentList
162
- parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
162
+ parameters_json_schema = _utils.check_object_json_schema(
163
+ self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
164
+ )
163
165
  else:
164
- response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
166
+ response_data_typed_dict = TypedDict( # noqa: UP013
167
+ 'response_data_typed_dict',
168
+ {'response': response_type}, # pyright: ignore[reportInvalidTypeForm]
169
+ )
165
170
  self.type_adapter = TypeAdapter(response_data_typed_dict)
166
171
  outer_typed_dict_key = 'response'
167
172
  # noinspection PyArgumentList
168
- parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
173
+ parameters_json_schema = _utils.check_object_json_schema(
174
+ self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
175
+ )
169
176
  # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
170
177
  parameters_json_schema.pop('title')
171
178
 
@@ -40,7 +40,7 @@ def is_model_like(type_: Any) -> bool:
40
40
  return (
41
41
  isinstance(type_, type)
42
42
  and not isinstance(type_, GenericAlias)
43
- and (issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_))
43
+ and (issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_)) # pyright: ignore[reportUnknownArgumentType]
44
44
  )
45
45
 
46
46
 
@@ -94,9 +94,11 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
94
94
  ```
95
95
  """
96
96
 
97
- # we use dataclass fields in order to conveniently know what attributes are available
98
- model: models.Model | models.KnownModelName | None
99
- """The default model configured for this agent."""
97
+ model: models.Model | models.KnownModelName | str | None
98
+ """The default model configured for this agent.
99
+
100
+ We allow str here since the actual list of allowed models changes frequently.
101
+ """
100
102
 
101
103
  name: str | None
102
104
  """The name of the agent, used for logging.
@@ -142,7 +144,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
142
144
 
143
145
  def __init__(
144
146
  self,
145
- model: models.Model | models.KnownModelName | None = None,
147
+ model: models.Model | models.KnownModelName | str | None = None,
146
148
  *,
147
149
  result_type: type[ResultDataT] = str,
148
150
  system_prompt: str | Sequence[str] = (),
@@ -163,7 +165,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
163
165
 
164
166
  Args:
165
167
  model: The default model to use for this agent, if not provide,
166
- you must provide the model when calling it.
168
+ you must provide the model when calling it. We allow str here since the actual list of allowed models changes frequently.
167
169
  result_type: The type of the result data, used to validate the result data, defaults to `str`.
168
170
  system_prompt: Static system prompts to use for this agent, you can also register system
169
171
  prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
@@ -212,16 +214,16 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
212
214
 
213
215
  self._result_tool_name = result_tool_name
214
216
  self._result_tool_description = result_tool_description
215
- self._result_schema: _result.ResultSchema[ResultDataT] | None = _result.ResultSchema[result_type].build(
217
+ self._result_schema = _result.ResultSchema[result_type].build(
216
218
  result_type, result_tool_name, result_tool_description
217
219
  )
218
- self._result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = []
220
+ self._result_validators = []
219
221
 
220
222
  self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
221
- self._system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = []
222
- self._system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = {}
223
+ self._system_prompt_functions = []
224
+ self._system_prompt_dynamic_functions = {}
223
225
 
224
- self._function_tools: dict[str, Tool[AgentDepsT]] = {}
226
+ self._function_tools = {}
225
227
 
226
228
  self._default_retries = retries
227
229
  self._max_result_retries = result_retries if result_retries is not None else retries
@@ -244,7 +246,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
244
246
  *,
245
247
  result_type: None = None,
246
248
  message_history: list[_messages.ModelMessage] | None = None,
247
- model: models.Model | models.KnownModelName | None = None,
249
+ model: models.Model | models.KnownModelName | str | None = None,
248
250
  deps: AgentDepsT = None,
249
251
  model_settings: ModelSettings | None = None,
250
252
  usage_limits: _usage.UsageLimits | None = None,
@@ -259,7 +261,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
259
261
  *,
260
262
  result_type: type[RunResultDataT],
261
263
  message_history: list[_messages.ModelMessage] | None = None,
262
- model: models.Model | models.KnownModelName | None = None,
264
+ model: models.Model | models.KnownModelName | str | None = None,
263
265
  deps: AgentDepsT = None,
264
266
  model_settings: ModelSettings | None = None,
265
267
  usage_limits: _usage.UsageLimits | None = None,
@@ -273,7 +275,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
273
275
  *,
274
276
  result_type: type[RunResultDataT] | None = None,
275
277
  message_history: list[_messages.ModelMessage] | None = None,
276
- model: models.Model | models.KnownModelName | None = None,
278
+ model: models.Model | models.KnownModelName | str | None = None,
277
279
  deps: AgentDepsT = None,
278
280
  model_settings: ModelSettings | None = None,
279
281
  usage_limits: _usage.UsageLimits | None = None,
@@ -327,8 +329,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
327
329
  async for _ in agent_run:
328
330
  pass
329
331
 
330
- assert (final_result := agent_run.result) is not None, 'The graph run did not finish properly'
331
- return final_result
332
+ assert agent_run.result is not None, 'The graph run did not finish properly'
333
+ return agent_run.result
332
334
 
333
335
  @asynccontextmanager
334
336
  async def iter(
@@ -337,7 +339,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
337
339
  *,
338
340
  result_type: type[RunResultDataT] | None = None,
339
341
  message_history: list[_messages.ModelMessage] | None = None,
340
- model: models.Model | models.KnownModelName | None = None,
342
+ model: models.Model | models.KnownModelName | str | None = None,
341
343
  deps: AgentDepsT = None,
342
344
  model_settings: ModelSettings | None = None,
343
345
  usage_limits: _usage.UsageLimits | None = None,
@@ -498,7 +500,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
498
500
  user_prompt: str | Sequence[_messages.UserContent],
499
501
  *,
500
502
  message_history: list[_messages.ModelMessage] | None = None,
501
- model: models.Model | models.KnownModelName | None = None,
503
+ model: models.Model | models.KnownModelName | str | None = None,
502
504
  deps: AgentDepsT = None,
503
505
  model_settings: ModelSettings | None = None,
504
506
  usage_limits: _usage.UsageLimits | None = None,
@@ -513,7 +515,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
513
515
  *,
514
516
  result_type: type[RunResultDataT] | None,
515
517
  message_history: list[_messages.ModelMessage] | None = None,
516
- model: models.Model | models.KnownModelName | None = None,
518
+ model: models.Model | models.KnownModelName | str | None = None,
517
519
  deps: AgentDepsT = None,
518
520
  model_settings: ModelSettings | None = None,
519
521
  usage_limits: _usage.UsageLimits | None = None,
@@ -527,7 +529,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
527
529
  *,
528
530
  result_type: type[RunResultDataT] | None = None,
529
531
  message_history: list[_messages.ModelMessage] | None = None,
530
- model: models.Model | models.KnownModelName | None = None,
532
+ model: models.Model | models.KnownModelName | str | None = None,
531
533
  deps: AgentDepsT = None,
532
534
  model_settings: ModelSettings | None = None,
533
535
  usage_limits: _usage.UsageLimits | None = None,
@@ -588,7 +590,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
588
590
  *,
589
591
  result_type: None = None,
590
592
  message_history: list[_messages.ModelMessage] | None = None,
591
- model: models.Model | models.KnownModelName | None = None,
593
+ model: models.Model | models.KnownModelName | str | None = None,
592
594
  deps: AgentDepsT = None,
593
595
  model_settings: ModelSettings | None = None,
594
596
  usage_limits: _usage.UsageLimits | None = None,
@@ -603,7 +605,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
603
605
  *,
604
606
  result_type: type[RunResultDataT],
605
607
  message_history: list[_messages.ModelMessage] | None = None,
606
- model: models.Model | models.KnownModelName | None = None,
608
+ model: models.Model | models.KnownModelName | str | None = None,
607
609
  deps: AgentDepsT = None,
608
610
  model_settings: ModelSettings | None = None,
609
611
  usage_limits: _usage.UsageLimits | None = None,
@@ -618,7 +620,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
618
620
  *,
619
621
  result_type: type[RunResultDataT] | None = None,
620
622
  message_history: list[_messages.ModelMessage] | None = None,
621
- model: models.Model | models.KnownModelName | None = None,
623
+ model: models.Model | models.KnownModelName | str | None = None,
622
624
  deps: AgentDepsT = None,
623
625
  model_settings: ModelSettings | None = None,
624
626
  usage_limits: _usage.UsageLimits | None = None,
@@ -757,12 +759,12 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
757
759
  self,
758
760
  *,
759
761
  deps: AgentDepsT | _utils.Unset = _utils.UNSET,
760
- model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
762
+ model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
761
763
  ) -> Iterator[None]:
762
764
  """Context manager to temporarily override agent dependencies and model.
763
765
 
764
766
  This is particularly useful when testing.
765
- You can find an example of this [here](../testing-evals.md#overriding-model-via-pytest-fixtures).
767
+ You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
766
768
 
767
769
  Args:
768
770
  deps: The dependencies to use instead of the dependencies passed to the agent run.
@@ -774,11 +776,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
774
776
  else:
775
777
  override_deps_before = _utils.UNSET
776
778
 
777
- # noinspection PyTypeChecker
778
779
  if _utils.is_set(model):
779
780
  override_model_before = self._override_model
780
- # noinspection PyTypeChecker
781
- self._override_model = _utils.Some(models.infer_model(model)) # pyright: ignore[reportArgumentType]
781
+ self._override_model = _utils.Some(models.infer_model(model))
782
782
  else:
783
783
  override_model_before = _utils.UNSET
784
784
 
@@ -1154,7 +1154,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1154
1154
 
1155
1155
  self._function_tools[tool.name] = tool
1156
1156
 
1157
- def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
1157
+ def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model:
1158
1158
  """Create a model configured for this agent.
1159
1159
 
1160
1160
  Args:
@@ -1168,7 +1168,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1168
1168
  # we don't want `override()` to cover up errors from the model not being defined, hence this check
1169
1169
  if model is None and self.model is None:
1170
1170
  raise exceptions.UserError(
1171
- '`model` must be set either when creating the agent or when calling it. '
1171
+ '`model` must either be set on the agent or included when calling it. '
1172
1172
  '(Even when `override(model=...)` is customizing the model that will actually be called)'
1173
1173
  )
1174
1174
  model_ = some_model.value
@@ -1178,7 +1178,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1178
1178
  # noinspection PyTypeChecker
1179
1179
  model_ = self.model = models.infer_model(self.model)
1180
1180
  else:
1181
- raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
1181
+ raise exceptions.UserError('`model` must either be set on the agent or included when calling it.')
1182
1182
 
1183
1183
  instrument = self.instrument
1184
1184
  if instrument is None: