pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (55) hide show
  1. pydantic_ai/_agent_graph.py +219 -315
  2. pydantic_ai/_cli.py +9 -7
  3. pydantic_ai/_output.py +296 -226
  4. pydantic_ai/_parts_manager.py +2 -2
  5. pydantic_ai/_run_context.py +8 -14
  6. pydantic_ai/_tool_manager.py +190 -0
  7. pydantic_ai/_utils.py +18 -1
  8. pydantic_ai/ag_ui.py +675 -0
  9. pydantic_ai/agent.py +369 -155
  10. pydantic_ai/common_tools/duckduckgo.py +5 -2
  11. pydantic_ai/exceptions.py +14 -2
  12. pydantic_ai/ext/aci.py +12 -3
  13. pydantic_ai/ext/langchain.py +9 -1
  14. pydantic_ai/mcp.py +147 -84
  15. pydantic_ai/messages.py +19 -9
  16. pydantic_ai/models/__init__.py +43 -19
  17. pydantic_ai/models/anthropic.py +2 -2
  18. pydantic_ai/models/bedrock.py +1 -1
  19. pydantic_ai/models/cohere.py +1 -1
  20. pydantic_ai/models/function.py +50 -24
  21. pydantic_ai/models/gemini.py +3 -11
  22. pydantic_ai/models/google.py +3 -12
  23. pydantic_ai/models/groq.py +2 -1
  24. pydantic_ai/models/huggingface.py +463 -0
  25. pydantic_ai/models/instrumented.py +1 -1
  26. pydantic_ai/models/mistral.py +3 -3
  27. pydantic_ai/models/openai.py +5 -5
  28. pydantic_ai/output.py +21 -7
  29. pydantic_ai/profiles/google.py +1 -1
  30. pydantic_ai/profiles/moonshotai.py +8 -0
  31. pydantic_ai/providers/__init__.py +4 -0
  32. pydantic_ai/providers/google.py +2 -2
  33. pydantic_ai/providers/google_vertex.py +10 -5
  34. pydantic_ai/providers/grok.py +13 -1
  35. pydantic_ai/providers/groq.py +2 -0
  36. pydantic_ai/providers/huggingface.py +88 -0
  37. pydantic_ai/result.py +57 -33
  38. pydantic_ai/tools.py +26 -119
  39. pydantic_ai/toolsets/__init__.py +22 -0
  40. pydantic_ai/toolsets/abstract.py +155 -0
  41. pydantic_ai/toolsets/combined.py +88 -0
  42. pydantic_ai/toolsets/deferred.py +38 -0
  43. pydantic_ai/toolsets/filtered.py +24 -0
  44. pydantic_ai/toolsets/function.py +238 -0
  45. pydantic_ai/toolsets/prefixed.py +37 -0
  46. pydantic_ai/toolsets/prepared.py +36 -0
  47. pydantic_ai/toolsets/renamed.py +42 -0
  48. pydantic_ai/toolsets/wrapper.py +37 -0
  49. pydantic_ai/usage.py +14 -8
  50. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
  51. pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
  52. pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
  53. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
  54. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
  55. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
- from typing import overload
4
+ from typing import Literal, overload
5
5
 
6
6
  from httpx import AsyncClient as AsyncHTTPClient
7
7
  from openai import AsyncOpenAI
@@ -21,6 +21,18 @@ except ImportError as _import_error: # pragma: no cover
21
21
  'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
22
22
  ) from _import_error
23
23
 
24
+ # https://docs.x.ai/docs/models
25
+ GrokModelName = Literal[
26
+ 'grok-4',
27
+ 'grok-4-0709',
28
+ 'grok-3',
29
+ 'grok-3-mini',
30
+ 'grok-3-fast',
31
+ 'grok-3-mini-fast',
32
+ 'grok-2-vision-1212',
33
+ 'grok-2-image-1212',
34
+ ]
35
+
24
36
 
