pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +34 -16
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +333 -148
- pydantic_ai/messages.py +87 -48
- pydantic_ai/models/__init__.py +30 -6
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +59 -31
- pydantic_ai/models/gemini.py +150 -108
- pydantic_ai/models/groq.py +94 -74
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +102 -76
- pydantic_ai/models/test.py +62 -51
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +28 -18
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/models/function.py
CHANGED
|
@@ -4,7 +4,7 @@ import inspect
|
|
|
4
4
|
import re
|
|
5
5
|
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import dataclass, field, replace
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
10
|
from typing import Callable, Union, cast
|
|
@@ -13,7 +13,20 @@ import pydantic_core
|
|
|
13
13
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
14
14
|
|
|
15
15
|
from .. import _utils, result
|
|
16
|
-
from ..messages import
|
|
16
|
+
from ..messages import (
|
|
17
|
+
ArgsJson,
|
|
18
|
+
ModelMessage,
|
|
19
|
+
ModelRequest,
|
|
20
|
+
ModelResponse,
|
|
21
|
+
ModelResponsePart,
|
|
22
|
+
RetryPromptPart,
|
|
23
|
+
SystemPromptPart,
|
|
24
|
+
TextPart,
|
|
25
|
+
ToolCallPart,
|
|
26
|
+
ToolReturnPart,
|
|
27
|
+
UserPromptPart,
|
|
28
|
+
)
|
|
29
|
+
from ..settings import ModelSettings
|
|
17
30
|
from ..tools import ToolDefinition
|
|
18
31
|
from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
|
|
19
32
|
|
|
@@ -59,7 +72,7 @@ class FunctionModel(Model):
|
|
|
59
72
|
result_tools: list[ToolDefinition],
|
|
60
73
|
) -> AgentModel:
|
|
61
74
|
return FunctionAgentModel(
|
|
62
|
-
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
|
|
75
|
+
self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
|
|
63
76
|
)
|
|
64
77
|
|
|
65
78
|
def name(self) -> str:
|
|
@@ -88,6 +101,8 @@ class AgentInfo:
|
|
|
88
101
|
"""Whether a plain text result is allowed."""
|
|
89
102
|
result_tools: list[ToolDefinition]
|
|
90
103
|
"""The tools that can called as the final result of the run."""
|
|
104
|
+
model_settings: ModelSettings | None
|
|
105
|
+
"""The model settings passed to the run call."""
|
|
91
106
|
|
|
92
107
|
|
|
93
108
|
@dataclass
|
|
@@ -106,10 +121,10 @@ class DeltaToolCall:
|
|
|
106
121
|
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
107
122
|
"""A mapping of tool call IDs to incremental changes."""
|
|
108
123
|
|
|
109
|
-
FunctionDef: TypeAlias = Callable[[list[
|
|
124
|
+
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
|
|
110
125
|
"""A function used to generate a non-streamed response."""
|
|
111
126
|
|
|
112
|
-
StreamFunctionDef: TypeAlias = Callable[[list[
|
|
127
|
+
StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
|
|
113
128
|
"""A function used to generate a streamed response.
|
|
114
129
|
|
|
115
130
|
While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
|
|
@@ -127,18 +142,25 @@ class FunctionAgentModel(AgentModel):
|
|
|
127
142
|
stream_function: StreamFunctionDef | None
|
|
128
143
|
agent_info: AgentInfo
|
|
129
144
|
|
|
130
|
-
async def request(
|
|
145
|
+
async def request(
|
|
146
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
147
|
+
) -> tuple[ModelResponse, result.Cost]:
|
|
148
|
+
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
149
|
+
|
|
131
150
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
132
151
|
if inspect.iscoroutinefunction(self.function):
|
|
133
|
-
response = await self.function(messages,
|
|
152
|
+
response = await self.function(messages, agent_info)
|
|
134
153
|
else:
|
|
135
|
-
response_ = await _utils.run_in_executor(self.function, messages,
|
|
136
|
-
|
|
154
|
+
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
155
|
+
assert isinstance(response_, ModelResponse), response_
|
|
156
|
+
response = response_
|
|
137
157
|
# TODO is `messages` right here? Should it just be new messages?
|
|
138
158
|
return response, _estimate_cost(chain(messages, [response]))
|
|
139
159
|
|
|
140
160
|
@asynccontextmanager
|
|
141
|
-
async def request_stream(
|
|
161
|
+
async def request_stream(
|
|
162
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
163
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
142
164
|
assert (
|
|
143
165
|
self.stream_function is not None
|
|
144
166
|
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
@@ -206,13 +228,13 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
|
|
|
206
228
|
else:
|
|
207
229
|
self._delta_tool_calls[key] = new
|
|
208
230
|
|
|
209
|
-
def get(self, *, final: bool = False) ->
|
|
210
|
-
calls: list[
|
|
231
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
232
|
+
calls: list[ModelResponsePart] = []
|
|
211
233
|
for c in self._delta_tool_calls.values():
|
|
212
234
|
if c.name is not None and c.json_args is not None:
|
|
213
|
-
calls.append(
|
|
235
|
+
calls.append(ToolCallPart.from_json(c.name, c.json_args))
|
|
214
236
|
|
|
215
|
-
return
|
|
237
|
+
return ModelResponse(calls, timestamp=self._timestamp)
|
|
216
238
|
|
|
217
239
|
def cost(self) -> result.Cost:
|
|
218
240
|
return result.Cost()
|
|
@@ -221,32 +243,38 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
|
|
|
221
243
|
return self._timestamp
|
|
222
244
|
|
|
223
245
|
|
|
224
|
-
def _estimate_cost(messages: Iterable[
|
|
246
|
+
def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
|
|
225
247
|
"""Very rough guesstimate of the number of tokens associate with a series of messages.
|
|
226
248
|
|
|
227
249
|
This is designed to be used solely to give plausible numbers for testing!
|
|
228
250
|
"""
|
|
229
251
|
# there seem to be about 50 tokens of overhead for both Gemini and OpenAI calls, so add that here ¯\_(ツ)_/¯
|
|
230
|
-
|
|
231
252
|
request_tokens = 50
|
|
232
253
|
response_tokens = 0
|
|
233
254
|
for message in messages:
|
|
234
|
-
if message
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
elif message.role == 'model-structured-response':
|
|
243
|
-
for call in message.calls:
|
|
244
|
-
if isinstance(call.args, ArgsJson):
|
|
245
|
-
args_str = call.args.args_json
|
|
255
|
+
if isinstance(message, ModelRequest):
|
|
256
|
+
for part in message.parts:
|
|
257
|
+
if isinstance(part, (SystemPromptPart, UserPromptPart)):
|
|
258
|
+
request_tokens += _string_cost(part.content)
|
|
259
|
+
elif isinstance(part, ToolReturnPart):
|
|
260
|
+
request_tokens += _string_cost(part.model_response_str())
|
|
261
|
+
elif isinstance(part, RetryPromptPart):
|
|
262
|
+
request_tokens += _string_cost(part.model_response())
|
|
246
263
|
else:
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
264
|
+
assert_never(part)
|
|
265
|
+
elif isinstance(message, ModelResponse):
|
|
266
|
+
for part in message.parts:
|
|
267
|
+
if isinstance(part, TextPart):
|
|
268
|
+
response_tokens += _string_cost(part.content)
|
|
269
|
+
elif isinstance(part, ToolCallPart):
|
|
270
|
+
call = part
|
|
271
|
+
if isinstance(call.args, ArgsJson):
|
|
272
|
+
args_str = call.args.args_json
|
|
273
|
+
else:
|
|
274
|
+
args_str = pydantic_core.to_json(call.args.args_dict).decode()
|
|
275
|
+
response_tokens += 1 + _string_cost(args_str)
|
|
276
|
+
else:
|
|
277
|
+
assert_never(part)
|
|
250
278
|
else:
|
|
251
279
|
assert_never(message)
|
|
252
280
|
return result.Cost(
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -2,29 +2,33 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from typing import Annotated, Any, Literal, Protocol, Union
|
|
11
11
|
|
|
12
|
+
import pydantic
|
|
12
13
|
import pydantic_core
|
|
13
|
-
from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
14
|
-
from pydantic import Discriminator, Field, Tag
|
|
14
|
+
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
15
15
|
from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
|
|
16
16
|
|
|
17
|
-
from .. import UnexpectedModelBehavior,
|
|
17
|
+
from .. import UnexpectedModelBehavior, _utils, exceptions, result
|
|
18
18
|
from ..messages import (
|
|
19
19
|
ArgsDict,
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
20
|
+
ModelMessage,
|
|
21
|
+
ModelRequest,
|
|
22
|
+
ModelResponse,
|
|
23
|
+
ModelResponsePart,
|
|
24
|
+
RetryPromptPart,
|
|
25
|
+
SystemPromptPart,
|
|
26
|
+
TextPart,
|
|
27
|
+
ToolCallPart,
|
|
28
|
+
ToolReturnPart,
|
|
29
|
+
UserPromptPart,
|
|
27
30
|
)
|
|
31
|
+
from ..settings import ModelSettings
|
|
28
32
|
from ..tools import ToolDefinition
|
|
29
33
|
from . import (
|
|
30
34
|
AgentModel,
|
|
@@ -37,7 +41,9 @@ from . import (
|
|
|
37
41
|
get_user_agent,
|
|
38
42
|
)
|
|
39
43
|
|
|
40
|
-
GeminiModelName = Literal[
|
|
44
|
+
GeminiModelName = Literal[
|
|
45
|
+
'gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro', 'gemini-2.0-flash-exp'
|
|
46
|
+
]
|
|
41
47
|
"""Named Gemini models.
|
|
42
48
|
|
|
43
49
|
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
|
|
@@ -164,26 +170,25 @@ class GeminiAgentModel(AgentModel):
|
|
|
164
170
|
self.tool_config = tool_config
|
|
165
171
|
self.url = url
|
|
166
172
|
|
|
167
|
-
async def request(
|
|
168
|
-
|
|
173
|
+
async def request(
|
|
174
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
175
|
+
) -> tuple[ModelResponse, result.Cost]:
|
|
176
|
+
async with self._make_request(messages, False, model_settings) as http_response:
|
|
169
177
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
170
178
|
return self._process_response(response), _metadata_as_cost(response)
|
|
171
179
|
|
|
172
180
|
@asynccontextmanager
|
|
173
|
-
async def request_stream(
|
|
174
|
-
|
|
181
|
+
async def request_stream(
|
|
182
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
183
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
184
|
+
async with self._make_request(messages, True, model_settings) as http_response:
|
|
175
185
|
yield await self._process_streamed_response(http_response)
|
|
176
186
|
|
|
177
187
|
@asynccontextmanager
|
|
178
|
-
async def _make_request(
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
either_content = self._message_to_gemini(m)
|
|
183
|
-
if left := either_content.left:
|
|
184
|
-
sys_prompt_parts.append(left.value)
|
|
185
|
-
else:
|
|
186
|
-
contents.append(either_content.right)
|
|
188
|
+
async def _make_request(
|
|
189
|
+
self, messages: list[ModelMessage], streamed: bool, model_settings: ModelSettings | None
|
|
190
|
+
) -> AsyncIterator[HTTPResponse]:
|
|
191
|
+
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
187
192
|
|
|
188
193
|
request_data = _GeminiRequest(contents=contents)
|
|
189
194
|
if sys_prompt_parts:
|
|
@@ -193,6 +198,17 @@ class GeminiAgentModel(AgentModel):
|
|
|
193
198
|
if self.tool_config is not None:
|
|
194
199
|
request_data['tool_config'] = self.tool_config
|
|
195
200
|
|
|
201
|
+
generation_config: _GeminiGenerationConfig = {}
|
|
202
|
+
if model_settings:
|
|
203
|
+
if (max_tokens := model_settings.get('max_tokens')) is not None:
|
|
204
|
+
generation_config['max_output_tokens'] = max_tokens
|
|
205
|
+
if (temperature := model_settings.get('temperature')) is not None:
|
|
206
|
+
generation_config['temperature'] = temperature
|
|
207
|
+
if (top_p := model_settings.get('top_p')) is not None:
|
|
208
|
+
generation_config['top_p'] = top_p
|
|
209
|
+
if generation_config:
|
|
210
|
+
request_data['generation_config'] = generation_config
|
|
211
|
+
|
|
196
212
|
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
197
213
|
|
|
198
214
|
headers = {
|
|
@@ -203,19 +219,24 @@ class GeminiAgentModel(AgentModel):
|
|
|
203
219
|
|
|
204
220
|
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
|
|
205
221
|
|
|
206
|
-
async with self.http_client.stream(
|
|
222
|
+
async with self.http_client.stream(
|
|
223
|
+
'POST',
|
|
224
|
+
url,
|
|
225
|
+
content=request_json,
|
|
226
|
+
headers=headers,
|
|
227
|
+
timeout=(model_settings or {}).get('timeout', USE_CLIENT_DEFAULT),
|
|
228
|
+
) as r:
|
|
207
229
|
if r.status_code != 200:
|
|
208
230
|
await r.aread()
|
|
209
231
|
raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
|
|
210
232
|
yield r
|
|
211
233
|
|
|
212
234
|
@staticmethod
|
|
213
|
-
def _process_response(response: _GeminiResponse) ->
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
return ModelTextResponse(content=''.join(part['text'] for part in either.right))
|
|
235
|
+
def _process_response(response: _GeminiResponse) -> ModelResponse:
|
|
236
|
+
if len(response['candidates']) != 1:
|
|
237
|
+
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
238
|
+
parts = response['candidates'][0]['content']['parts']
|
|
239
|
+
return _process_response_from_parts(parts)
|
|
219
240
|
|
|
220
241
|
@staticmethod
|
|
221
242
|
async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
|
|
@@ -239,34 +260,37 @@ class GeminiAgentModel(AgentModel):
|
|
|
239
260
|
if start_response is None:
|
|
240
261
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
241
262
|
|
|
263
|
+
# TODO: Update this once we rework stream responses to be more flexible
|
|
242
264
|
if _extract_response_parts(start_response).is_left():
|
|
243
265
|
return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
|
|
244
266
|
else:
|
|
245
267
|
return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
|
|
246
268
|
|
|
247
|
-
@
|
|
248
|
-
def
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
269
|
+
@classmethod
|
|
270
|
+
def _message_to_gemini_content(
|
|
271
|
+
cls, messages: list[ModelMessage]
|
|
272
|
+
) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
|
|
273
|
+
sys_prompt_parts: list[_GeminiTextPart] = []
|
|
274
|
+
contents: list[_GeminiContent] = []
|
|
275
|
+
for m in messages:
|
|
276
|
+
if isinstance(m, ModelRequest):
|
|
277
|
+
for part in m.parts:
|
|
278
|
+
if isinstance(part, SystemPromptPart):
|
|
279
|
+
sys_prompt_parts.append(_GeminiTextPart(text=part.content))
|
|
280
|
+
elif isinstance(part, UserPromptPart):
|
|
281
|
+
contents.append(_content_user_prompt(part))
|
|
282
|
+
elif isinstance(part, ToolReturnPart):
|
|
283
|
+
contents.append(_content_tool_return(part))
|
|
284
|
+
elif isinstance(part, RetryPromptPart):
|
|
285
|
+
contents.append(_content_retry_prompt(part))
|
|
286
|
+
else:
|
|
287
|
+
assert_never(part)
|
|
288
|
+
elif isinstance(m, ModelResponse):
|
|
289
|
+
contents.append(_content_model_response(m))
|
|
290
|
+
else:
|
|
291
|
+
assert_never(m)
|
|
292
|
+
|
|
293
|
+
return sys_prompt_parts, contents
|
|
270
294
|
|
|
271
295
|
|
|
272
296
|
@dataclass
|
|
@@ -327,8 +351,8 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
327
351
|
chunk = await self._stream.__anext__()
|
|
328
352
|
self._content.extend(chunk)
|
|
329
353
|
|
|
330
|
-
def get(self, *, final: bool = False) ->
|
|
331
|
-
"""Get the `
|
|
354
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
355
|
+
"""Get the `ModelResponse` at this point.
|
|
332
356
|
|
|
333
357
|
NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
|
|
334
358
|
reply with a single response, when returning a structured data.
|
|
@@ -340,20 +364,13 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
340
364
|
self._content,
|
|
341
365
|
experimental_allow_partial='off' if final else 'trailing-strings',
|
|
342
366
|
)
|
|
343
|
-
combined_parts: list[
|
|
367
|
+
combined_parts: list[_GeminiPartUnion] = []
|
|
344
368
|
self._cost = result.Cost()
|
|
345
369
|
for r in responses:
|
|
346
370
|
self._cost += _metadata_as_cost(r)
|
|
347
371
|
candidate = r['candidates'][0]
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
combined_parts.extend(parts)
|
|
351
|
-
elif not candidate.get('finish_reason'):
|
|
352
|
-
# you can get an empty text part along with the finish_reason, so we ignore that case
|
|
353
|
-
raise UnexpectedModelBehavior(
|
|
354
|
-
'Streamed response with unexpected content, expected all parts to be function calls'
|
|
355
|
-
)
|
|
356
|
-
return _structured_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
372
|
+
combined_parts.extend(candidate['content']['parts'])
|
|
373
|
+
return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
357
374
|
|
|
358
375
|
def cost(self) -> result.Cost:
|
|
359
376
|
return self._cost
|
|
@@ -367,6 +384,7 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
367
384
|
# TypeAdapters take care of validation and serialization
|
|
368
385
|
|
|
369
386
|
|
|
387
|
+
@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
|
|
370
388
|
class _GeminiRequest(TypedDict):
|
|
371
389
|
"""Schema for an API request to the Gemini API.
|
|
372
390
|
|
|
@@ -382,32 +400,37 @@ class _GeminiRequest(TypedDict):
|
|
|
382
400
|
Developer generated system instructions, see
|
|
383
401
|
<https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
|
|
384
402
|
"""
|
|
403
|
+
generation_config: NotRequired[_GeminiGenerationConfig]
|
|
385
404
|
|
|
386
405
|
|
|
387
|
-
class
|
|
388
|
-
|
|
389
|
-
parts: list[_GeminiPartUnion]
|
|
406
|
+
class _GeminiGenerationConfig(TypedDict, total=False):
|
|
407
|
+
"""Schema for an API request to the Gemini API.
|
|
390
408
|
|
|
409
|
+
Note there are many additional fields available that have not been added yet.
|
|
391
410
|
|
|
392
|
-
|
|
393
|
-
|
|
411
|
+
See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
max_output_tokens: int
|
|
415
|
+
temperature: float
|
|
416
|
+
top_p: float
|
|
394
417
|
|
|
395
418
|
|
|
396
|
-
|
|
397
|
-
|
|
419
|
+
class _GeminiContent(TypedDict):
|
|
420
|
+
role: Literal['user', 'model']
|
|
421
|
+
parts: list[_GeminiPartUnion]
|
|
398
422
|
|
|
399
423
|
|
|
400
|
-
def
|
|
401
|
-
|
|
402
|
-
return _GeminiContent(role='model', parts=parts)
|
|
424
|
+
def _content_user_prompt(m: UserPromptPart) -> _GeminiContent:
|
|
425
|
+
return _GeminiContent(role='user', parts=[_GeminiTextPart(text=m.content)])
|
|
403
426
|
|
|
404
427
|
|
|
405
|
-
def
|
|
428
|
+
def _content_tool_return(m: ToolReturnPart) -> _GeminiContent:
|
|
406
429
|
f_response = _response_part_from_response(m.tool_name, m.model_response_object())
|
|
407
430
|
return _GeminiContent(role='user', parts=[f_response])
|
|
408
431
|
|
|
409
432
|
|
|
410
|
-
def
|
|
433
|
+
def _content_retry_prompt(m: RetryPromptPart) -> _GeminiContent:
|
|
411
434
|
if m.tool_name is None:
|
|
412
435
|
part = _GeminiTextPart(text=m.model_response())
|
|
413
436
|
else:
|
|
@@ -416,26 +439,43 @@ def _content_function_retry(m: RetryPrompt) -> _GeminiContent:
|
|
|
416
439
|
return _GeminiContent(role='user', parts=[part])
|
|
417
440
|
|
|
418
441
|
|
|
442
|
+
def _content_model_response(m: ModelResponse) -> _GeminiContent:
|
|
443
|
+
parts: list[_GeminiPartUnion] = []
|
|
444
|
+
for item in m.parts:
|
|
445
|
+
if isinstance(item, ToolCallPart):
|
|
446
|
+
parts.append(_function_call_part_from_call(item))
|
|
447
|
+
elif isinstance(item, TextPart):
|
|
448
|
+
parts.append(_GeminiTextPart(text=item.content))
|
|
449
|
+
else:
|
|
450
|
+
assert_never(item)
|
|
451
|
+
return _GeminiContent(role='model', parts=parts)
|
|
452
|
+
|
|
453
|
+
|
|
419
454
|
class _GeminiTextPart(TypedDict):
|
|
420
455
|
text: str
|
|
421
456
|
|
|
422
457
|
|
|
423
458
|
class _GeminiFunctionCallPart(TypedDict):
|
|
424
|
-
function_call: Annotated[_GeminiFunctionCall, Field(alias='functionCall')]
|
|
459
|
+
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
|
|
425
460
|
|
|
426
461
|
|
|
427
|
-
def _function_call_part_from_call(tool:
|
|
462
|
+
def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
|
|
428
463
|
assert isinstance(tool.args, ArgsDict), f'Expected ArgsObject, got {tool.args}'
|
|
429
464
|
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_dict))
|
|
430
465
|
|
|
431
466
|
|
|
432
|
-
def
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
467
|
+
def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
|
|
468
|
+
items: list[ModelResponsePart] = []
|
|
469
|
+
for part in parts:
|
|
470
|
+
if 'text' in part:
|
|
471
|
+
items.append(TextPart(part['text']))
|
|
472
|
+
elif 'function_call' in part:
|
|
473
|
+
items.append(ToolCallPart.from_dict(part['function_call']['name'], part['function_call']['args']))
|
|
474
|
+
elif 'function_response' in part:
|
|
475
|
+
raise exceptions.UnexpectedModelBehavior(
|
|
476
|
+
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
477
|
+
)
|
|
478
|
+
return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
|
|
439
479
|
|
|
440
480
|
|
|
441
481
|
class _GeminiFunctionCall(TypedDict):
|
|
@@ -446,7 +486,7 @@ class _GeminiFunctionCall(TypedDict):
|
|
|
446
486
|
|
|
447
487
|
|
|
448
488
|
class _GeminiFunctionResponsePart(TypedDict):
|
|
449
|
-
function_response: Annotated[_GeminiFunctionResponse, Field(alias='functionResponse')]
|
|
489
|
+
function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')]
|
|
450
490
|
|
|
451
491
|
|
|
452
492
|
def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
|
|
@@ -476,11 +516,11 @@ def _part_discriminator(v: Any) -> str:
|
|
|
476
516
|
# TODO discriminator
|
|
477
517
|
_GeminiPartUnion = Annotated[
|
|
478
518
|
Union[
|
|
479
|
-
Annotated[_GeminiTextPart, Tag('text')],
|
|
480
|
-
Annotated[_GeminiFunctionCallPart, Tag('function_call')],
|
|
481
|
-
Annotated[_GeminiFunctionResponsePart, Tag('function_response')],
|
|
519
|
+
Annotated[_GeminiTextPart, pydantic.Tag('text')],
|
|
520
|
+
Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
|
|
521
|
+
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
|
|
482
522
|
],
|
|
483
|
-
Discriminator(_part_discriminator),
|
|
523
|
+
pydantic.Discriminator(_part_discriminator),
|
|
484
524
|
]
|
|
485
525
|
|
|
486
526
|
|
|
@@ -490,7 +530,7 @@ class _GeminiTextContent(TypedDict):
|
|
|
490
530
|
|
|
491
531
|
|
|
492
532
|
class _GeminiTools(TypedDict):
|
|
493
|
-
function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]]
|
|
533
|
+
function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]]
|
|
494
534
|
|
|
495
535
|
|
|
496
536
|
class _GeminiFunction(TypedDict):
|
|
@@ -531,6 +571,7 @@ class _GeminiFunctionCallingConfig(TypedDict):
|
|
|
531
571
|
allowed_function_names: list[str]
|
|
532
572
|
|
|
533
573
|
|
|
574
|
+
@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
|
|
534
575
|
class _GeminiResponse(TypedDict):
|
|
535
576
|
"""Schema for the response from the Gemini API.
|
|
536
577
|
|
|
@@ -540,10 +581,11 @@ class _GeminiResponse(TypedDict):
|
|
|
540
581
|
|
|
541
582
|
candidates: list[_GeminiCandidates]
|
|
542
583
|
# usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
|
|
543
|
-
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]]
|
|
544
|
-
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, Field(alias='promptFeedback')]]
|
|
584
|
+
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
|
|
585
|
+
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
|
|
545
586
|
|
|
546
587
|
|
|
588
|
+
# TODO: Delete the next three functions once we've reworked streams to be more flexible
|
|
547
589
|
def _extract_response_parts(
|
|
548
590
|
response: _GeminiResponse,
|
|
549
591
|
) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
|
|
@@ -576,14 +618,14 @@ class _GeminiCandidates(TypedDict):
|
|
|
576
618
|
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
|
|
577
619
|
|
|
578
620
|
content: _GeminiContent
|
|
579
|
-
finish_reason: NotRequired[Annotated[Literal['STOP'], Field(alias='finishReason')]]
|
|
621
|
+
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
|
|
580
622
|
"""
|
|
581
623
|
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
|
|
582
624
|
but let's wait until we see them and know what they mean to add them here.
|
|
583
625
|
"""
|
|
584
|
-
avg_log_probs: NotRequired[Annotated[float, Field(alias='avgLogProbs')]]
|
|
626
|
+
avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]]
|
|
585
627
|
index: NotRequired[int]
|
|
586
|
-
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]]
|
|
628
|
+
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
|
|
587
629
|
|
|
588
630
|
|
|
589
631
|
class _GeminiUsageMetaData(TypedDict, total=False):
|
|
@@ -592,10 +634,10 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
592
634
|
The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
|
|
593
635
|
"""
|
|
594
636
|
|
|
595
|
-
prompt_token_count: Annotated[int, Field(alias='promptTokenCount')]
|
|
596
|
-
candidates_token_count: NotRequired[Annotated[int, Field(alias='candidatesTokenCount')]]
|
|
597
|
-
total_token_count: Annotated[int, Field(alias='totalTokenCount')]
|
|
598
|
-
cached_content_token_count: NotRequired[Annotated[int, Field(alias='cachedContentTokenCount')]]
|
|
637
|
+
prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')]
|
|
638
|
+
candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
|
|
639
|
+
total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
|
|
640
|
+
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
|
|
599
641
|
|
|
600
642
|
|
|
601
643
|
def _metadata_as_cost(response: _GeminiResponse) -> result.Cost:
|
|
@@ -629,15 +671,15 @@ class _GeminiSafetyRating(TypedDict):
|
|
|
629
671
|
class _GeminiPromptFeedback(TypedDict):
|
|
630
672
|
"""See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
|
|
631
673
|
|
|
632
|
-
block_reason: Annotated[str, Field(alias='blockReason')]
|
|
633
|
-
safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]
|
|
674
|
+
block_reason: Annotated[str, pydantic.Field(alias='blockReason')]
|
|
675
|
+
safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]
|
|
634
676
|
|
|
635
677
|
|
|
636
|
-
_gemini_request_ta =
|
|
637
|
-
_gemini_response_ta =
|
|
678
|
+
_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest)
|
|
679
|
+
_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
|
|
638
680
|
|
|
639
681
|
# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
|
|
640
|
-
_gemini_streamed_response_ta =
|
|
682
|
+
_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
|
|
641
683
|
|
|
642
684
|
|
|
643
685
|
class _GeminiJsonSchema:
|