pydantic-ai-slim 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -31,15 +31,15 @@ from ..messages import (
31
31
  from ..settings import ModelSettings
32
32
  from ..tools import ToolDefinition
33
33
  from . import (
34
- AgentModel,
35
34
  Model,
35
+ ModelRequestParameters,
36
36
  StreamedResponse,
37
37
  cached_async_http_client,
38
38
  check_allow_model_requests,
39
39
  get_user_agent,
40
40
  )
41
41
 
42
- GeminiModelName = Literal[
42
+ LatestGeminiModelNames = Literal[
43
43
  'gemini-1.5-flash',
44
44
  'gemini-1.5-flash-8b',
45
45
  'gemini-1.5-pro',
@@ -47,9 +47,16 @@ GeminiModelName = Literal[
47
47
  'gemini-2.0-flash-exp',
48
48
  'gemini-2.0-flash-thinking-exp-01-21',
49
49
  'gemini-exp-1206',
50
+ 'gemini-2.0-flash',
51
+ 'gemini-2.0-flash-lite-preview-02-05',
50
52
  ]
51
- """Named Gemini models.
53
+ """Latest Gemini models."""
52
54
 
55
+ GeminiModelName = Union[str, LatestGeminiModelNames]
56
+ """Possible Gemini model names.
57
+
58
+ Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
59
+ allow any name in the type hints.
53
60
  See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
54
61
  """
55
62
 
@@ -57,7 +64,7 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#mo
57
64
  class GeminiModelSettings(ModelSettings):
58
65
  """Settings used for a Gemini model request."""
59
66
 
60
- # This class is a placeholder for any future gemini-specific settings
67
+ gemini_safety_settings: list[GeminiSafetySettings]
61
68
 
62
69
 
63
70
  @dataclass(init=False)
@@ -70,10 +77,12 @@ class GeminiModel(Model):
70
77
  Apart from `__init__`, all methods are private or match those of the base class.
71
78
  """
72
79
 
73
- model_name: GeminiModelName
74
- auth: AuthProtocol
75
- http_client: AsyncHTTPClient
76
- url: str
80
+ http_client: AsyncHTTPClient = field(repr=False)
81
+
82
+ _model_name: GeminiModelName = field(repr=False)
83
+ _auth: AuthProtocol | None = field(repr=False)
84
+ _url: str | None = field(repr=False)
85
+ _system: str | None = field(default='google-gla', repr=False)
77
86
 
78
87
  def __init__(
79
88
  self,
@@ -94,121 +103,97 @@ class GeminiModel(Model):
94
103
  docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
95
104
  `model` is substituted with the model name, and `function` is added to the end of the URL.
96
105
  """
97
- self.model_name = model_name
106
+ self._model_name = model_name
98
107
  if api_key is None:
99
108
  if env_api_key := os.getenv('GEMINI_API_KEY'):
100
109
  api_key = env_api_key
101
110
  else:
102
111
  raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
103
- self.auth = ApiKeyAuth(api_key)
104
112
  self.http_client = http_client or cached_async_http_client()
105
- self.url = url_template.format(model=model_name)
106
-
107
- async def agent_model(
108
- self,
109
- *,
110
- function_tools: list[ToolDefinition],
111
- allow_text_result: bool,
112
- result_tools: list[ToolDefinition],
113
- ) -> GeminiAgentModel:
114
- check_allow_model_requests()
115
- return GeminiAgentModel(
116
- http_client=self.http_client,
117
- model_name=self.model_name,
118
- auth=self.auth,
119
- url=self.url,
120
- function_tools=function_tools,
121
- allow_text_result=allow_text_result,
122
- result_tools=result_tools,
123
- )
124
-
125
- def name(self) -> str:
126
- return f'google-gla:{self.model_name}'
127
-
128
-
129
- class AuthProtocol(Protocol):
130
- """Abstract definition for Gemini authentication."""
131
-
132
- async def headers(self) -> dict[str, str]: ...
133
-
134
-
135
- @dataclass
136
- class ApiKeyAuth:
137
- """Authentication using an API key for the `X-Goog-Api-Key` header."""
138
-
139
- api_key: str
113
+ self._auth = ApiKeyAuth(api_key)
114
+ self._url = url_template.format(model=model_name)
140
115
 
141
- async def headers(self) -> dict[str, str]:
142
- # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
143
- return {'X-Goog-Api-Key': self.api_key}
144
-
145
-
146
- @dataclass(init=False)
147
- class GeminiAgentModel(AgentModel):
148
- """Implementation of `AgentModel` for Gemini models."""
149
-
150
- http_client: AsyncHTTPClient
151
- model_name: GeminiModelName
152
- auth: AuthProtocol
153
- tools: _GeminiTools | None
154
- tool_config: _GeminiToolConfig | None
155
- url: str
156
-
157
- def __init__(
158
- self,
159
- http_client: AsyncHTTPClient,
160
- model_name: GeminiModelName,
161
- auth: AuthProtocol,
162
- url: str,
163
- function_tools: list[ToolDefinition],
164
- allow_text_result: bool,
165
- result_tools: list[ToolDefinition],
166
- ):
167
- tools = [_function_from_abstract_tool(t) for t in function_tools]
168
- if result_tools:
169
- tools += [_function_from_abstract_tool(t) for t in result_tools]
116
+ @property
117
+ def auth(self) -> AuthProtocol:
118
+ assert self._auth is not None, 'Auth not initialized'
119
+ return self._auth
170
120
 
171
- if allow_text_result:
172
- tool_config = None
173
- else:
174
- tool_config = _tool_config([t['name'] for t in tools])
175
-
176
- self.http_client = http_client
177
- self.model_name = model_name
178
- self.auth = auth
179
- self.tools = _GeminiTools(function_declarations=tools) if tools else None
180
- self.tool_config = tool_config
181
- self.url = url
121
+ @property
122
+ def url(self) -> str:
123
+ assert self._url is not None, 'URL not initialized'
124
+ return self._url
182
125
 
183
126
  async def request(
184
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
127
+ self,
128
+ messages: list[ModelMessage],
129
+ model_settings: ModelSettings | None,
130
+ model_request_parameters: ModelRequestParameters,
185
131
  ) -> tuple[ModelResponse, usage.Usage]:
132
+ check_allow_model_requests()
186
133
  async with self._make_request(
187
- messages, False, cast(GeminiModelSettings, model_settings or {})
134
+ messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
188
135
  ) as http_response:
189
136
  response = _gemini_response_ta.validate_json(await http_response.aread())
190
137
  return self._process_response(response), _metadata_as_usage(response)
191
138
 
192
139
  @asynccontextmanager
193
140
  async def request_stream(
194
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
141
+ self,
142
+ messages: list[ModelMessage],
143
+ model_settings: ModelSettings | None,
144
+ model_request_parameters: ModelRequestParameters,
195
145
  ) -> AsyncIterator[StreamedResponse]:
196
- async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
146
+ check_allow_model_requests()
147
+ async with self._make_request(
148
+ messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
149
+ ) as http_response:
197
150
  yield await self._process_streamed_response(http_response)
198
151
 
152
+ @property
153
+ def model_name(self) -> GeminiModelName:
154
+ """The model name."""
155
+ return self._model_name
156
+
157
+ @property
158
+ def system(self) -> str | None:
159
+ """The system / model provider."""
160
+ return self._system
161
+
162
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
163
+ tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
164
+ if model_request_parameters.result_tools:
165
+ tools += [_function_from_abstract_tool(t) for t in model_request_parameters.result_tools]
166
+ return _GeminiTools(function_declarations=tools) if tools else None
167
+
168
+ def _get_tool_config(
169
+ self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
170
+ ) -> _GeminiToolConfig | None:
171
+ if model_request_parameters.allow_text_result:
172
+ return None
173
+ elif tools:
174
+ return _tool_config([t['name'] for t in tools['function_declarations']])
175
+ else:
176
+ return _tool_config([])
177
+
199
178
  @asynccontextmanager
200
179
  async def _make_request(
201
- self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
180
+ self,
181
+ messages: list[ModelMessage],
182
+ streamed: bool,
183
+ model_settings: GeminiModelSettings,
184
+ model_request_parameters: ModelRequestParameters,
202
185
  ) -> AsyncIterator[HTTPResponse]:
186
+ tools = self._get_tools(model_request_parameters)
187
+ tool_config = self._get_tool_config(model_request_parameters, tools)
203
188
  sys_prompt_parts, contents = self._message_to_gemini_content(messages)
204
189
 
205
190
  request_data = _GeminiRequest(contents=contents)
206
191
  if sys_prompt_parts:
207
192
  request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
208
- if self.tools is not None:
209
- request_data['tools'] = self.tools
210
- if self.tool_config is not None:
211
- request_data['tool_config'] = self.tool_config
193
+ if tools is not None:
194
+ request_data['tools'] = tools
195
+ if tool_config is not None:
196
+ request_data['tool_config'] = tool_config
212
197
 
213
198
  generation_config: _GeminiGenerationConfig = {}
214
199
  if model_settings:
@@ -222,6 +207,8 @@ class GeminiAgentModel(AgentModel):
222
207
  generation_config['presence_penalty'] = presence_penalty
223
208
  if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
224
209
  generation_config['frequency_penalty'] = frequency_penalty
210
+ if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []:
211
+ request_data['safety_settings'] = gemini_safety_settings
225
212
  if generation_config:
226
213
  request_data['generation_config'] = generation_config
227
214
 
@@ -250,8 +237,13 @@ class GeminiAgentModel(AgentModel):
250
237
  def _process_response(self, response: _GeminiResponse) -> ModelResponse:
251
238
  if len(response['candidates']) != 1:
252
239
  raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
240
+ if 'content' not in response['candidates'][0]:
241
+ if response['candidates'][0].get('finish_reason') == 'SAFETY':
242
+ raise UnexpectedModelBehavior('Safety settings triggered', str(response))
243
+ else:
244
+ raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
253
245
  parts = response['candidates'][0]['content']['parts']
254
- return _process_response_from_parts(parts, model_name=self.model_name)
246
+ return _process_response_from_parts(parts, model_name=response.get('model_version', self._model_name))
255
247
 
256
248
  async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
257
249
  """Process a streamed response, and prepare a streaming response to return."""
@@ -267,14 +259,14 @@ class GeminiAgentModel(AgentModel):
267
259
  )
268
260
  if responses:
269
261
  last = responses[-1]
270
- if last['candidates'] and last['candidates'][0]['content']['parts']:
262
+ if last['candidates'] and last['candidates'][0].get('content', {}).get('parts'):
271
263
  start_response = last
272
264
  break
273
265
 
274
266
  if start_response is None:
275
267
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
276
268
 
277
- return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
269
+ return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
278
270
 
279
271
  @classmethod
280
272
  def _message_to_gemini_content(
@@ -312,10 +304,28 @@ class GeminiAgentModel(AgentModel):
312
304
  return sys_prompt_parts, contents
313
305
 
314
306
 
307
+ class AuthProtocol(Protocol):
308
+ """Abstract definition for Gemini authentication."""
309
+
310
+ async def headers(self) -> dict[str, str]: ...
311
+
312
+
313
+ @dataclass
314
+ class ApiKeyAuth:
315
+ """Authentication using an API key for the `X-Goog-Api-Key` header."""
316
+
317
+ api_key: str
318
+
319
+ async def headers(self) -> dict[str, str]:
320
+ # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
321
+ return {'X-Goog-Api-Key': self.api_key}
322
+
323
+
315
324
  @dataclass
316
325
  class GeminiStreamedResponse(StreamedResponse):
317
326
  """Implementation of `StreamedResponse` for the Gemini model."""
318
327
 
328
+ _model_name: GeminiModelName
319
329
  _content: bytearray
320
330
  _stream: AsyncIterator[bytes]
321
331
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
@@ -323,6 +333,8 @@ class GeminiStreamedResponse(StreamedResponse):
323
333
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
324
334
  async for gemini_response in self._get_gemini_responses():
325
335
  candidate = gemini_response['candidates'][0]
336
+ if 'content' not in candidate:
337
+ raise UnexpectedModelBehavior('Streamed response has no content field')
326
338
  gemini_part: _GeminiPartUnion
327
339
  for gemini_part in candidate['content']['parts']:
328
340
  if 'text' in gemini_part:
@@ -377,7 +389,14 @@ class GeminiStreamedResponse(StreamedResponse):
377
389
  self._usage += _metadata_as_usage(r)
378
390
  yield r
379
391
 
392
+ @property
393
+ def model_name(self) -> GeminiModelName:
394
+ """Get the model name of the response."""
395
+ return self._model_name
396
+
397
+ @property
380
398
  def timestamp(self) -> datetime:
399
+ """Get the timestamp of the response."""
381
400
  return self._timestamp
382
401
 
383
402
 
@@ -396,6 +415,7 @@ class _GeminiRequest(TypedDict):
396
415
  contents: list[_GeminiContent]
397
416
  tools: NotRequired[_GeminiTools]
398
417
  tool_config: NotRequired[_GeminiToolConfig]
418
+ safety_settings: NotRequired[list[GeminiSafetySettings]]
399
419
  # we don't implement `generationConfig`, instead we use a named tool for the response
400
420
  system_instruction: NotRequired[_GeminiTextContent]
401
421
  """
@@ -405,6 +425,38 @@ class _GeminiRequest(TypedDict):
405
425
  generation_config: NotRequired[_GeminiGenerationConfig]
406
426
 
407
427
 
428
+ class GeminiSafetySettings(TypedDict):
429
+ """Safety settings options for Gemini model request.
430
+
431
+ See [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for safety category and threshold descriptions.
432
+ For an example on how to use `GeminiSafetySettings`, see [here](../../agents.md#model-specific-settings).
433
+ """
434
+
435
+ category: Literal[
436
+ 'HARM_CATEGORY_UNSPECIFIED',
437
+ 'HARM_CATEGORY_HARASSMENT',
438
+ 'HARM_CATEGORY_HATE_SPEECH',
439
+ 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
440
+ 'HARM_CATEGORY_DANGEROUS_CONTENT',
441
+ 'HARM_CATEGORY_CIVIC_INTEGRITY',
442
+ ]
443
+ """
444
+ Safety settings category.
445
+ """
446
+
447
+ threshold: Literal[
448
+ 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
449
+ 'BLOCK_LOW_AND_ABOVE',
450
+ 'BLOCK_MEDIUM_AND_ABOVE',
451
+ 'BLOCK_ONLY_HIGH',
452
+ 'BLOCK_NONE',
453
+ 'OFF',
454
+ ]
455
+ """
456
+ Safety settings threshold.
457
+ """
458
+
459
+
408
460
  class _GeminiGenerationConfig(TypedDict, total=False):
409
461
  """Schema for an API request to the Gemini API.
410
462
 
@@ -576,13 +628,14 @@ class _GeminiResponse(TypedDict):
576
628
  # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
577
629
  usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
578
630
  prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
631
+ model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
579
632
 
580
633
 
581
634
  class _GeminiCandidates(TypedDict):
582
635
  """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
583
636
 
584
- content: _GeminiContent
585
- finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
637
+ content: NotRequired[_GeminiContent]
638
+ finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS', 'SAFETY'], pydantic.Field(alias='finishReason')]]
586
639
  """
587
640
  See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
588
641
  but let's wait until we see them and know what they mean to add them here.
@@ -630,6 +683,7 @@ class _GeminiSafetyRating(TypedDict):
630
683
  'HARM_CATEGORY_CIVIC_INTEGRITY',
631
684
  ]
632
685
  probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH']
686
+ blocked: NotRequired[bool]
633
687
 
634
688
 
635
689
  class _GeminiPromptFeedback(TypedDict):