pydantic-ai-slim 0.2.4__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/.gitignore +2 -0
  2. pydantic_ai_slim-0.2.6/LICENSE +21 -0
  3. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/PKG-INFO +7 -4
  4. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/_agent_graph.py +16 -7
  5. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/_cli.py +11 -12
  6. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/_output.py +7 -7
  7. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/_parts_manager.py +1 -1
  8. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/agent.py +30 -18
  9. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/direct.py +2 -0
  10. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/exceptions.py +2 -2
  11. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/messages.py +29 -11
  12. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/__init__.py +43 -6
  13. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/anthropic.py +17 -12
  14. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/bedrock.py +10 -9
  15. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/cohere.py +4 -4
  16. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/fallback.py +2 -2
  17. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/function.py +1 -1
  18. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/gemini.py +26 -22
  19. pydantic_ai_slim-0.2.6/pydantic_ai/models/google.py +569 -0
  20. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/groq.py +12 -6
  21. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/instrumented.py +43 -33
  22. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/mistral.py +15 -9
  23. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/openai.py +46 -8
  24. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/test.py +1 -1
  25. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/wrapper.py +1 -1
  26. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/__init__.py +4 -0
  27. pydantic_ai_slim-0.2.6/pydantic_ai/providers/google.py +143 -0
  28. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/google_vertex.py +3 -3
  29. pydantic_ai_slim-0.2.6/pydantic_ai/providers/openrouter.py +69 -0
  30. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/result.py +13 -21
  31. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/tools.py +34 -2
  32. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/usage.py +1 -1
  33. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pyproject.toml +2 -0
  34. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/README.md +0 -0
  35. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/__init__.py +0 -0
  36. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/__main__.py +0 -0
  37. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/_a2a.py +0 -0
  38. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/_griffe.py +0 -0
  39. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/_pydantic.py +0 -0
  40. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/_system_prompt.py +0 -0
  41. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/_utils.py +0 -0
  42. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/common_tools/__init__.py +0 -0
  43. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  44. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/common_tools/tavily.py +0 -0
  45. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/format_as_xml.py +0 -0
  46. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/format_prompt.py +0 -0
  47. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/mcp.py +0 -0
  48. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/models/_json_schema.py +0 -0
  49. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/anthropic.py +0 -0
  50. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/azure.py +0 -0
  51. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/bedrock.py +0 -0
  52. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/cohere.py +0 -0
  53. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/deepseek.py +0 -0
  54. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/google_gla.py +0 -0
  55. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/groq.py +0 -0
  56. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/mistral.py +0 -0
  57. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/providers/openai.py +0 -0
  58. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/py.typed +0 -0
  59. {pydantic_ai_slim-0.2.4 → pydantic_ai_slim-0.2.6}/pydantic_ai/settings.py +0 -0
@@ -17,3 +17,5 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite
17
17
  /docs-site/.wrangler/
18
18
  /CLAUDE.md
19
19
  node_modules/
20
+ **.idea/
21
+ .coverage*
@@ -0,0 +1,21 @@
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) Pydantic Services Inc. 2024 to present
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -1,9 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>
6
6
  License-Expression: MIT
7
+ License-File: LICENSE
7
8
  Classifier: Development Status :: 4 - Beta
8
9
  Classifier: Environment :: Console
9
10
  Classifier: Environment :: MacOS X
@@ -29,11 +30,11 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
30
  Requires-Dist: griffe>=1.3.2
30
31
  Requires-Dist: httpx>=0.27
31
32
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.2.4
33
+ Requires-Dist: pydantic-graph==0.2.6
33
34
  Requires-Dist: pydantic>=2.10
34
35
  Requires-Dist: typing-inspection>=0.4.0
35
36
  Provides-Extra: a2a
36
- Requires-Dist: fasta2a==0.2.4; extra == 'a2a'
37
+ Requires-Dist: fasta2a==0.2.6; extra == 'a2a'
37
38
  Provides-Extra: anthropic
38
39
  Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
39
40
  Provides-Extra: bedrock
@@ -47,7 +48,9 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
47
48
  Provides-Extra: duckduckgo
48
49
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
49
50
  Provides-Extra: evals
