pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +33 -18
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +366 -171
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +111 -50
- pydantic_ai/models/__init__.py +39 -14
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +62 -40
- pydantic_ai/models/gemini.py +164 -124
- pydantic_ai/models/groq.py +112 -94
- pydantic_ai/models/mistral.py +668 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +120 -96
- pydantic_ai/models/test.py +78 -61
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +96 -68
- pydantic_ai/settings.py +137 -0
- pydantic_ai/tools.py +46 -26
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Literal, Union, cast, overload
|
|
7
|
+
|
|
8
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
9
|
+
from typing_extensions import assert_never
|
|
10
|
+
|
|
11
|
+
from .. import result
|
|
12
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
|
+
from ..messages import (
|
|
14
|
+
ArgsDict,
|
|
15
|
+
ModelMessage,
|
|
16
|
+
ModelRequest,
|
|
17
|
+
ModelResponse,
|
|
18
|
+
ModelResponsePart,
|
|
19
|
+
RetryPromptPart,
|
|
20
|
+
SystemPromptPart,
|
|
21
|
+
TextPart,
|
|
22
|
+
ToolCallPart,
|
|
23
|
+
ToolReturnPart,
|
|
24
|
+
UserPromptPart,
|
|
25
|
+
)
|
|
26
|
+
from ..settings import ModelSettings
|
|
27
|
+
from ..tools import ToolDefinition
|
|
28
|
+
from . import (
|
|
29
|
+
AgentModel,
|
|
30
|
+
EitherStreamedResponse,
|
|
31
|
+
Model,
|
|
32
|
+
cached_async_http_client,
|
|
33
|
+
check_allow_model_requests,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
|
|
38
|
+
from anthropic.types import (
|
|
39
|
+
Message as AnthropicMessage,
|
|
40
|
+
MessageParam,
|
|
41
|
+
RawMessageDeltaEvent,
|
|
42
|
+
RawMessageStartEvent,
|
|
43
|
+
RawMessageStreamEvent,
|
|
44
|
+
TextBlock,
|
|
45
|
+
TextBlockParam,
|
|
46
|
+
ToolChoiceParam,
|
|
47
|
+
ToolParam,
|
|
48
|
+
ToolResultBlockParam,
|
|
49
|
+
ToolUseBlock,
|
|
50
|
+
ToolUseBlockParam,
|
|
51
|
+
)
|
|
52
|
+
except ImportError as _import_error:
|
|
53
|
+
raise ImportError(
|
|
54
|
+
'Please install `anthropic` to use the Anthropic model, '
|
|
55
|
+
"you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
|
|
56
|
+
) from _import_error
|
|
57
|
+
|
|
58
|
+
LatestAnthropicModelNames = Literal[
|
|
59
|
+
'claude-3-5-haiku-latest',
|
|
60
|
+
'claude-3-5-sonnet-latest',
|
|
61
|
+
'claude-3-opus-latest',
|
|
62
|
+
]
|
|
63
|
+
"""Latest named Anthropic models."""
|
|
64
|
+
|
|
65
|
+
AnthropicModelName = Union[str, LatestAnthropicModelNames]
|
|
66
|
+
"""Possible Anthropic model names.
|
|
67
|
+
|
|
68
|
+
Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
|
|
69
|
+
allow any name in the type hints.
|
|
70
|
+
Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(init=False)
|
|
75
|
+
class AnthropicModel(Model):
|
|
76
|
+
"""A model that uses the Anthropic API.
|
|
77
|
+
|
|
78
|
+
Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API.
|
|
79
|
+
|
|
80
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
81
|
+
|
|
82
|
+
!!! note
|
|
83
|
+
The `AnthropicModel` class does not yet support streaming responses.
|
|
84
|
+
We anticipate adding support for streaming responses in a near-term future release.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
model_name: AnthropicModelName
|
|
88
|
+
client: AsyncAnthropic = field(repr=False)
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
model_name: AnthropicModelName,
|
|
93
|
+
*,
|
|
94
|
+
api_key: str | None = None,
|
|
95
|
+
anthropic_client: AsyncAnthropic | None = None,
|
|
96
|
+
http_client: AsyncHTTPClient | None = None,
|
|
97
|
+
):
|
|
98
|
+
"""Initialize an Anthropic model.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
model_name: The name of the Anthropic model to use. List of model names available
|
|
102
|
+
[here](https://docs.anthropic.com/en/docs/about-claude/models).
|
|
103
|
+
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
|
|
104
|
+
will be used if available.
|
|
105
|
+
anthropic_client: An existing
|
|
106
|
+
[`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
|
|
107
|
+
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
108
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
109
|
+
"""
|
|
110
|
+
self.model_name = model_name
|
|
111
|
+
if anthropic_client is not None:
|
|
112
|
+
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
|
|
113
|
+
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
|
|
114
|
+
self.client = anthropic_client
|
|
115
|
+
elif http_client is not None:
|
|
116
|
+
self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
|
|
117
|
+
else:
|
|
118
|
+
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
|
|
119
|
+
|
|
120
|
+
async def agent_model(
|
|
121
|
+
self,
|
|
122
|
+
*,
|
|
123
|
+
function_tools: list[ToolDefinition],
|
|
124
|
+
allow_text_result: bool,
|
|
125
|
+
result_tools: list[ToolDefinition],
|
|
126
|
+
) -> AgentModel:
|
|
127
|
+
check_allow_model_requests()
|
|
128
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
129
|
+
if result_tools:
|
|
130
|
+
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
131
|
+
return AnthropicAgentModel(
|
|
132
|
+
self.client,
|
|
133
|
+
self.model_name,
|
|
134
|
+
allow_text_result,
|
|
135
|
+
tools,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def name(self) -> str:
|
|
139
|
+
return self.model_name
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
143
|
+
return {
|
|
144
|
+
'name': f.name,
|
|
145
|
+
'description': f.description,
|
|
146
|
+
'input_schema': f.parameters_json_schema,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclass
|
|
151
|
+
class AnthropicAgentModel(AgentModel):
|
|
152
|
+
"""Implementation of `AgentModel` for Anthropic models."""
|
|
153
|
+
|
|
154
|
+
client: AsyncAnthropic
|
|
155
|
+
model_name: str
|
|
156
|
+
allow_text_result: bool
|
|
157
|
+
tools: list[ToolParam]
|
|
158
|
+
|
|
159
|
+
async def request(
|
|
160
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
162
|
+
response = await self._messages_create(messages, False, model_settings)
|
|
163
|
+
return self._process_response(response), _map_usage(response)
|
|
164
|
+
|
|
165
|
+
@asynccontextmanager
|
|
166
|
+
async def request_stream(
|
|
167
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
169
|
+
response = await self._messages_create(messages, True, model_settings)
|
|
170
|
+
async with response:
|
|
171
|
+
yield await self._process_streamed_response(response)
|
|
172
|
+
|
|
173
|
+
@overload
|
|
174
|
+
async def _messages_create(
|
|
175
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
|
|
176
|
+
) -> AsyncStream[RawMessageStreamEvent]:
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
@overload
|
|
180
|
+
async def _messages_create(
|
|
181
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
|
|
182
|
+
) -> AnthropicMessage:
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
async def _messages_create(
|
|
186
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
187
|
+
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
|
|
188
|
+
# standalone function to make it easier to override
|
|
189
|
+
if not self.tools:
|
|
190
|
+
tool_choice: ToolChoiceParam | None = None
|
|
191
|
+
elif not self.allow_text_result:
|
|
192
|
+
tool_choice = {'type': 'any'}
|
|
193
|
+
else:
|
|
194
|
+
tool_choice = {'type': 'auto'}
|
|
195
|
+
|
|
196
|
+
system_prompt, anthropic_messages = self._map_message(messages)
|
|
197
|
+
|
|
198
|
+
model_settings = model_settings or {}
|
|
199
|
+
|
|
200
|
+
return await self.client.messages.create(
|
|
201
|
+
max_tokens=model_settings.get('max_tokens', 1024),
|
|
202
|
+
system=system_prompt or NOT_GIVEN,
|
|
203
|
+
messages=anthropic_messages,
|
|
204
|
+
model=self.model_name,
|
|
205
|
+
tools=self.tools or NOT_GIVEN,
|
|
206
|
+
tool_choice=tool_choice or NOT_GIVEN,
|
|
207
|
+
stream=stream,
|
|
208
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
209
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
210
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def _process_response(response: AnthropicMessage) -> ModelResponse:
|
|
215
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
216
|
+
items: list[ModelResponsePart] = []
|
|
217
|
+
for item in response.content:
|
|
218
|
+
if isinstance(item, TextBlock):
|
|
219
|
+
items.append(TextPart(item.text))
|
|
220
|
+
else:
|
|
221
|
+
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
222
|
+
items.append(
|
|
223
|
+
ToolCallPart.from_raw_args(
|
|
224
|
+
item.name,
|
|
225
|
+
cast(dict[str, Any], item.input),
|
|
226
|
+
item.id,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return ModelResponse(items)
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> EitherStreamedResponse:
|
|
234
|
+
"""TODO: Process a streamed response, and prepare a streaming response to return."""
|
|
235
|
+
# We don't yet support streamed responses from Anthropic, so we raise an error here for now.
|
|
236
|
+
# Streamed responses will be supported in a future release.
|
|
237
|
+
|
|
238
|
+
raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
|
|
239
|
+
|
|
240
|
+
# Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamStructuredResponse
|
|
241
|
+
# depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
|
|
242
|
+
# RawMessageStartEvent
|
|
243
|
+
# RawMessageDeltaEvent
|
|
244
|
+
# RawMessageStopEvent
|
|
245
|
+
# RawContentBlockStartEvent
|
|
246
|
+
# RawContentBlockDeltaEvent
|
|
247
|
+
# RawContentBlockDeltaEvent
|
|
248
|
+
#
|
|
249
|
+
# We might refactor streaming internally before we implement this...
|
|
250
|
+
|
|
251
|
+
@staticmethod
|
|
252
|
+
def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
253
|
+
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
254
|
+
system_prompt: str = ''
|
|
255
|
+
anthropic_messages: list[MessageParam] = []
|
|
256
|
+
for m in messages:
|
|
257
|
+
if isinstance(m, ModelRequest):
|
|
258
|
+
for part in m.parts:
|
|
259
|
+
if isinstance(part, SystemPromptPart):
|
|
260
|
+
system_prompt += part.content
|
|
261
|
+
elif isinstance(part, UserPromptPart):
|
|
262
|
+
anthropic_messages.append(MessageParam(role='user', content=part.content))
|
|
263
|
+
elif isinstance(part, ToolReturnPart):
|
|
264
|
+
anthropic_messages.append(
|
|
265
|
+
MessageParam(
|
|
266
|
+
role='user',
|
|
267
|
+
content=[
|
|
268
|
+
ToolResultBlockParam(
|
|
269
|
+
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
270
|
+
type='tool_result',
|
|
271
|
+
content=part.model_response_str(),
|
|
272
|
+
is_error=False,
|
|
273
|
+
)
|
|
274
|
+
],
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
elif isinstance(part, RetryPromptPart):
|
|
278
|
+
if part.tool_name is None:
|
|
279
|
+
anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
|
|
280
|
+
else:
|
|
281
|
+
anthropic_messages.append(
|
|
282
|
+
MessageParam(
|
|
283
|
+
role='user',
|
|
284
|
+
content=[
|
|
285
|
+
ToolResultBlockParam(
|
|
286
|
+
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
287
|
+
type='tool_result',
|
|
288
|
+
content=part.model_response(),
|
|
289
|
+
is_error=True,
|
|
290
|
+
),
|
|
291
|
+
],
|
|
292
|
+
)
|
|
293
|
+
)
|
|
294
|
+
elif isinstance(m, ModelResponse):
|
|
295
|
+
content: list[TextBlockParam | ToolUseBlockParam] = []
|
|
296
|
+
for item in m.parts:
|
|
297
|
+
if isinstance(item, TextPart):
|
|
298
|
+
content.append(TextBlockParam(text=item.content, type='text'))
|
|
299
|
+
else:
|
|
300
|
+
assert isinstance(item, ToolCallPart)
|
|
301
|
+
content.append(_map_tool_call(item))
|
|
302
|
+
anthropic_messages.append(MessageParam(role='assistant', content=content))
|
|
303
|
+
else:
|
|
304
|
+
assert_never(m)
|
|
305
|
+
return system_prompt, anthropic_messages
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
309
|
+
assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
|
|
310
|
+
return ToolUseBlockParam(
|
|
311
|
+
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
312
|
+
type='tool_use',
|
|
313
|
+
name=t.tool_name,
|
|
314
|
+
input=t.args_as_dict(),
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage:
|
|
319
|
+
if isinstance(message, AnthropicMessage):
|
|
320
|
+
usage = message.usage
|
|
321
|
+
else:
|
|
322
|
+
if isinstance(message, RawMessageStartEvent):
|
|
323
|
+
usage = message.message.usage
|
|
324
|
+
elif isinstance(message, RawMessageDeltaEvent):
|
|
325
|
+
usage = message.usage
|
|
326
|
+
else:
|
|
327
|
+
# No usage information provided in:
|
|
328
|
+
# - RawMessageStopEvent
|
|
329
|
+
# - RawContentBlockStartEvent
|
|
330
|
+
# - RawContentBlockDeltaEvent
|
|
331
|
+
# - RawContentBlockStopEvent
|
|
332
|
+
usage = None
|
|
333
|
+
|
|
334
|
+
if usage is None:
|
|
335
|
+
return result.Usage()
|
|
336
|
+
|
|
337
|
+
request_tokens = getattr(usage, 'input_tokens', None)
|
|
338
|
+
|
|
339
|
+
return result.Usage(
|
|
340
|
+
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
|
|
341
|
+
request_tokens=request_tokens,
|
|
342
|
+
response_tokens=usage.output_tokens,
|
|
343
|
+
total_tokens=(request_tokens or 0) + usage.output_tokens,
|
|
344
|
+
)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -4,16 +4,27 @@ import inspect
|
|
|
4
4
|
import re
|
|
5
5
|
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import dataclass, field, replace
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
10
|
from typing import Callable, Union, cast
|
|
11
11
|
|
|
12
|
-
import pydantic_core
|
|
13
12
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
14
13
|
|
|
15
14
|
from .. import _utils, result
|
|
16
|
-
from ..messages import
|
|
15
|
+
from ..messages import (
|
|
16
|
+
ModelMessage,
|
|
17
|
+
ModelRequest,
|
|
18
|
+
ModelResponse,
|
|
19
|
+
ModelResponsePart,
|
|
20
|
+
RetryPromptPart,
|
|
21
|
+
SystemPromptPart,
|
|
22
|
+
TextPart,
|
|
23
|
+
ToolCallPart,
|
|
24
|
+
ToolReturnPart,
|
|
25
|
+
UserPromptPart,
|
|
26
|
+
)
|
|
27
|
+
from ..settings import ModelSettings
|
|
17
28
|
from ..tools import ToolDefinition
|
|
18
29
|
from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
|
|
19
30
|
|
|
@@ -59,7 +70,7 @@ class FunctionModel(Model):
|
|
|
59
70
|
result_tools: list[ToolDefinition],
|
|
60
71
|
) -> AgentModel:
|
|
61
72
|
return FunctionAgentModel(
|
|
62
|
-
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
|
|
73
|
+
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
|
|
63
74
|
)
|
|
64
75
|
|
|
65
76
|
def name(self) -> str:
|
|
@@ -88,6 +99,8 @@ class AgentInfo:
|
|
|
88
99
|
"""Whether a plain text result is allowed."""
|
|
89
100
|
result_tools: list[ToolDefinition]
|
|
90
101
|
"""The tools that can called as the final result of the run."""
|
|
102
|
+
model_settings: ModelSettings | None
|
|
103
|
+
"""The model settings passed to the run call."""
|
|
91
104
|
|
|
92
105
|
|
|
93
106
|
@dataclass
|
|
@@ -106,10 +119,10 @@ class DeltaToolCall:
|
|
|
106
119
|
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
107
120
|
"""A mapping of tool call IDs to incremental changes."""
|
|
108
121
|
|
|
109
|
-
FunctionDef: TypeAlias = Callable[[list[
|
|
122
|
+
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
|
|
110
123
|
"""A function used to generate a non-streamed response."""
|
|
111
124
|
|
|
112
|
-
StreamFunctionDef: TypeAlias = Callable[[list[
|
|
125
|
+
StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
|
|
113
126
|
"""A function used to generate a streamed response.
|
|
114
127
|
|
|
115
128
|
While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
|
|
@@ -127,18 +140,25 @@ class FunctionAgentModel(AgentModel):
|
|
|
127
140
|
stream_function: StreamFunctionDef | None
|
|
128
141
|
agent_info: AgentInfo
|
|
129
142
|
|
|
130
|
-
async def request(
|
|
143
|
+
async def request(
|
|
144
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
145
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
146
|
+
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
147
|
+
|
|
131
148
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
132
149
|
if inspect.iscoroutinefunction(self.function):
|
|
133
|
-
response = await self.function(messages,
|
|
150
|
+
response = await self.function(messages, agent_info)
|
|
134
151
|
else:
|
|
135
|
-
response_ = await _utils.run_in_executor(self.function, messages,
|
|
136
|
-
|
|
152
|
+
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
153
|
+
assert isinstance(response_, ModelResponse), response_
|
|
154
|
+
response = response_
|
|
137
155
|
# TODO is `messages` right here? Should it just be new messages?
|
|
138
|
-
return response,
|
|
156
|
+
return response, _estimate_usage(chain(messages, [response]))
|
|
139
157
|
|
|
140
158
|
@asynccontextmanager
|
|
141
|
-
async def request_stream(
|
|
159
|
+
async def request_stream(
|
|
160
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
142
162
|
assert (
|
|
143
163
|
self.stream_function is not None
|
|
144
164
|
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
@@ -176,8 +196,8 @@ class FunctionStreamTextResponse(StreamTextResponse):
|
|
|
176
196
|
yield from self._buffer
|
|
177
197
|
self._buffer.clear()
|
|
178
198
|
|
|
179
|
-
def
|
|
180
|
-
return result.
|
|
199
|
+
def usage(self) -> result.Usage:
|
|
200
|
+
return result.Usage()
|
|
181
201
|
|
|
182
202
|
def timestamp(self) -> datetime:
|
|
183
203
|
return self._timestamp
|
|
@@ -206,53 +226,55 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
|
|
|
206
226
|
else:
|
|
207
227
|
self._delta_tool_calls[key] = new
|
|
208
228
|
|
|
209
|
-
def get(self, *, final: bool = False) ->
|
|
210
|
-
calls: list[
|
|
229
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
230
|
+
calls: list[ModelResponsePart] = []
|
|
211
231
|
for c in self._delta_tool_calls.values():
|
|
212
232
|
if c.name is not None and c.json_args is not None:
|
|
213
|
-
calls.append(
|
|
233
|
+
calls.append(ToolCallPart.from_raw_args(c.name, c.json_args))
|
|
214
234
|
|
|
215
|
-
return
|
|
235
|
+
return ModelResponse(calls, timestamp=self._timestamp)
|
|
216
236
|
|
|
217
|
-
def
|
|
218
|
-
return
|
|
237
|
+
def usage(self) -> result.Usage:
|
|
238
|
+
return _estimate_usage([self.get()])
|
|
219
239
|
|
|
220
240
|
def timestamp(self) -> datetime:
|
|
221
241
|
return self._timestamp
|
|
222
242
|
|
|
223
243
|
|
|
224
|
-
def
|
|
225
|
-
"""Very rough guesstimate of the
|
|
244
|
+
def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
|
|
245
|
+
"""Very rough guesstimate of the token usage associated with a series of messages.
|
|
226
246
|
|
|
227
247
|
This is designed to be used solely to give plausible numbers for testing!
|
|
228
248
|
"""
|
|
229
249
|
# there seem to be about 50 tokens of overhead for both Gemini and OpenAI calls, so add that here ¯\_(ツ)_/¯
|
|
230
|
-
|
|
231
250
|
request_tokens = 50
|
|
232
251
|
response_tokens = 0
|
|
233
252
|
for message in messages:
|
|
234
|
-
if message
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
elif message.role == 'model-structured-response':
|
|
243
|
-
for call in message.calls:
|
|
244
|
-
if isinstance(call.args, ArgsJson):
|
|
245
|
-
args_str = call.args.args_json
|
|
253
|
+
if isinstance(message, ModelRequest):
|
|
254
|
+
for part in message.parts:
|
|
255
|
+
if isinstance(part, (SystemPromptPart, UserPromptPart)):
|
|
256
|
+
request_tokens += _estimate_string_usage(part.content)
|
|
257
|
+
elif isinstance(part, ToolReturnPart):
|
|
258
|
+
request_tokens += _estimate_string_usage(part.model_response_str())
|
|
259
|
+
elif isinstance(part, RetryPromptPart):
|
|
260
|
+
request_tokens += _estimate_string_usage(part.model_response())
|
|
246
261
|
else:
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
262
|
+
assert_never(part)
|
|
263
|
+
elif isinstance(message, ModelResponse):
|
|
264
|
+
for part in message.parts:
|
|
265
|
+
if isinstance(part, TextPart):
|
|
266
|
+
response_tokens += _estimate_string_usage(part.content)
|
|
267
|
+
elif isinstance(part, ToolCallPart):
|
|
268
|
+
call = part
|
|
269
|
+
response_tokens += 1 + _estimate_string_usage(call.args_as_json_str())
|
|
270
|
+
else:
|
|
271
|
+
assert_never(part)
|
|
250
272
|
else:
|
|
251
273
|
assert_never(message)
|
|
252
|
-
return result.
|
|
274
|
+
return result.Usage(
|
|
253
275
|
request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
|
|
254
276
|
)
|
|
255
277
|
|
|
256
278
|
|
|
257
|
-
def
|
|
279
|
+
def _estimate_string_usage(content: str) -> int:
|
|
258
280
|
return len(re.split(r'[\s",.:]+', content))
|