pydantic-ai-slim 0.0.6a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +8 -0
- pydantic_ai/_griffe.py +128 -0
- pydantic_ai/_pydantic.py +216 -0
- pydantic_ai/_result.py +258 -0
- pydantic_ai/_retriever.py +114 -0
- pydantic_ai/_system_prompt.py +33 -0
- pydantic_ai/_utils.py +247 -0
- pydantic_ai/agent.py +795 -0
- pydantic_ai/dependencies.py +83 -0
- pydantic_ai/exceptions.py +56 -0
- pydantic_ai/messages.py +205 -0
- pydantic_ai/models/__init__.py +300 -0
- pydantic_ai/models/function.py +268 -0
- pydantic_ai/models/gemini.py +720 -0
- pydantic_ai/models/groq.py +400 -0
- pydantic_ai/models/openai.py +379 -0
- pydantic_ai/models/test.py +389 -0
- pydantic_ai/models/vertexai.py +306 -0
- pydantic_ai/py.typed +0 -0
- pydantic_ai/result.py +314 -0
- pydantic_ai_slim-0.0.6a1.dist-info/METADATA +49 -0
- pydantic_ai_slim-0.0.6a1.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.6a1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,720 @@
|
|
|
1
|
+
"""Custom interface to the `generativelanguage.googleapis.com` API using [HTTPX] and [Pydantic].
|
|
2
|
+
|
|
3
|
+
The Google SDK for interacting with the `generativelanguage.googleapis.com` API
|
|
4
|
+
[`google-generativeai`](https://ai.google.dev/gemini-api/docs/quickstart?lang=python) reads like it was written by a
|
|
5
|
+
Java developer who thought they knew everything about OOP, spent 30 minutes trying to learn Python,
|
|
6
|
+
gave up and decided to build the library to prove how horrible Python is. It also doesn't use httpx for HTTP requests,
|
|
7
|
+
and tries to implement tool calling itself, but doesn't use Pydantic or equivalent for validation.
|
|
8
|
+
|
|
9
|
+
We could also use the Google Vertex SDK,
|
|
10
|
+
[`google-cloud-aiplatform`](https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk)
|
|
11
|
+
which uses the `*-aiplatform.googleapis.com` API, but that requires a service account for authentication
|
|
12
|
+
which is a faff to set up and manage.
|
|
13
|
+
|
|
14
|
+
Both APIs claim compatibility with OpenAI's API, but that breaks down with even the simplest of requests,
|
|
15
|
+
hence this custom interface.
|
|
16
|
+
|
|
17
|
+
Despite these limitations, the Gemini model is actually quite powerful and very fast.
|
|
18
|
+
|
|
19
|
+
[HTTPX]: https://www.python-httpx.org/
|
|
20
|
+
[Pydantic]: https://docs.pydantic.dev/latest/
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations as _annotations
|
|
24
|
+
|
|
25
|
+
import os
|
|
26
|
+
import re
|
|
27
|
+
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
|
|
28
|
+
from contextlib import asynccontextmanager
|
|
29
|
+
from copy import deepcopy
|
|
30
|
+
from dataclasses import dataclass, field
|
|
31
|
+
from datetime import datetime
|
|
32
|
+
from typing import Annotated, Any, Literal, Protocol, Union
|
|
33
|
+
|
|
34
|
+
import pydantic_core
|
|
35
|
+
from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
36
|
+
from pydantic import Discriminator, Field, Tag
|
|
37
|
+
from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
|
|
38
|
+
|
|
39
|
+
from .. import UnexpectedModelBehavior, _pydantic, _utils, exceptions, result
|
|
40
|
+
from ..messages import (
|
|
41
|
+
ArgsObject,
|
|
42
|
+
Message,
|
|
43
|
+
ModelAnyResponse,
|
|
44
|
+
ModelStructuredResponse,
|
|
45
|
+
ModelTextResponse,
|
|
46
|
+
RetryPrompt,
|
|
47
|
+
ToolCall,
|
|
48
|
+
ToolReturn,
|
|
49
|
+
)
|
|
50
|
+
from . import (
|
|
51
|
+
AbstractToolDefinition,
|
|
52
|
+
AgentModel,
|
|
53
|
+
EitherStreamedResponse,
|
|
54
|
+
Model,
|
|
55
|
+
StreamStructuredResponse,
|
|
56
|
+
StreamTextResponse,
|
|
57
|
+
cached_async_http_client,
|
|
58
|
+
check_allow_model_requests,
|
|
59
|
+
get_user_agent,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
GeminiModelName = Literal['gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro']
|
|
63
|
+
"""Named Gemini models.
|
|
64
|
+
|
|
65
|
+
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass(init=False)
|
|
70
|
+
class GeminiModel(Model):
|
|
71
|
+
"""A model that uses Gemini via `generativelanguage.googleapis.com` API.
|
|
72
|
+
|
|
73
|
+
This is implemented from scratch rather than using a dedicated SDK, good API documentation is
|
|
74
|
+
available [here](https://ai.google.dev/api).
|
|
75
|
+
|
|
76
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
model_name: GeminiModelName
|
|
80
|
+
auth: AuthProtocol
|
|
81
|
+
http_client: AsyncHTTPClient
|
|
82
|
+
url: str
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
model_name: GeminiModelName,
|
|
87
|
+
*,
|
|
88
|
+
api_key: str | None = None,
|
|
89
|
+
http_client: AsyncHTTPClient | None = None,
|
|
90
|
+
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
91
|
+
):
|
|
92
|
+
"""Initialize a Gemini model.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
model_name: The name of the model to use.
|
|
96
|
+
api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
|
|
97
|
+
will be used if available.
|
|
98
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
99
|
+
url_template: The URL template to use for making requests, you shouldn't need to change this,
|
|
100
|
+
docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
|
|
101
|
+
`model` is substituted with the model name, and `function` is added to the end of the URL.
|
|
102
|
+
"""
|
|
103
|
+
self.model_name = model_name
|
|
104
|
+
if api_key is None:
|
|
105
|
+
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
106
|
+
api_key = env_api_key
|
|
107
|
+
else:
|
|
108
|
+
raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
109
|
+
self.auth = ApiKeyAuth(api_key)
|
|
110
|
+
self.http_client = http_client or cached_async_http_client()
|
|
111
|
+
self.url = url_template.format(model=model_name)
|
|
112
|
+
|
|
113
|
+
async def agent_model(
|
|
114
|
+
self,
|
|
115
|
+
retrievers: Mapping[str, AbstractToolDefinition],
|
|
116
|
+
allow_text_result: bool,
|
|
117
|
+
result_tools: Sequence[AbstractToolDefinition] | None,
|
|
118
|
+
) -> GeminiAgentModel:
|
|
119
|
+
return GeminiAgentModel(
|
|
120
|
+
http_client=self.http_client,
|
|
121
|
+
model_name=self.model_name,
|
|
122
|
+
auth=self.auth,
|
|
123
|
+
url=self.url,
|
|
124
|
+
retrievers=retrievers,
|
|
125
|
+
allow_text_result=allow_text_result,
|
|
126
|
+
result_tools=result_tools,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def name(self) -> str:
|
|
130
|
+
return self.model_name
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class AuthProtocol(Protocol):
|
|
134
|
+
async def headers(self) -> dict[str, str]: ...
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclass
|
|
138
|
+
class ApiKeyAuth:
|
|
139
|
+
api_key: str
|
|
140
|
+
|
|
141
|
+
async def headers(self) -> dict[str, str]:
|
|
142
|
+
# https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
|
|
143
|
+
return {'X-Goog-Api-Key': self.api_key}
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass(init=False)
|
|
147
|
+
class GeminiAgentModel(AgentModel):
|
|
148
|
+
"""Implementation of `AgentModel` for Gemini models."""
|
|
149
|
+
|
|
150
|
+
http_client: AsyncHTTPClient
|
|
151
|
+
model_name: GeminiModelName
|
|
152
|
+
auth: AuthProtocol
|
|
153
|
+
tools: _GeminiTools | None
|
|
154
|
+
tool_config: _GeminiToolConfig | None
|
|
155
|
+
url: str
|
|
156
|
+
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
http_client: AsyncHTTPClient,
|
|
160
|
+
model_name: GeminiModelName,
|
|
161
|
+
auth: AuthProtocol,
|
|
162
|
+
url: str,
|
|
163
|
+
retrievers: Mapping[str, AbstractToolDefinition],
|
|
164
|
+
allow_text_result: bool,
|
|
165
|
+
result_tools: Sequence[AbstractToolDefinition] | None,
|
|
166
|
+
):
|
|
167
|
+
check_allow_model_requests()
|
|
168
|
+
tools = [_function_from_abstract_tool(t) for t in retrievers.values()]
|
|
169
|
+
if result_tools is not None:
|
|
170
|
+
tools += [_function_from_abstract_tool(t) for t in result_tools]
|
|
171
|
+
|
|
172
|
+
if allow_text_result:
|
|
173
|
+
tool_config = None
|
|
174
|
+
else:
|
|
175
|
+
tool_config = _tool_config([t['name'] for t in tools])
|
|
176
|
+
|
|
177
|
+
self.http_client = http_client
|
|
178
|
+
self.model_name = model_name
|
|
179
|
+
self.auth = auth
|
|
180
|
+
self.tools = _GeminiTools(function_declarations=tools) if tools else None
|
|
181
|
+
self.tool_config = tool_config
|
|
182
|
+
self.url = url
|
|
183
|
+
|
|
184
|
+
async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
|
|
185
|
+
async with self._make_request(messages, False) as http_response:
|
|
186
|
+
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
187
|
+
return self._process_response(response), _metadata_as_cost(response)
|
|
188
|
+
|
|
189
|
+
@asynccontextmanager
|
|
190
|
+
async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
|
|
191
|
+
async with self._make_request(messages, True) as http_response:
|
|
192
|
+
yield await self._process_streamed_response(http_response)
|
|
193
|
+
|
|
194
|
+
@asynccontextmanager
|
|
195
|
+
async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncIterator[HTTPResponse]:
|
|
196
|
+
contents: list[_GeminiContent] = []
|
|
197
|
+
sys_prompt_parts: list[_GeminiTextPart] = []
|
|
198
|
+
for m in messages:
|
|
199
|
+
either_content = self._message_to_gemini(m)
|
|
200
|
+
if left := either_content.left:
|
|
201
|
+
sys_prompt_parts.append(left.value)
|
|
202
|
+
else:
|
|
203
|
+
contents.append(either_content.right)
|
|
204
|
+
|
|
205
|
+
request_data = _GeminiRequest(contents=contents)
|
|
206
|
+
if sys_prompt_parts:
|
|
207
|
+
request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
|
|
208
|
+
if self.tools is not None:
|
|
209
|
+
request_data['tools'] = self.tools
|
|
210
|
+
if self.tool_config is not None:
|
|
211
|
+
request_data['tool_config'] = self.tool_config
|
|
212
|
+
|
|
213
|
+
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
214
|
+
|
|
215
|
+
headers = {
|
|
216
|
+
'Content-Type': 'application/json',
|
|
217
|
+
'User-Agent': get_user_agent(),
|
|
218
|
+
**await self.auth.headers(),
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
|
|
222
|
+
|
|
223
|
+
async with self.http_client.stream('POST', url, content=request_json, headers=headers) as r:
|
|
224
|
+
if r.status_code != 200:
|
|
225
|
+
await r.aread()
|
|
226
|
+
raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
|
|
227
|
+
yield r
|
|
228
|
+
|
|
229
|
+
@staticmethod
|
|
230
|
+
def _process_response(response: _GeminiResponse) -> ModelAnyResponse:
|
|
231
|
+
either = _extract_response_parts(response)
|
|
232
|
+
if left := either.left:
|
|
233
|
+
return _structured_response_from_parts(left.value)
|
|
234
|
+
else:
|
|
235
|
+
return ModelTextResponse(content=''.join(part['text'] for part in either.right))
|
|
236
|
+
|
|
237
|
+
@staticmethod
|
|
238
|
+
async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
|
|
239
|
+
"""Process a streamed response, and prepare a streaming response to return."""
|
|
240
|
+
aiter_bytes = http_response.aiter_bytes()
|
|
241
|
+
start_response: _GeminiResponse | None = None
|
|
242
|
+
content = bytearray()
|
|
243
|
+
|
|
244
|
+
async for chunk in aiter_bytes:
|
|
245
|
+
content.extend(chunk)
|
|
246
|
+
responses = _gemini_streamed_response_ta.validate_json(
|
|
247
|
+
content,
|
|
248
|
+
experimental_allow_partial='trailing-strings',
|
|
249
|
+
)
|
|
250
|
+
if responses:
|
|
251
|
+
last = responses[-1]
|
|
252
|
+
if last['candidates'] and last['candidates'][0]['content']['parts']:
|
|
253
|
+
start_response = last
|
|
254
|
+
break
|
|
255
|
+
|
|
256
|
+
if start_response is None:
|
|
257
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
258
|
+
|
|
259
|
+
if _extract_response_parts(start_response).is_left():
|
|
260
|
+
return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
|
|
261
|
+
else:
|
|
262
|
+
return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
|
|
263
|
+
|
|
264
|
+
@staticmethod
|
|
265
|
+
def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]:
|
|
266
|
+
"""Convert a message to a _GeminiTextPart for "system_instructions" or _GeminiContent for "contents"."""
|
|
267
|
+
if m.role == 'system':
|
|
268
|
+
# SystemPrompt ->
|
|
269
|
+
return _utils.Either(left=_GeminiTextPart(text=m.content))
|
|
270
|
+
elif m.role == 'user':
|
|
271
|
+
# UserPrompt ->
|
|
272
|
+
return _utils.Either(right=_content_user_text(m.content))
|
|
273
|
+
elif m.role == 'tool-return':
|
|
274
|
+
# ToolReturn ->
|
|
275
|
+
return _utils.Either(right=_content_function_return(m))
|
|
276
|
+
elif m.role == 'retry-prompt':
|
|
277
|
+
# RetryPrompt ->
|
|
278
|
+
return _utils.Either(right=_content_function_retry(m))
|
|
279
|
+
elif m.role == 'model-text-response':
|
|
280
|
+
# ModelTextResponse ->
|
|
281
|
+
return _utils.Either(right=_content_model_text(m.content))
|
|
282
|
+
elif m.role == 'model-structured-response':
|
|
283
|
+
# ModelStructuredResponse ->
|
|
284
|
+
return _utils.Either(right=_content_function_call(m))
|
|
285
|
+
else:
|
|
286
|
+
assert_never(m)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
@dataclass
|
|
290
|
+
class GeminiStreamTextResponse(StreamTextResponse):
|
|
291
|
+
"""Implementation of `StreamTextResponse` for the Gemini model."""
|
|
292
|
+
|
|
293
|
+
_json_content: bytearray
|
|
294
|
+
_stream: AsyncIterator[bytes]
|
|
295
|
+
_position: int = 0
|
|
296
|
+
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
297
|
+
_cost: result.Cost = field(default_factory=result.Cost, init=False)
|
|
298
|
+
|
|
299
|
+
async def __anext__(self) -> None:
|
|
300
|
+
chunk = await self._stream.__anext__()
|
|
301
|
+
self._json_content.extend(chunk)
|
|
302
|
+
|
|
303
|
+
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
304
|
+
if final:
|
|
305
|
+
all_items = pydantic_core.from_json(self._json_content)
|
|
306
|
+
new_items = all_items[self._position :]
|
|
307
|
+
self._position = len(all_items)
|
|
308
|
+
new_responses = _gemini_streamed_response_ta.validate_python(new_items)
|
|
309
|
+
else:
|
|
310
|
+
all_items = pydantic_core.from_json(self._json_content, allow_partial=True)
|
|
311
|
+
new_items = all_items[self._position : -1]
|
|
312
|
+
self._position = len(all_items) - 1
|
|
313
|
+
new_responses = _gemini_streamed_response_ta.validate_python(
|
|
314
|
+
new_items, experimental_allow_partial='trailing-strings'
|
|
315
|
+
)
|
|
316
|
+
for r in new_responses:
|
|
317
|
+
self._cost += _metadata_as_cost(r)
|
|
318
|
+
parts = r['candidates'][0]['content']['parts']
|
|
319
|
+
if _all_text_parts(parts):
|
|
320
|
+
for part in parts:
|
|
321
|
+
yield part['text']
|
|
322
|
+
else:
|
|
323
|
+
raise UnexpectedModelBehavior(
|
|
324
|
+
'Streamed response with unexpected content, expected all parts to be text'
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
def cost(self) -> result.Cost:
|
|
328
|
+
return self._cost
|
|
329
|
+
|
|
330
|
+
def timestamp(self) -> datetime:
|
|
331
|
+
return self._timestamp
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@dataclass
|
|
335
|
+
class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
336
|
+
"""Implementation of `StreamStructuredResponse` for the Gemini model."""
|
|
337
|
+
|
|
338
|
+
_content: bytearray
|
|
339
|
+
_stream: AsyncIterator[bytes]
|
|
340
|
+
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
341
|
+
_cost: result.Cost = field(default_factory=result.Cost, init=False)
|
|
342
|
+
|
|
343
|
+
async def __anext__(self) -> None:
|
|
344
|
+
chunk = await self._stream.__anext__()
|
|
345
|
+
self._content.extend(chunk)
|
|
346
|
+
|
|
347
|
+
def get(self, *, final: bool = False) -> ModelStructuredResponse:
|
|
348
|
+
"""Get the `ModelStructuredResponse` at this point.
|
|
349
|
+
|
|
350
|
+
NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
|
|
351
|
+
reply with a single response, when returning a structured data.
|
|
352
|
+
|
|
353
|
+
I'm therefore assuming that each part contains a complete tool call, and not trying to combine data from
|
|
354
|
+
separate parts.
|
|
355
|
+
"""
|
|
356
|
+
responses = _gemini_streamed_response_ta.validate_json(
|
|
357
|
+
self._content,
|
|
358
|
+
experimental_allow_partial='off' if final else 'trailing-strings',
|
|
359
|
+
)
|
|
360
|
+
combined_parts: list[_GeminiFunctionCallPart] = []
|
|
361
|
+
self._cost = result.Cost()
|
|
362
|
+
for r in responses:
|
|
363
|
+
self._cost += _metadata_as_cost(r)
|
|
364
|
+
candidate = r['candidates'][0]
|
|
365
|
+
parts = candidate['content']['parts']
|
|
366
|
+
if _all_function_call_parts(parts):
|
|
367
|
+
combined_parts.extend(parts)
|
|
368
|
+
elif not candidate.get('finish_reason'):
|
|
369
|
+
# you can get an empty text part along with the finish_reason, so we ignore that case
|
|
370
|
+
raise UnexpectedModelBehavior(
|
|
371
|
+
'Streamed response with unexpected content, expected all parts to be function calls'
|
|
372
|
+
)
|
|
373
|
+
return _structured_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
374
|
+
|
|
375
|
+
def cost(self) -> result.Cost:
|
|
376
|
+
return self._cost
|
|
377
|
+
|
|
378
|
+
def timestamp(self) -> datetime:
|
|
379
|
+
return self._timestamp
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
# We use typed dicts to define the Gemini API response schema
|
|
383
|
+
# once Pydantic partial validation supports, dataclasses, we could revert to using them
|
|
384
|
+
# TypeAdapters take care of validation and serialization
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class _GeminiRequest(TypedDict):
|
|
388
|
+
"""Schema for an API request to the Gemini API.
|
|
389
|
+
|
|
390
|
+
See <https://ai.google.dev/api/generate-content#request-body> for API docs.
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
contents: list[_GeminiContent]
|
|
394
|
+
tools: NotRequired[_GeminiTools]
|
|
395
|
+
tool_config: NotRequired[_GeminiToolConfig]
|
|
396
|
+
# we don't implement `generationConfig`, instead we use a named tool for the response
|
|
397
|
+
system_instruction: NotRequired[_GeminiTextContent]
|
|
398
|
+
"""
|
|
399
|
+
Developer generated system instructions, see
|
|
400
|
+
<https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class _GeminiContent(TypedDict):
|
|
405
|
+
role: Literal['user', 'model']
|
|
406
|
+
parts: list[_GeminiPartUnion]
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _content_user_text(text: str) -> _GeminiContent:
|
|
410
|
+
return _GeminiContent(role='user', parts=[_GeminiTextPart(text=text)])
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def _content_model_text(text: str) -> _GeminiContent:
|
|
414
|
+
return _GeminiContent(role='model', parts=[_GeminiTextPart(text=text)])
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def _content_function_call(m: ModelStructuredResponse) -> _GeminiContent:
|
|
418
|
+
parts: list[_GeminiPartUnion] = [_function_call_part_from_call(t) for t in m.calls]
|
|
419
|
+
return _GeminiContent(role='model', parts=parts)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _content_function_return(m: ToolReturn) -> _GeminiContent:
|
|
423
|
+
f_response = _response_part_from_response(m.tool_name, m.model_response_object())
|
|
424
|
+
return _GeminiContent(role='user', parts=[f_response])
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _content_function_retry(m: RetryPrompt) -> _GeminiContent:
|
|
428
|
+
if m.tool_name is None:
|
|
429
|
+
part = _GeminiTextPart(text=m.model_response())
|
|
430
|
+
else:
|
|
431
|
+
response = {'call_error': m.model_response()}
|
|
432
|
+
part = _response_part_from_response(m.tool_name, response)
|
|
433
|
+
return _GeminiContent(role='user', parts=[part])
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class _GeminiTextPart(TypedDict):
|
|
437
|
+
text: str
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class _GeminiFunctionCallPart(TypedDict):
|
|
441
|
+
function_call: Annotated[_GeminiFunctionCall, Field(alias='functionCall')]
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def _function_call_part_from_call(tool: ToolCall) -> _GeminiFunctionCallPart:
|
|
445
|
+
assert isinstance(tool.args, ArgsObject), f'Expected ArgsObject, got {tool.args}'
|
|
446
|
+
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_object))
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def _structured_response_from_parts(
|
|
450
|
+
parts: list[_GeminiFunctionCallPart], timestamp: datetime | None = None
|
|
451
|
+
) -> ModelStructuredResponse:
|
|
452
|
+
return ModelStructuredResponse(
|
|
453
|
+
calls=[ToolCall.from_object(part['function_call']['name'], part['function_call']['args']) for part in parts],
|
|
454
|
+
timestamp=timestamp or _utils.now_utc(),
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class _GeminiFunctionCall(TypedDict):
|
|
459
|
+
"""See <https://ai.google.dev/api/caching#FunctionCall>."""
|
|
460
|
+
|
|
461
|
+
name: str
|
|
462
|
+
args: dict[str, Any]
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
class _GeminiFunctionResponsePart(TypedDict):
|
|
466
|
+
function_response: Annotated[_GeminiFunctionResponse, Field(alias='functionResponse')]
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
|
|
470
|
+
return _GeminiFunctionResponsePart(function_response=_GeminiFunctionResponse(name=name, response=response))
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
class _GeminiFunctionResponse(TypedDict):
|
|
474
|
+
"""See <https://ai.google.dev/api/caching#FunctionResponse>."""
|
|
475
|
+
|
|
476
|
+
name: str
|
|
477
|
+
response: dict[str, Any]
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _part_discriminator(v: Any) -> str:
|
|
481
|
+
if isinstance(v, dict):
|
|
482
|
+
if 'text' in v:
|
|
483
|
+
return 'text'
|
|
484
|
+
elif 'functionCall' in v or 'function_call' in v:
|
|
485
|
+
return 'function_call'
|
|
486
|
+
elif 'functionResponse' in v or 'function_response' in v:
|
|
487
|
+
return 'function_response'
|
|
488
|
+
return 'text'
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
# See <https://ai.google.dev/api/caching#Part>
|
|
492
|
+
# we don't currently support other part types
|
|
493
|
+
# TODO discriminator
|
|
494
|
+
_GeminiPartUnion = Annotated[
|
|
495
|
+
Union[
|
|
496
|
+
Annotated[_GeminiTextPart, Tag('text')],
|
|
497
|
+
Annotated[_GeminiFunctionCallPart, Tag('function_call')],
|
|
498
|
+
Annotated[_GeminiFunctionResponsePart, Tag('function_response')],
|
|
499
|
+
],
|
|
500
|
+
Discriminator(_part_discriminator),
|
|
501
|
+
]
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
class _GeminiTextContent(TypedDict):
|
|
505
|
+
role: Literal['user', 'model']
|
|
506
|
+
parts: list[_GeminiTextPart]
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
class _GeminiTools(TypedDict):
|
|
510
|
+
function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]]
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
class _GeminiFunction(TypedDict):
|
|
514
|
+
name: str
|
|
515
|
+
description: str
|
|
516
|
+
parameters: NotRequired[dict[str, Any]]
|
|
517
|
+
"""
|
|
518
|
+
ObjectJsonSchema isn't really true since Gemini only accepts a subset of JSON Schema
|
|
519
|
+
<https://ai.google.dev/gemini-api/docs/function-calling#function_declarations>
|
|
520
|
+
and
|
|
521
|
+
<https://ai.google.dev/api/caching#FunctionDeclaration>
|
|
522
|
+
"""
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _function_from_abstract_tool(tool: AbstractToolDefinition) -> _GeminiFunction:
|
|
526
|
+
json_schema = _GeminiJsonSchema(tool.json_schema).simplify()
|
|
527
|
+
f = _GeminiFunction(
|
|
528
|
+
name=tool.name,
|
|
529
|
+
description=tool.description,
|
|
530
|
+
)
|
|
531
|
+
if json_schema.get('properties'):
|
|
532
|
+
f['parameters'] = json_schema
|
|
533
|
+
return f
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
class _GeminiToolConfig(TypedDict):
|
|
537
|
+
function_calling_config: _GeminiFunctionCallingConfig
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def _tool_config(function_names: list[str]) -> _GeminiToolConfig:
|
|
541
|
+
return _GeminiToolConfig(
|
|
542
|
+
function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=function_names)
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
class _GeminiFunctionCallingConfig(TypedDict):
|
|
547
|
+
mode: Literal['ANY', 'AUTO']
|
|
548
|
+
allowed_function_names: list[str]
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
class _GeminiResponse(TypedDict):
|
|
552
|
+
"""Schema for the response from the Gemini API.
|
|
553
|
+
|
|
554
|
+
See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>
|
|
555
|
+
and <https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse>
|
|
556
|
+
"""
|
|
557
|
+
|
|
558
|
+
candidates: list[_GeminiCandidates]
|
|
559
|
+
# usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
|
|
560
|
+
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]]
|
|
561
|
+
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, Field(alias='promptFeedback')]]
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def _extract_response_parts(
|
|
565
|
+
response: _GeminiResponse,
|
|
566
|
+
) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
|
|
567
|
+
"""Extract the parts of the response from the Gemini API.
|
|
568
|
+
|
|
569
|
+
Returns Either a list of function calls (Either.left) or a list of text parts (Either.right).
|
|
570
|
+
"""
|
|
571
|
+
if len(response['candidates']) != 1:
|
|
572
|
+
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
573
|
+
parts = response['candidates'][0]['content']['parts']
|
|
574
|
+
if _all_function_call_parts(parts):
|
|
575
|
+
return _utils.Either(left=parts)
|
|
576
|
+
elif _all_text_parts(parts):
|
|
577
|
+
return _utils.Either(right=parts)
|
|
578
|
+
else:
|
|
579
|
+
raise exceptions.UnexpectedModelBehavior(
|
|
580
|
+
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {parts!r}'
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def _all_function_call_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiFunctionCallPart]]:
|
|
585
|
+
return all('function_call' in part for part in parts)
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def _all_text_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiTextPart]]:
|
|
589
|
+
return all('text' in part for part in parts)
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
class _GeminiCandidates(TypedDict):
|
|
593
|
+
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
|
|
594
|
+
|
|
595
|
+
content: _GeminiContent
|
|
596
|
+
finish_reason: NotRequired[Annotated[Literal['STOP'], Field(alias='finishReason')]]
|
|
597
|
+
"""
|
|
598
|
+
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
|
|
599
|
+
but let's wait until we see them and know what they mean to add them here.
|
|
600
|
+
"""
|
|
601
|
+
avg_log_probs: NotRequired[Annotated[float, Field(alias='avgLogProbs')]]
|
|
602
|
+
index: NotRequired[int]
|
|
603
|
+
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]]
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
class _GeminiUsageMetaData(TypedDict, total=False):
|
|
607
|
+
"""See <https://ai.google.dev/api/generate-content#FinishReason>.
|
|
608
|
+
|
|
609
|
+
The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
|
|
610
|
+
"""
|
|
611
|
+
|
|
612
|
+
prompt_token_count: Annotated[int, Field(alias='promptTokenCount')]
|
|
613
|
+
candidates_token_count: NotRequired[Annotated[int, Field(alias='candidatesTokenCount')]]
|
|
614
|
+
total_token_count: Annotated[int, Field(alias='totalTokenCount')]
|
|
615
|
+
cached_content_token_count: NotRequired[Annotated[int, Field(alias='cachedContentTokenCount')]]
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def _metadata_as_cost(response: _GeminiResponse) -> result.Cost:
|
|
619
|
+
metadata = response.get('usage_metadata')
|
|
620
|
+
if metadata is None:
|
|
621
|
+
return result.Cost()
|
|
622
|
+
details: dict[str, int] = {}
|
|
623
|
+
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
624
|
+
details['cached_content_token_count'] = cached_content_token_count
|
|
625
|
+
return result.Cost(
|
|
626
|
+
request_tokens=metadata.get('prompt_token_count', 0),
|
|
627
|
+
response_tokens=metadata.get('candidates_token_count', 0),
|
|
628
|
+
total_tokens=metadata.get('total_token_count', 0),
|
|
629
|
+
details=details,
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
class _GeminiSafetyRating(TypedDict):
|
|
634
|
+
"""See <https://ai.google.dev/gemini-api/docs/safety-settings#safety-filters>."""
|
|
635
|
+
|
|
636
|
+
category: Literal[
|
|
637
|
+
'HARM_CATEGORY_HARASSMENT',
|
|
638
|
+
'HARM_CATEGORY_HATE_SPEECH',
|
|
639
|
+
'HARM_CATEGORY_SEXUALLY_EXPLICIT',
|
|
640
|
+
'HARM_CATEGORY_DANGEROUS_CONTENT',
|
|
641
|
+
'HARM_CATEGORY_CIVIC_INTEGRITY',
|
|
642
|
+
]
|
|
643
|
+
probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH']
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
class _GeminiPromptFeedback(TypedDict):
|
|
647
|
+
"""See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
|
|
648
|
+
|
|
649
|
+
block_reason: Annotated[str, Field(alias='blockReason')]
|
|
650
|
+
safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
_gemini_request_ta = _pydantic.LazyTypeAdapter(_GeminiRequest)
|
|
654
|
+
_gemini_response_ta = _pydantic.LazyTypeAdapter(_GeminiResponse)
|
|
655
|
+
|
|
656
|
+
# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
|
|
657
|
+
_gemini_streamed_response_ta = _pydantic.LazyTypeAdapter(list[_GeminiResponse])
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
class _GeminiJsonSchema:
|
|
661
|
+
"""Transforms the JSON Schema from Pydantic to be suitable for Gemini.
|
|
662
|
+
|
|
663
|
+
Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
|
|
664
|
+
a subset of OpenAPI v3.0.3.
|
|
665
|
+
|
|
666
|
+
Specifically:
|
|
667
|
+
* gemini doesn't allow the `title` keyword to be set
|
|
668
|
+
* gemini doesn't allow `$defs` — we need to inline the definitions where possible
|
|
669
|
+
"""
|
|
670
|
+
|
|
671
|
+
def __init__(self, schema: _utils.ObjectJsonSchema):
|
|
672
|
+
self.schema = deepcopy(schema)
|
|
673
|
+
self.defs = self.schema.pop('$defs', {})
|
|
674
|
+
|
|
675
|
+
def simplify(self) -> dict[str, Any]:
|
|
676
|
+
self._simplify(self.schema, refs_stack=())
|
|
677
|
+
return self.schema
|
|
678
|
+
|
|
679
|
+
def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
680
|
+
schema.pop('title', None)
|
|
681
|
+
schema.pop('default', None)
|
|
682
|
+
if ref := schema.pop('$ref', None):
|
|
683
|
+
# noinspection PyTypeChecker
|
|
684
|
+
key = re.sub(r'^#/\$defs/', '', ref)
|
|
685
|
+
if key in refs_stack:
|
|
686
|
+
raise exceptions.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
|
|
687
|
+
refs_stack += (key,)
|
|
688
|
+
schema_def = self.defs[key]
|
|
689
|
+
self._simplify(schema_def, refs_stack)
|
|
690
|
+
schema.update(schema_def)
|
|
691
|
+
return
|
|
692
|
+
|
|
693
|
+
if any_of := schema.get('anyOf'):
|
|
694
|
+
for schema in any_of:
|
|
695
|
+
self._simplify(schema, refs_stack)
|
|
696
|
+
|
|
697
|
+
type_ = schema.get('type')
|
|
698
|
+
|
|
699
|
+
if type_ == 'object':
|
|
700
|
+
self._object(schema, refs_stack)
|
|
701
|
+
elif type_ == 'array':
|
|
702
|
+
return self._array(schema, refs_stack)
|
|
703
|
+
|
|
704
|
+
def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
705
|
+
ad_props = schema.pop('additionalProperties', None)
|
|
706
|
+
if ad_props:
|
|
707
|
+
raise exceptions.UserError('Additional properties in JSON Schema are not supported by Gemini')
|
|
708
|
+
|
|
709
|
+
if properties := schema.get('properties'): # pragma: no branch
|
|
710
|
+
for value in properties.values():
|
|
711
|
+
self._simplify(value, refs_stack)
|
|
712
|
+
|
|
713
|
+
def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
714
|
+
if prefix_items := schema.get('prefixItems'):
|
|
715
|
+
# TODO I think this not is supported by Gemini, maybe we should raise an error?
|
|
716
|
+
for prefix_item in prefix_items:
|
|
717
|
+
self._simplify(prefix_item, refs_stack)
|
|
718
|
+
|
|
719
|
+
if items_schema := schema.get('items'): # pragma: no branch
|
|
720
|
+
self._simplify(items_schema, refs_stack)
|