50
- Requires-Dist: pydantic-evals==0.2.4; extra == 'evals'
51
+ Requires-Dist: pydantic-evals==0.2.6; extra == 'evals'
52
+ Provides-Extra: google
53
+ Requires-Dist: google-genai>=1.15.0; extra == 'google'
51
54
  Provides-Extra: groq
52
55
  Requires-Dist: groq>=0.15.0; extra == 'groq'
53
56
  Provides-Extra: logfire
@@ -26,7 +26,7 @@ from . import (
26
26
  )
27
27
  from .result import OutputDataT, ToolOutput
28
28
  from .settings import ModelSettings, merge_model_settings
29
- from .tools import RunContext, Tool, ToolDefinition
29
+ from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from .mcp import MCPServer
@@ -97,6 +97,8 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
97
97
 
98
98
  tracer: Tracer
99
99
 
100
+ prepare_tools: ToolsPrepareFunc[DepsT] | None = None
101
+
100
102
 
101
103
  class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
102
104
  """The base class for all agent nodes.
@@ -196,7 +198,9 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
196
198
  for i, part in enumerate(msg.parts):
197
199
  if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
198
200
  # Look up the runner by its ref
199
- if runner := self.system_prompt_dynamic_functions.get(part.dynamic_ref):
201
+ if runner := self.system_prompt_dynamic_functions.get( # pragma: lax no cover
202
+ part.dynamic_ref
203
+ ):
200
204
  updated_part_content = await runner.run(run_context)
201
205
  msg.parts[i] = _messages.SystemPromptPart(
202
206
  updated_part_content, dynamic_ref=part.dynamic_ref
@@ -239,6 +243,11 @@ async def _prepare_request_parameters(
239
243
  *map(add_mcp_server_tools, ctx.deps.mcp_servers),
240
244
  )
241
245
 
246
+ if ctx.deps.prepare_tools:
247
+ # Prepare the tools using the provided function
248
+ # This also acts over tool definitions pulled from MCP servers
249
+ function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or []
250
+
242
251
  output_schema = ctx.deps.output_schema
243
252
  return models.ModelRequestParameters(
244
253
  function_tools=function_tool_defs,
@@ -265,7 +274,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
265
274
  if self._did_stream:
266
275
  # `self._result` gets set when exiting the `stream` contextmanager, so hitting this
267
276
  # means that the stream was started but not finished before `run()` was called
268
- raise exceptions.AgentRunError('You must finish streaming before calling run()')
277
+ raise exceptions.AgentRunError('You must finish streaming before calling run()') # pragma: no cover
269
278
 
270
279
  return await self._make_request(ctx)
271
280
 
@@ -316,7 +325,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
316
325
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
317
326
  ) -> CallToolsNode[DepsT, NodeRunEndT]:
318
327
  if self._result is not None:
319
- return self._result
328
+ return self._result # pragma: no cover
320
329
 
321
330
  model_settings, model_request_parameters = await self._prepare_request(ctx)
322
331
  model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
@@ -333,7 +342,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
333
342
  ctx.state.message_history.append(self.request)
334
343
 
335
344
  # Check usage
336
- if ctx.deps.usage_limits:
345
+ if ctx.deps.usage_limits: # pragma: no branch
337
346
  ctx.deps.usage_limits.check_before_request(ctx.state.usage)
338
347
 
339
348
  # Increment run_step
@@ -350,7 +359,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
350
359
  ) -> CallToolsNode[DepsT, NodeRunEndT]:
351
360
  # Update usage
352
361
  ctx.state.usage.incr(response.usage)
353
- if ctx.deps.usage_limits:
362
+ if ctx.deps.usage_limits: # pragma: no branch
354
363
  ctx.deps.usage_limits.check_tokens(ctx.state.usage)
355
364
 
356
365
  # Append the model response to state.message_history
@@ -735,7 +744,7 @@ async def _tool_from_mcp_server(
735
744
 
736
745
  for server in ctx.deps.mcp_servers:
737
746
  tools = await server.list_tools()
738
- if tool_name in {tool.name for tool in tools}:
747
+ if tool_name in {tool.name for tool in tools}: # pragma: no branch
739
748
  return Tool(name=tool_name, function=run_tool, takes_ctx=True, max_retries=ctx.deps.default_retries)
740
749
  return None
741
750
 
@@ -53,11 +53,11 @@ PYDANTIC_AI_HOME = Path.home() / '.pydantic-ai'
53
53
  This folder is used to store the prompt history and configuration.
54
54
  """