25
37
  class GrokProvider(Provider[AsyncOpenAI]):
26
38
  """Provider for Grok API."""
@@ -12,6 +12,7 @@ from pydantic_ai.profiles.deepseek import deepseek_model_profile
12
12
  from pydantic_ai.profiles.google import google_model_profile
13
13
  from pydantic_ai.profiles.meta import meta_model_profile
14
14
  from pydantic_ai.profiles.mistral import mistral_model_profile
15
+ from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
15
16
  from pydantic_ai.profiles.qwen import qwen_model_profile
16
17
  from pydantic_ai.providers import Provider
17
18
 
@@ -47,6 +48,7 @@ class GroqProvider(Provider[AsyncGroq]):
47
48
  'qwen': qwen_model_profile,
48
49
  'deepseek': deepseek_model_profile,
49
50
  'mistral': mistral_model_profile,
51
+ 'moonshotai/': moonshotai_model_profile,
50
52
  }
51
53
 
52
54
  for prefix, profile_func in prefix_to_profile.items():
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ from httpx import AsyncClient
7
+
8
+ from pydantic_ai.exceptions import UserError
9
+
10
+ try:
11
+ from huggingface_hub import AsyncInferenceClient
12
+ except ImportError as _import_error: # pragma: no cover
13
+ raise ImportError(
14
+ 'Please install the `huggingface_hub` package to use the HuggingFace provider, '
15
+ "you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`"
16
+ ) from _import_error
17
+
18
+ from . import Provider
19
+
20
+
21
+ class HuggingFaceProvider(Provider[AsyncInferenceClient]):
22
+ """Provider for Hugging Face."""
23
+
24
+ @property
25
+ def name(self) -> str:
26
+ return 'huggingface'
27
+
28
+ @property
29
+ def base_url(self) -> str:
30
+ return self.client.model # type: ignore
31
+
32
+ @property
33
+ def client(self) -> AsyncInferenceClient:
34
+ return self._client
35
+
36
+ @overload
37
+ def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ...
38
+ @overload
39
+ def __init__(self, *, provider_name: str, api_key: str | None = None) -> None: ...
40
+ @overload
41
+ def __init__(self, *, hf_client: AsyncInferenceClient, api_key: str | None = None) -> None: ...
42
+ @overload
43
+ def __init__(self, *, hf_client: AsyncInferenceClient, base_url: str, api_key: str | None = None) -> None: ...
44
+ @overload
45
+ def __init__(self, *, hf_client: AsyncInferenceClient, provider_name: str, api_key: str | None = None) -> None: ...
46
+ @overload
47
+ def __init__(self, *, api_key: str | None = None) -> None: ...
48
+
49
+ def __init__(
50
+ self,
51
+ base_url: str | None = None,
52
+ api_key: str | None = None,
53
+ hf_client: AsyncInferenceClient | None = None,
54
+ http_client: AsyncClient | None = None,
55
+ provider_name: str | None = None,
56
+ ) -> None:
57
+ """Create a new Hugging Face provider.
58
+
59
+ Args:
60
+ base_url: The base url for the Hugging Face requests.
61
+ api_key: The API key to use for authentication, if not provided, the `HF_TOKEN` environment variable
62
+ will be used if available.
63
+ hf_client: An existing
64
+ [`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
65
+ client to use. If not provided, a new instance will be created.
66
+ http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests.
67
+ provider_name : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners).
68
+ defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
69
+ If `base_url` is passed, then `provider_name` is not used.
70
+ """
71
+ api_key = api_key or os.environ.get('HF_TOKEN')
72
+
73
+ if api_key is None:
74
+ raise UserError(
75
+ 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`'
76
+ 'to use the HuggingFace provider.'
77
+ )
78
+
79
+ if http_client is not None:
80
+ raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead.')
81
+
82
+ if base_url is not None and provider_name is not None:
83
+ raise ValueError('Cannot provide both `base_url` and `provider_name`.')
84
+
85
+ if hf_client is None:
86
+ self._client = AsyncInferenceClient(api_key=api_key, provider=provider_name, base_url=base_url) # type: ignore
87
+ else:
88
+ self._client = hf_client
pydantic_ai/result.py CHANGED
@@ -5,11 +5,13 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
5
5
  from copy import copy
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Generic
8
+ from typing import Generic, cast
9
9
 
10
10
  from pydantic import ValidationError
11
11
  from typing_extensions import TypeVar, deprecated, overload
12
12
 
13
+ from pydantic_ai._tool_manager import ToolManager
14
+
13
15
  from . import _utils, exceptions, messages as _messages, models
14
16
  from ._output import (
15
17
  OutputDataT_inv,
@@ -47,6 +49,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
47
49
  _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
48
50
  _run_ctx: RunContext[AgentDepsT]
49
51
  _usage_limits: UsageLimits | None
52
+ _tool_manager: ToolManager[AgentDepsT]
50
53
 
51
54
  _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
52
55
  _final_result_event: FinalResultEvent | None = field(default=None, init=False)
@@ -95,33 +98,40 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
95
98
  self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False
96
99
  ) -> OutputDataT:
97
100
  """Validate a structured result message."""
98
- call = None
99
101
  if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None:
100
- match = self._output_schema.find_named_tool(message.parts, output_tool_name)
101
- if match is None:
102
+ tool_call = next(
103
+ (
104
+ part
105
+ for part in message.parts
106
+ if isinstance(part, _messages.ToolCallPart) and part.tool_name == output_tool_name
107
+ ),
108
+ None,
109
+ )
110
+ if tool_call is None:
102
111
  raise exceptions.UnexpectedModelBehavior( # pragma: no cover
103
- f'Invalid response, unable to find tool: {self._output_schema.tool_names()}'
112
+ f'Invalid response, unable to find tool call for {output_tool_name!r}'
104
113
  )
105
-
106
- call, output_tool = match
107
- result_data = await output_tool.process(
108
- call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
109
- )
114
+ return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
115
+ elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
116
+ if not self._output_schema.allows_deferred_tool_calls:
117
+ raise exceptions.UserError( # pragma: no cover
118
+ 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
119
+ )
120
+ return cast(OutputDataT, deferred_tool_calls)
110
121
  elif isinstance(self._output_schema, TextOutputSchema):
111
122
  text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
112
123
 
113
124
  result_data = await self._output_schema.process(
114
125
  text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
115
126
  )
127
+ for validator in self._output_validators:
128
+ result_data = await validator.validate(result_data, self._run_ctx)
129
+ return result_data
116
130
  else:
117
131
  raise exceptions.UnexpectedModelBehavior( # pragma: no cover
118
132
  'Invalid response, unable to process text output'
119
133
  )
120
134
 
121
- for validator in self._output_validators:
122
- result_data = await validator.validate(result_data, call, self._run_ctx)
123
- return result_data
124
-
125
135
  def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
126
136
  """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
127
137
 
@@ -139,13 +149,19 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
139
149
  """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
140
150
  if isinstance(e, _messages.PartStartEvent):
141
151
  new_part = e.part
142
- if isinstance(new_part, _messages.ToolCallPart) and isinstance(output_schema, ToolOutputSchema):
143
- for call, _ in output_schema.find_tool([new_part]): # pragma: no branch
144
- return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id)
145
- elif isinstance(new_part, _messages.TextPart) and isinstance(
152
+ if isinstance(new_part, _messages.TextPart) and isinstance(
146
153
  output_schema, TextOutputSchema
147
154
  ): # pragma: no branch
148
155
  return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
156
+ elif isinstance(new_part, _messages.ToolCallPart) and (
157
+ tool_def := self._tool_manager.get_tool_def(new_part.tool_name)
158
+ ):
159
+ if tool_def.kind == 'output':
160
+ return _messages.FinalResultEvent(
161
+ tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id
162
+ )
163
+ elif tool_def.kind == 'deferred':
164
+ return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
149
165
 
150
166
  usage_checking_stream = _get_usage_checking_stream_response(
151
167
  self._raw_stream_response, self._usage_limits, self.usage
@@ -180,6 +196,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
180
196
  _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
181
197
  _output_tool_name: str | None
182
198
  _on_complete: Callable[[], Awaitable[None]]
199
+ _tool_manager: ToolManager[AgentDepsT]
183
200
 
184
201
  _initial_run_ctx_usage: Usage = field(init=False)
185
202
  is_complete: bool = field(default=False, init=False)
@@ -320,7 +337,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
320
337
  yield await self.validate_structured_output(structured_message, allow_partial=not is_last)
321
338
  except ValidationError:
322
339
  if is_last:
323
- raise # pragma: lax no cover
340
+ raise # pragma: no cover
324
341
 
325
342
  async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
326
343
  """Stream the text result as an async iterable.
@@ -413,36 +430,43 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
413
430
  self, message: _messages.ModelResponse, *, allow_partial: bool = False
414
431
  ) -> OutputDataT:
415
432
  """Validate a structured result message."""
416
- call = None
417
433
  if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None:
418
- match = self._output_schema.find_named_tool(message.parts, self._output_tool_name)
419
- if match is None:
434
+ tool_call = next(
435
+ (
436
+ part
437
+ for part in message.parts
438
+ if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name
439
+ ),
440
+ None,
441
+ )
442
+ if tool_call is None:
420
443
  raise exceptions.UnexpectedModelBehavior( # pragma: no cover
421
- f'Invalid response, unable to find tool: {self._output_schema.tool_names()}'
444
+ f'Invalid response, unable to find tool call for {self._output_tool_name!r}'
422
445
  )
423
-
424
- call, output_tool = match
425
- result_data = await output_tool.process(
426
- call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
427
- )
446
+ return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
447
+ elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
448
+ if not self._output_schema.allows_deferred_tool_calls:
449
+ raise exceptions.UserError(
450
+ 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
451
+ )
452
+ return cast(OutputDataT, deferred_tool_calls)
428
453
  elif isinstance(self._output_schema, TextOutputSchema):
429
454
  text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
430
455
 
431
456
  result_data = await self._output_schema.process(
432
457
  text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
433
458
  )
459
+ for validator in self._output_validators:
460
+ result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover
461
+ return result_data
434
462
  else:
435
463
  raise exceptions.UnexpectedModelBehavior( # pragma: no cover
436
464
  'Invalid response, unable to process text output'
437
465
  )
438
466
 
439
- for validator in self._output_validators:
440
- result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover
441
- return result_data
442
-
443
467
  async def _validate_text_output(self, text: str) -> str:
444
468
  for validator in self._output_validators:
445
- text = await validator.validate(text, None, self._run_ctx) # pragma: no cover
469
+ text = await validator.validate(text, self._run_ctx) # pragma: no cover
446
470
  return text
447
471
 
448
472
  async def _marked_completed(self, message: _messages.ModelResponse) -> None:
pydantic_ai/tools.py CHANGED
@@ -1,20 +1,15 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- import dataclasses
4
- import json
5
3
  from collections.abc import Awaitable, Sequence
6
4
  from dataclasses import dataclass, field
7
5
  from typing import Any, Callable, Generic, Literal, Union
8
6
 
9
- from opentelemetry.trace import Tracer
10
- from pydantic import ValidationError
11
7
  from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
12
8
  from pydantic_core import SchemaValidator, core_schema
13
9
  from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar
14
10
 
15
- from . import _function_schema, _utils, messages as _messages
11
+ from . import _function_schema, _utils
16
12
  from ._run_context import AgentDepsT, RunContext
17
- from .exceptions import ModelRetry, UnexpectedModelBehavior
18
13
 
19
14
  __all__ = (
20
15
  'AgentDepsT',
@@ -32,7 +27,6 @@ __all__ = (
32
27
  'ToolDefinition',
33
28
  )
34
29
 
35
- from .messages import ToolReturnPart
36
30
 
37
31
  ToolParams = ParamSpec('ToolParams', default=...)
38
32
  """Retrieval function param spec."""
@@ -173,12 +167,6 @@ class Tool(Generic[AgentDepsT]):
173
167
  This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request.
174
168
  """
175
169
 
176
- # TODO: Consider moving this current_retry state to live on something other than the tool.
177
- # We've worked around this for now by copying instances of the tool when creating new runs,
178
- # but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things
179
- # up, though is also likely a larger effort to refactor.
180
- current_retry: int = field(default=0, init=False)
181
-
182
170
  def __init__(
183
171
  self,
184
172
  function: ToolFuncEither[AgentDepsT],
@@ -303,6 +291,15 @@ class Tool(Generic[AgentDepsT]):
303
291
  function_schema=function_schema,
304
292
  )
305
293
 
294
+ @property
295
+ def tool_def(self):
296
+ return ToolDefinition(
297
+ name=self.name,
298
+ description=self.description,
299
+ parameters_json_schema=self.function_schema.json_schema,
300
+ strict=self.strict,
301
+ )
302
+
306
303
  async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
307
304
  """Get the tool definition.
308
305
 
@@ -312,113 +309,11 @@ class Tool(Generic[AgentDepsT]):
312
309
  Returns:
313
310
  return a `ToolDefinition` or `None` if the tools should not be registered for this run.
314
311
  """
315
- tool_def = ToolDefinition(
316
- name=self.name,
317
- description=self.description,
318
- parameters_json_schema=self.function_schema.json_schema,
319
- strict=self.strict,
320
- )
312
+ base_tool_def = self.tool_def
321
313
  if self.prepare is not None:
322
- return await self.prepare(ctx, tool_def)
314
+ return await self.prepare(ctx, base_tool_def)
323
315
  else:
324
- return tool_def
325
-
326
- async def run(
327
- self,
328
- message: _messages.ToolCallPart,
329
- run_context: RunContext[AgentDepsT],
330
- tracer: Tracer,
331
- include_content: bool = False,
332
- ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
333
- """Run the tool function asynchronously.
334
-
335
- This method wraps `_run` in an OpenTelemetry span.
336
-
337
- See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>.
338
- """
339
- span_attributes = {
340
- 'gen_ai.tool.name': self.name,
341
- # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
342
- 'gen_ai.tool.call.id': message.tool_call_id,
343
- **({'tool_arguments': message.args_as_json_str()} if include_content else {}),
344
- 'logfire.msg': f'running tool: {self.name}',
345
- # add the JSON schema so these attributes are formatted nicely in Logfire
346
- 'logfire.json_schema': json.dumps(
347
- {
348
- 'type': 'object',
349
- 'properties': {
350
- **(
351
- {
352
- 'tool_arguments': {'type': 'object'},
353
- 'tool_response': {'type': 'object'},
354
- }
355
- if include_content
356
- else {}
357
- ),
358
- 'gen_ai.tool.name': {},
359
- 'gen_ai.tool.call.id': {},
360
- },
361
- }
362
- ),
363
- }
364
- with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
365
- response = await self._run(message, run_context)
366
- if include_content and span.is_recording():
367
- span.set_attribute(
368
- 'tool_response',
369
- response.model_response_str()
370
- if isinstance(response, ToolReturnPart)
371
- else response.model_response(),
372
- )
373
-
374
- return response
375
-
376
- async def _run(
377
- self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
378
- ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
379
- try:
380
- validator = self.function_schema.validator
381
- if isinstance(message.args, str):
382
- args_dict = validator.validate_json(message.args or '{}')
383
- else:
384
- args_dict = validator.validate_python(message.args or {})
385
- except ValidationError as e:
386
- return self._on_error(e, message)
387
-
388
- ctx = dataclasses.replace(
389
- run_context,
390
- retry=self.current_retry,
391
- tool_name=message.tool_name,
392
- tool_call_id=message.tool_call_id,
393
- )
394
- try:
395
- response_content = await self.function_schema.call(args_dict, ctx)
396
- except ModelRetry as e:
397
- return self._on_error(e, message)
398
-
399
- self.current_retry = 0
400
- return _messages.ToolReturnPart(
401
- tool_name=message.tool_name,
402
- content=response_content,
403
- tool_call_id=message.tool_call_id,
404
- )
405
-
406
- def _on_error(
407
- self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart
408
- ) -> _messages.RetryPromptPart:
409
- self.current_retry += 1
410
- if self.max_retries is None or self.current_retry > self.max_retries:
411
- raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
412
- else:
413
- if isinstance(exc, ValidationError):
414
- content = exc.errors(include_url=False, include_context=False)
415
- else:
416
- content = exc.message
417
- return _messages.RetryPromptPart(
418
- tool_name=call_message.tool_name,
419
- content=content,
420
- tool_call_id=call_message.tool_call_id,
421
- )
316
+ return base_tool_def
422
317
 
423
318
 
424
319
  ObjectJsonSchema: TypeAlias = dict[str, Any]
@@ -429,6 +324,9 @@ This type is used to define tools parameters (aka arguments) in [ToolDefinition]
429
324
  With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any`
430
325
  """
431
326
 
327
+ ToolKind: TypeAlias = Literal['function', 'output', 'deferred']
328
+ """Kind of tool."""
329
+
432
330
 
433
331
  @dataclass(repr=False)
434
332
  class ToolDefinition:
@@ -440,7 +338,7 @@ class ToolDefinition:
440
338
  name: str
441
339
  """The name of the tool."""
442
340
 
443
- parameters_json_schema: ObjectJsonSchema
341
+ parameters_json_schema: ObjectJsonSchema = field(default_factory=lambda: {'type': 'object', 'properties': {}})
444
342
  """The JSON schema for the tool's parameters."""
445
343
 
446
344
  description: str | None = None
@@ -464,4 +362,13 @@ class ToolDefinition:
464
362
  Note: this is currently only supported by OpenAI models.
465
363
  """
466
364
 
365
+ kind: ToolKind = field(default='function')
366
+ """The kind of tool:
367
+
368
+ - `'function'`: a tool that will be executed by Pydantic AI during an agent run and has its result returned to the model
369
+ - `'output'`: a tool that passes through an output value that ends the run
370
+ - `'deferred'`: a tool whose result will be produced outside of the Pydantic AI agent run in which it was called, because it depends on an upstream service (or user) or could take longer to generate than it's reasonable to keep the agent process running.
371
+ When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s corresponding to each deferred call.
372
+ """
373
+
467
374
  __repr__ = _utils.dataclasses_no_defaults_repr
@@ -0,0 +1,22 @@
1
+ from .abstract import AbstractToolset, ToolsetTool
2
+ from .combined import CombinedToolset
3
+ from .deferred import DeferredToolset
4
+ from .filtered import FilteredToolset
5
+ from .function import FunctionToolset
6
+ from .prefixed import PrefixedToolset
7
+ from .prepared import PreparedToolset
8
+ from .renamed import RenamedToolset
9
+ from .wrapper import WrapperToolset
10
+
11
+ __all__ = (
12
+ 'AbstractToolset',
13
+ 'ToolsetTool',
14
+ 'CombinedToolset',
15
+ 'DeferredToolset',
16
+ 'FilteredToolset',
17
+ 'FunctionToolset',
18
+ 'PrefixedToolset',
19
+ 'RenamedToolset',
20
+ 'PreparedToolset',
21
+ 'WrapperToolset',
22
+ )