pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +774 -0
- pydantic_ai/agent.py +183 -555
- pydantic_ai/models/__init__.py +43 -37
- pydantic_ai/models/anthropic.py +69 -66
- pydantic_ai/models/cohere.py +56 -68
- pydantic_ai/models/function.py +58 -60
- pydantic_ai/models/gemini.py +139 -100
- pydantic_ai/models/groq.py +79 -72
- pydantic_ai/models/mistral.py +72 -71
- pydantic_ai/models/openai.py +96 -71
- pydantic_ai/models/test.py +81 -93
- pydantic_ai/models/vertexai.py +38 -44
- pydantic_ai/settings.py +5 -0
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.23.dist-info}/METADATA +3 -4
- pydantic_ai_slim-0.0.23.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.21.dist-info/RECORD +0 -29
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.23.dist-info}/WHEEL +0 -0
pydantic_ai/models/function.py
CHANGED
|
@@ -4,7 +4,7 @@ import inspect
|
|
|
4
4
|
import re
|
|
5
5
|
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
10
|
from typing import Callable, Union
|
|
@@ -27,7 +27,7 @@ from ..messages import (
|
|
|
27
27
|
)
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
|
-
from . import
|
|
30
|
+
from . import Model, ModelRequestParameters, StreamedResponse
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
@dataclass(init=False)
|
|
@@ -40,6 +40,9 @@ class FunctionModel(Model):
|
|
|
40
40
|
function: FunctionDef | None = None
|
|
41
41
|
stream_function: StreamFunctionDef | None = None
|
|
42
42
|
|
|
43
|
+
_model_name: str = field(repr=False)
|
|
44
|
+
_system: str | None = field(default=None, repr=False)
|
|
45
|
+
|
|
43
46
|
@overload
|
|
44
47
|
def __init__(self, function: FunctionDef) -> None: ...
|
|
45
48
|
|
|
@@ -63,23 +66,60 @@ class FunctionModel(Model):
|
|
|
63
66
|
self.function = function
|
|
64
67
|
self.stream_function = stream_function
|
|
65
68
|
|
|
66
|
-
|
|
69
|
+
function_name = self.function.__name__ if self.function is not None else ''
|
|
70
|
+
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
71
|
+
self._model_name = f'function:{function_name}:{stream_function_name}'
|
|
72
|
+
|
|
73
|
+
async def request(
|
|
67
74
|
self,
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
75
|
+
messages: list[ModelMessage],
|
|
76
|
+
model_settings: ModelSettings | None,
|
|
77
|
+
model_request_parameters: ModelRequestParameters,
|
|
78
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
79
|
+
agent_info = AgentInfo(
|
|
80
|
+
model_request_parameters.function_tools,
|
|
81
|
+
model_request_parameters.allow_text_result,
|
|
82
|
+
model_request_parameters.result_tools,
|
|
83
|
+
model_settings,
|
|
77
84
|
)
|
|
78
85
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
86
|
+
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
87
|
+
|
|
88
|
+
if inspect.iscoroutinefunction(self.function):
|
|
89
|
+
response = await self.function(messages, agent_info)
|
|
90
|
+
else:
|
|
91
|
+
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
92
|
+
assert isinstance(response_, ModelResponse), response_
|
|
93
|
+
response = response_
|
|
94
|
+
response.model_name = f'function:{self.function.__name__}'
|
|
95
|
+
# TODO is `messages` right here? Should it just be new messages?
|
|
96
|
+
return response, _estimate_usage(chain(messages, [response]))
|
|
97
|
+
|
|
98
|
+
@asynccontextmanager
|
|
99
|
+
async def request_stream(
|
|
100
|
+
self,
|
|
101
|
+
messages: list[ModelMessage],
|
|
102
|
+
model_settings: ModelSettings | None,
|
|
103
|
+
model_request_parameters: ModelRequestParameters,
|
|
104
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
105
|
+
agent_info = AgentInfo(
|
|
106
|
+
model_request_parameters.function_tools,
|
|
107
|
+
model_request_parameters.allow_text_result,
|
|
108
|
+
model_request_parameters.result_tools,
|
|
109
|
+
model_settings,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
assert (
|
|
113
|
+
self.stream_function is not None
|
|
114
|
+
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
115
|
+
|
|
116
|
+
response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
|
|
117
|
+
|
|
118
|
+
first = await response_stream.peek()
|
|
119
|
+
if isinstance(first, _utils.Unset):
|
|
120
|
+
raise ValueError('Stream function must return at least one item')
|
|
121
|
+
|
|
122
|
+
yield FunctionStreamedResponse(_model_name=f'function:{self.stream_function.__name__}', _iter=response_stream)
|
|
83
123
|
|
|
84
124
|
|
|
85
125
|
@dataclass(frozen=True)
|
|
@@ -119,9 +159,11 @@ class DeltaToolCall:
|
|
|
119
159
|
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
120
160
|
"""A mapping of tool call IDs to incremental changes."""
|
|
121
161
|
|
|
162
|
+
# TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
|
|
122
163
|
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
|
|
123
164
|
"""A function used to generate a non-streamed response."""
|
|
124
165
|
|
|
166
|
+
# TODO: Change signature as indicated above
|
|
125
167
|
StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
|
|
126
168
|
"""A function used to generate a streamed response.
|
|
127
169
|
|
|
@@ -132,50 +174,6 @@ E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
|
|
|
132
174
|
"""
|
|
133
175
|
|
|
134
176
|
|
|
135
|
-
@dataclass
|
|
136
|
-
class FunctionAgentModel(AgentModel):
|
|
137
|
-
"""Implementation of `AgentModel` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
138
|
-
|
|
139
|
-
function: FunctionDef | None
|
|
140
|
-
stream_function: StreamFunctionDef | None
|
|
141
|
-
agent_info: AgentInfo
|
|
142
|
-
|
|
143
|
-
async def request(
|
|
144
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
145
|
-
) -> tuple[ModelResponse, usage.Usage]:
|
|
146
|
-
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
147
|
-
|
|
148
|
-
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
149
|
-
model_name = f'function:{self.function.__name__}'
|
|
150
|
-
|
|
151
|
-
if inspect.iscoroutinefunction(self.function):
|
|
152
|
-
response = await self.function(messages, agent_info)
|
|
153
|
-
else:
|
|
154
|
-
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
155
|
-
assert isinstance(response_, ModelResponse), response_
|
|
156
|
-
response = response_
|
|
157
|
-
response.model_name = model_name
|
|
158
|
-
# TODO is `messages` right here? Should it just be new messages?
|
|
159
|
-
return response, _estimate_usage(chain(messages, [response]))
|
|
160
|
-
|
|
161
|
-
@asynccontextmanager
|
|
162
|
-
async def request_stream(
|
|
163
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
164
|
-
) -> AsyncIterator[StreamedResponse]:
|
|
165
|
-
assert (
|
|
166
|
-
self.stream_function is not None
|
|
167
|
-
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
168
|
-
model_name = f'function:{self.stream_function.__name__}'
|
|
169
|
-
|
|
170
|
-
response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
|
|
171
|
-
|
|
172
|
-
first = await response_stream.peek()
|
|
173
|
-
if isinstance(first, _utils.Unset):
|
|
174
|
-
raise ValueError('Stream function must return at least one item')
|
|
175
|
-
|
|
176
|
-
yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
|
|
177
|
-
|
|
178
|
-
|
|
179
177
|
@dataclass
|
|
180
178
|
class FunctionStreamedResponse(StreamedResponse):
|
|
181
179
|
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -31,19 +31,30 @@ from ..messages import (
|
|
|
31
31
|
from ..settings import ModelSettings
|
|
32
32
|
from ..tools import ToolDefinition
|
|
33
33
|
from . import (
|
|
34
|
-
AgentModel,
|
|
35
34
|
Model,
|
|
35
|
+
ModelRequestParameters,
|
|
36
36
|
StreamedResponse,
|
|
37
37
|
cached_async_http_client,
|
|
38
38
|
check_allow_model_requests,
|
|
39
39
|
get_user_agent,
|
|
40
40
|
)
|
|
41
41
|
|
|
42
|
-
|
|
43
|
-
'gemini-1.5-flash',
|
|
42
|
+
LatestGeminiModelNames = Literal[
|
|
43
|
+
'gemini-1.5-flash',
|
|
44
|
+
'gemini-1.5-flash-8b',
|
|
45
|
+
'gemini-1.5-pro',
|
|
46
|
+
'gemini-1.0-pro',
|
|
47
|
+
'gemini-2.0-flash-exp',
|
|
48
|
+
'gemini-2.0-flash-thinking-exp-01-21',
|
|
49
|
+
'gemini-exp-1206',
|
|
44
50
|
]
|
|
45
|
-
"""
|
|
51
|
+
"""Latest Gemini models."""
|
|
46
52
|
|
|
53
|
+
GeminiModelName = Union[str, LatestGeminiModelNames]
|
|
54
|
+
"""Possible Gemini model names.
|
|
55
|
+
|
|
56
|
+
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
|
|
57
|
+
allow any name in the type hints.
|
|
47
58
|
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
|
|
48
59
|
"""
|
|
49
60
|
|
|
@@ -51,7 +62,7 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#mo
|
|
|
51
62
|
class GeminiModelSettings(ModelSettings):
|
|
52
63
|
"""Settings used for a Gemini model request."""
|
|
53
64
|
|
|
54
|
-
|
|
65
|
+
gemini_safety_settings: list[GeminiSafetySettings]
|
|
55
66
|
|
|
56
67
|
|
|
57
68
|
@dataclass(init=False)
|
|
@@ -64,10 +75,12 @@ class GeminiModel(Model):
|
|
|
64
75
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
65
76
|
"""
|
|
66
77
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
78
|
+
http_client: AsyncHTTPClient = field(repr=False)
|
|
79
|
+
|
|
80
|
+
_model_name: GeminiModelName = field(repr=False)
|
|
81
|
+
_auth: AuthProtocol | None = field(repr=False)
|
|
82
|
+
_url: str | None = field(repr=False)
|
|
83
|
+
_system: str | None = field(default='google-gla', repr=False)
|
|
71
84
|
|
|
72
85
|
def __init__(
|
|
73
86
|
self,
|
|
@@ -88,121 +101,87 @@ class GeminiModel(Model):
|
|
|
88
101
|
docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
|
|
89
102
|
`model` is substituted with the model name, and `function` is added to the end of the URL.
|
|
90
103
|
"""
|
|
91
|
-
self.
|
|
104
|
+
self._model_name = model_name
|
|
92
105
|
if api_key is None:
|
|
93
106
|
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
94
107
|
api_key = env_api_key
|
|
95
108
|
else:
|
|
96
109
|
raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
97
|
-
self.auth = ApiKeyAuth(api_key)
|
|
98
110
|
self.http_client = http_client or cached_async_http_client()
|
|
99
|
-
self.
|
|
100
|
-
|
|
101
|
-
async def agent_model(
|
|
102
|
-
self,
|
|
103
|
-
*,
|
|
104
|
-
function_tools: list[ToolDefinition],
|
|
105
|
-
allow_text_result: bool,
|
|
106
|
-
result_tools: list[ToolDefinition],
|
|
107
|
-
) -> GeminiAgentModel:
|
|
108
|
-
check_allow_model_requests()
|
|
109
|
-
return GeminiAgentModel(
|
|
110
|
-
http_client=self.http_client,
|
|
111
|
-
model_name=self.model_name,
|
|
112
|
-
auth=self.auth,
|
|
113
|
-
url=self.url,
|
|
114
|
-
function_tools=function_tools,
|
|
115
|
-
allow_text_result=allow_text_result,
|
|
116
|
-
result_tools=result_tools,
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
def name(self) -> str:
|
|
120
|
-
return f'google-gla:{self.model_name}'
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
class AuthProtocol(Protocol):
|
|
124
|
-
"""Abstract definition for Gemini authentication."""
|
|
125
|
-
|
|
126
|
-
async def headers(self) -> dict[str, str]: ...
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
@dataclass
|
|
130
|
-
class ApiKeyAuth:
|
|
131
|
-
"""Authentication using an API key for the `X-Goog-Api-Key` header."""
|
|
132
|
-
|
|
133
|
-
api_key: str
|
|
111
|
+
self._auth = ApiKeyAuth(api_key)
|
|
112
|
+
self._url = url_template.format(model=model_name)
|
|
134
113
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
@dataclass(init=False)
|
|
141
|
-
class GeminiAgentModel(AgentModel):
|
|
142
|
-
"""Implementation of `AgentModel` for Gemini models."""
|
|
143
|
-
|
|
144
|
-
http_client: AsyncHTTPClient
|
|
145
|
-
model_name: GeminiModelName
|
|
146
|
-
auth: AuthProtocol
|
|
147
|
-
tools: _GeminiTools | None
|
|
148
|
-
tool_config: _GeminiToolConfig | None
|
|
149
|
-
url: str
|
|
150
|
-
|
|
151
|
-
def __init__(
|
|
152
|
-
self,
|
|
153
|
-
http_client: AsyncHTTPClient,
|
|
154
|
-
model_name: GeminiModelName,
|
|
155
|
-
auth: AuthProtocol,
|
|
156
|
-
url: str,
|
|
157
|
-
function_tools: list[ToolDefinition],
|
|
158
|
-
allow_text_result: bool,
|
|
159
|
-
result_tools: list[ToolDefinition],
|
|
160
|
-
):
|
|
161
|
-
tools = [_function_from_abstract_tool(t) for t in function_tools]
|
|
162
|
-
if result_tools:
|
|
163
|
-
tools += [_function_from_abstract_tool(t) for t in result_tools]
|
|
114
|
+
@property
|
|
115
|
+
def auth(self) -> AuthProtocol:
|
|
116
|
+
assert self._auth is not None, 'Auth not initialized'
|
|
117
|
+
return self._auth
|
|
164
118
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
self.http_client = http_client
|
|
171
|
-
self.model_name = model_name
|
|
172
|
-
self.auth = auth
|
|
173
|
-
self.tools = _GeminiTools(function_declarations=tools) if tools else None
|
|
174
|
-
self.tool_config = tool_config
|
|
175
|
-
self.url = url
|
|
119
|
+
@property
|
|
120
|
+
def url(self) -> str:
|
|
121
|
+
assert self._url is not None, 'URL not initialized'
|
|
122
|
+
return self._url
|
|
176
123
|
|
|
177
124
|
async def request(
|
|
178
|
-
self,
|
|
125
|
+
self,
|
|
126
|
+
messages: list[ModelMessage],
|
|
127
|
+
model_settings: ModelSettings | None,
|
|
128
|
+
model_request_parameters: ModelRequestParameters,
|
|
179
129
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
130
|
+
check_allow_model_requests()
|
|
180
131
|
async with self._make_request(
|
|
181
|
-
messages, False, cast(GeminiModelSettings, model_settings or {})
|
|
132
|
+
messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
|
|
182
133
|
) as http_response:
|
|
183
134
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
184
135
|
return self._process_response(response), _metadata_as_usage(response)
|
|
185
136
|
|
|
186
137
|
@asynccontextmanager
|
|
187
138
|
async def request_stream(
|
|
188
|
-
self,
|
|
139
|
+
self,
|
|
140
|
+
messages: list[ModelMessage],
|
|
141
|
+
model_settings: ModelSettings | None,
|
|
142
|
+
model_request_parameters: ModelRequestParameters,
|
|
189
143
|
) -> AsyncIterator[StreamedResponse]:
|
|
190
|
-
|
|
144
|
+
check_allow_model_requests()
|
|
145
|
+
async with self._make_request(
|
|
146
|
+
messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
|
|
147
|
+
) as http_response:
|
|
191
148
|
yield await self._process_streamed_response(http_response)
|
|
192
149
|
|
|
150
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
|
|
151
|
+
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
|
|
152
|
+
if model_request_parameters.result_tools:
|
|
153
|
+
tools += [_function_from_abstract_tool(t) for t in model_request_parameters.result_tools]
|
|
154
|
+
return _GeminiTools(function_declarations=tools) if tools else None
|
|
155
|
+
|
|
156
|
+
def _get_tool_config(
|
|
157
|
+
self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
|
|
158
|
+
) -> _GeminiToolConfig | None:
|
|
159
|
+
if model_request_parameters.allow_text_result:
|
|
160
|
+
return None
|
|
161
|
+
elif tools:
|
|
162
|
+
return _tool_config([t['name'] for t in tools['function_declarations']])
|
|
163
|
+
else:
|
|
164
|
+
return _tool_config([])
|
|
165
|
+
|
|
193
166
|
@asynccontextmanager
|
|
194
167
|
async def _make_request(
|
|
195
|
-
self,
|
|
168
|
+
self,
|
|
169
|
+
messages: list[ModelMessage],
|
|
170
|
+
streamed: bool,
|
|
171
|
+
model_settings: GeminiModelSettings,
|
|
172
|
+
model_request_parameters: ModelRequestParameters,
|
|
196
173
|
) -> AsyncIterator[HTTPResponse]:
|
|
174
|
+
tools = self._get_tools(model_request_parameters)
|
|
175
|
+
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
197
176
|
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
198
177
|
|
|
199
178
|
request_data = _GeminiRequest(contents=contents)
|
|
200
179
|
if sys_prompt_parts:
|
|
201
180
|
request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
|
|
202
|
-
if
|
|
203
|
-
request_data['tools'] =
|
|
204
|
-
if
|
|
205
|
-
request_data['tool_config'] =
|
|
181
|
+
if tools is not None:
|
|
182
|
+
request_data['tools'] = tools
|
|
183
|
+
if tool_config is not None:
|
|
184
|
+
request_data['tool_config'] = tool_config
|
|
206
185
|
|
|
207
186
|
generation_config: _GeminiGenerationConfig = {}
|
|
208
187
|
if model_settings:
|
|
@@ -216,6 +195,8 @@ class GeminiAgentModel(AgentModel):
|
|
|
216
195
|
generation_config['presence_penalty'] = presence_penalty
|
|
217
196
|
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
|
|
218
197
|
generation_config['frequency_penalty'] = frequency_penalty
|
|
198
|
+
if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []:
|
|
199
|
+
request_data['safety_settings'] = gemini_safety_settings
|
|
219
200
|
if generation_config:
|
|
220
201
|
request_data['generation_config'] = generation_config
|
|
221
202
|
|
|
@@ -244,8 +225,13 @@ class GeminiAgentModel(AgentModel):
|
|
|
244
225
|
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
|
|
245
226
|
if len(response['candidates']) != 1:
|
|
246
227
|
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
228
|
+
if 'content' not in response['candidates'][0]:
|
|
229
|
+
if response['candidates'][0].get('finish_reason') == 'SAFETY':
|
|
230
|
+
raise UnexpectedModelBehavior('Safety settings triggered', str(response))
|
|
231
|
+
else:
|
|
232
|
+
raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
|
|
247
233
|
parts = response['candidates'][0]['content']['parts']
|
|
248
|
-
return _process_response_from_parts(parts, model_name=self.
|
|
234
|
+
return _process_response_from_parts(parts, model_name=self._model_name)
|
|
249
235
|
|
|
250
236
|
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
|
|
251
237
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -261,14 +247,14 @@ class GeminiAgentModel(AgentModel):
|
|
|
261
247
|
)
|
|
262
248
|
if responses:
|
|
263
249
|
last = responses[-1]
|
|
264
|
-
if last['candidates'] and last['candidates'][0]
|
|
250
|
+
if last['candidates'] and last['candidates'][0].get('content', {}).get('parts'):
|
|
265
251
|
start_response = last
|
|
266
252
|
break
|
|
267
253
|
|
|
268
254
|
if start_response is None:
|
|
269
255
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
270
256
|
|
|
271
|
-
return GeminiStreamedResponse(_model_name=self.
|
|
257
|
+
return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
|
|
272
258
|
|
|
273
259
|
@classmethod
|
|
274
260
|
def _message_to_gemini_content(
|
|
@@ -306,6 +292,23 @@ class GeminiAgentModel(AgentModel):
|
|
|
306
292
|
return sys_prompt_parts, contents
|
|
307
293
|
|
|
308
294
|
|
|
295
|
+
class AuthProtocol(Protocol):
|
|
296
|
+
"""Abstract definition for Gemini authentication."""
|
|
297
|
+
|
|
298
|
+
async def headers(self) -> dict[str, str]: ...
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@dataclass
|
|
302
|
+
class ApiKeyAuth:
|
|
303
|
+
"""Authentication using an API key for the `X-Goog-Api-Key` header."""
|
|
304
|
+
|
|
305
|
+
api_key: str
|
|
306
|
+
|
|
307
|
+
async def headers(self) -> dict[str, str]:
|
|
308
|
+
# https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
|
|
309
|
+
return {'X-Goog-Api-Key': self.api_key}
|
|
310
|
+
|
|
311
|
+
|
|
309
312
|
@dataclass
|
|
310
313
|
class GeminiStreamedResponse(StreamedResponse):
|
|
311
314
|
"""Implementation of `StreamedResponse` for the Gemini model."""
|
|
@@ -317,6 +320,8 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
317
320
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
318
321
|
async for gemini_response in self._get_gemini_responses():
|
|
319
322
|
candidate = gemini_response['candidates'][0]
|
|
323
|
+
if 'content' not in candidate:
|
|
324
|
+
raise UnexpectedModelBehavior('Streamed response has no content field')
|
|
320
325
|
gemini_part: _GeminiPartUnion
|
|
321
326
|
for gemini_part in candidate['content']['parts']:
|
|
322
327
|
if 'text' in gemini_part:
|
|
@@ -390,6 +395,7 @@ class _GeminiRequest(TypedDict):
|
|
|
390
395
|
contents: list[_GeminiContent]
|
|
391
396
|
tools: NotRequired[_GeminiTools]
|
|
392
397
|
tool_config: NotRequired[_GeminiToolConfig]
|
|
398
|
+
safety_settings: NotRequired[list[GeminiSafetySettings]]
|
|
393
399
|
# we don't implement `generationConfig`, instead we use a named tool for the response
|
|
394
400
|
system_instruction: NotRequired[_GeminiTextContent]
|
|
395
401
|
"""
|
|
@@ -399,6 +405,38 @@ class _GeminiRequest(TypedDict):
|
|
|
399
405
|
generation_config: NotRequired[_GeminiGenerationConfig]
|
|
400
406
|
|
|
401
407
|
|
|
408
|
+
class GeminiSafetySettings(TypedDict):
|
|
409
|
+
"""Safety settings options for Gemini model request.
|
|
410
|
+
|
|
411
|
+
See [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for safety category and threshold descriptions.
|
|
412
|
+
For an example on how to use `GeminiSafetySettings`, see [here](../../agents.md#model-specific-settings).
|
|
413
|
+
"""
|
|
414
|
+
|
|
415
|
+
category: Literal[
|
|
416
|
+
'HARM_CATEGORY_UNSPECIFIED',
|
|
417
|
+
'HARM_CATEGORY_HARASSMENT',
|
|
418
|
+
'HARM_CATEGORY_HATE_SPEECH',
|
|
419
|
+
'HARM_CATEGORY_SEXUALLY_EXPLICIT',
|
|
420
|
+
'HARM_CATEGORY_DANGEROUS_CONTENT',
|
|
421
|
+
'HARM_CATEGORY_CIVIC_INTEGRITY',
|
|
422
|
+
]
|
|
423
|
+
"""
|
|
424
|
+
Safety settings category.
|
|
425
|
+
"""
|
|
426
|
+
|
|
427
|
+
threshold: Literal[
|
|
428
|
+
'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
|
429
|
+
'BLOCK_LOW_AND_ABOVE',
|
|
430
|
+
'BLOCK_MEDIUM_AND_ABOVE',
|
|
431
|
+
'BLOCK_ONLY_HIGH',
|
|
432
|
+
'BLOCK_NONE',
|
|
433
|
+
'OFF',
|
|
434
|
+
]
|
|
435
|
+
"""
|
|
436
|
+
Safety settings threshold.
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
|
|
402
440
|
class _GeminiGenerationConfig(TypedDict, total=False):
|
|
403
441
|
"""Schema for an API request to the Gemini API.
|
|
404
442
|
|
|
@@ -575,8 +613,8 @@ class _GeminiResponse(TypedDict):
|
|
|
575
613
|
class _GeminiCandidates(TypedDict):
|
|
576
614
|
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
|
|
577
615
|
|
|
578
|
-
content: _GeminiContent
|
|
579
|
-
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
|
|
616
|
+
content: NotRequired[_GeminiContent]
|
|
617
|
+
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS', 'SAFETY'], pydantic.Field(alias='finishReason')]]
|
|
580
618
|
"""
|
|
581
619
|
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
|
|
582
620
|
but let's wait until we see them and know what they mean to add them here.
|
|
@@ -624,6 +662,7 @@ class _GeminiSafetyRating(TypedDict):
|
|
|
624
662
|
'HARM_CATEGORY_CIVIC_INTEGRITY',
|
|
625
663
|
]
|
|
626
664
|
probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH']
|
|
665
|
+
blocked: NotRequired[bool]
|
|
627
666
|
|
|
628
667
|
|
|
629
668
|
class _GeminiPromptFeedback(TypedDict):
|