pydantic-ai-slim 0.0.22__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -4,7 +4,7 @@ import inspect
4
4
  import re
5
5
  from collections.abc import AsyncIterator, Awaitable, Iterable
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field, replace
7
+ from dataclasses import dataclass, field
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
10
  from typing import Callable, Union
@@ -27,7 +27,7 @@ from ..messages import (
27
27
  )
28
28
  from ..settings import ModelSettings
29
29
  from ..tools import ToolDefinition
30
- from . import AgentModel, Model, StreamedResponse
30
+ from . import Model, ModelRequestParameters, StreamedResponse
31
31
 
32
32
 
33
33
  @dataclass(init=False)
@@ -40,6 +40,9 @@ class FunctionModel(Model):
40
40
  function: FunctionDef | None = None
41
41
  stream_function: StreamFunctionDef | None = None
42
42
 
43
+ _model_name: str = field(repr=False)
44
+ _system: str | None = field(default=None, repr=False)
45
+
43
46
  @overload
44
47
  def __init__(self, function: FunctionDef) -> None: ...
45
48
 
@@ -63,23 +66,60 @@ class FunctionModel(Model):
63
66
  self.function = function
64
67
  self.stream_function = stream_function
65
68
 
66
- async def agent_model(
69
+ function_name = self.function.__name__ if self.function is not None else ''
70
+ stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
71
+ self._model_name = f'function:{function_name}:{stream_function_name}'
72
+
73
+ async def request(
67
74
  self,
68
- *,
69
- function_tools: list[ToolDefinition],
70
- allow_text_result: bool,
71
- result_tools: list[ToolDefinition],
72
- ) -> AgentModel:
73
- return FunctionAgentModel(
74
- self.function,
75
- self.stream_function,
76
- AgentInfo(function_tools, allow_text_result, result_tools, None),
75
+ messages: list[ModelMessage],
76
+ model_settings: ModelSettings | None,
77
+ model_request_parameters: ModelRequestParameters,
78
+ ) -> tuple[ModelResponse, usage.Usage]:
79
+ agent_info = AgentInfo(
80
+ model_request_parameters.function_tools,
81
+ model_request_parameters.allow_text_result,
82
+ model_request_parameters.result_tools,
83
+ model_settings,
77
84
  )
78
85
 
79
- def name(self) -> str:
80
- function_name = self.function.__name__ if self.function is not None else ''
81
- stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
82
- return f'function:{function_name}:{stream_function_name}'
86
+ assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
87
+
88
+ if inspect.iscoroutinefunction(self.function):
89
+ response = await self.function(messages, agent_info)
90
+ else:
91
+ response_ = await _utils.run_in_executor(self.function, messages, agent_info)
92
+ assert isinstance(response_, ModelResponse), response_
93
+ response = response_
94
+ response.model_name = f'function:{self.function.__name__}'
95
+ # TODO is `messages` right here? Should it just be new messages?
96
+ return response, _estimate_usage(chain(messages, [response]))
97
+
98
+ @asynccontextmanager
99
+ async def request_stream(
100
+ self,
101
+ messages: list[ModelMessage],
102
+ model_settings: ModelSettings | None,
103
+ model_request_parameters: ModelRequestParameters,
104
+ ) -> AsyncIterator[StreamedResponse]:
105
+ agent_info = AgentInfo(
106
+ model_request_parameters.function_tools,
107
+ model_request_parameters.allow_text_result,
108
+ model_request_parameters.result_tools,
109
+ model_settings,
110
+ )
111
+
112
+ assert (
113
+ self.stream_function is not None
114
+ ), 'FunctionModel must receive a `stream_function` to support streamed requests'
115
+
116
+ response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
117
+
118
+ first = await response_stream.peek()
119
+ if isinstance(first, _utils.Unset):
120
+ raise ValueError('Stream function must return at least one item')
121
+
122
+ yield FunctionStreamedResponse(_model_name=f'function:{self.stream_function.__name__}', _iter=response_stream)
83
123
 
84
124
 
85
125
  @dataclass(frozen=True)
@@ -119,9 +159,11 @@ class DeltaToolCall:
119
159
  DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
120
160
  """A mapping of tool call IDs to incremental changes."""
121
161
 
162
+ # TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
122
163
  FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
123
164
  """A function used to generate a non-streamed response."""
124
165
 
166
+ # TODO: Change signature as indicated above
125
167
  StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
126
168
  """A function used to generate a streamed response.
127
169
 
@@ -132,50 +174,6 @@ E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
132
174
  """
133
175
 
134
176
 
135
- @dataclass
136
- class FunctionAgentModel(AgentModel):
137
- """Implementation of `AgentModel` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
138
-
139
- function: FunctionDef | None
140
- stream_function: StreamFunctionDef | None
141
- agent_info: AgentInfo
142
-
143
- async def request(
144
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
145
- ) -> tuple[ModelResponse, usage.Usage]:
146
- agent_info = replace(self.agent_info, model_settings=model_settings)
147
-
148
- assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
149
- model_name = f'function:{self.function.__name__}'
150
-
151
- if inspect.iscoroutinefunction(self.function):
152
- response = await self.function(messages, agent_info)
153
- else:
154
- response_ = await _utils.run_in_executor(self.function, messages, agent_info)
155
- assert isinstance(response_, ModelResponse), response_
156
- response = response_
157
- response.model_name = model_name
158
- # TODO is `messages` right here? Should it just be new messages?
159
- return response, _estimate_usage(chain(messages, [response]))
160
-
161
- @asynccontextmanager
162
- async def request_stream(
163
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
164
- ) -> AsyncIterator[StreamedResponse]:
165
- assert (
166
- self.stream_function is not None
167
- ), 'FunctionModel must receive a `stream_function` to support streamed requests'
168
- model_name = f'function:{self.stream_function.__name__}'
169
-
170
- response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
171
-
172
- first = await response_stream.peek()
173
- if isinstance(first, _utils.Unset):
174
- raise ValueError('Stream function must return at least one item')
175
-
176
- yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
177
-
178
-
179
177
  @dataclass
180
178
  class FunctionStreamedResponse(StreamedResponse):
181
179
  """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
@@ -31,15 +31,15 @@ from ..messages import (
31
31
  from ..settings import ModelSettings
32
32
  from ..tools import ToolDefinition
33
33
  from . import (
34
- AgentModel,
35
34
  Model,
35
+ ModelRequestParameters,
36
36
  StreamedResponse,
37
37
  cached_async_http_client,
38
38
  check_allow_model_requests,
39
39
  get_user_agent,
40
40
  )
41
41
 
42
- GeminiModelName = Literal[
42
+ LatestGeminiModelNames = Literal[
43
43
  'gemini-1.5-flash',
44
44
  'gemini-1.5-flash-8b',
45
45
  'gemini-1.5-pro',
@@ -48,8 +48,13 @@ GeminiModelName = Literal[
48
48
  'gemini-2.0-flash-thinking-exp-01-21',
49
49
  'gemini-exp-1206',
50
50
  ]
51
- """Named Gemini models.
51
+ """Latest Gemini models."""
52
52
 
53
+ GeminiModelName = Union[str, LatestGeminiModelNames]
54
+ """Possible Gemini model names.
55
+
56
+ Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
57
+ allow any name in the type hints.
53
58
  See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
54
59
  """
55
60
 
@@ -57,7 +62,7 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#mo
57
62
  class GeminiModelSettings(ModelSettings):
58
63
  """Settings used for a Gemini model request."""
59
64
 
60
- # This class is a placeholder for any future gemini-specific settings
65
+ gemini_safety_settings: list[GeminiSafetySettings]
61
66
 
62
67
 
63
68
  @dataclass(init=False)
@@ -70,10 +75,12 @@ class GeminiModel(Model):
70
75
  Apart from `__init__`, all methods are private or match those of the base class.
71
76
  """
72
77
 
73
- model_name: GeminiModelName
74
- auth: AuthProtocol
75
- http_client: AsyncHTTPClient
76
- url: str
78
+ http_client: AsyncHTTPClient = field(repr=False)
79
+
80
+ _model_name: GeminiModelName = field(repr=False)
81
+ _auth: AuthProtocol | None = field(repr=False)
82
+ _url: str | None = field(repr=False)
83
+ _system: str | None = field(default='google-gla', repr=False)
77
84
 
78
85
  def __init__(
79
86
  self,
@@ -94,121 +101,87 @@ class GeminiModel(Model):
94
101
  docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
95
102
  `model` is substituted with the model name, and `function` is added to the end of the URL.
96
103
  """
97
- self.model_name = model_name
104
+ self._model_name = model_name
98
105
  if api_key is None:
99
106
  if env_api_key := os.getenv('GEMINI_API_KEY'):
100
107
  api_key = env_api_key
101
108
  else:
102
109
  raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
103
- self.auth = ApiKeyAuth(api_key)
104
110
  self.http_client = http_client or cached_async_http_client()
105
- self.url = url_template.format(model=model_name)
106
-
107
- async def agent_model(
108
- self,
109
- *,
110
- function_tools: list[ToolDefinition],
111
- allow_text_result: bool,
112
- result_tools: list[ToolDefinition],
113
- ) -> GeminiAgentModel:
114
- check_allow_model_requests()
115
- return GeminiAgentModel(
116
- http_client=self.http_client,
117
- model_name=self.model_name,
118
- auth=self.auth,
119
- url=self.url,
120
- function_tools=function_tools,
121
- allow_text_result=allow_text_result,
122
- result_tools=result_tools,
123
- )
124
-
125
- def name(self) -> str:
126
- return f'google-gla:{self.model_name}'
127
-
128
-
129
- class AuthProtocol(Protocol):
130
- """Abstract definition for Gemini authentication."""
131
-
132
- async def headers(self) -> dict[str, str]: ...
133
-
134
-
135
- @dataclass
136
- class ApiKeyAuth:
137
- """Authentication using an API key for the `X-Goog-Api-Key` header."""
138
-
139
- api_key: str
111
+ self._auth = ApiKeyAuth(api_key)
112
+ self._url = url_template.format(model=model_name)
140
113
 
141
- async def headers(self) -> dict[str, str]:
142
- # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
143
- return {'X-Goog-Api-Key': self.api_key}
144
-
145
-
146
- @dataclass(init=False)
147
- class GeminiAgentModel(AgentModel):
148
- """Implementation of `AgentModel` for Gemini models."""
149
-
150
- http_client: AsyncHTTPClient
151
- model_name: GeminiModelName
152
- auth: AuthProtocol
153
- tools: _GeminiTools | None
154
- tool_config: _GeminiToolConfig | None
155
- url: str
156
-
157
- def __init__(
158
- self,
159
- http_client: AsyncHTTPClient,
160
- model_name: GeminiModelName,
161
- auth: AuthProtocol,
162
- url: str,
163
- function_tools: list[ToolDefinition],
164
- allow_text_result: bool,
165
- result_tools: list[ToolDefinition],
166
- ):
167
- tools = [_function_from_abstract_tool(t) for t in function_tools]
168
- if result_tools:
169
- tools += [_function_from_abstract_tool(t) for t in result_tools]
114
+ @property
115
+ def auth(self) -> AuthProtocol:
116
+ assert self._auth is not None, 'Auth not initialized'
117
+ return self._auth
170
118
 
171
- if allow_text_result:
172
- tool_config = None
173
- else:
174
- tool_config = _tool_config([t['name'] for t in tools])
175
-
176
- self.http_client = http_client
177
- self.model_name = model_name
178
- self.auth = auth
179
- self.tools = _GeminiTools(function_declarations=tools) if tools else None
180
- self.tool_config = tool_config
181
- self.url = url
119
+ @property
120
+ def url(self) -> str:
121
+ assert self._url is not None, 'URL not initialized'
122
+ return self._url
182
123
 
183
124
  async def request(
184
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
125
+ self,
126
+ messages: list[ModelMessage],
127
+ model_settings: ModelSettings | None,
128
+ model_request_parameters: ModelRequestParameters,
185
129
  ) -> tuple[ModelResponse, usage.Usage]:
130
+ check_allow_model_requests()
186
131
  async with self._make_request(
187
- messages, False, cast(GeminiModelSettings, model_settings or {})
132
+ messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
188
133
  ) as http_response:
189
134
  response = _gemini_response_ta.validate_json(await http_response.aread())
190
135
  return self._process_response(response), _metadata_as_usage(response)
191
136
 
192
137
  @asynccontextmanager
193
138
  async def request_stream(
194
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
139
+ self,
140
+ messages: list[ModelMessage],
141
+ model_settings: ModelSettings | None,
142
+ model_request_parameters: ModelRequestParameters,
195
143
  ) -> AsyncIterator[StreamedResponse]:
196
- async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
144
+ check_allow_model_requests()
145
+ async with self._make_request(
146
+ messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
147
+ ) as http_response:
197
148
  yield await self._process_streamed_response(http_response)
198
149
 
150
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
151
+ tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
152
+ if model_request_parameters.result_tools:
153
+ tools += [_function_from_abstract_tool(t) for t in model_request_parameters.result_tools]
154
+ return _GeminiTools(function_declarations=tools) if tools else None
155
+
156
+ def _get_tool_config(
157
+ self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
158
+ ) -> _GeminiToolConfig | None:
159
+ if model_request_parameters.allow_text_result:
160
+ return None
161
+ elif tools:
162
+ return _tool_config([t['name'] for t in tools['function_declarations']])
163
+ else:
164
+ return _tool_config([])
165
+
199
166
  @asynccontextmanager
200
167
  async def _make_request(
201
- self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
168
+ self,
169
+ messages: list[ModelMessage],
170
+ streamed: bool,
171
+ model_settings: GeminiModelSettings,
172
+ model_request_parameters: ModelRequestParameters,
202
173
  ) -> AsyncIterator[HTTPResponse]:
174
+ tools = self._get_tools(model_request_parameters)
175
+ tool_config = self._get_tool_config(model_request_parameters, tools)
203
176
  sys_prompt_parts, contents = self._message_to_gemini_content(messages)
204
177
 
205
178
  request_data = _GeminiRequest(contents=contents)
206
179
  if sys_prompt_parts:
207
180
  request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
208
- if self.tools is not None:
209
- request_data['tools'] = self.tools
210
- if self.tool_config is not None:
211
- request_data['tool_config'] = self.tool_config
181
+ if tools is not None:
182
+ request_data['tools'] = tools
183
+ if tool_config is not None:
184
+ request_data['tool_config'] = tool_config
212
185
 
213
186
  generation_config: _GeminiGenerationConfig = {}
214
187
  if model_settings:
@@ -222,6 +195,8 @@ class GeminiAgentModel(AgentModel):
222
195
  generation_config['presence_penalty'] = presence_penalty
223
196
  if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
224
197
  generation_config['frequency_penalty'] = frequency_penalty
198
+ if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []:
199
+ request_data['safety_settings'] = gemini_safety_settings
225
200
  if generation_config:
226
201
  request_data['generation_config'] = generation_config
227
202
 
@@ -250,8 +225,13 @@ class GeminiAgentModel(AgentModel):
250
225
  def _process_response(self, response: _GeminiResponse) -> ModelResponse:
251
226
  if len(response['candidates']) != 1:
252
227
  raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
228
+ if 'content' not in response['candidates'][0]:
229
+ if response['candidates'][0].get('finish_reason') == 'SAFETY':
230
+ raise UnexpectedModelBehavior('Safety settings triggered', str(response))
231
+ else:
232
+ raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
253
233
  parts = response['candidates'][0]['content']['parts']
254
- return _process_response_from_parts(parts, model_name=self.model_name)
234
+ return _process_response_from_parts(parts, model_name=self._model_name)
255
235
 
256
236
  async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
257
237
  """Process a streamed response, and prepare a streaming response to return."""
@@ -267,14 +247,14 @@ class GeminiAgentModel(AgentModel):
267
247
  )
268
248
  if responses:
269
249
  last = responses[-1]
270
- if last['candidates'] and last['candidates'][0]['content']['parts']:
250
+ if last['candidates'] and last['candidates'][0].get('content', {}).get('parts'):
271
251
  start_response = last
272
252
  break
273
253
 
274
254
  if start_response is None:
275
255
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
276
256
 
277
- return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
257
+ return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
278
258
 
279
259
  @classmethod
280
260
  def _message_to_gemini_content(
@@ -312,6 +292,23 @@ class GeminiAgentModel(AgentModel):
312
292
  return sys_prompt_parts, contents
313
293
 
314
294
 
295
+ class AuthProtocol(Protocol):
296
+ """Abstract definition for Gemini authentication."""
297
+
298
+ async def headers(self) -> dict[str, str]: ...
299
+
300
+
301
+ @dataclass
302
+ class ApiKeyAuth:
303
+ """Authentication using an API key for the `X-Goog-Api-Key` header."""
304
+
305
+ api_key: str
306
+
307
+ async def headers(self) -> dict[str, str]:
308
+ # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
309
+ return {'X-Goog-Api-Key': self.api_key}
310
+
311
+
315
312
  @dataclass
316
313
  class GeminiStreamedResponse(StreamedResponse):
317
314
  """Implementation of `StreamedResponse` for the Gemini model."""
@@ -323,6 +320,8 @@ class GeminiStreamedResponse(StreamedResponse):
323
320
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
324
321
  async for gemini_response in self._get_gemini_responses():
325
322
  candidate = gemini_response['candidates'][0]
323
+ if 'content' not in candidate:
324
+ raise UnexpectedModelBehavior('Streamed response has no content field')
326
325
  gemini_part: _GeminiPartUnion
327
326
  for gemini_part in candidate['content']['parts']:
328
327
  if 'text' in gemini_part:
@@ -396,6 +395,7 @@ class _GeminiRequest(TypedDict):
396
395
  contents: list[_GeminiContent]
397
396
  tools: NotRequired[_GeminiTools]
398
397
  tool_config: NotRequired[_GeminiToolConfig]
398
+ safety_settings: NotRequired[list[GeminiSafetySettings]]
399
399
  # we don't implement `generationConfig`, instead we use a named tool for the response
400
400
  system_instruction: NotRequired[_GeminiTextContent]
401
401
  """
@@ -405,6 +405,38 @@ class _GeminiRequest(TypedDict):
405
405
  generation_config: NotRequired[_GeminiGenerationConfig]
406
406
 
407
407
 
408
+ class GeminiSafetySettings(TypedDict):
409
+ """Safety settings options for Gemini model request.
410
+
411
+ See [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for safety category and threshold descriptions.
412
+ For an example on how to use `GeminiSafetySettings`, see [here](../../agents.md#model-specific-settings).
413
+ """
414
+
415
+ category: Literal[
416
+ 'HARM_CATEGORY_UNSPECIFIED',
417
+ 'HARM_CATEGORY_HARASSMENT',
418
+ 'HARM_CATEGORY_HATE_SPEECH',
419
+ 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
420
+ 'HARM_CATEGORY_DANGEROUS_CONTENT',
421
+ 'HARM_CATEGORY_CIVIC_INTEGRITY',
422
+ ]
423
+ """
424
+ Safety settings category.
425
+ """
426
+
427
+ threshold: Literal[
428
+ 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
429
+ 'BLOCK_LOW_AND_ABOVE',
430
+ 'BLOCK_MEDIUM_AND_ABOVE',
431
+ 'BLOCK_ONLY_HIGH',
432
+ 'BLOCK_NONE',
433
+ 'OFF',
434
+ ]
435
+ """
436
+ Safety settings threshold.
437
+ """
438
+
439
+
408
440
  class _GeminiGenerationConfig(TypedDict, total=False):
409
441
  """Schema for an API request to the Gemini API.
410
442
 
@@ -581,8 +613,8 @@ class _GeminiResponse(TypedDict):
581
613
  class _GeminiCandidates(TypedDict):
582
614
  """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
583
615
 
584
- content: _GeminiContent
585
- finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
616
+ content: NotRequired[_GeminiContent]
617
+ finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS', 'SAFETY'], pydantic.Field(alias='finishReason')]]
586
618
  """
587
619
  See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
588
620
  but let's wait until we see them and know what they mean to add them here.
@@ -630,6 +662,7 @@ class _GeminiSafetyRating(TypedDict):
630
662
  'HARM_CATEGORY_CIVIC_INTEGRITY',
631
663
  ]
632
664
  probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH']
665
+ blocked: NotRequired[bool]
633
666
 
634
667
 
635
668
  class _GeminiPromptFeedback(TypedDict):