55
55
 
56
- PROMPT_HISTORY_PATH = PYDANTIC_AI_HOME / 'prompt-history.txt'
56
+ PROMPT_HISTORY_FILENAME = 'prompt-history.txt'
57
57
 
58
58
 
59
59
  class SimpleCodeBlock(CodeBlock):
60
- """Customised code blocks in markdown.
60
+ """Customized code blocks in markdown.
61
61
 
62
62
  This avoids a background color which messes up copy-pasting and sets the language name as dim prefix and suffix.
63
63
  """
@@ -70,7 +70,7 @@ class SimpleCodeBlock(CodeBlock):
70
70
 
71
71
 
72
72
  class LeftHeading(Heading):
73
- """Customised headings in markdown to stop centering and prepend markdown style hashes."""
73
+ """Customized headings in markdown to stop centering and prepend markdown style hashes."""
74
74
 
75
75
  def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
76
76
  # note we use `Style(bold=True)` not `self.style_name` here to disable underlining which is ugly IMHO
@@ -202,7 +202,7 @@ Special prompts:
202
202
  elif args.code_theme == 'dark':
203
203
  code_theme = 'monokai'
204
204
  else:
205
- code_theme = args.code_theme
205
+ code_theme = args.code_theme # pragma: no cover
206
206
 
207
207
  if prompt := cast(str, args.prompt):
208
208
  try:
@@ -211,27 +211,26 @@ Special prompts:
211
211
  pass
212
212
  return 0
213
213
 
214
- # Ensure the history directory and file exist
215
- PROMPT_HISTORY_PATH.parent.mkdir(parents=True, exist_ok=True)
216
- PROMPT_HISTORY_PATH.touch(exist_ok=True)
217
-
218
- # doing this instead of `PromptSession[Any](history=` allows mocking of PromptSession in tests
219
- session: PromptSession[Any] = PromptSession(history=FileHistory(str(PROMPT_HISTORY_PATH)))
220
214
  try:
221
- return asyncio.run(run_chat(session, stream, agent, console, code_theme, prog_name))
215
+ return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name))
222
216
  except KeyboardInterrupt: # pragma: no cover
223
217
  return 0
224
218
 
225
219
 
226
220
  async def run_chat(
227
- session: PromptSession[Any],
228
221
  stream: bool,
229
222
  agent: Agent[AgentDepsT, OutputDataT],
230
223
  console: Console,
231
224
  code_theme: str,
232
225
  prog_name: str,
226
+ config_dir: Path | None = None,
233
227
  deps: AgentDepsT = None,
234
228
  ) -> int:
229
+ prompt_history_path = (config_dir or PYDANTIC_AI_HOME) / PROMPT_HISTORY_FILENAME
230
+ prompt_history_path.parent.mkdir(parents=True, exist_ok=True)
231
+ prompt_history_path.touch(exist_ok=True)
232
+ session: PromptSession[Any] = PromptSession(history=FileHistory(str(prompt_history_path)))
233
+
235
234
  multiline = False
236
235
  messages: list[ModelMessage] = []
237
236
 
@@ -140,8 +140,8 @@ class OutputSchema(Generic[OutputDataT]):
140
140
  self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
141
141
  ) -> tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]] | None:
142
142
  """Find a tool that matches one of the calls, with a specific name."""
143
- for part in parts:
144
- if isinstance(part, _messages.ToolCallPart):
143
+ for part in parts: # pragma: no branch
144
+ if isinstance(part, _messages.ToolCallPart): # pragma: no branch
145
145
  if part.tool_name == tool_name:
146
146
  return part, self.tools[tool_name]
147
147
 
@@ -151,7 +151,7 @@ class OutputSchema(Generic[OutputDataT]):
151
151
  ) -> Iterator[tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]]]:
152
152
  """Find a tool that matches one of the calls."""
153
153
  for part in parts:
154
- if isinstance(part, _messages.ToolCallPart):
154
+ if isinstance(part, _messages.ToolCallPart): # pragma: no branch
155
155
  if result := self.tools.get(part.tool_name):
156
156
  yield part, result
157
157
 
@@ -201,7 +201,7 @@ class OutputSchemaTool(Generic[OutputDataT]):
201
201
  if description is None:
202
202
  tool_description = json_schema_description
203
203
  else:
204
- tool_description = f'{description}. {json_schema_description}'
204
+ tool_description = f'{description}. {json_schema_description}' # pragma: no cover
205
205
  else:
206
206
  tool_description = description or DEFAULT_DESCRIPTION
207
207
  if multiple:
@@ -243,7 +243,7 @@ class OutputSchemaTool(Generic[OutputDataT]):
243
243
  )
244
244
  raise ToolRetryError(m) from e
245
245
  else:
246
- raise
246
+ raise # pragma: lax no cover
247
247
  else:
248
248
  if k := self.tool_def.outer_typed_dict_key:
249
249
  output = output[k]
@@ -269,11 +269,11 @@ def extract_str_from_union(output_type: Any) -> _utils.Option[Any]:
269
269
  includes_str = True
270
270
  else:
271
271
  remain_args.append(arg)
272
- if includes_str:
272
+ if includes_str: # pragma: no branch
273
273
  if len(remain_args) == 1:
274
274
  return _utils.Some(remain_args[0])
275
275
  else:
276
- return _utils.Some(Union[tuple(remain_args)])
276
+ return _utils.Some(Union[tuple(remain_args)]) # pragma: no cover
277
277
 
278
278
 
279
279
  def get_union_args(tp: Any) -> tuple[Any, ...]:
@@ -164,7 +164,7 @@ class ModelResponsePartsManager:
164
164
  if tool_name is None and self._parts:
165
165
  part_index = len(self._parts) - 1
166
166
  latest_part = self._parts[part_index]
167
- if isinstance(latest_part, (ToolCallPart, ToolCallPartDelta)):
167
+ if isinstance(latest_part, (ToolCallPart, ToolCallPartDelta)): # pragma: no branch
168
168
  existing_matching_part_and_index = latest_part, part_index
169
169
  else:
170
170
  # vendor_part_id is provided, so look up the corresponding part or delta
@@ -42,6 +42,7 @@ from .tools import (
42
42
  ToolFuncPlain,
43
43
  ToolParams,
44
44
  ToolPrepareFunc,
45
+ ToolsPrepareFunc,
45
46
  )
46
47
 
47
48
  # Re-exporting like this improves auto-import behavior in PyCharm
@@ -148,6 +149,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
148
149
  _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
149
150
  repr=False
150
151
  )
152
+ _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
151
153
  _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
152
154
  _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
153
155
  _default_retries: int = dataclasses.field(repr=False)
@@ -172,6 +174,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
172
174
  retries: int = 1,
173
175
  output_retries: int | None = None,
174
176
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
177
+ prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
175
178
  mcp_servers: Sequence[MCPServer] = (),
176
179
  defer_model_check: bool = False,
177
180
  end_strategy: EndStrategy = 'early',
@@ -200,6 +203,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
200
203
  result_tool_description: str | None = None,
201
204
  result_retries: int | None = None,
202
205
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
206
+ prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
203
207
  mcp_servers: Sequence[MCPServer] = (),
204
208
  defer_model_check: bool = False,
205
209
  end_strategy: EndStrategy = 'early',
@@ -223,6 +227,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
223
227
  retries: int = 1,
224
228
  output_retries: int | None = None,
225
229
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
230
+ prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
226
231
  mcp_servers: Sequence[MCPServer] = (),
227
232
  defer_model_check: bool = False,
228
233
  end_strategy: EndStrategy = 'early',
@@ -251,6 +256,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
251
256
  output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
252
257
  tools: Tools to register with the agent, you can also register tools via the decorators
253
258
  [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
259
+ prepare_tools: custom method to prepare the tool definition of all tools for each step.
260
+ This is useful if you want to customize the definition of multiple tools or you want to register
261
+ a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
254
262
  mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
255
263
  for each server you want the agent to connect to.
256
264
  defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
@@ -334,6 +342,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
334
342
  self._default_retries = retries
335
343
  self._max_result_retries = output_retries if output_retries is not None else retries
336
344
  self._mcp_servers = mcp_servers
345
+ self._prepare_tools = prepare_tools
337
346
  for tool in tools:
338
347
  if isinstance(tool, Tool):
339
348
  self._register_tool(tool)
@@ -585,6 +594,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
585
594
  model_name='gpt-4o',
586
595
  timestamp=datetime.datetime(...),
587
596
  kind='response',
597
+ vendor_id=None,
588
598
  )
589
599
  ),
590
600
  End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
@@ -654,8 +664,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
654
664
  usage_limits = usage_limits or _usage.UsageLimits()
655
665
 
656
666
  if isinstance(model_used, InstrumentedModel):
667
+ instrumentation_settings = model_used.settings
657
668
  tracer = model_used.settings.tracer
658
669
  else:
670
+ instrumentation_settings = None
659
671
  tracer = NoOpTracer()
660
672
  agent_name = self.name or 'agent'
661
673
  run_span = tracer.start_span(
@@ -691,6 +703,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
691
703
  mcp_servers=self._mcp_servers,
692
704
  default_retries=self._default_retries,
693
705
  tracer=tracer,
706
+ prepare_tools=self._prepare_tools,
694
707
  get_instructions=get_instructions,
695
708
  )
696
709
  start_node = _agent_graph.UserPromptNode[AgentDepsT](
@@ -723,19 +736,18 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
723
736
  )
724
737
  finally:
725
738
  try:
726
- if run_span.is_recording():
727
- run_span.set_attributes(self._run_span_end_attributes(state, usage))
739
+ if instrumentation_settings and run_span.is_recording():
740
+ run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
728
741
  finally:
729
742
  run_span.end()
730
743
 
731
- def _run_span_end_attributes(self, state: _agent_graph.GraphAgentState, usage: _usage.Usage):
744
+ def _run_span_end_attributes(
745
+ self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings
746
+ ):
732
747
  return {
733
748
  **usage.opentelemetry_attributes(),
734
749
  'all_messages_events': json.dumps(
735
- [
736
- InstrumentedModel.event_to_dict(e)
737
- for e in InstrumentedModel.messages_to_otel_events(state.message_history)
738
- ]
750
+ [InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(state.message_history)]
739
751
  ),
740
752
  'logfire.json_schema': json.dumps(
741
753
  {
@@ -1001,7 +1013,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1001
1013
  final_result_details = await stream_to_final(streamed_response)
1002
1014
  if final_result_details is not None:
1003
1015
  if yielded:
1004
- raise exceptions.AgentRunError('Agent run produced final results')
1016
+ raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
1005
1017
  yielded = True
1006
1018
 
1007
1019
  messages = graph_ctx.state.message_history.copy()
@@ -1048,11 +1060,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1048
1060
  break
1049
1061
  next_node = await agent_run.next(node)
1050
1062
  if not isinstance(next_node, _agent_graph.AgentNode):
1051
- raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here')
1063
+ raise exceptions.AgentRunError( # pragma: no cover
1064
+ 'Should have produced a StreamedRunResult before getting here'
1065
+ )
1052
1066
  node = cast(_agent_graph.AgentNode[Any, Any], next_node)
1053
1067
 
1054
1068
  if not yielded:
1055
- raise exceptions.AgentRunError('Agent run finished without producing a final result')
1069
+ raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover
1056
1070
 
1057
1071
  @contextmanager
1058
1072
  def override(
@@ -1226,7 +1240,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1226
1240
  ) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
1227
1241
  runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic)
1228
1242
  self._system_prompt_functions.append(runner)
1229
- if dynamic:
1243
+ if dynamic: # pragma: lax no cover
1230
1244
  self._system_prompt_dynamic_functions[func_.__qualname__] = runner
1231
1245
  return func_
1232
1246
 
@@ -1608,7 +1622,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1608
1622
  if item is self:
1609
1623
  self.name = name
1610
1624
  return
1611
- if parent_frame.f_locals != parent_frame.f_globals:
1625
+ if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch
1612
1626
  # if we couldn't find the agent in locals and globals are a different dict, try globals
1613
1627
  for name, item in parent_frame.f_globals.items():
1614
1628
  if item is self:
@@ -1759,18 +1773,14 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1759
1773
  await agent.to_cli()
1760
1774
  ```
1761
1775
  """
