pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +16 -3
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +82 -74
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +218 -9
- pydantic_ai/models/__init__.py +31 -72
- pydantic_ai/models/anthropic.py +21 -21
- pydantic_ai/models/function.py +47 -79
- pydantic_ai/models/gemini.py +76 -122
- pydantic_ai/models/groq.py +53 -125
- pydantic_ai/models/mistral.py +75 -137
- pydantic_ai/models/ollama.py +1 -0
- pydantic_ai/models/openai.py +50 -125
- pydantic_ai/models/test.py +40 -73
- pydantic_ai/result.py +91 -92
- pydantic_ai/tools.py +24 -5
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.19.dist-info}/METADATA +3 -1
- pydantic_ai_slim-0.0.19.dist-info/RECORD +29 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.19.dist-info}/WHEEL +0 -0
pydantic_ai/models/__init__.py
CHANGED
|
@@ -7,20 +7,22 @@ specific LLM being used.
|
|
|
7
7
|
from __future__ import annotations as _annotations
|
|
8
8
|
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
|
-
from collections.abc import AsyncIterator,
|
|
10
|
+
from collections.abc import AsyncIterator, Iterator
|
|
11
11
|
from contextlib import asynccontextmanager, contextmanager
|
|
12
|
+
from dataclasses import dataclass, field
|
|
12
13
|
from datetime import datetime
|
|
13
14
|
from functools import cache
|
|
14
|
-
from typing import TYPE_CHECKING, Literal
|
|
15
|
+
from typing import TYPE_CHECKING, Literal
|
|
15
16
|
|
|
16
17
|
import httpx
|
|
17
18
|
|
|
19
|
+
from .._parts_manager import ModelResponsePartsManager
|
|
18
20
|
from ..exceptions import UserError
|
|
19
|
-
from ..messages import ModelMessage, ModelResponse
|
|
21
|
+
from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
|
|
20
22
|
from ..settings import ModelSettings
|
|
23
|
+
from ..usage import Usage
|
|
21
24
|
|
|
22
25
|
if TYPE_CHECKING:
|
|
23
|
-
from ..result import Usage
|
|
24
26
|
from ..tools import ToolDefinition
|
|
25
27
|
|
|
26
28
|
|
|
@@ -70,6 +72,7 @@ KnownModelName = Literal[
|
|
|
70
72
|
'ollama:mistral-nemo',
|
|
71
73
|
'ollama:mixtral',
|
|
72
74
|
'ollama:phi3',
|
|
75
|
+
'ollama:phi4',
|
|
73
76
|
'ollama:qwq',
|
|
74
77
|
'ollama:qwen',
|
|
75
78
|
'ollama:qwen2',
|
|
@@ -129,88 +132,47 @@ class AgentModel(ABC):
|
|
|
129
132
|
@asynccontextmanager
|
|
130
133
|
async def request_stream(
|
|
131
134
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
132
|
-
) -> AsyncIterator[
|
|
135
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
133
136
|
"""Make a request to the model and return a streaming response."""
|
|
137
|
+
# This method is not required, but you need to implement it if you want to support streamed responses
|
|
134
138
|
raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
|
|
135
139
|
# yield is required to make this a generator for type checking
|
|
136
140
|
# noinspection PyUnreachableCode
|
|
137
141
|
yield # pragma: no cover
|
|
138
142
|
|
|
139
143
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def __aiter__(self) -> AsyncIterator[None]:
|
|
144
|
-
"""Stream the response as an async iterable, building up the text as it goes.
|
|
145
|
-
|
|
146
|
-
This is an async iterator that yields `None` to avoid doing the work of validating the input and
|
|
147
|
-
extracting the text field when it will often be thrown away.
|
|
148
|
-
"""
|
|
149
|
-
return self
|
|
150
|
-
|
|
151
|
-
@abstractmethod
|
|
152
|
-
async def __anext__(self) -> None:
|
|
153
|
-
"""Process the next chunk of the response, see above for why this returns `None`."""
|
|
154
|
-
raise NotImplementedError()
|
|
155
|
-
|
|
156
|
-
@abstractmethod
|
|
157
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
158
|
-
"""Returns an iterable of text since the last call to `get()` — e.g. the text delta.
|
|
159
|
-
|
|
160
|
-
Args:
|
|
161
|
-
final: If True, this is the final call, after iteration is complete, the response should be fully validated
|
|
162
|
-
and all text extracted.
|
|
163
|
-
"""
|
|
164
|
-
raise NotImplementedError()
|
|
165
|
-
|
|
166
|
-
@abstractmethod
|
|
167
|
-
def usage(self) -> Usage:
|
|
168
|
-
"""Return the usage of the request.
|
|
169
|
-
|
|
170
|
-
NOTE: this won't return the full usage until the stream is finished.
|
|
171
|
-
"""
|
|
172
|
-
raise NotImplementedError()
|
|
173
|
-
|
|
174
|
-
@abstractmethod
|
|
175
|
-
def timestamp(self) -> datetime:
|
|
176
|
-
"""Get the timestamp of the response."""
|
|
177
|
-
raise NotImplementedError()
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
class StreamStructuredResponse(ABC):
|
|
144
|
+
@dataclass
|
|
145
|
+
class StreamedResponse(ABC):
|
|
181
146
|
"""Streamed response from an LLM when calling a tool."""
|
|
182
147
|
|
|
183
|
-
|
|
184
|
-
|
|
148
|
+
_usage: Usage = field(default_factory=Usage, init=False)
|
|
149
|
+
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
150
|
+
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
|
|
185
151
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
152
|
+
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
153
|
+
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
|
|
154
|
+
if self._event_iterator is None:
|
|
155
|
+
self._event_iterator = self._get_event_iterator()
|
|
156
|
+
return self._event_iterator
|
|
190
157
|
|
|
191
158
|
@abstractmethod
|
|
192
|
-
async def
|
|
193
|
-
"""
|
|
194
|
-
raise NotImplementedError()
|
|
195
|
-
|
|
196
|
-
@abstractmethod
|
|
197
|
-
def get(self, *, final: bool = False) -> ModelResponse:
|
|
198
|
-
"""Get the `ModelResponse` at this point.
|
|
199
|
-
|
|
200
|
-
The `ModelResponse` may or may not be complete, depending on whether the stream is finished.
|
|
159
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
160
|
+
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
|
|
201
161
|
|
|
202
|
-
|
|
203
|
-
|
|
162
|
+
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
163
|
+
pydantic_ai-format events.
|
|
204
164
|
"""
|
|
205
165
|
raise NotImplementedError()
|
|
166
|
+
# noinspection PyUnreachableCode
|
|
167
|
+
yield
|
|
206
168
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
169
|
+
def get(self) -> ModelResponse:
|
|
170
|
+
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
|
|
171
|
+
return ModelResponse(parts=self._parts_manager.get_parts(), timestamp=self.timestamp())
|
|
210
172
|
|
|
211
|
-
|
|
212
|
-
"""
|
|
213
|
-
|
|
173
|
+
def usage(self) -> Usage:
|
|
174
|
+
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
|
|
175
|
+
return self._usage
|
|
214
176
|
|
|
215
177
|
@abstractmethod
|
|
216
178
|
def timestamp(self) -> datetime:
|
|
@@ -218,9 +180,6 @@ class StreamStructuredResponse(ABC):
|
|
|
218
180
|
raise NotImplementedError()
|
|
219
181
|
|
|
220
182
|
|
|
221
|
-
EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse]
|
|
222
|
-
|
|
223
|
-
|
|
224
183
|
ALLOW_MODEL_REQUESTS = True
|
|
225
184
|
"""Whether to allow requests to models.
|
|
226
185
|
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Any, Literal, Union, cast, overload
|
|
|
8
8
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
9
9
|
from typing_extensions import assert_never
|
|
10
10
|
|
|
11
|
-
from .. import
|
|
11
|
+
from .. import usage
|
|
12
12
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
13
|
from ..messages import (
|
|
14
14
|
ArgsDict,
|
|
@@ -27,8 +27,8 @@ from ..settings import ModelSettings
|
|
|
27
27
|
from ..tools import ToolDefinition
|
|
28
28
|
from . import (
|
|
29
29
|
AgentModel,
|
|
30
|
-
EitherStreamedResponse,
|
|
31
30
|
Model,
|
|
31
|
+
StreamedResponse,
|
|
32
32
|
cached_async_http_client,
|
|
33
33
|
check_allow_model_requests,
|
|
34
34
|
)
|
|
@@ -158,14 +158,14 @@ class AnthropicAgentModel(AgentModel):
|
|
|
158
158
|
|
|
159
159
|
async def request(
|
|
160
160
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
-
) -> tuple[ModelResponse,
|
|
161
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
162
162
|
response = await self._messages_create(messages, False, model_settings)
|
|
163
163
|
return self._process_response(response), _map_usage(response)
|
|
164
164
|
|
|
165
165
|
@asynccontextmanager
|
|
166
166
|
async def request_stream(
|
|
167
167
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
|
-
) -> AsyncIterator[
|
|
168
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
169
169
|
response = await self._messages_create(messages, True, model_settings)
|
|
170
170
|
async with response:
|
|
171
171
|
yield await self._process_streamed_response(response)
|
|
@@ -216,28 +216,28 @@ class AnthropicAgentModel(AgentModel):
|
|
|
216
216
|
items: list[ModelResponsePart] = []
|
|
217
217
|
for item in response.content:
|
|
218
218
|
if isinstance(item, TextBlock):
|
|
219
|
-
items.append(TextPart(item.text))
|
|
219
|
+
items.append(TextPart(content=item.text))
|
|
220
220
|
else:
|
|
221
221
|
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
222
222
|
items.append(
|
|
223
223
|
ToolCallPart.from_raw_args(
|
|
224
|
-
item.name,
|
|
225
|
-
cast(dict[str, Any], item.input),
|
|
226
|
-
item.id,
|
|
224
|
+
tool_name=item.name,
|
|
225
|
+
args=cast(dict[str, Any], item.input),
|
|
226
|
+
tool_call_id=item.id,
|
|
227
227
|
)
|
|
228
228
|
)
|
|
229
229
|
|
|
230
230
|
return ModelResponse(items)
|
|
231
231
|
|
|
232
232
|
@staticmethod
|
|
233
|
-
async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) ->
|
|
233
|
+
async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
|
|
234
234
|
"""TODO: Process a streamed response, and prepare a streaming response to return."""
|
|
235
235
|
# We don't yet support streamed responses from Anthropic, so we raise an error here for now.
|
|
236
236
|
# Streamed responses will be supported in a future release.
|
|
237
237
|
|
|
238
238
|
raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
|
|
239
239
|
|
|
240
|
-
# Should be returning some sort of AnthropicStreamTextResponse or
|
|
240
|
+
# Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse
|
|
241
241
|
# depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
|
|
242
242
|
# RawMessageStartEvent
|
|
243
243
|
# RawMessageDeltaEvent
|
|
@@ -315,30 +315,30 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
|
315
315
|
)
|
|
316
316
|
|
|
317
317
|
|
|
318
|
-
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) ->
|
|
318
|
+
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
|
|
319
319
|
if isinstance(message, AnthropicMessage):
|
|
320
|
-
|
|
320
|
+
response_usage = message.usage
|
|
321
321
|
else:
|
|
322
322
|
if isinstance(message, RawMessageStartEvent):
|
|
323
|
-
|
|
323
|
+
response_usage = message.message.usage
|
|
324
324
|
elif isinstance(message, RawMessageDeltaEvent):
|
|
325
|
-
|
|
325
|
+
response_usage = message.usage
|
|
326
326
|
else:
|
|
327
327
|
# No usage information provided in:
|
|
328
328
|
# - RawMessageStopEvent
|
|
329
329
|
# - RawContentBlockStartEvent
|
|
330
330
|
# - RawContentBlockDeltaEvent
|
|
331
331
|
# - RawContentBlockStopEvent
|
|
332
|
-
|
|
332
|
+
response_usage = None
|
|
333
333
|
|
|
334
|
-
if
|
|
335
|
-
return
|
|
334
|
+
if response_usage is None:
|
|
335
|
+
return usage.Usage()
|
|
336
336
|
|
|
337
|
-
request_tokens = getattr(
|
|
337
|
+
request_tokens = getattr(response_usage, 'input_tokens', None)
|
|
338
338
|
|
|
339
|
-
return
|
|
339
|
+
return usage.Usage(
|
|
340
340
|
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
|
|
341
341
|
request_tokens=request_tokens,
|
|
342
|
-
response_tokens=
|
|
343
|
-
total_tokens=(request_tokens or 0) +
|
|
342
|
+
response_tokens=response_usage.output_tokens,
|
|
343
|
+
total_tokens=(request_tokens or 0) + response_usage.output_tokens,
|
|
344
344
|
)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -7,16 +7,17 @@ from contextlib import asynccontextmanager
|
|
|
7
7
|
from dataclasses import dataclass, field, replace
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
|
-
from typing import Callable, Union
|
|
10
|
+
from typing import Callable, Union
|
|
11
11
|
|
|
12
12
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
13
13
|
|
|
14
|
-
from .. import _utils,
|
|
14
|
+
from .. import _utils, usage
|
|
15
|
+
from .._utils import PeekableAsyncStream
|
|
15
16
|
from ..messages import (
|
|
16
17
|
ModelMessage,
|
|
17
18
|
ModelRequest,
|
|
18
19
|
ModelResponse,
|
|
19
|
-
|
|
20
|
+
ModelResponseStreamEvent,
|
|
20
21
|
RetryPromptPart,
|
|
21
22
|
SystemPromptPart,
|
|
22
23
|
TextPart,
|
|
@@ -26,7 +27,7 @@ from ..messages import (
|
|
|
26
27
|
)
|
|
27
28
|
from ..settings import ModelSettings
|
|
28
29
|
from ..tools import ToolDefinition
|
|
29
|
-
from . import AgentModel,
|
|
30
|
+
from . import AgentModel, Model, StreamedResponse
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
@dataclass(init=False)
|
|
@@ -142,7 +143,7 @@ class FunctionAgentModel(AgentModel):
|
|
|
142
143
|
|
|
143
144
|
async def request(
|
|
144
145
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
145
|
-
) -> tuple[ModelResponse,
|
|
146
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
146
147
|
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
147
148
|
|
|
148
149
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
@@ -158,90 +159,55 @@ class FunctionAgentModel(AgentModel):
|
|
|
158
159
|
@asynccontextmanager
|
|
159
160
|
async def request_stream(
|
|
160
161
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
-
) -> AsyncIterator[
|
|
162
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
162
163
|
assert (
|
|
163
164
|
self.stream_function is not None
|
|
164
165
|
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
165
|
-
response_stream = self.stream_function(messages, self.agent_info)
|
|
166
|
-
try:
|
|
167
|
-
first = await response_stream.__anext__()
|
|
168
|
-
except StopAsyncIteration as e:
|
|
169
|
-
raise ValueError('Stream function must return at least one item') from e
|
|
170
|
-
|
|
171
|
-
if isinstance(first, str):
|
|
172
|
-
text_stream = cast(AsyncIterator[str], response_stream)
|
|
173
|
-
yield FunctionStreamTextResponse(first, text_stream)
|
|
174
|
-
else:
|
|
175
|
-
structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream)
|
|
176
|
-
yield FunctionStreamStructuredResponse(first, structured_stream)
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@dataclass
|
|
180
|
-
class FunctionStreamTextResponse(StreamTextResponse):
|
|
181
|
-
"""Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
182
|
-
|
|
183
|
-
_next: str | None
|
|
184
|
-
_iter: AsyncIterator[str]
|
|
185
|
-
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
186
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
187
|
-
|
|
188
|
-
async def __anext__(self) -> None:
|
|
189
|
-
if self._next is not None:
|
|
190
|
-
self._buffer.append(self._next)
|
|
191
|
-
self._next = None
|
|
192
|
-
else:
|
|
193
|
-
self._buffer.append(await self._iter.__anext__())
|
|
194
|
-
|
|
195
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
196
|
-
yield from self._buffer
|
|
197
|
-
self._buffer.clear()
|
|
166
|
+
response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
|
|
198
167
|
|
|
199
|
-
|
|
200
|
-
|
|
168
|
+
first = await response_stream.peek()
|
|
169
|
+
if isinstance(first, _utils.Unset):
|
|
170
|
+
raise ValueError('Stream function must return at least one item')
|
|
201
171
|
|
|
202
|
-
|
|
203
|
-
return self._timestamp
|
|
172
|
+
yield FunctionStreamedResponse(response_stream)
|
|
204
173
|
|
|
205
174
|
|
|
206
175
|
@dataclass
|
|
207
|
-
class
|
|
208
|
-
"""Implementation of `
|
|
176
|
+
class FunctionStreamedResponse(StreamedResponse):
|
|
177
|
+
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
209
178
|
|
|
210
|
-
|
|
211
|
-
_iter: AsyncIterator[DeltaToolCalls]
|
|
212
|
-
_delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict)
|
|
179
|
+
_iter: AsyncIterator[str | DeltaToolCalls]
|
|
213
180
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
214
181
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
tool_call = self._next
|
|
218
|
-
self._next = None
|
|
219
|
-
else:
|
|
220
|
-
tool_call = await self._iter.__anext__()
|
|
182
|
+
def __post_init__(self):
|
|
183
|
+
self._usage += _estimate_usage([])
|
|
221
184
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
185
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
186
|
+
async for item in self._iter:
|
|
187
|
+
if isinstance(item, str):
|
|
188
|
+
response_tokens = _estimate_string_tokens(item)
|
|
189
|
+
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
190
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
|
|
226
191
|
else:
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
192
|
+
delta_tool_calls = item
|
|
193
|
+
for dtc_index, delta_tool_call in delta_tool_calls.items():
|
|
194
|
+
if delta_tool_call.json_args:
|
|
195
|
+
response_tokens = _estimate_string_tokens(delta_tool_call.json_args)
|
|
196
|
+
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
197
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
198
|
+
vendor_part_id=dtc_index,
|
|
199
|
+
tool_name=delta_tool_call.name,
|
|
200
|
+
args=delta_tool_call.json_args,
|
|
201
|
+
tool_call_id=None,
|
|
202
|
+
)
|
|
203
|
+
if maybe_event is not None:
|
|
204
|
+
yield maybe_event
|
|
239
205
|
|
|
240
206
|
def timestamp(self) -> datetime:
|
|
241
207
|
return self._timestamp
|
|
242
208
|
|
|
243
209
|
|
|
244
|
-
def _estimate_usage(messages: Iterable[ModelMessage]) ->
|
|
210
|
+
def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
245
211
|
"""Very rough guesstimate of the token usage associated with a series of messages.
|
|
246
212
|
|
|
247
213
|
This is designed to be used solely to give plausible numbers for testing!
|
|
@@ -253,28 +219,30 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
|
|
|
253
219
|
if isinstance(message, ModelRequest):
|
|
254
220
|
for part in message.parts:
|
|
255
221
|
if isinstance(part, (SystemPromptPart, UserPromptPart)):
|
|
256
|
-
request_tokens +=
|
|
222
|
+
request_tokens += _estimate_string_tokens(part.content)
|
|
257
223
|
elif isinstance(part, ToolReturnPart):
|
|
258
|
-
request_tokens +=
|
|
224
|
+
request_tokens += _estimate_string_tokens(part.model_response_str())
|
|
259
225
|
elif isinstance(part, RetryPromptPart):
|
|
260
|
-
request_tokens +=
|
|
226
|
+
request_tokens += _estimate_string_tokens(part.model_response())
|
|
261
227
|
else:
|
|
262
228
|
assert_never(part)
|
|
263
229
|
elif isinstance(message, ModelResponse):
|
|
264
230
|
for part in message.parts:
|
|
265
231
|
if isinstance(part, TextPart):
|
|
266
|
-
response_tokens +=
|
|
232
|
+
response_tokens += _estimate_string_tokens(part.content)
|
|
267
233
|
elif isinstance(part, ToolCallPart):
|
|
268
234
|
call = part
|
|
269
|
-
response_tokens += 1 +
|
|
235
|
+
response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
|
|
270
236
|
else:
|
|
271
237
|
assert_never(part)
|
|
272
238
|
else:
|
|
273
239
|
assert_never(message)
|
|
274
|
-
return
|
|
240
|
+
return usage.Usage(
|
|
275
241
|
request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
|
|
276
242
|
)
|
|
277
243
|
|
|
278
244
|
|
|
279
|
-
def
|
|
280
|
-
|
|
245
|
+
def _estimate_string_tokens(content: str) -> int:
|
|
246
|
+
if not content:
|
|
247
|
+
return 0
|
|
248
|
+
return len(re.split(r'[\s",.:]+', content.strip()))
|