pydantic-ai-slim 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,268 @@
1
+ """A model controlled by a local function.
2
+
3
+ [FunctionModel][pydantic_ai.models.function.FunctionModel] is similar to [TestModel][pydantic_ai.models.test.TestModel],
4
+ but allows greater control over the model's behavior.
5
+
6
+ It's primary use case for more advanced unit testing than is possible with `TestModel`.
7
+ """
8
+
9
+ from __future__ import annotations as _annotations
10
+
11
+ import inspect
12
+ import re
13
+ from collections.abc import AsyncIterator, Awaitable, Iterable, Mapping, Sequence
14
+ from contextlib import asynccontextmanager
15
+ from dataclasses import dataclass, field
16
+ from datetime import datetime
17
+ from itertools import chain
18
+ from typing import Callable, Union, cast
19
+
20
+ import pydantic_core
21
+ from typing_extensions import TypeAlias, assert_never, overload
22
+
23
+ from .. import _utils, result
24
+ from ..messages import ArgsJson, Message, ModelAnyResponse, ModelStructuredResponse, ToolCall
25
+ from . import (
26
+ AbstractToolDefinition,
27
+ AgentModel,
28
+ EitherStreamedResponse,
29
+ Model,
30
+ StreamStructuredResponse,
31
+ StreamTextResponse,
32
+ )
33
+
34
+
35
+ @dataclass(init=False)
36
+ class FunctionModel(Model):
37
+ """A model controlled by a local function.
38
+
39
+ Apart from `__init__`, all methods are private or match those of the base class.
40
+ """
41
+
42
+ function: FunctionDef | None = None
43
+ stream_function: StreamFunctionDef | None = None
44
+
45
+ @overload
46
+ def __init__(self, function: FunctionDef) -> None: ...
47
+
48
+ @overload
49
+ def __init__(self, *, stream_function: StreamFunctionDef) -> None: ...
50
+
51
+ @overload
52
+ def __init__(self, function: FunctionDef, *, stream_function: StreamFunctionDef) -> None: ...
53
+
54
+ def __init__(self, function: FunctionDef | None = None, *, stream_function: StreamFunctionDef | None = None):
55
+ """Initialize a `FunctionModel`.
56
+
57
+ Either `function` or `stream_function` must be provided, providing both is allowed.
58
+
59
+ Args:
60
+ function: The function to call for non-streamed requests.
61
+ stream_function: The function to call for streamed requests.
62
+ """
63
+ if function is None and stream_function is None:
64
+ raise TypeError('Either `function` or `stream_function` must be provided')
65
+ self.function = function
66
+ self.stream_function = stream_function
67
+
68
+ async def agent_model(
69
+ self,
70
+ retrievers: Mapping[str, AbstractToolDefinition],
71
+ allow_text_result: bool,
72
+ result_tools: Sequence[AbstractToolDefinition] | None,
73
+ ) -> AgentModel:
74
+ result_tools = list(result_tools) if result_tools is not None else None
75
+ return FunctionAgentModel(
76
+ self.function, self.stream_function, AgentInfo(retrievers, allow_text_result, result_tools)
77
+ )
78
+
79
+ def name(self) -> str:
80
+ labels: list[str] = []
81
+ if self.function is not None:
82
+ labels.append(self.function.__name__)
83
+ if self.stream_function is not None:
84
+ labels.append(f'stream-{self.stream_function.__name__}')
85
+ return f'function:{",".join(labels)}'
86
+
87
+
88
+ @dataclass(frozen=True)
89
+ class AgentInfo:
90
+ """Information about an agent.
91
+
92
+ This is passed as the second to functions.
93
+ """
94
+
95
+ retrievers: Mapping[str, AbstractToolDefinition]
96
+ """The retrievers available on this agent."""
97
+ allow_text_result: bool
98
+ """Whether a plain text result is allowed."""
99
+ result_tools: list[AbstractToolDefinition] | None
100
+ """The tools that can called as the final result of the run."""
101
+
102
+
103
+ @dataclass
104
+ class DeltaToolCall:
105
+ """Incremental change to a tool call.
106
+
107
+ Used to describe a chunk when streaming structured responses.
108
+ """
109
+
110
+ name: str | None = None
111
+ """Incremental change to the name of the tool."""
112
+ json_args: str | None = None
113
+ """Incremental change to the arguments as JSON"""
114
+
115
+
116
+ DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
117
+ """A mapping of tool call IDs to incremental changes."""
118
+
119
+ FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[ModelAnyResponse, Awaitable[ModelAnyResponse]]]
120
+ """A function used to generate a non-streamed response."""
121
+
122
+ StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
123
+ """A function used to generate a streamed response.
124
+
125
+ While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
126
+ really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls]`,
127
+
128
+ E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
129
+ """
130
+
131
+
132
+ @dataclass
133
+ class FunctionAgentModel(AgentModel):
134
+ """Implementation of `AgentModel` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
135
+
136
+ function: FunctionDef | None
137
+ stream_function: StreamFunctionDef | None
138
+ agent_info: AgentInfo
139
+
140
+ async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
141
+ assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
142
+ if inspect.iscoroutinefunction(self.function):
143
+ response = await self.function(messages, self.agent_info)
144
+ else:
145
+ response_ = await _utils.run_in_executor(self.function, messages, self.agent_info)
146
+ response = cast(ModelAnyResponse, response_)
147
+ # TODO is `messages` right here? Should it just be new messages?
148
+ return response, _estimate_cost(chain(messages, [response]))
149
+
150
+ @asynccontextmanager
151
+ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
152
+ assert (
153
+ self.stream_function is not None
154
+ ), 'FunctionModel must receive a `stream_function` to support streamed requests'
155
+ response_stream = self.stream_function(messages, self.agent_info)
156
+ try:
157
+ first = await response_stream.__anext__()
158
+ except StopAsyncIteration as e:
159
+ raise ValueError('Stream function must return at least one item') from e
160
+
161
+ if isinstance(first, str):
162
+ text_stream = cast(AsyncIterator[str], response_stream)
163
+ yield FunctionStreamTextResponse(first, text_stream)
164
+ else:
165
+ structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream)
166
+ yield FunctionStreamStructuredResponse(first, structured_stream)
167
+
168
+
169
+ @dataclass
170
+ class FunctionStreamTextResponse(StreamTextResponse):
171
+ """Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
172
+
173
+ _next: str | None
174
+ _iter: AsyncIterator[str]
175
+ _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
176
+ _buffer: list[str] = field(default_factory=list, init=False)
177
+
178
+ async def __anext__(self) -> None:
179
+ if self._next is not None:
180
+ self._buffer.append(self._next)
181
+ self._next = None
182
+ else:
183
+ self._buffer.append(await self._iter.__anext__())
184
+
185
+ def get(self, *, final: bool = False) -> Iterable[str]:
186
+ yield from self._buffer
187
+ self._buffer.clear()
188
+
189
+ def cost(self) -> result.Cost:
190
+ return result.Cost()
191
+
192
+ def timestamp(self) -> datetime:
193
+ return self._timestamp
194
+
195
+
196
+ @dataclass
197
+ class FunctionStreamStructuredResponse(StreamStructuredResponse):
198
+ """Implementation of `StreamStructuredResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
199
+
200
+ _next: DeltaToolCalls | None
201
+ _iter: AsyncIterator[DeltaToolCalls]
202
+ _delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict)
203
+ _timestamp: datetime = field(default_factory=_utils.now_utc)
204
+
205
+ async def __anext__(self) -> None:
206
+ if self._next is not None:
207
+ tool_call = self._next
208
+ self._next = None
209
+ else:
210
+ tool_call = await self._iter.__anext__()
211
+
212
+ for key, new in tool_call.items():
213
+ if current := self._delta_tool_calls.get(key):
214
+ current.name = _utils.add_optional(current.name, new.name)
215
+ current.json_args = _utils.add_optional(current.json_args, new.json_args)
216
+ else:
217
+ self._delta_tool_calls[key] = new
218
+
219
+ def get(self, *, final: bool = False) -> ModelStructuredResponse:
220
+ calls: list[ToolCall] = []
221
+ for c in self._delta_tool_calls.values():
222
+ if c.name is not None and c.json_args is not None:
223
+ calls.append(ToolCall.from_json(c.name, c.json_args))
224
+
225
+ return ModelStructuredResponse(calls, timestamp=self._timestamp)
226
+
227
+ def cost(self) -> result.Cost:
228
+ return result.Cost()
229
+
230
+ def timestamp(self) -> datetime:
231
+ return self._timestamp
232
+
233
+
234
+ def _estimate_cost(messages: Iterable[Message]) -> result.Cost:
235
+ """Very rough guesstimate of the number of tokens associate with a series of messages.
236
+
237
+ This is designed to be used solely to give plausible numbers for testing!
238
+ """
239
+ # there seem to be about 50 tokens of overhead for both Gemini and OpenAI calls, so add that here ¯\_(ツ)_/¯
240
+
241
+ request_tokens = 50
242
+ response_tokens = 0
243
+ for message in messages:
244
+ if message.role == 'system' or message.role == 'user':
245
+ request_tokens += _string_cost(message.content)
246
+ elif message.role == 'tool-return':
247
+ request_tokens += _string_cost(message.model_response_str())
248
+ elif message.role == 'retry-prompt':
249
+ request_tokens += _string_cost(message.model_response())
250
+ elif message.role == 'model-text-response':
251
+ response_tokens += _string_cost(message.content)
252
+ elif message.role == 'model-structured-response':
253
+ for call in message.calls:
254
+ if isinstance(call.args, ArgsJson):
255
+ args_str = call.args.args_json
256
+ else:
257
+ args_str = pydantic_core.to_json(call.args.args_object).decode()
258
+
259
+ response_tokens += 1 + _string_cost(args_str)
260
+ else:
261
+ assert_never(message)
262
+ return result.Cost(
263
+ request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
264
+ )
265
+
266
+
267
+ def _string_cost(content: str) -> int:
268
+ return len(re.split(r'[\s",.:]+', content))