1762
- from prompt_toolkit import PromptSession
1763
- from prompt_toolkit.history import FileHistory
1764
1776
  from rich.console import Console
1765
1777
 
1766
- from pydantic_ai._cli import PROMPT_HISTORY_PATH, run_chat
1778
+ from pydantic_ai._cli import run_chat
1767
1779
 
1768
1780
  # TODO(Marcelo): We need to refactor the CLI code to be able to be able to just pass `agent`, `deps` and
1769
1781
  # `prog_name` from here.
1770
1782
 
1771
- session: PromptSession[Any] = PromptSession(history=FileHistory(str(PROMPT_HISTORY_PATH)))
1772
1783
  await run_chat(
1773
- session=session,
1774
1784
  stream=True,
1775
1785
  agent=self,
1776
1786
  deps=deps,
@@ -1851,6 +1861,7 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
1851
1861
  model_name='gpt-4o',
1852
1862
  timestamp=datetime.datetime(...),
1853
1863
  kind='response',
1864
+ vendor_id=None,
1854
1865
  )
1855
1866
  ),
1856
1867
  End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
@@ -1996,6 +2007,7 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
1996
2007
  model_name='gpt-4o',
1997
2008
  timestamp=datetime.datetime(...),
1998
2009
  kind='response',
2010
+ vendor_id=None,
1999
2011
  )
