pydantic-ai-slim 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +8 -0
- pydantic_ai/_griffe.py +128 -0
- pydantic_ai/_pydantic.py +216 -0
- pydantic_ai/_result.py +258 -0
- pydantic_ai/_retriever.py +114 -0
- pydantic_ai/_system_prompt.py +33 -0
- pydantic_ai/_utils.py +247 -0
- pydantic_ai/agent.py +795 -0
- pydantic_ai/dependencies.py +83 -0
- pydantic_ai/exceptions.py +56 -0
- pydantic_ai/messages.py +205 -0
- pydantic_ai/models/__init__.py +300 -0
- pydantic_ai/models/function.py +268 -0
- pydantic_ai/models/gemini.py +720 -0
- pydantic_ai/models/groq.py +400 -0
- pydantic_ai/models/openai.py +379 -0
- pydantic_ai/models/test.py +389 -0
- pydantic_ai/models/vertexai.py +306 -0
- pydantic_ai/py.typed +0 -0
- pydantic_ai/result.py +314 -0
- pydantic_ai_slim-0.0.6.dist-info/METADATA +49 -0
- pydantic_ai_slim-0.0.6.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.6.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""A model controlled by a local function.
|
|
2
|
+
|
|
3
|
+
[FunctionModel][pydantic_ai.models.function.FunctionModel] is similar to [TestModel][pydantic_ai.models.test.TestModel],
|
|
4
|
+
but allows greater control over the model's behavior.
|
|
5
|
+
|
|
6
|
+
It's primary use case for more advanced unit testing than is possible with `TestModel`.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations as _annotations
|
|
10
|
+
|
|
11
|
+
import inspect
|
|
12
|
+
import re
|
|
13
|
+
from collections.abc import AsyncIterator, Awaitable, Iterable, Mapping, Sequence
|
|
14
|
+
from contextlib import asynccontextmanager
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from itertools import chain
|
|
18
|
+
from typing import Callable, Union, cast
|
|
19
|
+
|
|
20
|
+
import pydantic_core
|
|
21
|
+
from typing_extensions import TypeAlias, assert_never, overload
|
|
22
|
+
|
|
23
|
+
from .. import _utils, result
|
|
24
|
+
from ..messages import ArgsJson, Message, ModelAnyResponse, ModelStructuredResponse, ToolCall
|
|
25
|
+
from . import (
|
|
26
|
+
AbstractToolDefinition,
|
|
27
|
+
AgentModel,
|
|
28
|
+
EitherStreamedResponse,
|
|
29
|
+
Model,
|
|
30
|
+
StreamStructuredResponse,
|
|
31
|
+
StreamTextResponse,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(init=False)
|
|
36
|
+
class FunctionModel(Model):
|
|
37
|
+
"""A model controlled by a local function.
|
|
38
|
+
|
|
39
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
function: FunctionDef | None = None
|
|
43
|
+
stream_function: StreamFunctionDef | None = None
|
|
44
|
+
|
|
45
|
+
@overload
|
|
46
|
+
def __init__(self, function: FunctionDef) -> None: ...
|
|
47
|
+
|
|
48
|
+
@overload
|
|
49
|
+
def __init__(self, *, stream_function: StreamFunctionDef) -> None: ...
|
|
50
|
+
|
|
51
|
+
@overload
|
|
52
|
+
def __init__(self, function: FunctionDef, *, stream_function: StreamFunctionDef) -> None: ...
|
|
53
|
+
|
|
54
|
+
def __init__(self, function: FunctionDef | None = None, *, stream_function: StreamFunctionDef | None = None):
|
|
55
|
+
"""Initialize a `FunctionModel`.
|
|
56
|
+
|
|
57
|
+
Either `function` or `stream_function` must be provided, providing both is allowed.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
function: The function to call for non-streamed requests.
|
|
61
|
+
stream_function: The function to call for streamed requests.
|
|
62
|
+
"""
|
|
63
|
+
if function is None and stream_function is None:
|
|
64
|
+
raise TypeError('Either `function` or `stream_function` must be provided')
|
|
65
|
+
self.function = function
|
|
66
|
+
self.stream_function = stream_function
|
|
67
|
+
|
|
68
|
+
async def agent_model(
|
|
69
|
+
self,
|
|
70
|
+
retrievers: Mapping[str, AbstractToolDefinition],
|
|
71
|
+
allow_text_result: bool,
|
|
72
|
+
result_tools: Sequence[AbstractToolDefinition] | None,
|
|
73
|
+
) -> AgentModel:
|
|
74
|
+
result_tools = list(result_tools) if result_tools is not None else None
|
|
75
|
+
return FunctionAgentModel(
|
|
76
|
+
self.function, self.stream_function, AgentInfo(retrievers, allow_text_result, result_tools)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def name(self) -> str:
|
|
80
|
+
labels: list[str] = []
|
|
81
|
+
if self.function is not None:
|
|
82
|
+
labels.append(self.function.__name__)
|
|
83
|
+
if self.stream_function is not None:
|
|
84
|
+
labels.append(f'stream-{self.stream_function.__name__}')
|
|
85
|
+
return f'function:{",".join(labels)}'
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass(frozen=True)
|
|
89
|
+
class AgentInfo:
|
|
90
|
+
"""Information about an agent.
|
|
91
|
+
|
|
92
|
+
This is passed as the second to functions.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
retrievers: Mapping[str, AbstractToolDefinition]
|
|
96
|
+
"""The retrievers available on this agent."""
|
|
97
|
+
allow_text_result: bool
|
|
98
|
+
"""Whether a plain text result is allowed."""
|
|
99
|
+
result_tools: list[AbstractToolDefinition] | None
|
|
100
|
+
"""The tools that can called as the final result of the run."""
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dataclass
|
|
104
|
+
class DeltaToolCall:
|
|
105
|
+
"""Incremental change to a tool call.
|
|
106
|
+
|
|
107
|
+
Used to describe a chunk when streaming structured responses.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
name: str | None = None
|
|
111
|
+
"""Incremental change to the name of the tool."""
|
|
112
|
+
json_args: str | None = None
|
|
113
|
+
"""Incremental change to the arguments as JSON"""
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
117
|
+
"""A mapping of tool call IDs to incremental changes."""
|
|
118
|
+
|
|
119
|
+
FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[ModelAnyResponse, Awaitable[ModelAnyResponse]]]
|
|
120
|
+
"""A function used to generate a non-streamed response."""
|
|
121
|
+
|
|
122
|
+
StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
|
|
123
|
+
"""A function used to generate a streamed response.
|
|
124
|
+
|
|
125
|
+
While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
|
|
126
|
+
really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls]`,
|
|
127
|
+
|
|
128
|
+
E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dataclass
|
|
133
|
+
class FunctionAgentModel(AgentModel):
|
|
134
|
+
"""Implementation of `AgentModel` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
135
|
+
|
|
136
|
+
function: FunctionDef | None
|
|
137
|
+
stream_function: StreamFunctionDef | None
|
|
138
|
+
agent_info: AgentInfo
|
|
139
|
+
|
|
140
|
+
async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
|
|
141
|
+
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
142
|
+
if inspect.iscoroutinefunction(self.function):
|
|
143
|
+
response = await self.function(messages, self.agent_info)
|
|
144
|
+
else:
|
|
145
|
+
response_ = await _utils.run_in_executor(self.function, messages, self.agent_info)
|
|
146
|
+
response = cast(ModelAnyResponse, response_)
|
|
147
|
+
# TODO is `messages` right here? Should it just be new messages?
|
|
148
|
+
return response, _estimate_cost(chain(messages, [response]))
|
|
149
|
+
|
|
150
|
+
@asynccontextmanager
|
|
151
|
+
async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
|
|
152
|
+
assert (
|
|
153
|
+
self.stream_function is not None
|
|
154
|
+
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
155
|
+
response_stream = self.stream_function(messages, self.agent_info)
|
|
156
|
+
try:
|
|
157
|
+
first = await response_stream.__anext__()
|
|
158
|
+
except StopAsyncIteration as e:
|
|
159
|
+
raise ValueError('Stream function must return at least one item') from e
|
|
160
|
+
|
|
161
|
+
if isinstance(first, str):
|
|
162
|
+
text_stream = cast(AsyncIterator[str], response_stream)
|
|
163
|
+
yield FunctionStreamTextResponse(first, text_stream)
|
|
164
|
+
else:
|
|
165
|
+
structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream)
|
|
166
|
+
yield FunctionStreamStructuredResponse(first, structured_stream)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@dataclass
|
|
170
|
+
class FunctionStreamTextResponse(StreamTextResponse):
|
|
171
|
+
"""Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
172
|
+
|
|
173
|
+
_next: str | None
|
|
174
|
+
_iter: AsyncIterator[str]
|
|
175
|
+
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
176
|
+
_buffer: list[str] = field(default_factory=list, init=False)
|
|
177
|
+
|
|
178
|
+
async def __anext__(self) -> None:
|
|
179
|
+
if self._next is not None:
|
|
180
|
+
self._buffer.append(self._next)
|
|
181
|
+
self._next = None
|
|
182
|
+
else:
|
|
183
|
+
self._buffer.append(await self._iter.__anext__())
|
|
184
|
+
|
|
185
|
+
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
186
|
+
yield from self._buffer
|
|
187
|
+
self._buffer.clear()
|
|
188
|
+
|
|
189
|
+
def cost(self) -> result.Cost:
|
|
190
|
+
return result.Cost()
|
|
191
|
+
|
|
192
|
+
def timestamp(self) -> datetime:
|
|
193
|
+
return self._timestamp
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@dataclass
|
|
197
|
+
class FunctionStreamStructuredResponse(StreamStructuredResponse):
|
|
198
|
+
"""Implementation of `StreamStructuredResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
199
|
+
|
|
200
|
+
_next: DeltaToolCalls | None
|
|
201
|
+
_iter: AsyncIterator[DeltaToolCalls]
|
|
202
|
+
_delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict)
|
|
203
|
+
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
204
|
+
|
|
205
|
+
async def __anext__(self) -> None:
|
|
206
|
+
if self._next is not None:
|
|
207
|
+
tool_call = self._next
|
|
208
|
+
self._next = None
|
|
209
|
+
else:
|
|
210
|
+
tool_call = await self._iter.__anext__()
|
|
211
|
+
|
|
212
|
+
for key, new in tool_call.items():
|
|
213
|
+
if current := self._delta_tool_calls.get(key):
|
|
214
|
+
current.name = _utils.add_optional(current.name, new.name)
|
|
215
|
+
current.json_args = _utils.add_optional(current.json_args, new.json_args)
|
|
216
|
+
else:
|
|
217
|
+
self._delta_tool_calls[key] = new
|
|
218
|
+
|
|
219
|
+
def get(self, *, final: bool = False) -> ModelStructuredResponse:
|
|
220
|
+
calls: list[ToolCall] = []
|
|
221
|
+
for c in self._delta_tool_calls.values():
|
|
222
|
+
if c.name is not None and c.json_args is not None:
|
|
223
|
+
calls.append(ToolCall.from_json(c.name, c.json_args))
|
|
224
|
+
|
|
225
|
+
return ModelStructuredResponse(calls, timestamp=self._timestamp)
|
|
226
|
+
|
|
227
|
+
def cost(self) -> result.Cost:
|
|
228
|
+
return result.Cost()
|
|
229
|
+
|
|
230
|
+
def timestamp(self) -> datetime:
|
|
231
|
+
return self._timestamp
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _estimate_cost(messages: Iterable[Message]) -> result.Cost:
|
|
235
|
+
"""Very rough guesstimate of the number of tokens associate with a series of messages.
|
|
236
|
+
|
|
237
|
+
This is designed to be used solely to give plausible numbers for testing!
|
|
238
|
+
"""
|
|
239
|
+
# there seem to be about 50 tokens of overhead for both Gemini and OpenAI calls, so add that here ¯\_(ツ)_/¯
|
|
240
|
+
|
|
241
|
+
request_tokens = 50
|
|
242
|
+
response_tokens = 0
|
|
243
|
+
for message in messages:
|
|
244
|
+
if message.role == 'system' or message.role == 'user':
|
|
245
|
+
request_tokens += _string_cost(message.content)
|
|
246
|
+
elif message.role == 'tool-return':
|
|
247
|
+
request_tokens += _string_cost(message.model_response_str())
|
|
248
|
+
elif message.role == 'retry-prompt':
|
|
249
|
+
request_tokens += _string_cost(message.model_response())
|
|
250
|
+
elif message.role == 'model-text-response':
|
|
251
|
+
response_tokens += _string_cost(message.content)
|
|
252
|
+
elif message.role == 'model-structured-response':
|
|
253
|
+
for call in message.calls:
|
|
254
|
+
if isinstance(call.args, ArgsJson):
|
|
255
|
+
args_str = call.args.args_json
|
|
256
|
+
else:
|
|
257
|
+
args_str = pydantic_core.to_json(call.args.args_object).decode()
|
|
258
|
+
|
|
259
|
+
response_tokens += 1 + _string_cost(args_str)
|
|
260
|
+
else:
|
|
261
|
+
assert_never(message)
|
|
262
|
+
return result.Cost(
|
|
263
|
+
request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _string_cost(content: str) -> int:
|
|
268
|
+
return len(re.split(r'[\s",.:]+', content))
|