pydantic-ai-slim 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +23 -4
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +16 -3
- pydantic_ai/_system_prompt.py +1 -0
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +332 -124
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +224 -9
- pydantic_ai/models/__init__.py +59 -82
- pydantic_ai/models/anthropic.py +22 -22
- pydantic_ai/models/function.py +47 -79
- pydantic_ai/models/gemini.py +86 -125
- pydantic_ai/models/groq.py +53 -125
- pydantic_ai/models/mistral.py +75 -137
- pydantic_ai/models/ollama.py +1 -0
- pydantic_ai/models/openai.py +50 -125
- pydantic_ai/models/test.py +40 -73
- pydantic_ai/models/vertexai.py +1 -1
- pydantic_ai/result.py +91 -92
- pydantic_ai/tools.py +24 -5
- {pydantic_ai_slim-0.0.17.dist-info → pydantic_ai_slim-0.0.19.dist-info}/METADATA +3 -1
- pydantic_ai_slim-0.0.19.dist-info/RECORD +29 -0
- pydantic_ai_slim-0.0.17.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.17.dist-info → pydantic_ai_slim-0.0.19.dist-info}/WHEEL +0 -0
pydantic_ai/models/function.py
CHANGED
|
@@ -7,16 +7,17 @@ from contextlib import asynccontextmanager
|
|
|
7
7
|
from dataclasses import dataclass, field, replace
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
|
-
from typing import Callable, Union
|
|
10
|
+
from typing import Callable, Union
|
|
11
11
|
|
|
12
12
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
13
13
|
|
|
14
|
-
from .. import _utils,
|
|
14
|
+
from .. import _utils, usage
|
|
15
|
+
from .._utils import PeekableAsyncStream
|
|
15
16
|
from ..messages import (
|
|
16
17
|
ModelMessage,
|
|
17
18
|
ModelRequest,
|
|
18
19
|
ModelResponse,
|
|
19
|
-
|
|
20
|
+
ModelResponseStreamEvent,
|
|
20
21
|
RetryPromptPart,
|
|
21
22
|
SystemPromptPart,
|
|
22
23
|
TextPart,
|
|
@@ -26,7 +27,7 @@ from ..messages import (
|
|
|
26
27
|
)
|
|
27
28
|
from ..settings import ModelSettings
|
|
28
29
|
from ..tools import ToolDefinition
|
|
29
|
-
from . import AgentModel,
|
|
30
|
+
from . import AgentModel, Model, StreamedResponse
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
@dataclass(init=False)
|
|
@@ -142,7 +143,7 @@ class FunctionAgentModel(AgentModel):
|
|
|
142
143
|
|
|
143
144
|
async def request(
|
|
144
145
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
145
|
-
) -> tuple[ModelResponse,
|
|
146
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
146
147
|
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
147
148
|
|
|
148
149
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
@@ -158,90 +159,55 @@ class FunctionAgentModel(AgentModel):
|
|
|
158
159
|
@asynccontextmanager
|
|
159
160
|
async def request_stream(
|
|
160
161
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
-
) -> AsyncIterator[
|
|
162
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
162
163
|
assert (
|
|
163
164
|
self.stream_function is not None
|
|
164
165
|
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
165
|
-
response_stream = self.stream_function(messages, self.agent_info)
|
|
166
|
-
try:
|
|
167
|
-
first = await response_stream.__anext__()
|
|
168
|
-
except StopAsyncIteration as e:
|
|
169
|
-
raise ValueError('Stream function must return at least one item') from e
|
|
170
|
-
|
|
171
|
-
if isinstance(first, str):
|
|
172
|
-
text_stream = cast(AsyncIterator[str], response_stream)
|
|
173
|
-
yield FunctionStreamTextResponse(first, text_stream)
|
|
174
|
-
else:
|
|
175
|
-
structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream)
|
|
176
|
-
yield FunctionStreamStructuredResponse(first, structured_stream)
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@dataclass
|
|
180
|
-
class FunctionStreamTextResponse(StreamTextResponse):
|
|
181
|
-
"""Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
182
|
-
|
|
183
|
-
_next: str | None
|
|
184
|
-
_iter: AsyncIterator[str]
|
|
185
|
-
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
186
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
187
|
-
|
|
188
|
-
async def __anext__(self) -> None:
|
|
189
|
-
if self._next is not None:
|
|
190
|
-
self._buffer.append(self._next)
|
|
191
|
-
self._next = None
|
|
192
|
-
else:
|
|
193
|
-
self._buffer.append(await self._iter.__anext__())
|
|
194
|
-
|
|
195
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
196
|
-
yield from self._buffer
|
|
197
|
-
self._buffer.clear()
|
|
166
|
+
response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
|
|
198
167
|
|
|
199
|
-
|
|
200
|
-
|
|
168
|
+
first = await response_stream.peek()
|
|
169
|
+
if isinstance(first, _utils.Unset):
|
|
170
|
+
raise ValueError('Stream function must return at least one item')
|
|
201
171
|
|
|
202
|
-
|
|
203
|
-
return self._timestamp
|
|
172
|
+
yield FunctionStreamedResponse(response_stream)
|
|
204
173
|
|
|
205
174
|
|
|
206
175
|
@dataclass
|
|
207
|
-
class
|
|
208
|
-
"""Implementation of `
|
|
176
|
+
class FunctionStreamedResponse(StreamedResponse):
|
|
177
|
+
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
209
178
|
|
|
210
|
-
|
|
211
|
-
_iter: AsyncIterator[DeltaToolCalls]
|
|
212
|
-
_delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict)
|
|
179
|
+
_iter: AsyncIterator[str | DeltaToolCalls]
|
|
213
180
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
214
181
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
tool_call = self._next
|
|
218
|
-
self._next = None
|
|
219
|
-
else:
|
|
220
|
-
tool_call = await self._iter.__anext__()
|
|
182
|
+
def __post_init__(self):
|
|
183
|
+
self._usage += _estimate_usage([])
|
|
221
184
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
185
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
186
|
+
async for item in self._iter:
|
|
187
|
+
if isinstance(item, str):
|
|
188
|
+
response_tokens = _estimate_string_tokens(item)
|
|
189
|
+
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
190
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
|
|
226
191
|
else:
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
192
|
+
delta_tool_calls = item
|
|
193
|
+
for dtc_index, delta_tool_call in delta_tool_calls.items():
|
|
194
|
+
if delta_tool_call.json_args:
|
|
195
|
+
response_tokens = _estimate_string_tokens(delta_tool_call.json_args)
|
|
196
|
+
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
197
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
198
|
+
vendor_part_id=dtc_index,
|
|
199
|
+
tool_name=delta_tool_call.name,
|
|
200
|
+
args=delta_tool_call.json_args,
|
|
201
|
+
tool_call_id=None,
|
|
202
|
+
)
|
|
203
|
+
if maybe_event is not None:
|
|
204
|
+
yield maybe_event
|
|
239
205
|
|
|
240
206
|
def timestamp(self) -> datetime:
|
|
241
207
|
return self._timestamp
|
|
242
208
|
|
|
243
209
|
|
|
244
|
-
def _estimate_usage(messages: Iterable[ModelMessage]) ->
|
|
210
|
+
def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
245
211
|
"""Very rough guesstimate of the token usage associated with a series of messages.
|
|
246
212
|
|
|
247
213
|
This is designed to be used solely to give plausible numbers for testing!
|
|
@@ -253,28 +219,30 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
|
|
|
253
219
|
if isinstance(message, ModelRequest):
|
|
254
220
|
for part in message.parts:
|
|
255
221
|
if isinstance(part, (SystemPromptPart, UserPromptPart)):
|
|
256
|
-
request_tokens +=
|
|
222
|
+
request_tokens += _estimate_string_tokens(part.content)
|
|
257
223
|
elif isinstance(part, ToolReturnPart):
|
|
258
|
-
request_tokens +=
|
|
224
|
+
request_tokens += _estimate_string_tokens(part.model_response_str())
|
|
259
225
|
elif isinstance(part, RetryPromptPart):
|
|
260
|
-
request_tokens +=
|
|
226
|
+
request_tokens += _estimate_string_tokens(part.model_response())
|
|
261
227
|
else:
|
|
262
228
|
assert_never(part)
|
|
263
229
|
elif isinstance(message, ModelResponse):
|
|
264
230
|
for part in message.parts:
|
|
265
231
|
if isinstance(part, TextPart):
|
|
266
|
-
response_tokens +=
|
|
232
|
+
response_tokens += _estimate_string_tokens(part.content)
|
|
267
233
|
elif isinstance(part, ToolCallPart):
|
|
268
234
|
call = part
|
|
269
|
-
response_tokens += 1 +
|
|
235
|
+
response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
|
|
270
236
|
else:
|
|
271
237
|
assert_never(part)
|
|
272
238
|
else:
|
|
273
239
|
assert_never(message)
|
|
274
|
-
return
|
|
240
|
+
return usage.Usage(
|
|
275
241
|
request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
|
|
276
242
|
)
|
|
277
243
|
|
|
278
244
|
|
|
279
|
-
def
|
|
280
|
-
|
|
245
|
+
def _estimate_string_tokens(content: str) -> int:
|
|
246
|
+
if not content:
|
|
247
|
+
return 0
|
|
248
|
+
return len(re.split(r'[\s",.:]+', content.strip()))
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -2,24 +2,25 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator,
|
|
5
|
+
from collections.abc import AsyncIterator, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from typing import Annotated, Any, Literal, Protocol, Union
|
|
11
|
+
from uuid import uuid4
|
|
11
12
|
|
|
12
13
|
import pydantic
|
|
13
|
-
import pydantic_core
|
|
14
14
|
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
15
|
-
from typing_extensions import NotRequired, TypedDict,
|
|
15
|
+
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
16
16
|
|
|
17
|
-
from .. import UnexpectedModelBehavior, _utils, exceptions,
|
|
17
|
+
from .. import UnexpectedModelBehavior, _utils, exceptions, usage
|
|
18
18
|
from ..messages import (
|
|
19
19
|
ModelMessage,
|
|
20
20
|
ModelRequest,
|
|
21
21
|
ModelResponse,
|
|
22
22
|
ModelResponsePart,
|
|
23
|
+
ModelResponseStreamEvent,
|
|
23
24
|
RetryPromptPart,
|
|
24
25
|
SystemPromptPart,
|
|
25
26
|
TextPart,
|
|
@@ -31,10 +32,8 @@ from ..settings import ModelSettings
|
|
|
31
32
|
from ..tools import ToolDefinition
|
|
32
33
|
from . import (
|
|
33
34
|
AgentModel,
|
|
34
|
-
EitherStreamedResponse,
|
|
35
35
|
Model,
|
|
36
|
-
|
|
37
|
-
StreamTextResponse,
|
|
36
|
+
StreamedResponse,
|
|
38
37
|
cached_async_http_client,
|
|
39
38
|
check_allow_model_requests,
|
|
40
39
|
get_user_agent,
|
|
@@ -111,7 +110,7 @@ class GeminiModel(Model):
|
|
|
111
110
|
)
|
|
112
111
|
|
|
113
112
|
def name(self) -> str:
|
|
114
|
-
return self.model_name
|
|
113
|
+
return f'google-gla:{self.model_name}'
|
|
115
114
|
|
|
116
115
|
|
|
117
116
|
class AuthProtocol(Protocol):
|
|
@@ -171,7 +170,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
171
170
|
|
|
172
171
|
async def request(
|
|
173
172
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
174
|
-
) -> tuple[ModelResponse,
|
|
173
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
175
174
|
async with self._make_request(messages, False, model_settings) as http_response:
|
|
176
175
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
177
176
|
return self._process_response(response), _metadata_as_usage(response)
|
|
@@ -179,7 +178,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
179
178
|
@asynccontextmanager
|
|
180
179
|
async def request_stream(
|
|
181
180
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
182
|
-
) -> AsyncIterator[
|
|
181
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
183
182
|
async with self._make_request(messages, True, model_settings) as http_response:
|
|
184
183
|
yield await self._process_streamed_response(http_response)
|
|
185
184
|
|
|
@@ -238,7 +237,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
238
237
|
return _process_response_from_parts(parts)
|
|
239
238
|
|
|
240
239
|
@staticmethod
|
|
241
|
-
async def _process_streamed_response(http_response: HTTPResponse) ->
|
|
240
|
+
async def _process_streamed_response(http_response: HTTPResponse) -> StreamedResponse:
|
|
242
241
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
243
242
|
aiter_bytes = http_response.aiter_bytes()
|
|
244
243
|
start_response: _GeminiResponse | None = None
|
|
@@ -259,11 +258,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
259
258
|
if start_response is None:
|
|
260
259
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
261
260
|
|
|
262
|
-
|
|
263
|
-
if _extract_response_parts(start_response).is_left():
|
|
264
|
-
return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
|
|
265
|
-
else:
|
|
266
|
-
return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
|
|
261
|
+
return GeminiStreamedResponse(_content=content, _stream=aiter_bytes)
|
|
267
262
|
|
|
268
263
|
@classmethod
|
|
269
264
|
def _message_to_gemini_content(
|
|
@@ -302,86 +297,69 @@ class GeminiAgentModel(AgentModel):
|
|
|
302
297
|
|
|
303
298
|
|
|
304
299
|
@dataclass
|
|
305
|
-
class
|
|
306
|
-
"""Implementation of `
|
|
307
|
-
|
|
308
|
-
_json_content: bytearray
|
|
309
|
-
_stream: AsyncIterator[bytes]
|
|
310
|
-
_position: int = 0
|
|
311
|
-
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
312
|
-
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
313
|
-
|
|
314
|
-
async def __anext__(self) -> None:
|
|
315
|
-
chunk = await self._stream.__anext__()
|
|
316
|
-
self._json_content.extend(chunk)
|
|
317
|
-
|
|
318
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
319
|
-
if final:
|
|
320
|
-
all_items = pydantic_core.from_json(self._json_content)
|
|
321
|
-
new_items = all_items[self._position :]
|
|
322
|
-
self._position = len(all_items)
|
|
323
|
-
new_responses = _gemini_streamed_response_ta.validate_python(new_items)
|
|
324
|
-
else:
|
|
325
|
-
all_items = pydantic_core.from_json(self._json_content, allow_partial=True)
|
|
326
|
-
new_items = all_items[self._position : -1]
|
|
327
|
-
self._position = len(all_items) - 1
|
|
328
|
-
new_responses = _gemini_streamed_response_ta.validate_python(
|
|
329
|
-
new_items, experimental_allow_partial='trailing-strings'
|
|
330
|
-
)
|
|
331
|
-
for r in new_responses:
|
|
332
|
-
self._usage += _metadata_as_usage(r)
|
|
333
|
-
parts = r['candidates'][0]['content']['parts']
|
|
334
|
-
if _all_text_parts(parts):
|
|
335
|
-
for part in parts:
|
|
336
|
-
yield part['text']
|
|
337
|
-
else:
|
|
338
|
-
raise UnexpectedModelBehavior(
|
|
339
|
-
'Streamed response with unexpected content, expected all parts to be text'
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
def usage(self) -> result.Usage:
|
|
343
|
-
return self._usage
|
|
344
|
-
|
|
345
|
-
def timestamp(self) -> datetime:
|
|
346
|
-
return self._timestamp
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
@dataclass
|
|
350
|
-
class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
351
|
-
"""Implementation of `StreamStructuredResponse` for the Gemini model."""
|
|
300
|
+
class GeminiStreamedResponse(StreamedResponse):
|
|
301
|
+
"""Implementation of `StreamedResponse` for the Gemini model."""
|
|
352
302
|
|
|
353
303
|
_content: bytearray
|
|
354
304
|
_stream: AsyncIterator[bytes]
|
|
355
305
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
356
|
-
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
357
|
-
|
|
358
|
-
async def __anext__(self) -> None:
|
|
359
|
-
chunk = await self._stream.__anext__()
|
|
360
|
-
self._content.extend(chunk)
|
|
361
|
-
|
|
362
|
-
def get(self, *, final: bool = False) -> ModelResponse:
|
|
363
|
-
"""Get the `ModelResponse` at this point.
|
|
364
306
|
|
|
365
|
-
|
|
366
|
-
|
|
307
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
308
|
+
async for gemini_response in self._get_gemini_responses():
|
|
309
|
+
candidate = gemini_response['candidates'][0]
|
|
310
|
+
gemini_part: _GeminiPartUnion
|
|
311
|
+
for gemini_part in candidate['content']['parts']:
|
|
312
|
+
if 'text' in gemini_part:
|
|
313
|
+
# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
|
|
314
|
+
# amongst the tool call deltas
|
|
315
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
|
|
316
|
+
|
|
317
|
+
elif 'function_call' in gemini_part:
|
|
318
|
+
# Here, we assume all function_call parts are complete and don't have deltas.
|
|
319
|
+
# We do this by assigning a unique randomly generated "vendor_part_id".
|
|
320
|
+
# We need to confirm whether this is actually true, but if it isn't, we can still handle it properly
|
|
321
|
+
# it would just be a bit more complicated. And we'd need to confirm the intended semantics.
|
|
322
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
323
|
+
vendor_part_id=uuid4(),
|
|
324
|
+
tool_name=gemini_part['function_call']['name'],
|
|
325
|
+
args=gemini_part['function_call']['args'],
|
|
326
|
+
tool_call_id=None,
|
|
327
|
+
)
|
|
328
|
+
if maybe_event is not None:
|
|
329
|
+
yield maybe_event
|
|
330
|
+
else:
|
|
331
|
+
assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}'
|
|
332
|
+
|
|
333
|
+
async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
|
|
334
|
+
# This method exists to ensure we only yield completed items, so we don't need to worry about
|
|
335
|
+
# partial gemini responses, which would make everything more complicated
|
|
336
|
+
|
|
337
|
+
gemini_responses: list[_GeminiResponse] = []
|
|
338
|
+
current_gemini_response_index = 0
|
|
339
|
+
# Right now, there are some circumstances where we will have information that could be yielded sooner than it is
|
|
340
|
+
# But changing that would make things a lot more complicated.
|
|
341
|
+
async for chunk in self._stream:
|
|
342
|
+
self._content.extend(chunk)
|
|
343
|
+
|
|
344
|
+
gemini_responses = _gemini_streamed_response_ta.validate_json(
|
|
345
|
+
self._content,
|
|
346
|
+
experimental_allow_partial='trailing-strings',
|
|
347
|
+
)
|
|
367
348
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
349
|
+
# The idea: yield only up to the latest response, which might still be partial.
|
|
350
|
+
# Note that if the latest response is complete, we could yield it immediately, but there's not a good
|
|
351
|
+
# allow_partial API to determine if the last item in the list is complete.
|
|
352
|
+
responses_to_yield = gemini_responses[:-1]
|
|
353
|
+
for r in responses_to_yield[current_gemini_response_index:]:
|
|
354
|
+
current_gemini_response_index += 1
|
|
355
|
+
self._usage += _metadata_as_usage(r)
|
|
356
|
+
yield r
|
|
357
|
+
|
|
358
|
+
# Now yield the final response, which should be complete
|
|
359
|
+
if gemini_responses:
|
|
360
|
+
r = gemini_responses[-1]
|
|
378
361
|
self._usage += _metadata_as_usage(r)
|
|
379
|
-
|
|
380
|
-
combined_parts.extend(candidate['content']['parts'])
|
|
381
|
-
return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
382
|
-
|
|
383
|
-
def usage(self) -> result.Usage:
|
|
384
|
-
return self._usage
|
|
362
|
+
yield r
|
|
385
363
|
|
|
386
364
|
def timestamp(self) -> datetime:
|
|
387
365
|
return self._timestamp
|
|
@@ -458,9 +436,14 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
|
|
|
458
436
|
items: list[ModelResponsePart] = []
|
|
459
437
|
for part in parts:
|
|
460
438
|
if 'text' in part:
|
|
461
|
-
items.append(TextPart(part['text']))
|
|
439
|
+
items.append(TextPart(content=part['text']))
|
|
462
440
|
elif 'function_call' in part:
|
|
463
|
-
items.append(
|
|
441
|
+
items.append(
|
|
442
|
+
ToolCallPart.from_raw_args(
|
|
443
|
+
tool_name=part['function_call']['name'],
|
|
444
|
+
args=part['function_call']['args'],
|
|
445
|
+
)
|
|
446
|
+
)
|
|
464
447
|
elif 'function_response' in part:
|
|
465
448
|
raise exceptions.UnexpectedModelBehavior(
|
|
466
449
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
@@ -575,35 +558,6 @@ class _GeminiResponse(TypedDict):
|
|
|
575
558
|
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
|
|
576
559
|
|
|
577
560
|
|
|
578
|
-
# TODO: Delete the next three functions once we've reworked streams to be more flexible
|
|
579
|
-
def _extract_response_parts(
|
|
580
|
-
response: _GeminiResponse,
|
|
581
|
-
) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
|
|
582
|
-
"""Extract the parts of the response from the Gemini API.
|
|
583
|
-
|
|
584
|
-
Returns Either a list of function calls (Either.left) or a list of text parts (Either.right).
|
|
585
|
-
"""
|
|
586
|
-
if len(response['candidates']) != 1:
|
|
587
|
-
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
588
|
-
parts = response['candidates'][0]['content']['parts']
|
|
589
|
-
if _all_function_call_parts(parts):
|
|
590
|
-
return _utils.Either(left=parts)
|
|
591
|
-
elif _all_text_parts(parts):
|
|
592
|
-
return _utils.Either(right=parts)
|
|
593
|
-
else:
|
|
594
|
-
raise exceptions.UnexpectedModelBehavior(
|
|
595
|
-
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {parts!r}'
|
|
596
|
-
)
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
def _all_function_call_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiFunctionCallPart]]:
|
|
600
|
-
return all('function_call' in part for part in parts)
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
def _all_text_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiTextPart]]:
|
|
604
|
-
return all('text' in part for part in parts)
|
|
605
|
-
|
|
606
|
-
|
|
607
561
|
class _GeminiCandidates(TypedDict):
|
|
608
562
|
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
|
|
609
563
|
|
|
@@ -630,14 +584,14 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
630
584
|
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
|
|
631
585
|
|
|
632
586
|
|
|
633
|
-
def _metadata_as_usage(response: _GeminiResponse) ->
|
|
587
|
+
def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
|
|
634
588
|
metadata = response.get('usage_metadata')
|
|
635
589
|
if metadata is None:
|
|
636
|
-
return
|
|
590
|
+
return usage.Usage()
|
|
637
591
|
details: dict[str, int] = {}
|
|
638
592
|
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
639
593
|
details['cached_content_token_count'] = cached_content_token_count
|
|
640
|
-
return
|
|
594
|
+
return usage.Usage(
|
|
641
595
|
request_tokens=metadata.get('prompt_token_count', 0),
|
|
642
596
|
response_tokens=metadata.get('candidates_token_count', 0),
|
|
643
597
|
total_tokens=metadata.get('total_token_count', 0),
|
|
@@ -693,7 +647,7 @@ class _GeminiJsonSchema:
|
|
|
693
647
|
|
|
694
648
|
def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
695
649
|
schema.pop('title', None)
|
|
696
|
-
|
|
650
|
+
schema.pop('default', None)
|
|
697
651
|
if ref := schema.pop('$ref', None):
|
|
698
652
|
# noinspection PyTypeChecker
|
|
699
653
|
key = re.sub(r'^#/\$defs/', '', ref)
|
|
@@ -708,11 +662,12 @@ class _GeminiJsonSchema:
|
|
|
708
662
|
if any_of := schema.get('anyOf'):
|
|
709
663
|
for item_schema in any_of:
|
|
710
664
|
self._simplify(item_schema, refs_stack)
|
|
711
|
-
if len(any_of) == 2 and {'type': 'null'} in any_of
|
|
665
|
+
if len(any_of) == 2 and {'type': 'null'} in any_of:
|
|
712
666
|
for item_schema in any_of:
|
|
713
667
|
if item_schema != {'type': 'null'}:
|
|
714
668
|
schema.clear()
|
|
715
669
|
schema.update(item_schema)
|
|
670
|
+
schema['nullable'] = True
|
|
716
671
|
return
|
|
717
672
|
|
|
718
673
|
type_ = schema.get('type')
|
|
@@ -721,6 +676,12 @@ class _GeminiJsonSchema:
|
|
|
721
676
|
self._object(schema, refs_stack)
|
|
722
677
|
elif type_ == 'array':
|
|
723
678
|
return self._array(schema, refs_stack)
|
|
679
|
+
elif type_ == 'string' and (fmt := schema.pop('format', None)):
|
|
680
|
+
description = schema.get('description')
|
|
681
|
+
if description:
|
|
682
|
+
schema['description'] = f'{description} (format: {fmt})'
|
|
683
|
+
else:
|
|
684
|
+
schema['description'] = f'Format: {fmt}'
|
|
724
685
|
|
|
725
686
|
def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
726
687
|
ad_props = schema.pop('additionalProperties', None)
|