2000
2012
  ),
2001
2013
  End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
@@ -2024,7 +2036,7 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
2024
2036
  """Get usage statistics for the run so far, including token usage, model requests, and so on."""
2025
2037
  return self._graph_run.state.usage
2026
2038
 
2027
- def __repr__(self) -> str:
2039
+ def __repr__(self) -> str: # pragma: no cover
2028
2040
  result = self._graph_run.result
2029
2041
  result_repr = '<run not finished>' if result is None else repr(result.output)
2030
2042
  return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>'
@@ -52,6 +52,7 @@ async def model_request(
52
52
  model_name='claude-3-5-haiku-latest',
53
53
  timestamp=datetime.datetime(...),
54
54
  kind='response',
55
+ vendor_id=None,
55
56
  )
56
57
  '''
57
58
  ```
@@ -108,6 +109,7 @@ def model_request_sync(
108
109
  model_name='claude-3-5-haiku-latest',
109
110
  timestamp=datetime.datetime(...),
110
111
  kind='response',
112
+ vendor_id=None,
111
113
  )
112
114
  '''
113
115
  ```
@@ -4,9 +4,9 @@ import json
4
4
  import sys
5
5
 
6
6
  if sys.version_info < (3, 11):
7
- from exceptiongroup import ExceptionGroup
7
+ from exceptiongroup import ExceptionGroup # pragma: lax no cover
8
8
  else:
9
- ExceptionGroup = ExceptionGroup
9
+ ExceptionGroup = ExceptionGroup # pragma: lax no cover
10
10
 
11
11
  __all__ = (
12
12
  'ModelRetry',
@@ -6,7 +6,7 @@ from collections.abc import Sequence
6
6
  from dataclasses import dataclass, field, replace
7
7
  from datetime import datetime
8
8
  from mimetypes import guess_type
9
- from typing import Annotated, Any, Literal, Union, cast, overload
9
+ from typing import TYPE_CHECKING, Annotated, Any, Literal, Union, cast, overload
10
10
 
11
11
  import pydantic
12
12
  import pydantic_core
@@ -17,6 +17,10 @@ from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as
17
17
  from .exceptions import UnexpectedModelBehavior
18
18
  from .usage import Usage
19
19
 
20
+ if TYPE_CHECKING:
21
+ from .models.instrumented import InstrumentationSettings
22
+
23
+
20
24
  AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
21
25
  ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
22
26
  DocumentMediaType: TypeAlias = Literal[
@@ -68,7 +72,7 @@ class SystemPromptPart:
68
72
  part_kind: Literal['system-prompt'] = 'system-prompt'
69
73
  """Part type identifier, this is available on all parts as a discriminator."""
70
74
 
71
- def otel_event(self) -> Event:
75
+ def otel_event(self, _settings: InstrumentationSettings) -> Event:
72
76
  return Event('gen_ai.system.message', body={'content': self.content, 'role': 'system'})
73
77
 
74
78
 
@@ -305,7 +309,7 @@ class UserPromptPart:
305
309
  part_kind: Literal['user-prompt'] = 'user-prompt'
306
310
  """Part type identifier, this is available on all parts as a discriminator."""
307
311
 
308
- def otel_event(self) -> Event:
312
+ def otel_event(self, settings: InstrumentationSettings) -> Event:
309
313
  content: str | list[dict[str, Any] | str]
310
314
  if isinstance(self.content, str):
311
315
  content = self.content
@@ -317,10 +321,12 @@ class UserPromptPart:
317
321
  elif isinstance(part, (ImageUrl, AudioUrl, DocumentUrl, VideoUrl)):
318
322
  content.append({'kind': part.kind, 'url': part.url})
319
323
  elif isinstance(part, BinaryContent):
320
- base64_data = base64.b64encode(part.data).decode()
321
- content.append({'kind': part.kind, 'content': base64_data, 'media_type': part.media_type})
324
+ converted_part = {'kind': part.kind, 'media_type': part.media_type}
325
+ if settings.include_binary_content:
326
+ converted_part['binary_content'] = base64.b64encode(part.data).decode()
327
+ content.append(converted_part)
322
328
  else:
323
- content.append({'kind': part.kind})
329
+ content.append({'kind': part.kind}) # pragma: no cover
324
330
  return Event('gen_ai.user.message', body={'content': content, 'role': 'user'})
325
331
 
326
332
 
@@ -357,11 +363,11 @@ class ToolReturnPart:
357
363
  """Return a dictionary representation of the content, wrapping non-dict types appropriately."""
358
364
  # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
359
365
  if isinstance(self.content, dict):
360
- return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
366
+ return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType] # pragma: no cover
361
367
  else:
362
368
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
363
369
 
364
- def otel_event(self) -> Event:
370
+ def otel_event(self, _settings: InstrumentationSettings) -> Event:
365
371
  return Event(
366
372
  'gen_ai.tool.message',
367
373
  body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id, 'name': self.tool_name},
@@ -418,7 +424,7 @@ class RetryPromptPart:
418
424
  description = f'{len(self.content)} validation errors: {json_errors.decode()}'
419
425
  return f'{description}\n\nFix the errors and try again.'
420
426
 
421
- def otel_event(self) -> Event:
427
+ def otel_event(self, _settings: InstrumentationSettings) -> Event:
422
428
  if self.tool_name is None:
423
429
  return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
424
430
  else:
@@ -556,6 +562,16 @@ class ModelResponse:
556
562
  kind: Literal['response'] = 'response'
557
563
  """Message type identifier, this is available on all parts as a discriminator."""
558
564
 
565
+ vendor_details: dict[str, Any] | None = field(default=None, repr=False)
566
+ """Additional vendor-specific details in a serializable format.
567
+
568
+ This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.
569
+ For OpenAI models, this may include 'logprobs', 'finish_reason', etc.
570
+ """
571
+
572
+ vendor_id: str | None = None
573
+ """Vendor ID as specified by the model provider. This can be used to track the specific request to the model."""
574
+
559
575
  def otel_events(self) -> list[Event]:
560
576
  """Return OpenTelemetry events for the response."""
561
577
  result: list[Event] = []
@@ -619,7 +635,7 @@ class TextPartDelta:
619
635
  ValueError: If `part` is not a `TextPart`.
620
636
  """
621
637
  if not isinstance(part, TextPart):
622
- raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
638
+ raise ValueError('Cannot apply TextPartDeltas to non-TextParts') # pragma: no cover
623
639
  return replace(part, content=part.content + self.content_delta)
624
640
 
625
641
 
@@ -682,7 +698,9 @@ class ToolCallPartDelta:
682
698
  if isinstance(part, ToolCallPartDelta):
683
699
  return self._apply_to_delta(part)
684
700
 
685
- raise ValueError(f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}')
701
+ raise ValueError( # pragma: no cover
702
+ f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}'
703
+ )
686
704
 
687
705
  def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
688
706
  """Internal helper to apply this delta to another delta."""