pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +219 -315
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +296 -226
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -155
- pydantic_ai/common_tools/duckduckgo.py +5 -2
- pydantic_ai/exceptions.py +14 -2
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +19 -9
- pydantic_ai/models/__init__.py +43 -19
- pydantic_ai/models/anthropic.py +2 -2
- pydantic_ai/models/bedrock.py +1 -1
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +3 -11
- pydantic_ai/models/google.py +3 -12
- pydantic_ai/models/groq.py +2 -1
- pydantic_ai/models/huggingface.py +463 -0
- pydantic_ai/models/instrumented.py +1 -1
- pydantic_ai/models/mistral.py +3 -3
- pydantic_ai/models/openai.py +5 -5
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +2 -2
- pydantic_ai/providers/google_vertex.py +10 -5
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/providers/huggingface.py +88 -0
- pydantic_ai/result.py +57 -33
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
- pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
- pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/providers/grok.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import overload
|
|
4
|
+
from typing import Literal, overload
|
|
5
5
|
|
|
6
6
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
7
|
from openai import AsyncOpenAI
|
|
@@ -21,6 +21,18 @@ except ImportError as _import_error: # pragma: no cover
|
|
|
21
21
|
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
22
22
|
) from _import_error
|
|
23
23
|
|
|
24
|
+
# https://docs.x.ai/docs/models
|
|
25
|
+
GrokModelName = Literal[
|
|
26
|
+
'grok-4',
|
|
27
|
+
'grok-4-0709',
|
|
28
|
+
'grok-3',
|
|
29
|
+
'grok-3-mini',
|
|
30
|
+
'grok-3-fast',
|
|
31
|
+
'grok-3-mini-fast',
|
|
32
|
+
'grok-2-vision-1212',
|
|
33
|
+
'grok-2-image-1212',
|
|
34
|
+
]
|
|
35
|
+
|
|
24
36
|
|
|
25
37
|
class GrokProvider(Provider[AsyncOpenAI]):
|
|
26
38
|
"""Provider for Grok API."""
|
pydantic_ai/providers/groq.py
CHANGED
|
@@ -12,6 +12,7 @@ from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
|
12
12
|
from pydantic_ai.profiles.google import google_model_profile
|
|
13
13
|
from pydantic_ai.profiles.meta import meta_model_profile
|
|
14
14
|
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
15
|
+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
|
|
15
16
|
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
16
17
|
from pydantic_ai.providers import Provider
|
|
17
18
|
|
|
@@ -47,6 +48,7 @@ class GroqProvider(Provider[AsyncGroq]):
|
|
|
47
48
|
'qwen': qwen_model_profile,
|
|
48
49
|
'deepseek': deepseek_model_profile,
|
|
49
50
|
'mistral': mistral_model_profile,
|
|
51
|
+
'moonshotai/': moonshotai_model_profile,
|
|
50
52
|
}
|
|
51
53
|
|
|
52
54
|
for prefix, profile_func in prefix_to_profile.items():
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient
|
|
7
|
+
|
|
8
|
+
from pydantic_ai.exceptions import UserError
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from huggingface_hub import AsyncInferenceClient
|
|
12
|
+
except ImportError as _import_error: # pragma: no cover
|
|
13
|
+
raise ImportError(
|
|
14
|
+
'Please install the `huggingface_hub` package to use the HuggingFace provider, '
|
|
15
|
+
"you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`"
|
|
16
|
+
) from _import_error
|
|
17
|
+
|
|
18
|
+
from . import Provider
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class HuggingFaceProvider(Provider[AsyncInferenceClient]):
|
|
22
|
+
"""Provider for Hugging Face."""
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def name(self) -> str:
|
|
26
|
+
return 'huggingface'
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def base_url(self) -> str:
|
|
30
|
+
return self.client.model # type: ignore
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def client(self) -> AsyncInferenceClient:
|
|
34
|
+
return self._client
|
|
35
|
+
|
|
36
|
+
@overload
|
|
37
|
+
def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ...
|
|
38
|
+
@overload
|
|
39
|
+
def __init__(self, *, provider_name: str, api_key: str | None = None) -> None: ...
|
|
40
|
+
@overload
|
|
41
|
+
def __init__(self, *, hf_client: AsyncInferenceClient, api_key: str | None = None) -> None: ...
|
|
42
|
+
@overload
|
|
43
|
+
def __init__(self, *, hf_client: AsyncInferenceClient, base_url: str, api_key: str | None = None) -> None: ...
|
|
44
|
+
@overload
|
|
45
|
+
def __init__(self, *, hf_client: AsyncInferenceClient, provider_name: str, api_key: str | None = None) -> None: ...
|
|
46
|
+
@overload
|
|
47
|
+
def __init__(self, *, api_key: str | None = None) -> None: ...
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
base_url: str | None = None,
|
|
52
|
+
api_key: str | None = None,
|
|
53
|
+
hf_client: AsyncInferenceClient | None = None,
|
|
54
|
+
http_client: AsyncClient | None = None,
|
|
55
|
+
provider_name: str | None = None,
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Create a new Hugging Face provider.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
base_url: The base url for the Hugging Face requests.
|
|
61
|
+
api_key: The API key to use for authentication, if not provided, the `HF_TOKEN` environment variable
|
|
62
|
+
will be used if available.
|
|
63
|
+
hf_client: An existing
|
|
64
|
+
[`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
|
|
65
|
+
client to use. If not provided, a new instance will be created.
|
|
66
|
+
http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
67
|
+
provider_name : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners).
|
|
68
|
+
defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
|
69
|
+
If `base_url` is passed, then `provider_name` is not used.
|
|
70
|
+
"""
|
|
71
|
+
api_key = api_key or os.environ.get('HF_TOKEN')
|
|
72
|
+
|
|
73
|
+
if api_key is None:
|
|
74
|
+
raise UserError(
|
|
75
|
+
'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`'
|
|
76
|
+
'to use the HuggingFace provider.'
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if http_client is not None:
|
|
80
|
+
raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead.')
|
|
81
|
+
|
|
82
|
+
if base_url is not None and provider_name is not None:
|
|
83
|
+
raise ValueError('Cannot provide both `base_url` and `provider_name`.')
|
|
84
|
+
|
|
85
|
+
if hf_client is None:
|
|
86
|
+
self._client = AsyncInferenceClient(api_key=api_key, provider=provider_name, base_url=base_url) # type: ignore
|
|
87
|
+
else:
|
|
88
|
+
self._client = hf_client
|
pydantic_ai/result.py
CHANGED
|
@@ -5,11 +5,13 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
|
|
5
5
|
from copy import copy
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Generic
|
|
8
|
+
from typing import Generic, cast
|
|
9
9
|
|
|
10
10
|
from pydantic import ValidationError
|
|
11
11
|
from typing_extensions import TypeVar, deprecated, overload
|
|
12
12
|
|
|
13
|
+
from pydantic_ai._tool_manager import ToolManager
|
|
14
|
+
|
|
13
15
|
from . import _utils, exceptions, messages as _messages, models
|
|
14
16
|
from ._output import (
|
|
15
17
|
OutputDataT_inv,
|
|
@@ -47,6 +49,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
47
49
|
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
|
|
48
50
|
_run_ctx: RunContext[AgentDepsT]
|
|
49
51
|
_usage_limits: UsageLimits | None
|
|
52
|
+
_tool_manager: ToolManager[AgentDepsT]
|
|
50
53
|
|
|
51
54
|
_agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
|
|
52
55
|
_final_result_event: FinalResultEvent | None = field(default=None, init=False)
|
|
@@ -95,33 +98,40 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
95
98
|
self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False
|
|
96
99
|
) -> OutputDataT:
|
|
97
100
|
"""Validate a structured result message."""
|
|
98
|
-
call = None
|
|
99
101
|
if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
+
tool_call = next(
|
|
103
|
+
(
|
|
104
|
+
part
|
|
105
|
+
for part in message.parts
|
|
106
|
+
if isinstance(part, _messages.ToolCallPart) and part.tool_name == output_tool_name
|
|
107
|
+
),
|
|
108
|
+
None,
|
|
109
|
+
)
|
|
110
|
+
if tool_call is None:
|
|
102
111
|
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
103
|
-
f'Invalid response, unable to find tool
|
|
112
|
+
f'Invalid response, unable to find tool call for {output_tool_name!r}'
|
|
104
113
|
)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
114
|
+
return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
|
|
115
|
+
elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
|
|
116
|
+
if not self._output_schema.allows_deferred_tool_calls:
|
|
117
|
+
raise exceptions.UserError( # pragma: no cover
|
|
118
|
+
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
|
|
119
|
+
)
|
|
120
|
+
return cast(OutputDataT, deferred_tool_calls)
|
|
110
121
|
elif isinstance(self._output_schema, TextOutputSchema):
|
|
111
122
|
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
112
123
|
|
|
113
124
|
result_data = await self._output_schema.process(
|
|
114
125
|
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
|
|
115
126
|
)
|
|
127
|
+
for validator in self._output_validators:
|
|
128
|
+
result_data = await validator.validate(result_data, self._run_ctx)
|
|
129
|
+
return result_data
|
|
116
130
|
else:
|
|
117
131
|
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
118
132
|
'Invalid response, unable to process text output'
|
|
119
133
|
)
|
|
120
134
|
|
|
121
|
-
for validator in self._output_validators:
|
|
122
|
-
result_data = await validator.validate(result_data, call, self._run_ctx)
|
|
123
|
-
return result_data
|
|
124
|
-
|
|
125
135
|
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
|
|
126
136
|
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
|
|
127
137
|
|
|
@@ -139,13 +149,19 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
139
149
|
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
|
|
140
150
|
if isinstance(e, _messages.PartStartEvent):
|
|
141
151
|
new_part = e.part
|
|
142
|
-
if isinstance(new_part, _messages.
|
|
143
|
-
for call, _ in output_schema.find_tool([new_part]): # pragma: no branch
|
|
144
|
-
return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id)
|
|
145
|
-
elif isinstance(new_part, _messages.TextPart) and isinstance(
|
|
152
|
+
if isinstance(new_part, _messages.TextPart) and isinstance(
|
|
146
153
|
output_schema, TextOutputSchema
|
|
147
154
|
): # pragma: no branch
|
|
148
155
|
return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
|
|
156
|
+
elif isinstance(new_part, _messages.ToolCallPart) and (
|
|
157
|
+
tool_def := self._tool_manager.get_tool_def(new_part.tool_name)
|
|
158
|
+
):
|
|
159
|
+
if tool_def.kind == 'output':
|
|
160
|
+
return _messages.FinalResultEvent(
|
|
161
|
+
tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id
|
|
162
|
+
)
|
|
163
|
+
elif tool_def.kind == 'deferred':
|
|
164
|
+
return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
|
|
149
165
|
|
|
150
166
|
usage_checking_stream = _get_usage_checking_stream_response(
|
|
151
167
|
self._raw_stream_response, self._usage_limits, self.usage
|
|
@@ -180,6 +196,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
180
196
|
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
|
|
181
197
|
_output_tool_name: str | None
|
|
182
198
|
_on_complete: Callable[[], Awaitable[None]]
|
|
199
|
+
_tool_manager: ToolManager[AgentDepsT]
|
|
183
200
|
|
|
184
201
|
_initial_run_ctx_usage: Usage = field(init=False)
|
|
185
202
|
is_complete: bool = field(default=False, init=False)
|
|
@@ -320,7 +337,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
320
337
|
yield await self.validate_structured_output(structured_message, allow_partial=not is_last)
|
|
321
338
|
except ValidationError:
|
|
322
339
|
if is_last:
|
|
323
|
-
raise # pragma:
|
|
340
|
+
raise # pragma: no cover
|
|
324
341
|
|
|
325
342
|
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
|
|
326
343
|
"""Stream the text result as an async iterable.
|
|
@@ -413,36 +430,43 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
413
430
|
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
414
431
|
) -> OutputDataT:
|
|
415
432
|
"""Validate a structured result message."""
|
|
416
|
-
call = None
|
|
417
433
|
if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None:
|
|
418
|
-
|
|
419
|
-
|
|
434
|
+
tool_call = next(
|
|
435
|
+
(
|
|
436
|
+
part
|
|
437
|
+
for part in message.parts
|
|
438
|
+
if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name
|
|
439
|
+
),
|
|
440
|
+
None,
|
|
441
|
+
)
|
|
442
|
+
if tool_call is None:
|
|
420
443
|
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
421
|
-
f'Invalid response, unable to find tool
|
|
444
|
+
f'Invalid response, unable to find tool call for {self._output_tool_name!r}'
|
|
422
445
|
)
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
446
|
+
return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
|
|
447
|
+
elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
|
|
448
|
+
if not self._output_schema.allows_deferred_tool_calls:
|
|
449
|
+
raise exceptions.UserError(
|
|
450
|
+
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
|
|
451
|
+
)
|
|
452
|
+
return cast(OutputDataT, deferred_tool_calls)
|
|
428
453
|
elif isinstance(self._output_schema, TextOutputSchema):
|
|
429
454
|
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
430
455
|
|
|
431
456
|
result_data = await self._output_schema.process(
|
|
432
457
|
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
|
|
433
458
|
)
|
|
459
|
+
for validator in self._output_validators:
|
|
460
|
+
result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover
|
|
461
|
+
return result_data
|
|
434
462
|
else:
|
|
435
463
|
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
436
464
|
'Invalid response, unable to process text output'
|
|
437
465
|
)
|
|
438
466
|
|
|
439
|
-
for validator in self._output_validators:
|
|
440
|
-
result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover
|
|
441
|
-
return result_data
|
|
442
|
-
|
|
443
467
|
async def _validate_text_output(self, text: str) -> str:
|
|
444
468
|
for validator in self._output_validators:
|
|
445
|
-
text = await validator.validate(text,
|
|
469
|
+
text = await validator.validate(text, self._run_ctx) # pragma: no cover
|
|
446
470
|
return text
|
|
447
471
|
|
|
448
472
|
async def _marked_completed(self, message: _messages.ModelResponse) -> None:
|
pydantic_ai/tools.py
CHANGED
|
@@ -1,20 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
import dataclasses
|
|
4
|
-
import json
|
|
5
3
|
from collections.abc import Awaitable, Sequence
|
|
6
4
|
from dataclasses import dataclass, field
|
|
7
5
|
from typing import Any, Callable, Generic, Literal, Union
|
|
8
6
|
|
|
9
|
-
from opentelemetry.trace import Tracer
|
|
10
|
-
from pydantic import ValidationError
|
|
11
7
|
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
|
|
12
8
|
from pydantic_core import SchemaValidator, core_schema
|
|
13
9
|
from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar
|
|
14
10
|
|
|
15
|
-
from . import _function_schema, _utils
|
|
11
|
+
from . import _function_schema, _utils
|
|
16
12
|
from ._run_context import AgentDepsT, RunContext
|
|
17
|
-
from .exceptions import ModelRetry, UnexpectedModelBehavior
|
|
18
13
|
|
|
19
14
|
__all__ = (
|
|
20
15
|
'AgentDepsT',
|
|
@@ -32,7 +27,6 @@ __all__ = (
|
|
|
32
27
|
'ToolDefinition',
|
|
33
28
|
)
|
|
34
29
|
|
|
35
|
-
from .messages import ToolReturnPart
|
|
36
30
|
|
|
37
31
|
ToolParams = ParamSpec('ToolParams', default=...)
|
|
38
32
|
"""Retrieval function param spec."""
|
|
@@ -173,12 +167,6 @@ class Tool(Generic[AgentDepsT]):
|
|
|
173
167
|
This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request.
|
|
174
168
|
"""
|
|
175
169
|
|
|
176
|
-
# TODO: Consider moving this current_retry state to live on something other than the tool.
|
|
177
|
-
# We've worked around this for now by copying instances of the tool when creating new runs,
|
|
178
|
-
# but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things
|
|
179
|
-
# up, though is also likely a larger effort to refactor.
|
|
180
|
-
current_retry: int = field(default=0, init=False)
|
|
181
|
-
|
|
182
170
|
def __init__(
|
|
183
171
|
self,
|
|
184
172
|
function: ToolFuncEither[AgentDepsT],
|
|
@@ -303,6 +291,15 @@ class Tool(Generic[AgentDepsT]):
|
|
|
303
291
|
function_schema=function_schema,
|
|
304
292
|
)
|
|
305
293
|
|
|
294
|
+
@property
|
|
295
|
+
def tool_def(self):
|
|
296
|
+
return ToolDefinition(
|
|
297
|
+
name=self.name,
|
|
298
|
+
description=self.description,
|
|
299
|
+
parameters_json_schema=self.function_schema.json_schema,
|
|
300
|
+
strict=self.strict,
|
|
301
|
+
)
|
|
302
|
+
|
|
306
303
|
async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
|
|
307
304
|
"""Get the tool definition.
|
|
308
305
|
|
|
@@ -312,113 +309,11 @@ class Tool(Generic[AgentDepsT]):
|
|
|
312
309
|
Returns:
|
|
313
310
|
return a `ToolDefinition` or `None` if the tools should not be registered for this run.
|
|
314
311
|
"""
|
|
315
|
-
|
|
316
|
-
name=self.name,
|
|
317
|
-
description=self.description,
|
|
318
|
-
parameters_json_schema=self.function_schema.json_schema,
|
|
319
|
-
strict=self.strict,
|
|
320
|
-
)
|
|
312
|
+
base_tool_def = self.tool_def
|
|
321
313
|
if self.prepare is not None:
|
|
322
|
-
return await self.prepare(ctx,
|
|
314
|
+
return await self.prepare(ctx, base_tool_def)
|
|
323
315
|
else:
|
|
324
|
-
return
|
|
325
|
-
|
|
326
|
-
async def run(
|
|
327
|
-
self,
|
|
328
|
-
message: _messages.ToolCallPart,
|
|
329
|
-
run_context: RunContext[AgentDepsT],
|
|
330
|
-
tracer: Tracer,
|
|
331
|
-
include_content: bool = False,
|
|
332
|
-
) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
|
|
333
|
-
"""Run the tool function asynchronously.
|
|
334
|
-
|
|
335
|
-
This method wraps `_run` in an OpenTelemetry span.
|
|
336
|
-
|
|
337
|
-
See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>.
|
|
338
|
-
"""
|
|
339
|
-
span_attributes = {
|
|
340
|
-
'gen_ai.tool.name': self.name,
|
|
341
|
-
# NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
|
|
342
|
-
'gen_ai.tool.call.id': message.tool_call_id,
|
|
343
|
-
**({'tool_arguments': message.args_as_json_str()} if include_content else {}),
|
|
344
|
-
'logfire.msg': f'running tool: {self.name}',
|
|
345
|
-
# add the JSON schema so these attributes are formatted nicely in Logfire
|
|
346
|
-
'logfire.json_schema': json.dumps(
|
|
347
|
-
{
|
|
348
|
-
'type': 'object',
|
|
349
|
-
'properties': {
|
|
350
|
-
**(
|
|
351
|
-
{
|
|
352
|
-
'tool_arguments': {'type': 'object'},
|
|
353
|
-
'tool_response': {'type': 'object'},
|
|
354
|
-
}
|
|
355
|
-
if include_content
|
|
356
|
-
else {}
|
|
357
|
-
),
|
|
358
|
-
'gen_ai.tool.name': {},
|
|
359
|
-
'gen_ai.tool.call.id': {},
|
|
360
|
-
},
|
|
361
|
-
}
|
|
362
|
-
),
|
|
363
|
-
}
|
|
364
|
-
with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
|
|
365
|
-
response = await self._run(message, run_context)
|
|
366
|
-
if include_content and span.is_recording():
|
|
367
|
-
span.set_attribute(
|
|
368
|
-
'tool_response',
|
|
369
|
-
response.model_response_str()
|
|
370
|
-
if isinstance(response, ToolReturnPart)
|
|
371
|
-
else response.model_response(),
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
return response
|
|
375
|
-
|
|
376
|
-
async def _run(
|
|
377
|
-
self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
|
|
378
|
-
) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
|
|
379
|
-
try:
|
|
380
|
-
validator = self.function_schema.validator
|
|
381
|
-
if isinstance(message.args, str):
|
|
382
|
-
args_dict = validator.validate_json(message.args or '{}')
|
|
383
|
-
else:
|
|
384
|
-
args_dict = validator.validate_python(message.args or {})
|
|
385
|
-
except ValidationError as e:
|
|
386
|
-
return self._on_error(e, message)
|
|
387
|
-
|
|
388
|
-
ctx = dataclasses.replace(
|
|
389
|
-
run_context,
|
|
390
|
-
retry=self.current_retry,
|
|
391
|
-
tool_name=message.tool_name,
|
|
392
|
-
tool_call_id=message.tool_call_id,
|
|
393
|
-
)
|
|
394
|
-
try:
|
|
395
|
-
response_content = await self.function_schema.call(args_dict, ctx)
|
|
396
|
-
except ModelRetry as e:
|
|
397
|
-
return self._on_error(e, message)
|
|
398
|
-
|
|
399
|
-
self.current_retry = 0
|
|
400
|
-
return _messages.ToolReturnPart(
|
|
401
|
-
tool_name=message.tool_name,
|
|
402
|
-
content=response_content,
|
|
403
|
-
tool_call_id=message.tool_call_id,
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
def _on_error(
|
|
407
|
-
self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart
|
|
408
|
-
) -> _messages.RetryPromptPart:
|
|
409
|
-
self.current_retry += 1
|
|
410
|
-
if self.max_retries is None or self.current_retry > self.max_retries:
|
|
411
|
-
raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
|
|
412
|
-
else:
|
|
413
|
-
if isinstance(exc, ValidationError):
|
|
414
|
-
content = exc.errors(include_url=False, include_context=False)
|
|
415
|
-
else:
|
|
416
|
-
content = exc.message
|
|
417
|
-
return _messages.RetryPromptPart(
|
|
418
|
-
tool_name=call_message.tool_name,
|
|
419
|
-
content=content,
|
|
420
|
-
tool_call_id=call_message.tool_call_id,
|
|
421
|
-
)
|
|
316
|
+
return base_tool_def
|
|
422
317
|
|
|
423
318
|
|
|
424
319
|
ObjectJsonSchema: TypeAlias = dict[str, Any]
|
|
@@ -429,6 +324,9 @@ This type is used to define tools parameters (aka arguments) in [ToolDefinition]
|
|
|
429
324
|
With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any`
|
|
430
325
|
"""
|
|
431
326
|
|
|
327
|
+
ToolKind: TypeAlias = Literal['function', 'output', 'deferred']
|
|
328
|
+
"""Kind of tool."""
|
|
329
|
+
|
|
432
330
|
|
|
433
331
|
@dataclass(repr=False)
|
|
434
332
|
class ToolDefinition:
|
|
@@ -440,7 +338,7 @@ class ToolDefinition:
|
|
|
440
338
|
name: str
|
|
441
339
|
"""The name of the tool."""
|
|
442
340
|
|
|
443
|
-
parameters_json_schema: ObjectJsonSchema
|
|
341
|
+
parameters_json_schema: ObjectJsonSchema = field(default_factory=lambda: {'type': 'object', 'properties': {}})
|
|
444
342
|
"""The JSON schema for the tool's parameters."""
|
|
445
343
|
|
|
446
344
|
description: str | None = None
|
|
@@ -464,4 +362,13 @@ class ToolDefinition:
|
|
|
464
362
|
Note: this is currently only supported by OpenAI models.
|
|
465
363
|
"""
|
|
466
364
|
|
|
365
|
+
kind: ToolKind = field(default='function')
|
|
366
|
+
"""The kind of tool:
|
|
367
|
+
|
|
368
|
+
- `'function'`: a tool that will be executed by Pydantic AI during an agent run and has its result returned to the model
|
|
369
|
+
- `'output'`: a tool that passes through an output value that ends the run
|
|
370
|
+
- `'deferred'`: a tool whose result will be produced outside of the Pydantic AI agent run in which it was called, because it depends on an upstream service (or user) or could take longer to generate than it's reasonable to keep the agent process running.
|
|
371
|
+
When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s corresponding to each deferred call.
|
|
372
|
+
"""
|
|
373
|
+
|
|
467
374
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .abstract import AbstractToolset, ToolsetTool
|
|
2
|
+
from .combined import CombinedToolset
|
|
3
|
+
from .deferred import DeferredToolset
|
|
4
|
+
from .filtered import FilteredToolset
|
|
5
|
+
from .function import FunctionToolset
|
|
6
|
+
from .prefixed import PrefixedToolset
|
|
7
|
+
from .prepared import PreparedToolset
|
|
8
|
+
from .renamed import RenamedToolset
|
|
9
|
+
from .wrapper import WrapperToolset
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
'AbstractToolset',
|
|
13
|
+
'ToolsetTool',
|
|
14
|
+
'CombinedToolset',
|
|
15
|
+
'DeferredToolset',
|
|
16
|
+
'FilteredToolset',
|
|
17
|
+
'FunctionToolset',
|
|
18
|
+
'PrefixedToolset',
|
|
19
|
+
'RenamedToolset',
|
|
20
|
+
'PreparedToolset',
|
|
21
|
+
'WrapperToolset',
|
|
22
|
+
)
|