pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +13 -29
- pydantic_ai/_result.py +52 -38
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +20 -8
- pydantic_ai/agent.py +431 -167
- pydantic_ai/messages.py +90 -48
- pydantic_ai/models/__init__.py +59 -42
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +66 -44
- pydantic_ai/models/gemini.py +160 -117
- pydantic_ai/models/groq.py +125 -108
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +145 -114
- pydantic_ai/models/test.py +109 -77
- pydantic_ai/models/vertexai.py +14 -9
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +140 -45
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.11.dist-info/RECORD +0 -22
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Literal, Union, cast, overload
|
|
7
|
+
|
|
8
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
9
|
+
from typing_extensions import assert_never
|
|
10
|
+
|
|
11
|
+
from .. import result
|
|
12
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
|
+
from ..messages import (
|
|
14
|
+
ArgsDict,
|
|
15
|
+
ModelMessage,
|
|
16
|
+
ModelRequest,
|
|
17
|
+
ModelResponse,
|
|
18
|
+
ModelResponsePart,
|
|
19
|
+
RetryPromptPart,
|
|
20
|
+
SystemPromptPart,
|
|
21
|
+
TextPart,
|
|
22
|
+
ToolCallPart,
|
|
23
|
+
ToolReturnPart,
|
|
24
|
+
UserPromptPart,
|
|
25
|
+
)
|
|
26
|
+
from ..settings import ModelSettings
|
|
27
|
+
from ..tools import ToolDefinition
|
|
28
|
+
from . import (
|
|
29
|
+
AgentModel,
|
|
30
|
+
EitherStreamedResponse,
|
|
31
|
+
Model,
|
|
32
|
+
cached_async_http_client,
|
|
33
|
+
check_allow_model_requests,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
|
|
38
|
+
from anthropic.types import (
|
|
39
|
+
Message as AnthropicMessage,
|
|
40
|
+
MessageParam,
|
|
41
|
+
RawMessageDeltaEvent,
|
|
42
|
+
RawMessageStartEvent,
|
|
43
|
+
RawMessageStreamEvent,
|
|
44
|
+
TextBlock,
|
|
45
|
+
TextBlockParam,
|
|
46
|
+
ToolChoiceParam,
|
|
47
|
+
ToolParam,
|
|
48
|
+
ToolResultBlockParam,
|
|
49
|
+
ToolUseBlock,
|
|
50
|
+
ToolUseBlockParam,
|
|
51
|
+
)
|
|
52
|
+
except ImportError as _import_error:
|
|
53
|
+
raise ImportError(
|
|
54
|
+
'Please install `anthropic` to use the Anthropic model, '
|
|
55
|
+
"you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
|
|
56
|
+
) from _import_error
|
|
57
|
+
|
|
58
|
+
LatestAnthropicModelNames = Literal[
|
|
59
|
+
'claude-3-5-haiku-latest',
|
|
60
|
+
'claude-3-5-sonnet-latest',
|
|
61
|
+
'claude-3-opus-latest',
|
|
62
|
+
]
|
|
63
|
+
"""Latest named Anthropic models."""
|
|
64
|
+
|
|
65
|
+
AnthropicModelName = Union[str, LatestAnthropicModelNames]
|
|
66
|
+
"""Possible Anthropic model names.
|
|
67
|
+
|
|
68
|
+
Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
|
|
69
|
+
allow any name in the type hints.
|
|
70
|
+
Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(init=False)
|
|
75
|
+
class AnthropicModel(Model):
|
|
76
|
+
"""A model that uses the Anthropic API.
|
|
77
|
+
|
|
78
|
+
Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API.
|
|
79
|
+
|
|
80
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
81
|
+
|
|
82
|
+
!!! note
|
|
83
|
+
The `AnthropicModel` class does not yet support streaming responses.
|
|
84
|
+
We anticipate adding support for streaming responses in a near-term future release.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
model_name: AnthropicModelName
|
|
88
|
+
client: AsyncAnthropic = field(repr=False)
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
model_name: AnthropicModelName,
|
|
93
|
+
*,
|
|
94
|
+
api_key: str | None = None,
|
|
95
|
+
anthropic_client: AsyncAnthropic | None = None,
|
|
96
|
+
http_client: AsyncHTTPClient | None = None,
|
|
97
|
+
):
|
|
98
|
+
"""Initialize an Anthropic model.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
model_name: The name of the Anthropic model to use. List of model names available
|
|
102
|
+
[here](https://docs.anthropic.com/en/docs/about-claude/models).
|
|
103
|
+
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
|
|
104
|
+
will be used if available.
|
|
105
|
+
anthropic_client: An existing
|
|
106
|
+
[`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
|
|
107
|
+
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
108
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
109
|
+
"""
|
|
110
|
+
self.model_name = model_name
|
|
111
|
+
if anthropic_client is not None:
|
|
112
|
+
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
|
|
113
|
+
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
|
|
114
|
+
self.client = anthropic_client
|
|
115
|
+
elif http_client is not None:
|
|
116
|
+
self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
|
|
117
|
+
else:
|
|
118
|
+
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
|
|
119
|
+
|
|
120
|
+
async def agent_model(
|
|
121
|
+
self,
|
|
122
|
+
*,
|
|
123
|
+
function_tools: list[ToolDefinition],
|
|
124
|
+
allow_text_result: bool,
|
|
125
|
+
result_tools: list[ToolDefinition],
|
|
126
|
+
) -> AgentModel:
|
|
127
|
+
check_allow_model_requests()
|
|
128
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
129
|
+
if result_tools:
|
|
130
|
+
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
131
|
+
return AnthropicAgentModel(
|
|
132
|
+
self.client,
|
|
133
|
+
self.model_name,
|
|
134
|
+
allow_text_result,
|
|
135
|
+
tools,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def name(self) -> str:
|
|
139
|
+
return self.model_name
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
143
|
+
return {
|
|
144
|
+
'name': f.name,
|
|
145
|
+
'description': f.description,
|
|
146
|
+
'input_schema': f.parameters_json_schema,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclass
|
|
151
|
+
class AnthropicAgentModel(AgentModel):
|
|
152
|
+
"""Implementation of `AgentModel` for Anthropic models."""
|
|
153
|
+
|
|
154
|
+
client: AsyncAnthropic
|
|
155
|
+
model_name: str
|
|
156
|
+
allow_text_result: bool
|
|
157
|
+
tools: list[ToolParam]
|
|
158
|
+
|
|
159
|
+
async def request(
|
|
160
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
+
) -> tuple[ModelResponse, result.Cost]:
|
|
162
|
+
response = await self._messages_create(messages, False, model_settings)
|
|
163
|
+
return self._process_response(response), _map_cost(response)
|
|
164
|
+
|
|
165
|
+
@asynccontextmanager
|
|
166
|
+
async def request_stream(
|
|
167
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
169
|
+
response = await self._messages_create(messages, True, model_settings)
|
|
170
|
+
async with response:
|
|
171
|
+
yield await self._process_streamed_response(response)
|
|
172
|
+
|
|
173
|
+
@overload
|
|
174
|
+
async def _messages_create(
|
|
175
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
|
|
176
|
+
) -> AsyncStream[RawMessageStreamEvent]:
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
@overload
|
|
180
|
+
async def _messages_create(
|
|
181
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
|
|
182
|
+
) -> AnthropicMessage:
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
async def _messages_create(
|
|
186
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
187
|
+
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
|
|
188
|
+
# standalone function to make it easier to override
|
|
189
|
+
if not self.tools:
|
|
190
|
+
tool_choice: ToolChoiceParam | None = None
|
|
191
|
+
elif not self.allow_text_result:
|
|
192
|
+
tool_choice = {'type': 'any'}
|
|
193
|
+
else:
|
|
194
|
+
tool_choice = {'type': 'auto'}
|
|
195
|
+
|
|
196
|
+
system_prompt, anthropic_messages = self._map_message(messages)
|
|
197
|
+
|
|
198
|
+
model_settings = model_settings or {}
|
|
199
|
+
|
|
200
|
+
return await self.client.messages.create(
|
|
201
|
+
max_tokens=model_settings.get('max_tokens', 1024),
|
|
202
|
+
system=system_prompt or NOT_GIVEN,
|
|
203
|
+
messages=anthropic_messages,
|
|
204
|
+
model=self.model_name,
|
|
205
|
+
tools=self.tools or NOT_GIVEN,
|
|
206
|
+
tool_choice=tool_choice or NOT_GIVEN,
|
|
207
|
+
stream=stream,
|
|
208
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
209
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
210
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def _process_response(response: AnthropicMessage) -> ModelResponse:
|
|
215
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
216
|
+
items: list[ModelResponsePart] = []
|
|
217
|
+
for item in response.content:
|
|
218
|
+
if isinstance(item, TextBlock):
|
|
219
|
+
items.append(TextPart(item.text))
|
|
220
|
+
else:
|
|
221
|
+
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
222
|
+
items.append(
|
|
223
|
+
ToolCallPart.from_dict(
|
|
224
|
+
item.name,
|
|
225
|
+
cast(dict[str, Any], item.input),
|
|
226
|
+
item.id,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return ModelResponse(items)
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> EitherStreamedResponse:
|
|
234
|
+
"""TODO: Process a streamed response, and prepare a streaming response to return."""
|
|
235
|
+
# We don't yet support streamed responses from Anthropic, so we raise an error here for now.
|
|
236
|
+
# Streamed responses will be supported in a future release.
|
|
237
|
+
|
|
238
|
+
raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
|
|
239
|
+
|
|
240
|
+
# Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamStructuredResponse
|
|
241
|
+
# depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
|
|
242
|
+
# RawMessageStartEvent
|
|
243
|
+
# RawMessageDeltaEvent
|
|
244
|
+
# RawMessageStopEvent
|
|
245
|
+
# RawContentBlockStartEvent
|
|
246
|
+
# RawContentBlockDeltaEvent
|
|
247
|
+
# RawContentBlockDeltaEvent
|
|
248
|
+
#
|
|
249
|
+
# We might refactor streaming internally before we implement this...
|
|
250
|
+
|
|
251
|
+
@staticmethod
|
|
252
|
+
def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
253
|
+
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
254
|
+
system_prompt: str = ''
|
|
255
|
+
anthropic_messages: list[MessageParam] = []
|
|
256
|
+
for m in messages:
|
|
257
|
+
if isinstance(m, ModelRequest):
|
|
258
|
+
for part in m.parts:
|
|
259
|
+
if isinstance(part, SystemPromptPart):
|
|
260
|
+
system_prompt += part.content
|
|
261
|
+
elif isinstance(part, UserPromptPart):
|
|
262
|
+
anthropic_messages.append(MessageParam(role='user', content=part.content))
|
|
263
|
+
elif isinstance(part, ToolReturnPart):
|
|
264
|
+
anthropic_messages.append(
|
|
265
|
+
MessageParam(
|
|
266
|
+
role='user',
|
|
267
|
+
content=[
|
|
268
|
+
ToolResultBlockParam(
|
|
269
|
+
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
270
|
+
type='tool_result',
|
|
271
|
+
content=part.model_response_str(),
|
|
272
|
+
is_error=False,
|
|
273
|
+
)
|
|
274
|
+
],
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
elif isinstance(part, RetryPromptPart):
|
|
278
|
+
if part.tool_name is None:
|
|
279
|
+
anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
|
|
280
|
+
else:
|
|
281
|
+
anthropic_messages.append(
|
|
282
|
+
MessageParam(
|
|
283
|
+
role='user',
|
|
284
|
+
content=[
|
|
285
|
+
ToolUseBlockParam(
|
|
286
|
+
id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
287
|
+
input=part.model_response(),
|
|
288
|
+
name=part.tool_name,
|
|
289
|
+
type='tool_use',
|
|
290
|
+
),
|
|
291
|
+
],
|
|
292
|
+
)
|
|
293
|
+
)
|
|
294
|
+
elif isinstance(m, ModelResponse):
|
|
295
|
+
content: list[TextBlockParam | ToolUseBlockParam] = []
|
|
296
|
+
for item in m.parts:
|
|
297
|
+
if isinstance(item, TextPart):
|
|
298
|
+
content.append(TextBlockParam(text=item.content, type='text'))
|
|
299
|
+
else:
|
|
300
|
+
assert isinstance(item, ToolCallPart)
|
|
301
|
+
content.append(_map_tool_call(item))
|
|
302
|
+
anthropic_messages.append(MessageParam(role='assistant', content=content))
|
|
303
|
+
else:
|
|
304
|
+
assert_never(m)
|
|
305
|
+
return system_prompt, anthropic_messages
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
309
|
+
assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
|
|
310
|
+
return ToolUseBlockParam(
|
|
311
|
+
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
312
|
+
type='tool_use',
|
|
313
|
+
name=t.tool_name,
|
|
314
|
+
input=t.args.args_dict,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost:
|
|
319
|
+
if isinstance(message, AnthropicMessage):
|
|
320
|
+
usage = message.usage
|
|
321
|
+
else:
|
|
322
|
+
if isinstance(message, RawMessageStartEvent):
|
|
323
|
+
usage = message.message.usage
|
|
324
|
+
elif isinstance(message, RawMessageDeltaEvent):
|
|
325
|
+
usage = message.usage
|
|
326
|
+
else:
|
|
327
|
+
# No usage information provided in:
|
|
328
|
+
# - RawMessageStopEvent
|
|
329
|
+
# - RawContentBlockStartEvent
|
|
330
|
+
# - RawContentBlockDeltaEvent
|
|
331
|
+
# - RawContentBlockStopEvent
|
|
332
|
+
usage = None
|
|
333
|
+
|
|
334
|
+
if usage is None:
|
|
335
|
+
return result.Cost()
|
|
336
|
+
|
|
337
|
+
request_tokens = getattr(usage, 'input_tokens', None)
|
|
338
|
+
|
|
339
|
+
return result.Cost(
|
|
340
|
+
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
|
|
341
|
+
request_tokens=request_tokens,
|
|
342
|
+
response_tokens=usage.output_tokens,
|
|
343
|
+
total_tokens=(request_tokens or 0) + usage.output_tokens,
|
|
344
|
+
)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -2,9 +2,9 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import dataclass, field, replace
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
10
|
from typing import Callable, Union, cast
|
|
@@ -13,15 +13,22 @@ import pydantic_core
|
|
|
13
13
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
14
14
|
|
|
15
15
|
from .. import _utils, result
|
|
16
|
-
from ..messages import
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
16
|
+
from ..messages import (
|
|
17
|
+
ArgsJson,
|
|
18
|
+
ModelMessage,
|
|
19
|
+
ModelRequest,
|
|
20
|
+
ModelResponse,
|
|
21
|
+
ModelResponsePart,
|
|
22
|
+
RetryPromptPart,
|
|
23
|
+
SystemPromptPart,
|
|
24
|
+
TextPart,
|
|
25
|
+
ToolCallPart,
|
|
26
|
+
ToolReturnPart,
|
|
27
|
+
UserPromptPart,
|
|
24
28
|
)
|
|
29
|
+
from ..settings import ModelSettings
|
|
30
|
+
from ..tools import ToolDefinition
|
|
31
|
+
from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
|
|
25
32
|
|
|
26
33
|
|
|
27
34
|
@dataclass(init=False)
|
|
@@ -59,13 +66,13 @@ class FunctionModel(Model):
|
|
|
59
66
|
|
|
60
67
|
async def agent_model(
|
|
61
68
|
self,
|
|
62
|
-
|
|
69
|
+
*,
|
|
70
|
+
function_tools: list[ToolDefinition],
|
|
63
71
|
allow_text_result: bool,
|
|
64
|
-
result_tools:
|
|
72
|
+
result_tools: list[ToolDefinition],
|
|
65
73
|
) -> AgentModel:
|
|
66
|
-
result_tools = list(result_tools) if result_tools is not None else None
|
|
67
74
|
return FunctionAgentModel(
|
|
68
|
-
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
|
|
75
|
+
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
|
|
69
76
|
)
|
|
70
77
|
|
|
71
78
|
def name(self) -> str:
|
|
@@ -84,7 +91,7 @@ class AgentInfo:
|
|
|
84
91
|
This is passed as the second to functions used within [`FunctionModel`][pydantic_ai.models.function.FunctionModel].
|
|
85
92
|
"""
|
|
86
93
|
|
|
87
|
-
function_tools:
|
|
94
|
+
function_tools: list[ToolDefinition]
|
|
88
95
|
"""The function tools available on this agent.
|
|
89
96
|
|
|
90
97
|
These are the tools registered via the [`tool`][pydantic_ai.Agent.tool] and
|
|
@@ -92,8 +99,10 @@ class AgentInfo:
|
|
|
92
99
|
"""
|
|
93
100
|
allow_text_result: bool
|
|
94
101
|
"""Whether a plain text result is allowed."""
|
|
95
|
-
result_tools: list[
|
|
102
|
+
result_tools: list[ToolDefinition]
|
|
96
103
|
"""The tools that can called as the final result of the run."""
|
|
104
|
+
model_settings: ModelSettings | None
|
|
105
|
+
"""The model settings passed to the run call."""
|
|
97
106
|
|
|
98
107
|
|
|
99
108
|
@dataclass
|
|
@@ -112,10 +121,10 @@ class DeltaToolCall:
|
|
|
112
121
|
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
113
122
|
"""A mapping of tool call IDs to incremental changes."""
|
|
114
123
|
|
|
115
|
-
FunctionDef: TypeAlias = Callable[[list[
|
|
124
|
+
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
|
|
116
125
|
"""A function used to generate a non-streamed response."""
|
|
117
126
|
|
|
118
|
-
StreamFunctionDef: TypeAlias = Callable[[list[
|
|
127
|
+
StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
|
|
119
128
|
"""A function used to generate a streamed response.
|
|
120
129
|
|
|
121
130
|
While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
|
|
@@ -133,18 +142,25 @@ class FunctionAgentModel(AgentModel):
|
|
|
133
142
|
stream_function: StreamFunctionDef | None
|
|
134
143
|
agent_info: AgentInfo
|
|
135
144
|
|
|
136
|
-
async def request(
|
|
145
|
+
async def request(
|
|
146
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
147
|
+
) -> tuple[ModelResponse, result.Cost]:
|
|
148
|
+
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
149
|
+
|
|
137
150
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
138
151
|
if inspect.iscoroutinefunction(self.function):
|
|
139
|
-
response = await self.function(messages,
|
|
152
|
+
response = await self.function(messages, agent_info)
|
|
140
153
|
else:
|
|
141
|
-
response_ = await _utils.run_in_executor(self.function, messages,
|
|
142
|
-
|
|
154
|
+
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
155
|
+
assert isinstance(response_, ModelResponse), response_
|
|
156
|
+
response = response_
|
|
143
157
|
# TODO is `messages` right here? Should it just be new messages?
|
|
144
158
|
return response, _estimate_cost(chain(messages, [response]))
|
|
145
159
|
|
|
146
160
|
@asynccontextmanager
|
|
147
|
-
async def request_stream(
|
|
161
|
+
async def request_stream(
|
|
162
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
163
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
148
164
|
assert (
|
|
149
165
|
self.stream_function is not None
|
|
150
166
|
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
@@ -212,13 +228,13 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
|
|
|
212
228
|
else:
|
|
213
229
|
self._delta_tool_calls[key] = new
|
|
214
230
|
|
|
215
|
-
def get(self, *, final: bool = False) ->
|
|
216
|
-
calls: list[
|
|
231
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
232
|
+
calls: list[ModelResponsePart] = []
|
|
217
233
|
for c in self._delta_tool_calls.values():
|
|
218
234
|
if c.name is not None and c.json_args is not None:
|
|
219
|
-
calls.append(
|
|
235
|
+
calls.append(ToolCallPart.from_json(c.name, c.json_args))
|
|
220
236
|
|
|
221
|
-
return
|
|
237
|
+
return ModelResponse(calls, timestamp=self._timestamp)
|
|
222
238
|
|
|
223
239
|
def cost(self) -> result.Cost:
|
|
224
240
|
return result.Cost()
|
|
@@ -227,32 +243,38 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
|
|
|
227
243
|
return self._timestamp
|
|
228
244
|
|
|
229
245
|
|
|
230
|
-
def _estimate_cost(messages: Iterable[
|
|
246
|
+
def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
|
|
231
247
|
"""Very rough guesstimate of the number of tokens associate with a series of messages.
|
|
232
248
|
|
|
233
249
|
This is designed to be used solely to give plausible numbers for testing!
|
|
234
250
|
"""
|
|
235
251
|
# there seem to be about 50 tokens of overhead for both Gemini and OpenAI calls, so add that here ¯\_(ツ)_/¯
|
|
236
|
-
|
|
237
252
|
request_tokens = 50
|
|
238
253
|
response_tokens = 0
|
|
239
254
|
for message in messages:
|
|
240
|
-
if message
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
elif message.role == 'model-structured-response':
|
|
249
|
-
for call in message.calls:
|
|
250
|
-
if isinstance(call.args, ArgsJson):
|
|
251
|
-
args_str = call.args.args_json
|
|
255
|
+
if isinstance(message, ModelRequest):
|
|
256
|
+
for part in message.parts:
|
|
257
|
+
if isinstance(part, (SystemPromptPart, UserPromptPart)):
|
|
258
|
+
request_tokens += _string_cost(part.content)
|
|
259
|
+
elif isinstance(part, ToolReturnPart):
|
|
260
|
+
request_tokens += _string_cost(part.model_response_str())
|
|
261
|
+
elif isinstance(part, RetryPromptPart):
|
|
262
|
+
request_tokens += _string_cost(part.model_response())
|
|
252
263
|
else:
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
264
|
+
assert_never(part)
|
|
265
|
+
elif isinstance(message, ModelResponse):
|
|
266
|
+
for part in message.parts:
|
|
267
|
+
if isinstance(part, TextPart):
|
|
268
|
+
response_tokens += _string_cost(part.content)
|
|
269
|
+
elif isinstance(part, ToolCallPart):
|
|
270
|
+
call = part
|
|
271
|
+
if isinstance(call.args, ArgsJson):
|
|
272
|
+
args_str = call.args.args_json
|
|
273
|
+
else:
|
|
274
|
+
args_str = pydantic_core.to_json(call.args.args_dict).decode()
|
|
275
|
+
response_tokens += 1 + _string_cost(args_str)
|
|
276
|
+
else:
|
|
277
|
+
assert_never(part)
|
|
256
278
|
else:
|
|
257
279
|
assert_never(message)
|
|
258
280
|
return result.Cost(
|