pydantic-ai-slim 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +8 -0
- pydantic_ai/_griffe.py +128 -0
- pydantic_ai/_pydantic.py +216 -0
- pydantic_ai/_result.py +258 -0
- pydantic_ai/_retriever.py +114 -0
- pydantic_ai/_system_prompt.py +33 -0
- pydantic_ai/_utils.py +247 -0
- pydantic_ai/agent.py +795 -0
- pydantic_ai/dependencies.py +83 -0
- pydantic_ai/exceptions.py +56 -0
- pydantic_ai/messages.py +205 -0
- pydantic_ai/models/__init__.py +300 -0
- pydantic_ai/models/function.py +268 -0
- pydantic_ai/models/gemini.py +720 -0
- pydantic_ai/models/groq.py +400 -0
- pydantic_ai/models/openai.py +379 -0
- pydantic_ai/models/test.py +389 -0
- pydantic_ai/models/vertexai.py +306 -0
- pydantic_ai/py.typed +0 -0
- pydantic_ai/result.py +314 -0
- pydantic_ai_slim-0.0.6.dist-info/METADATA +49 -0
- pydantic_ai_slim-0.0.6.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.6.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Awaitable, Mapping, Sequence
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
|
|
6
|
+
|
|
7
|
+
from typing_extensions import Concatenate, ParamSpec, TypeAlias
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from .result import ResultData
|
|
11
|
+
else:
|
|
12
|
+
ResultData = Any
|
|
13
|
+
|
|
14
|
+
__all__ = (
|
|
15
|
+
'AgentDeps',
|
|
16
|
+
'CallContext',
|
|
17
|
+
'ResultValidatorFunc',
|
|
18
|
+
'SystemPromptFunc',
|
|
19
|
+
'RetrieverReturnValue',
|
|
20
|
+
'RetrieverContextFunc',
|
|
21
|
+
'RetrieverPlainFunc',
|
|
22
|
+
'RetrieverParams',
|
|
23
|
+
'JsonData',
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
AgentDeps = TypeVar('AgentDeps')
|
|
27
|
+
"""Type variable for agent dependencies."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class CallContext(Generic[AgentDeps]):
|
|
32
|
+
"""Information about the current call."""
|
|
33
|
+
|
|
34
|
+
deps: AgentDeps
|
|
35
|
+
"""Dependencies for the agent."""
|
|
36
|
+
retry: int
|
|
37
|
+
"""Number of retries so far."""
|
|
38
|
+
tool_name: str | None
|
|
39
|
+
"""Name of the tool being called."""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
RetrieverParams = ParamSpec('RetrieverParams')
|
|
43
|
+
"""Retrieval function param spec."""
|
|
44
|
+
|
|
45
|
+
SystemPromptFunc = Union[
|
|
46
|
+
Callable[[CallContext[AgentDeps]], str],
|
|
47
|
+
Callable[[CallContext[AgentDeps]], Awaitable[str]],
|
|
48
|
+
Callable[[], str],
|
|
49
|
+
Callable[[], Awaitable[str]],
|
|
50
|
+
]
|
|
51
|
+
"""A function that may or maybe not take `CallContext` as an argument, and may or may not be async.
|
|
52
|
+
|
|
53
|
+
Usage `SystemPromptFunc[AgentDeps]`.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
ResultValidatorFunc = Union[
|
|
57
|
+
Callable[[CallContext[AgentDeps], ResultData], ResultData],
|
|
58
|
+
Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
59
|
+
Callable[[ResultData], ResultData],
|
|
60
|
+
Callable[[ResultData], Awaitable[ResultData]],
|
|
61
|
+
]
|
|
62
|
+
"""
|
|
63
|
+
A function that always takes `ResultData` and returns `ResultData`,
|
|
64
|
+
but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
|
|
65
|
+
|
|
66
|
+
Usage `ResultValidator[AgentDeps, ResultData]`.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
|
|
70
|
+
"""Type representing any JSON data."""
|
|
71
|
+
|
|
72
|
+
RetrieverReturnValue = Union[JsonData, Awaitable[JsonData]]
|
|
73
|
+
"""Return value of a retriever function."""
|
|
74
|
+
RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], RetrieverParams], RetrieverReturnValue]
|
|
75
|
+
"""A retriever function that takes `CallContext` as the first argument.
|
|
76
|
+
|
|
77
|
+
Usage `RetrieverContextFunc[AgentDeps, RetrieverParams]`.
|
|
78
|
+
"""
|
|
79
|
+
RetrieverPlainFunc = Callable[RetrieverParams, RetrieverReturnValue]
|
|
80
|
+
"""A retriever function that does not take `CallContext` as the first argument.
|
|
81
|
+
|
|
82
|
+
Usage `RetrieverPlainFunc[RetrieverParams]`.
|
|
83
|
+
"""
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
__all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ModelRetry(Exception):
|
|
9
|
+
"""Exception raised when a retriever function should be retried.
|
|
10
|
+
|
|
11
|
+
The agent will return the message to the model and ask it to try calling the function/tool again.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
message: str
|
|
15
|
+
"""The message to return to the model."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, message: str):
|
|
18
|
+
self.message = message
|
|
19
|
+
super().__init__(message)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class UserError(RuntimeError):
|
|
23
|
+
"""Error caused by a usage mistake by the application developer — You!"""
|
|
24
|
+
|
|
25
|
+
message: str
|
|
26
|
+
"""Description of the mistake."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, message: str):
|
|
29
|
+
self.message = message
|
|
30
|
+
super().__init__(message)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class UnexpectedModelBehavior(RuntimeError):
|
|
34
|
+
"""Error caused by unexpected Model behavior, e.g. an unexpected response code."""
|
|
35
|
+
|
|
36
|
+
message: str
|
|
37
|
+
"""Description of the unexpected behavior."""
|
|
38
|
+
body: str | None
|
|
39
|
+
"""The body of the response, if available."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, message: str, body: str | None = None):
|
|
42
|
+
self.message = message
|
|
43
|
+
if body is None:
|
|
44
|
+
self.body: str | None = None
|
|
45
|
+
else:
|
|
46
|
+
try:
|
|
47
|
+
self.body = json.dumps(json.loads(body), indent=2)
|
|
48
|
+
except ValueError:
|
|
49
|
+
self.body = body
|
|
50
|
+
super().__init__(message)
|
|
51
|
+
|
|
52
|
+
def __str__(self) -> str:
|
|
53
|
+
if self.body:
|
|
54
|
+
return f'{self.message}, body:\n{self.body}'
|
|
55
|
+
else:
|
|
56
|
+
return self.message
|
pydantic_ai/messages.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Mapping, Sequence
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
|
|
8
|
+
|
|
9
|
+
import pydantic
|
|
10
|
+
import pydantic_core
|
|
11
|
+
from pydantic import TypeAdapter
|
|
12
|
+
from typing_extensions import TypeAlias, TypeAliasType
|
|
13
|
+
|
|
14
|
+
from . import _pydantic
|
|
15
|
+
from ._utils import now_utc as _now_utc
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class SystemPrompt:
|
|
20
|
+
"""A system prompt, generally written by the application developer.
|
|
21
|
+
|
|
22
|
+
This gives the model context and guidance on how to respond.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
content: str
|
|
26
|
+
"""The content of the prompt."""
|
|
27
|
+
role: Literal['system'] = 'system'
|
|
28
|
+
"""Message type identifier, this type is available on all message as a discriminator."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class UserPrompt:
|
|
33
|
+
"""A user prompt, generally written by the end user.
|
|
34
|
+
|
|
35
|
+
Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run],
|
|
36
|
+
[`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream].
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
content: str
|
|
40
|
+
"""The content of the prompt."""
|
|
41
|
+
timestamp: datetime = field(default_factory=_now_utc)
|
|
42
|
+
"""The timestamp of the prompt."""
|
|
43
|
+
role: Literal['user'] = 'user'
|
|
44
|
+
"""Message type identifier, this type is available on all message as a discriminator."""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
JsonData: TypeAlias = 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]'
|
|
48
|
+
if not TYPE_CHECKING:
|
|
49
|
+
# work around for https://github.com/pydantic/pydantic/issues/10873
|
|
50
|
+
# this is need for pydantic to work both `json_ta` and `MessagesTypeAdapter` at the bottom of this file
|
|
51
|
+
JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]')
|
|
52
|
+
|
|
53
|
+
json_ta: TypeAdapter[JsonData] = TypeAdapter(JsonData)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class ToolReturn:
|
|
58
|
+
"""A tool return message, this encodes the result of running a retriever."""
|
|
59
|
+
|
|
60
|
+
tool_name: str
|
|
61
|
+
"""The name of the "tool" was called."""
|
|
62
|
+
content: JsonData
|
|
63
|
+
"""The return value."""
|
|
64
|
+
tool_id: str | None = None
|
|
65
|
+
"""Optional tool identifier, this is used by some models including OpenAI."""
|
|
66
|
+
timestamp: datetime = field(default_factory=_now_utc)
|
|
67
|
+
"""The timestamp, when the tool returned."""
|
|
68
|
+
role: Literal['tool-return'] = 'tool-return'
|
|
69
|
+
"""Message type identifier, this type is available on all message as a discriminator."""
|
|
70
|
+
|
|
71
|
+
def model_response_str(self) -> str:
|
|
72
|
+
if isinstance(self.content, str):
|
|
73
|
+
return self.content
|
|
74
|
+
else:
|
|
75
|
+
content = json_ta.validate_python(self.content)
|
|
76
|
+
return json_ta.dump_json(content).decode()
|
|
77
|
+
|
|
78
|
+
def model_response_object(self) -> dict[str, JsonData]:
|
|
79
|
+
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
|
|
80
|
+
if isinstance(self.content, dict):
|
|
81
|
+
return json_ta.validate_python(self.content) # pyright: ignore[reportReturnType]
|
|
82
|
+
else:
|
|
83
|
+
return {'return_value': json_ta.validate_python(self.content)}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class RetryPrompt:
|
|
88
|
+
"""A message back to a model asking it to try again.
|
|
89
|
+
|
|
90
|
+
This can be sent for a number of reasons:
|
|
91
|
+
|
|
92
|
+
* Pydantic validation of retriever arguments failed, here content is derived from a Pydantic
|
|
93
|
+
[`ValidationError`][pydantic_core.ValidationError]
|
|
94
|
+
* a retriever raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception
|
|
95
|
+
* no retriever was found for the tool name
|
|
96
|
+
* the model returned plain text when a structured response was expected
|
|
97
|
+
* Pydantic validation of a structured response failed, here content is derived from a Pydantic
|
|
98
|
+
[`ValidationError`][pydantic_core.ValidationError]
|
|
99
|
+
* a result validator raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
content: list[pydantic_core.ErrorDetails] | str
|
|
103
|
+
"""Details of why and how the model should retry.
|
|
104
|
+
|
|
105
|
+
If the retry was triggered by a [`ValidationError`][pydantic_core.ValidationError], this will be a list of
|
|
106
|
+
error details.
|
|
107
|
+
"""
|
|
108
|
+
tool_name: str | None = None
|
|
109
|
+
"""The name of the tool that was called, if any."""
|
|
110
|
+
tool_id: str | None = None
|
|
111
|
+
"""The tool identifier, if any."""
|
|
112
|
+
timestamp: datetime = field(default_factory=_now_utc)
|
|
113
|
+
"""The timestamp, when the retry was triggered."""
|
|
114
|
+
role: Literal['retry-prompt'] = 'retry-prompt'
|
|
115
|
+
"""Message type identifier, this type is available on all message as a discriminator."""
|
|
116
|
+
|
|
117
|
+
def model_response(self) -> str:
|
|
118
|
+
if isinstance(self.content, str):
|
|
119
|
+
description = self.content
|
|
120
|
+
else:
|
|
121
|
+
description = f'{len(self.content)} validation errors: {json.dumps(self.content, indent=2)}'
|
|
122
|
+
return f'{description}\n\nFix the errors and try again.'
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@dataclass
|
|
126
|
+
class ModelTextResponse:
|
|
127
|
+
"""A plain text response from a model."""
|
|
128
|
+
|
|
129
|
+
content: str
|
|
130
|
+
"""The text content of the response."""
|
|
131
|
+
timestamp: datetime = field(default_factory=_now_utc)
|
|
132
|
+
"""The timestamp of the response.
|
|
133
|
+
|
|
134
|
+
If the model provides a timestamp in the response (as OpenAI does) that will be used.
|
|
135
|
+
"""
|
|
136
|
+
role: Literal['model-text-response'] = 'model-text-response'
|
|
137
|
+
"""Message type identifier, this type is available on all message as a discriminator."""
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@dataclass
|
|
141
|
+
class ArgsJson:
|
|
142
|
+
args_json: str
|
|
143
|
+
"""A JSON string of arguments."""
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass
|
|
147
|
+
class ArgsObject:
|
|
148
|
+
args_object: dict[str, Any]
|
|
149
|
+
"""A python dictionary of arguments."""
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@dataclass
|
|
153
|
+
class ToolCall:
|
|
154
|
+
"""Either a tool call from the agent."""
|
|
155
|
+
|
|
156
|
+
tool_name: str
|
|
157
|
+
"""The name of the tool to call."""
|
|
158
|
+
args: ArgsJson | ArgsObject
|
|
159
|
+
"""The arguments to pass to the tool.
|
|
160
|
+
|
|
161
|
+
Either as JSON or a Python dictionary depending on how data was returned.
|
|
162
|
+
"""
|
|
163
|
+
tool_id: str | None = None
|
|
164
|
+
"""Optional tool identifier, this is used by some models including OpenAI."""
|
|
165
|
+
|
|
166
|
+
@classmethod
|
|
167
|
+
def from_json(cls, tool_name: str, args_json: str, tool_id: str | None = None) -> ToolCall:
|
|
168
|
+
return cls(tool_name, ArgsJson(args_json), tool_id)
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def from_object(cls, tool_name: str, args_object: dict[str, Any]) -> ToolCall:
|
|
172
|
+
return cls(tool_name, ArgsObject(args_object))
|
|
173
|
+
|
|
174
|
+
def has_content(self) -> bool:
|
|
175
|
+
if isinstance(self.args, ArgsObject):
|
|
176
|
+
return any(self.args.args_object.values())
|
|
177
|
+
else:
|
|
178
|
+
return bool(self.args.args_json)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@dataclass
|
|
182
|
+
class ModelStructuredResponse:
|
|
183
|
+
"""A structured response from a model.
|
|
184
|
+
|
|
185
|
+
This is used either to call a retriever or to return a structured response from an agent run.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
calls: list[ToolCall]
|
|
189
|
+
"""The tool calls being made."""
|
|
190
|
+
timestamp: datetime = field(default_factory=_now_utc)
|
|
191
|
+
"""The timestamp of the response.
|
|
192
|
+
|
|
193
|
+
If the model provides a timestamp in the response (as OpenAI does) that will be used.
|
|
194
|
+
"""
|
|
195
|
+
role: Literal['model-structured-response'] = 'model-structured-response'
|
|
196
|
+
"""Message type identifier, this type is available on all message as a discriminator."""
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
ModelAnyResponse = Union[ModelTextResponse, ModelStructuredResponse]
|
|
200
|
+
"""Any response from a model."""
|
|
201
|
+
Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, ModelTextResponse, ModelStructuredResponse]
|
|
202
|
+
"""Any message send to or returned by a model."""
|
|
203
|
+
|
|
204
|
+
MessagesTypeAdapter = _pydantic.LazyTypeAdapter(list[Annotated[Message, pydantic.Field(discriminator='role')]])
|
|
205
|
+
"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
"""Logic related to making requests to an LLM.
|
|
2
|
+
|
|
3
|
+
The aim here is to make a common interface for different LLMs, so that the rest of the code can be agnostic to the
|
|
4
|
+
specific LLM being used.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations as _annotations
|
|
8
|
+
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence
|
|
11
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from functools import cache
|
|
14
|
+
from typing import TYPE_CHECKING, Literal, Protocol, Union
|
|
15
|
+
|
|
16
|
+
import httpx
|
|
17
|
+
|
|
18
|
+
from ..exceptions import UserError
|
|
19
|
+
from ..messages import Message, ModelAnyResponse, ModelStructuredResponse
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from .._utils import ObjectJsonSchema
|
|
23
|
+
from ..result import Cost
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
KnownModelName = Literal[
|
|
27
|
+
'openai:gpt-4o',
|
|
28
|
+
'openai:gpt-4o-mini',
|
|
29
|
+
'openai:gpt-4-turbo',
|
|
30
|
+
'openai:gpt-4',
|
|
31
|
+
'openai:o1-preview',
|
|
32
|
+
'openai:o1-mini',
|
|
33
|
+
'openai:gpt-3.5-turbo',
|
|
34
|
+
'groq:llama-3.1-70b-versatile',
|
|
35
|
+
'groq:llama3-groq-70b-8192-tool-use-preview',
|
|
36
|
+
'groq:llama3-groq-8b-8192-tool-use-preview',
|
|
37
|
+
'groq:llama-3.1-70b-specdec',
|
|
38
|
+
'groq:llama-3.1-8b-instant',
|
|
39
|
+
'groq:llama-3.2-1b-preview',
|
|
40
|
+
'groq:llama-3.2-3b-preview',
|
|
41
|
+
'groq:llama-3.2-11b-vision-preview',
|
|
42
|
+
'groq:llama-3.2-90b-vision-preview',
|
|
43
|
+
'groq:llama3-70b-8192',
|
|
44
|
+
'groq:llama3-8b-8192',
|
|
45
|
+
'groq:mixtral-8x7b-32768',
|
|
46
|
+
'groq:gemma2-9b-it',
|
|
47
|
+
'groq:gemma-7b-it',
|
|
48
|
+
'gemini-1.5-flash',
|
|
49
|
+
'gemini-1.5-pro',
|
|
50
|
+
'vertexai:gemini-1.5-flash',
|
|
51
|
+
'vertexai:gemini-1.5-pro',
|
|
52
|
+
'test',
|
|
53
|
+
]
|
|
54
|
+
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
55
|
+
|
|
56
|
+
`KnownModelName` is provided as a concise way to specify a model.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Model(ABC):
|
|
61
|
+
"""Abstract class for a model."""
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
async def agent_model(
|
|
65
|
+
self,
|
|
66
|
+
retrievers: Mapping[str, AbstractToolDefinition],
|
|
67
|
+
allow_text_result: bool,
|
|
68
|
+
result_tools: Sequence[AbstractToolDefinition] | None,
|
|
69
|
+
) -> AgentModel:
|
|
70
|
+
"""Create an agent model.
|
|
71
|
+
|
|
72
|
+
This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
retrievers: The retrievers available to the agent.
|
|
76
|
+
allow_text_result: Whether a plain text final response/result is permitted.
|
|
77
|
+
result_tools: Tool definitions for the final result tool(s), if any.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
An agent model.
|
|
81
|
+
"""
|
|
82
|
+
raise NotImplementedError()
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def name(self) -> str:
|
|
86
|
+
raise NotImplementedError()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class AgentModel(ABC):
|
|
90
|
+
"""Model configured for a specific agent."""
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
|
|
94
|
+
"""Make a request to the model."""
|
|
95
|
+
raise NotImplementedError()
|
|
96
|
+
|
|
97
|
+
@asynccontextmanager
|
|
98
|
+
async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
|
|
99
|
+
"""Make a request to the model and return a streaming response."""
|
|
100
|
+
raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
|
|
101
|
+
# yield is required to make this a generator for type checking
|
|
102
|
+
# noinspection PyUnreachableCode
|
|
103
|
+
yield # pragma: no cover
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class StreamTextResponse(ABC):
|
|
107
|
+
"""Streamed response from an LLM when returning text."""
|
|
108
|
+
|
|
109
|
+
def __aiter__(self) -> AsyncIterator[None]:
|
|
110
|
+
"""Stream the response as an async iterable, building up the text as it goes.
|
|
111
|
+
|
|
112
|
+
This is an async iterator that yields `None` to avoid doing the work of validating the input and
|
|
113
|
+
extracting the text field when it will often be thrown away.
|
|
114
|
+
"""
|
|
115
|
+
return self
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
async def __anext__(self) -> None:
|
|
119
|
+
"""Process the next chunk of the response, see above for why this returns `None`."""
|
|
120
|
+
raise NotImplementedError()
|
|
121
|
+
|
|
122
|
+
@abstractmethod
|
|
123
|
+
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
124
|
+
"""Returns an iterable of text since the last call to `get()` — e.g. the text delta.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
final: If True, this is the final call, after iteration is complete, the response should be fully validated
|
|
128
|
+
and all text extracted.
|
|
129
|
+
"""
|
|
130
|
+
raise NotImplementedError()
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def cost(self) -> Cost:
|
|
134
|
+
"""Return the cost of the request.
|
|
135
|
+
|
|
136
|
+
NOTE: this won't return the ful cost until the stream is finished.
|
|
137
|
+
"""
|
|
138
|
+
raise NotImplementedError()
|
|
139
|
+
|
|
140
|
+
@abstractmethod
|
|
141
|
+
def timestamp(self) -> datetime:
|
|
142
|
+
"""Get the timestamp of the response."""
|
|
143
|
+
raise NotImplementedError()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class StreamStructuredResponse(ABC):
|
|
147
|
+
"""Streamed response from an LLM when calling a tool."""
|
|
148
|
+
|
|
149
|
+
def __aiter__(self) -> AsyncIterator[None]:
|
|
150
|
+
"""Stream the response as an async iterable, building up the tool call as it goes.
|
|
151
|
+
|
|
152
|
+
This is an async iterator that yields `None` to avoid doing the work of building the final tool call when
|
|
153
|
+
it will often be thrown away.
|
|
154
|
+
"""
|
|
155
|
+
return self
|
|
156
|
+
|
|
157
|
+
@abstractmethod
|
|
158
|
+
async def __anext__(self) -> None:
|
|
159
|
+
"""Process the next chunk of the response, see above for why this returns `None`."""
|
|
160
|
+
raise NotImplementedError()
|
|
161
|
+
|
|
162
|
+
@abstractmethod
|
|
163
|
+
def get(self, *, final: bool = False) -> ModelStructuredResponse:
|
|
164
|
+
"""Get the `ModelStructuredResponse` at this point.
|
|
165
|
+
|
|
166
|
+
The `ModelStructuredResponse` may or may not be complete, depending on whether the stream is finished.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
final: If True, this is the final call, after iteration is complete, the response should be fully validated.
|
|
170
|
+
"""
|
|
171
|
+
raise NotImplementedError()
|
|
172
|
+
|
|
173
|
+
@abstractmethod
|
|
174
|
+
def cost(self) -> Cost:
|
|
175
|
+
"""Get the cost of the request.
|
|
176
|
+
|
|
177
|
+
NOTE: this won't return the full cost until the stream is finished.
|
|
178
|
+
"""
|
|
179
|
+
raise NotImplementedError()
|
|
180
|
+
|
|
181
|
+
@abstractmethod
|
|
182
|
+
def timestamp(self) -> datetime:
|
|
183
|
+
"""Get the timestamp of the response."""
|
|
184
|
+
raise NotImplementedError()
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse]
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
ALLOW_MODEL_REQUESTS = True
|
|
191
|
+
"""Whether to allow requests to models.
|
|
192
|
+
|
|
193
|
+
This global setting allows you to disable request to most models, e.g. to make sure you don't accidentally
|
|
194
|
+
make costly requests to a model during tests.
|
|
195
|
+
|
|
196
|
+
The testing models [`TestModel`][pydantic_ai.models.test.TestModel] and
|
|
197
|
+
[`FunctionModel`][pydantic_ai.models.function.FunctionModel] are no affected by this setting.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def check_allow_model_requests() -> None:
|
|
202
|
+
"""Check if model requests are allowed.
|
|
203
|
+
|
|
204
|
+
If you're defining your own models that have cost or latency associated with their use, you should call this in
|
|
205
|
+
[`Model.agent_model`][pydantic_ai.models.Model.agent_model].
|
|
206
|
+
|
|
207
|
+
Raises:
|
|
208
|
+
RuntimeError: If model requests are not allowed.
|
|
209
|
+
"""
|
|
210
|
+
if not ALLOW_MODEL_REQUESTS:
|
|
211
|
+
raise RuntimeError('Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False')
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@contextmanager
|
|
215
|
+
def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]:
|
|
216
|
+
"""Context manager to temporarily override [`ALLOW_MODEL_REQUESTS`][pydantic_ai.models.ALLOW_MODEL_REQUESTS].
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
allow_model_requests: Whether to allow model requests within the context.
|
|
220
|
+
"""
|
|
221
|
+
global ALLOW_MODEL_REQUESTS
|
|
222
|
+
old_value = ALLOW_MODEL_REQUESTS
|
|
223
|
+
ALLOW_MODEL_REQUESTS = allow_model_requests # pyright: ignore[reportConstantRedefinition]
|
|
224
|
+
try:
|
|
225
|
+
yield
|
|
226
|
+
finally:
|
|
227
|
+
ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition]
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def infer_model(model: Model | KnownModelName) -> Model:
|
|
231
|
+
"""Infer the model from the name."""
|
|
232
|
+
if isinstance(model, Model):
|
|
233
|
+
return model
|
|
234
|
+
elif model == 'test':
|
|
235
|
+
from .test import TestModel
|
|
236
|
+
|
|
237
|
+
return TestModel()
|
|
238
|
+
elif model.startswith('openai:'):
|
|
239
|
+
from .openai import OpenAIModel
|
|
240
|
+
|
|
241
|
+
return OpenAIModel(model[7:]) # pyright: ignore[reportArgumentType]
|
|
242
|
+
elif model.startswith('gemini'):
|
|
243
|
+
from .gemini import GeminiModel
|
|
244
|
+
|
|
245
|
+
# noinspection PyTypeChecker
|
|
246
|
+
return GeminiModel(model) # pyright: ignore[reportArgumentType]
|
|
247
|
+
elif model.startswith('groq:'):
|
|
248
|
+
from .groq import GroqModel
|
|
249
|
+
|
|
250
|
+
return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
|
|
251
|
+
elif model.startswith('vertexai:'):
|
|
252
|
+
from .vertexai import VertexAIModel
|
|
253
|
+
|
|
254
|
+
return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
|
|
255
|
+
else:
|
|
256
|
+
raise UserError(f'Unknown model: {model}')
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class AbstractToolDefinition(Protocol):
|
|
260
|
+
"""Abstract definition of a function/tool.
|
|
261
|
+
|
|
262
|
+
This is used for both retrievers and result tools.
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
name: str
|
|
266
|
+
"""The name of the tool."""
|
|
267
|
+
description: str
|
|
268
|
+
"""The description of the tool."""
|
|
269
|
+
json_schema: ObjectJsonSchema
|
|
270
|
+
"""The JSON schema for the tool's arguments."""
|
|
271
|
+
outer_typed_dict_key: str | None
|
|
272
|
+
"""The key in the outer [TypedDict] that wraps a result tool.
|
|
273
|
+
|
|
274
|
+
This will only be set for result tools which don't have an `object` JSON schema.
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@cache
|
|
279
|
+
def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
280
|
+
"""Cached HTTPX async client so multiple agents and calls can share the same client.
|
|
281
|
+
|
|
282
|
+
There are good reasons why in production you should use a `httpx.AsyncClient` as an async context manager as
|
|
283
|
+
described in [encode/httpx#2026](https://github.com/encode/httpx/pull/2026), but when experimenting or showing
|
|
284
|
+
examples, it's very useful not to, this allows multiple Agents to use a single client.
|
|
285
|
+
|
|
286
|
+
The default timeouts match those of OpenAI,
|
|
287
|
+
see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
|
|
288
|
+
"""
|
|
289
|
+
return httpx.AsyncClient(
|
|
290
|
+
timeout=httpx.Timeout(timeout=timeout, connect=connect),
|
|
291
|
+
headers={'User-Agent': get_user_agent()},
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@cache
|
|
296
|
+
def get_user_agent() -> str:
|
|
297
|
+
"""Get the user agent string for the HTTP client."""
|
|
298
|
+
from .. import __version__
|
|
299
|
+
|
|
300
|
+
return f'pydantic-ai/{__version__}'
|