pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +17 -3
- pydantic_ai/_result.py +26 -21
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +187 -159
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +217 -15
- pydantic_ai/models/__init__.py +58 -71
- pydantic_ai/models/anthropic.py +112 -48
- pydantic_ai/models/cohere.py +278 -0
- pydantic_ai/models/function.py +57 -85
- pydantic_ai/models/gemini.py +83 -129
- pydantic_ai/models/groq.py +60 -130
- pydantic_ai/models/mistral.py +86 -142
- pydantic_ai/models/ollama.py +4 -0
- pydantic_ai/models/openai.py +75 -136
- pydantic_ai/models/test.py +55 -80
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +132 -114
- pydantic_ai/settings.py +18 -1
- pydantic_ai/tools.py +42 -23
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/METADATA +7 -3
- pydantic_ai_slim-0.0.20.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/WHEEL +0 -0
pydantic_ai/models/function.py
CHANGED
|
@@ -7,16 +7,17 @@ from contextlib import asynccontextmanager
|
|
|
7
7
|
from dataclasses import dataclass, field, replace
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
|
-
from typing import Callable, Union
|
|
10
|
+
from typing import Callable, Union
|
|
11
11
|
|
|
12
12
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
13
13
|
|
|
14
|
-
from .. import _utils,
|
|
14
|
+
from .. import _utils, usage
|
|
15
|
+
from .._utils import PeekableAsyncStream
|
|
15
16
|
from ..messages import (
|
|
16
17
|
ModelMessage,
|
|
17
18
|
ModelRequest,
|
|
18
19
|
ModelResponse,
|
|
19
|
-
|
|
20
|
+
ModelResponseStreamEvent,
|
|
20
21
|
RetryPromptPart,
|
|
21
22
|
SystemPromptPart,
|
|
22
23
|
TextPart,
|
|
@@ -26,7 +27,7 @@ from ..messages import (
|
|
|
26
27
|
)
|
|
27
28
|
from ..settings import ModelSettings
|
|
28
29
|
from ..tools import ToolDefinition
|
|
29
|
-
from . import AgentModel,
|
|
30
|
+
from . import AgentModel, Model, StreamedResponse
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
@dataclass(init=False)
|
|
@@ -70,16 +71,15 @@ class FunctionModel(Model):
|
|
|
70
71
|
result_tools: list[ToolDefinition],
|
|
71
72
|
) -> AgentModel:
|
|
72
73
|
return FunctionAgentModel(
|
|
73
|
-
self.function,
|
|
74
|
+
self.function,
|
|
75
|
+
self.stream_function,
|
|
76
|
+
AgentInfo(function_tools, allow_text_result, result_tools, None),
|
|
74
77
|
)
|
|
75
78
|
|
|
76
79
|
def name(self) -> str:
|
|
77
|
-
|
|
78
|
-
if self.
|
|
79
|
-
|
|
80
|
-
if self.stream_function is not None:
|
|
81
|
-
labels.append(f'stream-{self.stream_function.__name__}')
|
|
82
|
-
return f'function:{",".join(labels)}'
|
|
80
|
+
function_name = self.function.__name__ if self.function is not None else ''
|
|
81
|
+
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
82
|
+
return f'function:{function_name}:{stream_function_name}'
|
|
83
83
|
|
|
84
84
|
|
|
85
85
|
@dataclass(frozen=True)
|
|
@@ -142,106 +142,76 @@ class FunctionAgentModel(AgentModel):
|
|
|
142
142
|
|
|
143
143
|
async def request(
|
|
144
144
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
145
|
-
) -> tuple[ModelResponse,
|
|
145
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
146
146
|
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
147
147
|
|
|
148
148
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
149
|
+
model_name = f'function:{self.function.__name__}'
|
|
150
|
+
|
|
149
151
|
if inspect.iscoroutinefunction(self.function):
|
|
150
152
|
response = await self.function(messages, agent_info)
|
|
151
153
|
else:
|
|
152
154
|
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
153
155
|
assert isinstance(response_, ModelResponse), response_
|
|
154
156
|
response = response_
|
|
157
|
+
response.model_name = model_name
|
|
155
158
|
# TODO is `messages` right here? Should it just be new messages?
|
|
156
159
|
return response, _estimate_usage(chain(messages, [response]))
|
|
157
160
|
|
|
158
161
|
@asynccontextmanager
|
|
159
162
|
async def request_stream(
|
|
160
163
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
-
) -> AsyncIterator[
|
|
164
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
162
165
|
assert (
|
|
163
166
|
self.stream_function is not None
|
|
164
167
|
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
165
|
-
|
|
166
|
-
try:
|
|
167
|
-
first = await response_stream.__anext__()
|
|
168
|
-
except StopAsyncIteration as e:
|
|
169
|
-
raise ValueError('Stream function must return at least one item') from e
|
|
170
|
-
|
|
171
|
-
if isinstance(first, str):
|
|
172
|
-
text_stream = cast(AsyncIterator[str], response_stream)
|
|
173
|
-
yield FunctionStreamTextResponse(first, text_stream)
|
|
174
|
-
else:
|
|
175
|
-
structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream)
|
|
176
|
-
yield FunctionStreamStructuredResponse(first, structured_stream)
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@dataclass
|
|
180
|
-
class FunctionStreamTextResponse(StreamTextResponse):
|
|
181
|
-
"""Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
182
|
-
|
|
183
|
-
_next: str | None
|
|
184
|
-
_iter: AsyncIterator[str]
|
|
185
|
-
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
186
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
187
|
-
|
|
188
|
-
async def __anext__(self) -> None:
|
|
189
|
-
if self._next is not None:
|
|
190
|
-
self._buffer.append(self._next)
|
|
191
|
-
self._next = None
|
|
192
|
-
else:
|
|
193
|
-
self._buffer.append(await self._iter.__anext__())
|
|
168
|
+
model_name = f'function:{self.stream_function.__name__}'
|
|
194
169
|
|
|
195
|
-
|
|
196
|
-
yield from self._buffer
|
|
197
|
-
self._buffer.clear()
|
|
170
|
+
response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
|
|
198
171
|
|
|
199
|
-
|
|
200
|
-
|
|
172
|
+
first = await response_stream.peek()
|
|
173
|
+
if isinstance(first, _utils.Unset):
|
|
174
|
+
raise ValueError('Stream function must return at least one item')
|
|
201
175
|
|
|
202
|
-
|
|
203
|
-
return self._timestamp
|
|
176
|
+
yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
|
|
204
177
|
|
|
205
178
|
|
|
206
179
|
@dataclass
|
|
207
|
-
class
|
|
208
|
-
"""Implementation of `
|
|
180
|
+
class FunctionStreamedResponse(StreamedResponse):
|
|
181
|
+
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
209
182
|
|
|
210
|
-
|
|
211
|
-
_iter: AsyncIterator[DeltaToolCalls]
|
|
212
|
-
_delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict)
|
|
183
|
+
_iter: AsyncIterator[str | DeltaToolCalls]
|
|
213
184
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
214
185
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
tool_call = self._next
|
|
218
|
-
self._next = None
|
|
219
|
-
else:
|
|
220
|
-
tool_call = await self._iter.__anext__()
|
|
186
|
+
def __post_init__(self):
|
|
187
|
+
self._usage += _estimate_usage([])
|
|
221
188
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
189
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
190
|
+
async for item in self._iter:
|
|
191
|
+
if isinstance(item, str):
|
|
192
|
+
response_tokens = _estimate_string_tokens(item)
|
|
193
|
+
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
194
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
|
|
226
195
|
else:
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
196
|
+
delta_tool_calls = item
|
|
197
|
+
for dtc_index, delta_tool_call in delta_tool_calls.items():
|
|
198
|
+
if delta_tool_call.json_args:
|
|
199
|
+
response_tokens = _estimate_string_tokens(delta_tool_call.json_args)
|
|
200
|
+
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
201
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
202
|
+
vendor_part_id=dtc_index,
|
|
203
|
+
tool_name=delta_tool_call.name,
|
|
204
|
+
args=delta_tool_call.json_args,
|
|
205
|
+
tool_call_id=None,
|
|
206
|
+
)
|
|
207
|
+
if maybe_event is not None:
|
|
208
|
+
yield maybe_event
|
|
239
209
|
|
|
240
210
|
def timestamp(self) -> datetime:
|
|
241
211
|
return self._timestamp
|
|
242
212
|
|
|
243
213
|
|
|
244
|
-
def _estimate_usage(messages: Iterable[ModelMessage]) ->
|
|
214
|
+
def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
245
215
|
"""Very rough guesstimate of the token usage associated with a series of messages.
|
|
246
216
|
|
|
247
217
|
This is designed to be used solely to give plausible numbers for testing!
|
|
@@ -253,28 +223,30 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
|
|
|
253
223
|
if isinstance(message, ModelRequest):
|
|
254
224
|
for part in message.parts:
|
|
255
225
|
if isinstance(part, (SystemPromptPart, UserPromptPart)):
|
|
256
|
-
request_tokens +=
|
|
226
|
+
request_tokens += _estimate_string_tokens(part.content)
|
|
257
227
|
elif isinstance(part, ToolReturnPart):
|
|
258
|
-
request_tokens +=
|
|
228
|
+
request_tokens += _estimate_string_tokens(part.model_response_str())
|
|
259
229
|
elif isinstance(part, RetryPromptPart):
|
|
260
|
-
request_tokens +=
|
|
230
|
+
request_tokens += _estimate_string_tokens(part.model_response())
|
|
261
231
|
else:
|
|
262
232
|
assert_never(part)
|
|
263
233
|
elif isinstance(message, ModelResponse):
|
|
264
234
|
for part in message.parts:
|
|
265
235
|
if isinstance(part, TextPart):
|
|
266
|
-
response_tokens +=
|
|
236
|
+
response_tokens += _estimate_string_tokens(part.content)
|
|
267
237
|
elif isinstance(part, ToolCallPart):
|
|
268
238
|
call = part
|
|
269
|
-
response_tokens += 1 +
|
|
239
|
+
response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
|
|
270
240
|
else:
|
|
271
241
|
assert_never(part)
|
|
272
242
|
else:
|
|
273
243
|
assert_never(message)
|
|
274
|
-
return
|
|
244
|
+
return usage.Usage(
|
|
275
245
|
request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
|
|
276
246
|
)
|
|
277
247
|
|
|
278
248
|
|
|
279
|
-
def
|
|
280
|
-
|
|
249
|
+
def _estimate_string_tokens(content: str) -> int:
|
|
250
|
+
if not content:
|
|
251
|
+
return 0
|
|
252
|
+
return len(re.split(r'[\s",.:]+', content.strip()))
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -2,24 +2,25 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator,
|
|
5
|
+
from collections.abc import AsyncIterator, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from typing import Annotated, Any, Literal, Protocol, Union
|
|
11
|
+
from uuid import uuid4
|
|
11
12
|
|
|
12
13
|
import pydantic
|
|
13
|
-
import pydantic_core
|
|
14
14
|
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
15
|
-
from typing_extensions import NotRequired, TypedDict,
|
|
15
|
+
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
16
16
|
|
|
17
|
-
from .. import UnexpectedModelBehavior, _utils, exceptions,
|
|
17
|
+
from .. import UnexpectedModelBehavior, _utils, exceptions, usage
|
|
18
18
|
from ..messages import (
|
|
19
19
|
ModelMessage,
|
|
20
20
|
ModelRequest,
|
|
21
21
|
ModelResponse,
|
|
22
22
|
ModelResponsePart,
|
|
23
|
+
ModelResponseStreamEvent,
|
|
23
24
|
RetryPromptPart,
|
|
24
25
|
SystemPromptPart,
|
|
25
26
|
TextPart,
|
|
@@ -31,10 +32,8 @@ from ..settings import ModelSettings
|
|
|
31
32
|
from ..tools import ToolDefinition
|
|
32
33
|
from . import (
|
|
33
34
|
AgentModel,
|
|
34
|
-
EitherStreamedResponse,
|
|
35
35
|
Model,
|
|
36
|
-
|
|
37
|
-
StreamTextResponse,
|
|
36
|
+
StreamedResponse,
|
|
38
37
|
cached_async_http_client,
|
|
39
38
|
check_allow_model_requests,
|
|
40
39
|
get_user_agent,
|
|
@@ -100,6 +99,7 @@ class GeminiModel(Model):
|
|
|
100
99
|
allow_text_result: bool,
|
|
101
100
|
result_tools: list[ToolDefinition],
|
|
102
101
|
) -> GeminiAgentModel:
|
|
102
|
+
check_allow_model_requests()
|
|
103
103
|
return GeminiAgentModel(
|
|
104
104
|
http_client=self.http_client,
|
|
105
105
|
model_name=self.model_name,
|
|
@@ -152,7 +152,6 @@ class GeminiAgentModel(AgentModel):
|
|
|
152
152
|
allow_text_result: bool,
|
|
153
153
|
result_tools: list[ToolDefinition],
|
|
154
154
|
):
|
|
155
|
-
check_allow_model_requests()
|
|
156
155
|
tools = [_function_from_abstract_tool(t) for t in function_tools]
|
|
157
156
|
if result_tools:
|
|
158
157
|
tools += [_function_from_abstract_tool(t) for t in result_tools]
|
|
@@ -171,7 +170,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
171
170
|
|
|
172
171
|
async def request(
|
|
173
172
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
174
|
-
) -> tuple[ModelResponse,
|
|
173
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
175
174
|
async with self._make_request(messages, False, model_settings) as http_response:
|
|
176
175
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
177
176
|
return self._process_response(response), _metadata_as_usage(response)
|
|
@@ -179,7 +178,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
179
178
|
@asynccontextmanager
|
|
180
179
|
async def request_stream(
|
|
181
180
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
182
|
-
) -> AsyncIterator[
|
|
181
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
183
182
|
async with self._make_request(messages, True, model_settings) as http_response:
|
|
184
183
|
yield await self._process_streamed_response(http_response)
|
|
185
184
|
|
|
@@ -230,15 +229,13 @@ class GeminiAgentModel(AgentModel):
|
|
|
230
229
|
raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
|
|
231
230
|
yield r
|
|
232
231
|
|
|
233
|
-
|
|
234
|
-
def _process_response(response: _GeminiResponse) -> ModelResponse:
|
|
232
|
+
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
|
|
235
233
|
if len(response['candidates']) != 1:
|
|
236
234
|
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
237
235
|
parts = response['candidates'][0]['content']['parts']
|
|
238
|
-
return _process_response_from_parts(parts)
|
|
236
|
+
return _process_response_from_parts(parts, model_name=self.model_name)
|
|
239
237
|
|
|
240
|
-
|
|
241
|
-
async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
|
|
238
|
+
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
|
|
242
239
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
243
240
|
aiter_bytes = http_response.aiter_bytes()
|
|
244
241
|
start_response: _GeminiResponse | None = None
|
|
@@ -259,11 +256,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
259
256
|
if start_response is None:
|
|
260
257
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
261
258
|
|
|
262
|
-
|
|
263
|
-
if _extract_response_parts(start_response).is_left():
|
|
264
|
-
return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
|
|
265
|
-
else:
|
|
266
|
-
return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
|
|
259
|
+
return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
|
|
267
260
|
|
|
268
261
|
@classmethod
|
|
269
262
|
def _message_to_gemini_content(
|
|
@@ -302,86 +295,69 @@ class GeminiAgentModel(AgentModel):
|
|
|
302
295
|
|
|
303
296
|
|
|
304
297
|
@dataclass
|
|
305
|
-
class
|
|
306
|
-
"""Implementation of `
|
|
307
|
-
|
|
308
|
-
_json_content: bytearray
|
|
309
|
-
_stream: AsyncIterator[bytes]
|
|
310
|
-
_position: int = 0
|
|
311
|
-
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
312
|
-
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
313
|
-
|
|
314
|
-
async def __anext__(self) -> None:
|
|
315
|
-
chunk = await self._stream.__anext__()
|
|
316
|
-
self._json_content.extend(chunk)
|
|
317
|
-
|
|
318
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
319
|
-
if final:
|
|
320
|
-
all_items = pydantic_core.from_json(self._json_content)
|
|
321
|
-
new_items = all_items[self._position :]
|
|
322
|
-
self._position = len(all_items)
|
|
323
|
-
new_responses = _gemini_streamed_response_ta.validate_python(new_items)
|
|
324
|
-
else:
|
|
325
|
-
all_items = pydantic_core.from_json(self._json_content, allow_partial=True)
|
|
326
|
-
new_items = all_items[self._position : -1]
|
|
327
|
-
self._position = len(all_items) - 1
|
|
328
|
-
new_responses = _gemini_streamed_response_ta.validate_python(
|
|
329
|
-
new_items, experimental_allow_partial='trailing-strings'
|
|
330
|
-
)
|
|
331
|
-
for r in new_responses:
|
|
332
|
-
self._usage += _metadata_as_usage(r)
|
|
333
|
-
parts = r['candidates'][0]['content']['parts']
|
|
334
|
-
if _all_text_parts(parts):
|
|
335
|
-
for part in parts:
|
|
336
|
-
yield part['text']
|
|
337
|
-
else:
|
|
338
|
-
raise UnexpectedModelBehavior(
|
|
339
|
-
'Streamed response with unexpected content, expected all parts to be text'
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
def usage(self) -> result.Usage:
|
|
343
|
-
return self._usage
|
|
344
|
-
|
|
345
|
-
def timestamp(self) -> datetime:
|
|
346
|
-
return self._timestamp
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
@dataclass
|
|
350
|
-
class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
351
|
-
"""Implementation of `StreamStructuredResponse` for the Gemini model."""
|
|
298
|
+
class GeminiStreamedResponse(StreamedResponse):
|
|
299
|
+
"""Implementation of `StreamedResponse` for the Gemini model."""
|
|
352
300
|
|
|
353
301
|
_content: bytearray
|
|
354
302
|
_stream: AsyncIterator[bytes]
|
|
355
303
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
356
|
-
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
357
|
-
|
|
358
|
-
async def __anext__(self) -> None:
|
|
359
|
-
chunk = await self._stream.__anext__()
|
|
360
|
-
self._content.extend(chunk)
|
|
361
304
|
|
|
362
|
-
def
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
305
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
306
|
+
async for gemini_response in self._get_gemini_responses():
|
|
307
|
+
candidate = gemini_response['candidates'][0]
|
|
308
|
+
gemini_part: _GeminiPartUnion
|
|
309
|
+
for gemini_part in candidate['content']['parts']:
|
|
310
|
+
if 'text' in gemini_part:
|
|
311
|
+
# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
|
|
312
|
+
# amongst the tool call deltas
|
|
313
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
|
|
314
|
+
|
|
315
|
+
elif 'function_call' in gemini_part:
|
|
316
|
+
# Here, we assume all function_call parts are complete and don't have deltas.
|
|
317
|
+
# We do this by assigning a unique randomly generated "vendor_part_id".
|
|
318
|
+
# We need to confirm whether this is actually true, but if it isn't, we can still handle it properly
|
|
319
|
+
# it would just be a bit more complicated. And we'd need to confirm the intended semantics.
|
|
320
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
321
|
+
vendor_part_id=uuid4(),
|
|
322
|
+
tool_name=gemini_part['function_call']['name'],
|
|
323
|
+
args=gemini_part['function_call']['args'],
|
|
324
|
+
tool_call_id=None,
|
|
325
|
+
)
|
|
326
|
+
if maybe_event is not None:
|
|
327
|
+
yield maybe_event
|
|
328
|
+
else:
|
|
329
|
+
assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}'
|
|
330
|
+
|
|
331
|
+
async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
|
|
332
|
+
# This method exists to ensure we only yield completed items, so we don't need to worry about
|
|
333
|
+
# partial gemini responses, which would make everything more complicated
|
|
334
|
+
|
|
335
|
+
gemini_responses: list[_GeminiResponse] = []
|
|
336
|
+
current_gemini_response_index = 0
|
|
337
|
+
# Right now, there are some circumstances where we will have information that could be yielded sooner than it is
|
|
338
|
+
# But changing that would make things a lot more complicated.
|
|
339
|
+
async for chunk in self._stream:
|
|
340
|
+
self._content.extend(chunk)
|
|
341
|
+
|
|
342
|
+
gemini_responses = _gemini_streamed_response_ta.validate_json(
|
|
343
|
+
self._content,
|
|
344
|
+
experimental_allow_partial='trailing-strings',
|
|
345
|
+
)
|
|
367
346
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
347
|
+
# The idea: yield only up to the latest response, which might still be partial.
|
|
348
|
+
# Note that if the latest response is complete, we could yield it immediately, but there's not a good
|
|
349
|
+
# allow_partial API to determine if the last item in the list is complete.
|
|
350
|
+
responses_to_yield = gemini_responses[:-1]
|
|
351
|
+
for r in responses_to_yield[current_gemini_response_index:]:
|
|
352
|
+
current_gemini_response_index += 1
|
|
353
|
+
self._usage += _metadata_as_usage(r)
|
|
354
|
+
yield r
|
|
355
|
+
|
|
356
|
+
# Now yield the final response, which should be complete
|
|
357
|
+
if gemini_responses:
|
|
358
|
+
r = gemini_responses[-1]
|
|
378
359
|
self._usage += _metadata_as_usage(r)
|
|
379
|
-
|
|
380
|
-
combined_parts.extend(candidate['content']['parts'])
|
|
381
|
-
return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
382
|
-
|
|
383
|
-
def usage(self) -> result.Usage:
|
|
384
|
-
return self._usage
|
|
360
|
+
yield r
|
|
385
361
|
|
|
386
362
|
def timestamp(self) -> datetime:
|
|
387
363
|
return self._timestamp
|
|
@@ -454,18 +430,25 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
|
|
|
454
430
|
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
|
|
455
431
|
|
|
456
432
|
|
|
457
|
-
def _process_response_from_parts(
|
|
433
|
+
def _process_response_from_parts(
|
|
434
|
+
parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
|
|
435
|
+
) -> ModelResponse:
|
|
458
436
|
items: list[ModelResponsePart] = []
|
|
459
437
|
for part in parts:
|
|
460
438
|
if 'text' in part:
|
|
461
|
-
items.append(TextPart(part['text']))
|
|
439
|
+
items.append(TextPart(content=part['text']))
|
|
462
440
|
elif 'function_call' in part:
|
|
463
|
-
items.append(
|
|
441
|
+
items.append(
|
|
442
|
+
ToolCallPart.from_raw_args(
|
|
443
|
+
tool_name=part['function_call']['name'],
|
|
444
|
+
args=part['function_call']['args'],
|
|
445
|
+
)
|
|
446
|
+
)
|
|
464
447
|
elif 'function_response' in part:
|
|
465
448
|
raise exceptions.UnexpectedModelBehavior(
|
|
466
449
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
467
450
|
)
|
|
468
|
-
return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
|
|
451
|
+
return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
|
|
469
452
|
|
|
470
453
|
|
|
471
454
|
class _GeminiFunctionCall(TypedDict):
|
|
@@ -575,35 +558,6 @@ class _GeminiResponse(TypedDict):
|
|
|
575
558
|
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
|
|
576
559
|
|
|
577
560
|
|
|
578
|
-
# TODO: Delete the next three functions once we've reworked streams to be more flexible
|
|
579
|
-
def _extract_response_parts(
|
|
580
|
-
response: _GeminiResponse,
|
|
581
|
-
) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
|
|
582
|
-
"""Extract the parts of the response from the Gemini API.
|
|
583
|
-
|
|
584
|
-
Returns Either a list of function calls (Either.left) or a list of text parts (Either.right).
|
|
585
|
-
"""
|
|
586
|
-
if len(response['candidates']) != 1:
|
|
587
|
-
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
588
|
-
parts = response['candidates'][0]['content']['parts']
|
|
589
|
-
if _all_function_call_parts(parts):
|
|
590
|
-
return _utils.Either(left=parts)
|
|
591
|
-
elif _all_text_parts(parts):
|
|
592
|
-
return _utils.Either(right=parts)
|
|
593
|
-
else:
|
|
594
|
-
raise exceptions.UnexpectedModelBehavior(
|
|
595
|
-
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {parts!r}'
|
|
596
|
-
)
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
def _all_function_call_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiFunctionCallPart]]:
|
|
600
|
-
return all('function_call' in part for part in parts)
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
def _all_text_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiTextPart]]:
|
|
604
|
-
return all('text' in part for part in parts)
|
|
605
|
-
|
|
606
|
-
|
|
607
561
|
class _GeminiCandidates(TypedDict):
|
|
608
562
|
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
|
|
609
563
|
|
|
@@ -630,14 +584,14 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
630
584
|
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
|
|
631
585
|
|
|
632
586
|
|
|
633
|
-
def _metadata_as_usage(response: _GeminiResponse) ->
|
|
587
|
+
def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
|
|
634
588
|
metadata = response.get('usage_metadata')
|
|
635
589
|
if metadata is None:
|
|
636
|
-
return
|
|
590
|
+
return usage.Usage()
|
|
637
591
|
details: dict[str, int] = {}
|
|
638
592
|
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
639
593
|
details['cached_content_token_count'] = cached_content_token_count
|
|
640
|
-
return
|
|
594
|
+
return usage.Usage(
|
|
641
595
|
request_tokens=metadata.get('prompt_token_count', 0),
|
|
642
596
|
response_tokens=metadata.get('candidates_token_count', 0),
|
|
643
597
|
total_tokens=metadata.get('total_token_count', 0),
|