pydantic-ai-slim 0.0.22__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +12 -8
- pydantic_ai/agent.py +2 -2
- pydantic_ai/models/__init__.py +39 -37
- pydantic_ai/models/anthropic.py +69 -66
- pydantic_ai/models/cohere.py +56 -68
- pydantic_ai/models/function.py +58 -60
- pydantic_ai/models/gemini.py +132 -99
- pydantic_ai/models/groq.py +79 -72
- pydantic_ai/models/mistral.py +72 -71
- pydantic_ai/models/openai.py +90 -70
- pydantic_ai/models/test.py +81 -93
- pydantic_ai/models/vertexai.py +38 -44
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.23.dist-info}/METADATA +3 -3
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.23.dist-info}/RECORD +15 -15
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.23.dist-info}/WHEEL +0 -0
pydantic_ai/models/function.py
CHANGED
|
@@ -4,7 +4,7 @@ import inspect
|
|
|
4
4
|
import re
|
|
5
5
|
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
10
|
from typing import Callable, Union
|
|
@@ -27,7 +27,7 @@ from ..messages import (
|
|
|
27
27
|
)
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
|
-
from . import
|
|
30
|
+
from . import Model, ModelRequestParameters, StreamedResponse
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
@dataclass(init=False)
|
|
@@ -40,6 +40,9 @@ class FunctionModel(Model):
|
|
|
40
40
|
function: FunctionDef | None = None
|
|
41
41
|
stream_function: StreamFunctionDef | None = None
|
|
42
42
|
|
|
43
|
+
_model_name: str = field(repr=False)
|
|
44
|
+
_system: str | None = field(default=None, repr=False)
|
|
45
|
+
|
|
43
46
|
@overload
|
|
44
47
|
def __init__(self, function: FunctionDef) -> None: ...
|
|
45
48
|
|
|
@@ -63,23 +66,60 @@ class FunctionModel(Model):
|
|
|
63
66
|
self.function = function
|
|
64
67
|
self.stream_function = stream_function
|
|
65
68
|
|
|
66
|
-
|
|
69
|
+
function_name = self.function.__name__ if self.function is not None else ''
|
|
70
|
+
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
71
|
+
self._model_name = f'function:{function_name}:{stream_function_name}'
|
|
72
|
+
|
|
73
|
+
async def request(
|
|
67
74
|
self,
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
75
|
+
messages: list[ModelMessage],
|
|
76
|
+
model_settings: ModelSettings | None,
|
|
77
|
+
model_request_parameters: ModelRequestParameters,
|
|
78
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
79
|
+
agent_info = AgentInfo(
|
|
80
|
+
model_request_parameters.function_tools,
|
|
81
|
+
model_request_parameters.allow_text_result,
|
|
82
|
+
model_request_parameters.result_tools,
|
|
83
|
+
model_settings,
|
|
77
84
|
)
|
|
78
85
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
86
|
+
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
87
|
+
|
|
88
|
+
if inspect.iscoroutinefunction(self.function):
|
|
89
|
+
response = await self.function(messages, agent_info)
|
|
90
|
+
else:
|
|
91
|
+
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
92
|
+
assert isinstance(response_, ModelResponse), response_
|
|
93
|
+
response = response_
|
|
94
|
+
response.model_name = f'function:{self.function.__name__}'
|
|
95
|
+
# TODO is `messages` right here? Should it just be new messages?
|
|
96
|
+
return response, _estimate_usage(chain(messages, [response]))
|
|
97
|
+
|
|
98
|
+
@asynccontextmanager
|
|
99
|
+
async def request_stream(
|
|
100
|
+
self,
|
|
101
|
+
messages: list[ModelMessage],
|
|
102
|
+
model_settings: ModelSettings | None,
|
|
103
|
+
model_request_parameters: ModelRequestParameters,
|
|
104
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
105
|
+
agent_info = AgentInfo(
|
|
106
|
+
model_request_parameters.function_tools,
|
|
107
|
+
model_request_parameters.allow_text_result,
|
|
108
|
+
model_request_parameters.result_tools,
|
|
109
|
+
model_settings,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
assert (
|
|
113
|
+
self.stream_function is not None
|
|
114
|
+
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
115
|
+
|
|
116
|
+
response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
|
|
117
|
+
|
|
118
|
+
first = await response_stream.peek()
|
|
119
|
+
if isinstance(first, _utils.Unset):
|
|
120
|
+
raise ValueError('Stream function must return at least one item')
|
|
121
|
+
|
|
122
|
+
yield FunctionStreamedResponse(_model_name=f'function:{self.stream_function.__name__}', _iter=response_stream)
|
|
83
123
|
|
|
84
124
|
|
|
85
125
|
@dataclass(frozen=True)
|
|
@@ -119,9 +159,11 @@ class DeltaToolCall:
|
|
|
119
159
|
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
120
160
|
"""A mapping of tool call IDs to incremental changes."""
|
|
121
161
|
|
|
162
|
+
# TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
|
|
122
163
|
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
|
|
123
164
|
"""A function used to generate a non-streamed response."""
|
|
124
165
|
|
|
166
|
+
# TODO: Change signature as indicated above
|
|
125
167
|
StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
|
|
126
168
|
"""A function used to generate a streamed response.
|
|
127
169
|
|
|
@@ -132,50 +174,6 @@ E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
|
|
|
132
174
|
"""
|
|
133
175
|
|
|
134
176
|
|
|
135
|
-
@dataclass
|
|
136
|
-
class FunctionAgentModel(AgentModel):
|
|
137
|
-
"""Implementation of `AgentModel` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
138
|
-
|
|
139
|
-
function: FunctionDef | None
|
|
140
|
-
stream_function: StreamFunctionDef | None
|
|
141
|
-
agent_info: AgentInfo
|
|
142
|
-
|
|
143
|
-
async def request(
|
|
144
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
145
|
-
) -> tuple[ModelResponse, usage.Usage]:
|
|
146
|
-
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
147
|
-
|
|
148
|
-
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
149
|
-
model_name = f'function:{self.function.__name__}'
|
|
150
|
-
|
|
151
|
-
if inspect.iscoroutinefunction(self.function):
|
|
152
|
-
response = await self.function(messages, agent_info)
|
|
153
|
-
else:
|
|
154
|
-
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
155
|
-
assert isinstance(response_, ModelResponse), response_
|
|
156
|
-
response = response_
|
|
157
|
-
response.model_name = model_name
|
|
158
|
-
# TODO is `messages` right here? Should it just be new messages?
|
|
159
|
-
return response, _estimate_usage(chain(messages, [response]))
|
|
160
|
-
|
|
161
|
-
@asynccontextmanager
|
|
162
|
-
async def request_stream(
|
|
163
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
164
|
-
) -> AsyncIterator[StreamedResponse]:
|
|
165
|
-
assert (
|
|
166
|
-
self.stream_function is not None
|
|
167
|
-
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
168
|
-
model_name = f'function:{self.stream_function.__name__}'
|
|
169
|
-
|
|
170
|
-
response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
|
|
171
|
-
|
|
172
|
-
first = await response_stream.peek()
|
|
173
|
-
if isinstance(first, _utils.Unset):
|
|
174
|
-
raise ValueError('Stream function must return at least one item')
|
|
175
|
-
|
|
176
|
-
yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
|
|
177
|
-
|
|
178
|
-
|
|
179
177
|
@dataclass
|
|
180
178
|
class FunctionStreamedResponse(StreamedResponse):
|
|
181
179
|
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -31,15 +31,15 @@ from ..messages import (
|
|
|
31
31
|
from ..settings import ModelSettings
|
|
32
32
|
from ..tools import ToolDefinition
|
|
33
33
|
from . import (
|
|
34
|
-
AgentModel,
|
|
35
34
|
Model,
|
|
35
|
+
ModelRequestParameters,
|
|
36
36
|
StreamedResponse,
|
|
37
37
|
cached_async_http_client,
|
|
38
38
|
check_allow_model_requests,
|
|
39
39
|
get_user_agent,
|
|
40
40
|
)
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
LatestGeminiModelNames = Literal[
|
|
43
43
|
'gemini-1.5-flash',
|
|
44
44
|
'gemini-1.5-flash-8b',
|
|
45
45
|
'gemini-1.5-pro',
|
|
@@ -48,8 +48,13 @@ GeminiModelName = Literal[
|
|
|
48
48
|
'gemini-2.0-flash-thinking-exp-01-21',
|
|
49
49
|
'gemini-exp-1206',
|
|
50
50
|
]
|
|
51
|
-
"""
|
|
51
|
+
"""Latest Gemini models."""
|
|
52
52
|
|
|
53
|
+
GeminiModelName = Union[str, LatestGeminiModelNames]
|
|
54
|
+
"""Possible Gemini model names.
|
|
55
|
+
|
|
56
|
+
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
|
|
57
|
+
allow any name in the type hints.
|
|
53
58
|
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
|
|
54
59
|
"""
|
|
55
60
|
|
|
@@ -57,7 +62,7 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#mo
|
|
|
57
62
|
class GeminiModelSettings(ModelSettings):
|
|
58
63
|
"""Settings used for a Gemini model request."""
|
|
59
64
|
|
|
60
|
-
|
|
65
|
+
gemini_safety_settings: list[GeminiSafetySettings]
|
|
61
66
|
|
|
62
67
|
|
|
63
68
|
@dataclass(init=False)
|
|
@@ -70,10 +75,12 @@ class GeminiModel(Model):
|
|
|
70
75
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
71
76
|
"""
|
|
72
77
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
78
|
+
http_client: AsyncHTTPClient = field(repr=False)
|
|
79
|
+
|
|
80
|
+
_model_name: GeminiModelName = field(repr=False)
|
|
81
|
+
_auth: AuthProtocol | None = field(repr=False)
|
|
82
|
+
_url: str | None = field(repr=False)
|
|
83
|
+
_system: str | None = field(default='google-gla', repr=False)
|
|
77
84
|
|
|
78
85
|
def __init__(
|
|
79
86
|
self,
|
|
@@ -94,121 +101,87 @@ class GeminiModel(Model):
|
|
|
94
101
|
docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
|
|
95
102
|
`model` is substituted with the model name, and `function` is added to the end of the URL.
|
|
96
103
|
"""
|
|
97
|
-
self.
|
|
104
|
+
self._model_name = model_name
|
|
98
105
|
if api_key is None:
|
|
99
106
|
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
100
107
|
api_key = env_api_key
|
|
101
108
|
else:
|
|
102
109
|
raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
103
|
-
self.auth = ApiKeyAuth(api_key)
|
|
104
110
|
self.http_client = http_client or cached_async_http_client()
|
|
105
|
-
self.
|
|
106
|
-
|
|
107
|
-
async def agent_model(
|
|
108
|
-
self,
|
|
109
|
-
*,
|
|
110
|
-
function_tools: list[ToolDefinition],
|
|
111
|
-
allow_text_result: bool,
|
|
112
|
-
result_tools: list[ToolDefinition],
|
|
113
|
-
) -> GeminiAgentModel:
|
|
114
|
-
check_allow_model_requests()
|
|
115
|
-
return GeminiAgentModel(
|
|
116
|
-
http_client=self.http_client,
|
|
117
|
-
model_name=self.model_name,
|
|
118
|
-
auth=self.auth,
|
|
119
|
-
url=self.url,
|
|
120
|
-
function_tools=function_tools,
|
|
121
|
-
allow_text_result=allow_text_result,
|
|
122
|
-
result_tools=result_tools,
|
|
123
|
-
)
|
|
124
|
-
|
|
125
|
-
def name(self) -> str:
|
|
126
|
-
return f'google-gla:{self.model_name}'
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
class AuthProtocol(Protocol):
|
|
130
|
-
"""Abstract definition for Gemini authentication."""
|
|
131
|
-
|
|
132
|
-
async def headers(self) -> dict[str, str]: ...
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
@dataclass
|
|
136
|
-
class ApiKeyAuth:
|
|
137
|
-
"""Authentication using an API key for the `X-Goog-Api-Key` header."""
|
|
138
|
-
|
|
139
|
-
api_key: str
|
|
111
|
+
self._auth = ApiKeyAuth(api_key)
|
|
112
|
+
self._url = url_template.format(model=model_name)
|
|
140
113
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
@dataclass(init=False)
|
|
147
|
-
class GeminiAgentModel(AgentModel):
|
|
148
|
-
"""Implementation of `AgentModel` for Gemini models."""
|
|
149
|
-
|
|
150
|
-
http_client: AsyncHTTPClient
|
|
151
|
-
model_name: GeminiModelName
|
|
152
|
-
auth: AuthProtocol
|
|
153
|
-
tools: _GeminiTools | None
|
|
154
|
-
tool_config: _GeminiToolConfig | None
|
|
155
|
-
url: str
|
|
156
|
-
|
|
157
|
-
def __init__(
|
|
158
|
-
self,
|
|
159
|
-
http_client: AsyncHTTPClient,
|
|
160
|
-
model_name: GeminiModelName,
|
|
161
|
-
auth: AuthProtocol,
|
|
162
|
-
url: str,
|
|
163
|
-
function_tools: list[ToolDefinition],
|
|
164
|
-
allow_text_result: bool,
|
|
165
|
-
result_tools: list[ToolDefinition],
|
|
166
|
-
):
|
|
167
|
-
tools = [_function_from_abstract_tool(t) for t in function_tools]
|
|
168
|
-
if result_tools:
|
|
169
|
-
tools += [_function_from_abstract_tool(t) for t in result_tools]
|
|
114
|
+
@property
|
|
115
|
+
def auth(self) -> AuthProtocol:
|
|
116
|
+
assert self._auth is not None, 'Auth not initialized'
|
|
117
|
+
return self._auth
|
|
170
118
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
self.http_client = http_client
|
|
177
|
-
self.model_name = model_name
|
|
178
|
-
self.auth = auth
|
|
179
|
-
self.tools = _GeminiTools(function_declarations=tools) if tools else None
|
|
180
|
-
self.tool_config = tool_config
|
|
181
|
-
self.url = url
|
|
119
|
+
@property
|
|
120
|
+
def url(self) -> str:
|
|
121
|
+
assert self._url is not None, 'URL not initialized'
|
|
122
|
+
return self._url
|
|
182
123
|
|
|
183
124
|
async def request(
|
|
184
|
-
self,
|
|
125
|
+
self,
|
|
126
|
+
messages: list[ModelMessage],
|
|
127
|
+
model_settings: ModelSettings | None,
|
|
128
|
+
model_request_parameters: ModelRequestParameters,
|
|
185
129
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
130
|
+
check_allow_model_requests()
|
|
186
131
|
async with self._make_request(
|
|
187
|
-
messages, False, cast(GeminiModelSettings, model_settings or {})
|
|
132
|
+
messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
|
|
188
133
|
) as http_response:
|
|
189
134
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
190
135
|
return self._process_response(response), _metadata_as_usage(response)
|
|
191
136
|
|
|
192
137
|
@asynccontextmanager
|
|
193
138
|
async def request_stream(
|
|
194
|
-
self,
|
|
139
|
+
self,
|
|
140
|
+
messages: list[ModelMessage],
|
|
141
|
+
model_settings: ModelSettings | None,
|
|
142
|
+
model_request_parameters: ModelRequestParameters,
|
|
195
143
|
) -> AsyncIterator[StreamedResponse]:
|
|
196
|
-
|
|
144
|
+
check_allow_model_requests()
|
|
145
|
+
async with self._make_request(
|
|
146
|
+
messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
|
|
147
|
+
) as http_response:
|
|
197
148
|
yield await self._process_streamed_response(http_response)
|
|
198
149
|
|
|
150
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
|
|
151
|
+
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
|
|
152
|
+
if model_request_parameters.result_tools:
|
|
153
|
+
tools += [_function_from_abstract_tool(t) for t in model_request_parameters.result_tools]
|
|
154
|
+
return _GeminiTools(function_declarations=tools) if tools else None
|
|
155
|
+
|
|
156
|
+
def _get_tool_config(
|
|
157
|
+
self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
|
|
158
|
+
) -> _GeminiToolConfig | None:
|
|
159
|
+
if model_request_parameters.allow_text_result:
|
|
160
|
+
return None
|
|
161
|
+
elif tools:
|
|
162
|
+
return _tool_config([t['name'] for t in tools['function_declarations']])
|
|
163
|
+
else:
|
|
164
|
+
return _tool_config([])
|
|
165
|
+
|
|
199
166
|
@asynccontextmanager
|
|
200
167
|
async def _make_request(
|
|
201
|
-
self,
|
|
168
|
+
self,
|
|
169
|
+
messages: list[ModelMessage],
|
|
170
|
+
streamed: bool,
|
|
171
|
+
model_settings: GeminiModelSettings,
|
|
172
|
+
model_request_parameters: ModelRequestParameters,
|
|
202
173
|
) -> AsyncIterator[HTTPResponse]:
|
|
174
|
+
tools = self._get_tools(model_request_parameters)
|
|
175
|
+
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
203
176
|
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
204
177
|
|
|
205
178
|
request_data = _GeminiRequest(contents=contents)
|
|
206
179
|
if sys_prompt_parts:
|
|
207
180
|
request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
|
|
208
|
-
if
|
|
209
|
-
request_data['tools'] =
|
|
210
|
-
if
|
|
211
|
-
request_data['tool_config'] =
|
|
181
|
+
if tools is not None:
|
|
182
|
+
request_data['tools'] = tools
|
|
183
|
+
if tool_config is not None:
|
|
184
|
+
request_data['tool_config'] = tool_config
|
|
212
185
|
|
|
213
186
|
generation_config: _GeminiGenerationConfig = {}
|
|
214
187
|
if model_settings:
|
|
@@ -222,6 +195,8 @@ class GeminiAgentModel(AgentModel):
|
|
|
222
195
|
generation_config['presence_penalty'] = presence_penalty
|
|
223
196
|
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
|
|
224
197
|
generation_config['frequency_penalty'] = frequency_penalty
|
|
198
|
+
if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []:
|
|
199
|
+
request_data['safety_settings'] = gemini_safety_settings
|
|
225
200
|
if generation_config:
|
|
226
201
|
request_data['generation_config'] = generation_config
|
|
227
202
|
|
|
@@ -250,8 +225,13 @@ class GeminiAgentModel(AgentModel):
|
|
|
250
225
|
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
|
|
251
226
|
if len(response['candidates']) != 1:
|
|
252
227
|
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
228
|
+
if 'content' not in response['candidates'][0]:
|
|
229
|
+
if response['candidates'][0].get('finish_reason') == 'SAFETY':
|
|
230
|
+
raise UnexpectedModelBehavior('Safety settings triggered', str(response))
|
|
231
|
+
else:
|
|
232
|
+
raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
|
|
253
233
|
parts = response['candidates'][0]['content']['parts']
|
|
254
|
-
return _process_response_from_parts(parts, model_name=self.
|
|
234
|
+
return _process_response_from_parts(parts, model_name=self._model_name)
|
|
255
235
|
|
|
256
236
|
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
|
|
257
237
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -267,14 +247,14 @@ class GeminiAgentModel(AgentModel):
|
|
|
267
247
|
)
|
|
268
248
|
if responses:
|
|
269
249
|
last = responses[-1]
|
|
270
|
-
if last['candidates'] and last['candidates'][0]
|
|
250
|
+
if last['candidates'] and last['candidates'][0].get('content', {}).get('parts'):
|
|
271
251
|
start_response = last
|
|
272
252
|
break
|
|
273
253
|
|
|
274
254
|
if start_response is None:
|
|
275
255
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
276
256
|
|
|
277
|
-
return GeminiStreamedResponse(_model_name=self.
|
|
257
|
+
return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
|
|
278
258
|
|
|
279
259
|
@classmethod
|
|
280
260
|
def _message_to_gemini_content(
|
|
@@ -312,6 +292,23 @@ class GeminiAgentModel(AgentModel):
|
|
|
312
292
|
return sys_prompt_parts, contents
|
|
313
293
|
|
|
314
294
|
|
|
295
|
+
class AuthProtocol(Protocol):
|
|
296
|
+
"""Abstract definition for Gemini authentication."""
|
|
297
|
+
|
|
298
|
+
async def headers(self) -> dict[str, str]: ...
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@dataclass
|
|
302
|
+
class ApiKeyAuth:
|
|
303
|
+
"""Authentication using an API key for the `X-Goog-Api-Key` header."""
|
|
304
|
+
|
|
305
|
+
api_key: str
|
|
306
|
+
|
|
307
|
+
async def headers(self) -> dict[str, str]:
|
|
308
|
+
# https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
|
|
309
|
+
return {'X-Goog-Api-Key': self.api_key}
|
|
310
|
+
|
|
311
|
+
|
|
315
312
|
@dataclass
|
|
316
313
|
class GeminiStreamedResponse(StreamedResponse):
|
|
317
314
|
"""Implementation of `StreamedResponse` for the Gemini model."""
|
|
@@ -323,6 +320,8 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
323
320
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
324
321
|
async for gemini_response in self._get_gemini_responses():
|
|
325
322
|
candidate = gemini_response['candidates'][0]
|
|
323
|
+
if 'content' not in candidate:
|
|
324
|
+
raise UnexpectedModelBehavior('Streamed response has no content field')
|
|
326
325
|
gemini_part: _GeminiPartUnion
|
|
327
326
|
for gemini_part in candidate['content']['parts']:
|
|
328
327
|
if 'text' in gemini_part:
|
|
@@ -396,6 +395,7 @@ class _GeminiRequest(TypedDict):
|
|
|
396
395
|
contents: list[_GeminiContent]
|
|
397
396
|
tools: NotRequired[_GeminiTools]
|
|
398
397
|
tool_config: NotRequired[_GeminiToolConfig]
|
|
398
|
+
safety_settings: NotRequired[list[GeminiSafetySettings]]
|
|
399
399
|
# we don't implement `generationConfig`, instead we use a named tool for the response
|
|
400
400
|
system_instruction: NotRequired[_GeminiTextContent]
|
|
401
401
|
"""
|
|
@@ -405,6 +405,38 @@ class _GeminiRequest(TypedDict):
|
|
|
405
405
|
generation_config: NotRequired[_GeminiGenerationConfig]
|
|
406
406
|
|
|
407
407
|
|
|
408
|
+
class GeminiSafetySettings(TypedDict):
|
|
409
|
+
"""Safety settings options for Gemini model request.
|
|
410
|
+
|
|
411
|
+
See [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for safety category and threshold descriptions.
|
|
412
|
+
For an example on how to use `GeminiSafetySettings`, see [here](../../agents.md#model-specific-settings).
|
|
413
|
+
"""
|
|
414
|
+
|
|
415
|
+
category: Literal[
|
|
416
|
+
'HARM_CATEGORY_UNSPECIFIED',
|
|
417
|
+
'HARM_CATEGORY_HARASSMENT',
|
|
418
|
+
'HARM_CATEGORY_HATE_SPEECH',
|
|
419
|
+
'HARM_CATEGORY_SEXUALLY_EXPLICIT',
|
|
420
|
+
'HARM_CATEGORY_DANGEROUS_CONTENT',
|
|
421
|
+
'HARM_CATEGORY_CIVIC_INTEGRITY',
|
|
422
|
+
]
|
|
423
|
+
"""
|
|
424
|
+
Safety settings category.
|
|
425
|
+
"""
|
|
426
|
+
|
|
427
|
+
threshold: Literal[
|
|
428
|
+
'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
|
429
|
+
'BLOCK_LOW_AND_ABOVE',
|
|
430
|
+
'BLOCK_MEDIUM_AND_ABOVE',
|
|
431
|
+
'BLOCK_ONLY_HIGH',
|
|
432
|
+
'BLOCK_NONE',
|
|
433
|
+
'OFF',
|
|
434
|
+
]
|
|
435
|
+
"""
|
|
436
|
+
Safety settings threshold.
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
|
|
408
440
|
class _GeminiGenerationConfig(TypedDict, total=False):
|
|
409
441
|
"""Schema for an API request to the Gemini API.
|
|
410
442
|
|
|
@@ -581,8 +613,8 @@ class _GeminiResponse(TypedDict):
|
|
|
581
613
|
class _GeminiCandidates(TypedDict):
|
|
582
614
|
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
|
|
583
615
|
|
|
584
|
-
content: _GeminiContent
|
|
585
|
-
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
|
|
616
|
+
content: NotRequired[_GeminiContent]
|
|
617
|
+
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS', 'SAFETY'], pydantic.Field(alias='finishReason')]]
|
|
586
618
|
"""
|
|
587
619
|
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
|
|
588
620
|
but let's wait until we see them and know what they mean to add them here.
|
|
@@ -630,6 +662,7 @@ class _GeminiSafetyRating(TypedDict):
|
|
|
630
662
|
'HARM_CATEGORY_CIVIC_INTEGRITY',
|
|
631
663
|
]
|
|
632
664
|
probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH']
|
|
665
|
+
blocked: NotRequired[bool]
|
|
633
666
|
|
|
634
667
|
|
|
635
668
|
class _GeminiPromptFeedback(TypedDict):
|