pydantic-ai-slim 0.0.55__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,9 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
- import re
4
+ import warnings
5
5
  from collections.abc import AsyncIterator, Sequence
6
6
  from contextlib import asynccontextmanager
7
- from copy import deepcopy
8
7
  from dataclasses import dataclass, field, replace
9
8
  from datetime import datetime
10
9
  from typing import Annotated, Any, Literal, Protocol, Union, cast
@@ -34,6 +33,7 @@ from ..messages import (
34
33
  ToolCallPart,
35
34
  ToolReturnPart,
36
35
  UserPromptPart,
36
+ VideoUrl,
37
37
  )
38
38
  from ..settings import ModelSettings
39
39
  from ..tools import ToolDefinition
@@ -45,6 +45,7 @@ from . import (
45
45
  check_allow_model_requests,
46
46
  get_user_agent,
47
47
  )
48
+ from ._json_schema import JsonSchema, WalkJsonSchema
48
49
 
49
50
  LatestGeminiModelNames = Literal[
50
51
  'gemini-1.5-flash',
@@ -58,6 +59,7 @@ LatestGeminiModelNames = Literal[
58
59
  'gemini-2.0-flash-lite-preview-02-05',
59
60
  'gemini-2.0-pro-exp-02-05',
60
61
  'gemini-2.5-pro-exp-03-25',
62
+ 'gemini-2.5-pro-preview-03-25',
61
63
  ]
62
64
  """Latest Gemini models."""
63
65
 
@@ -154,12 +156,12 @@ class GeminiModel(Model):
154
156
 
155
157
  def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
156
158
  def _customize_tool_def(t: ToolDefinition):
157
- return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).simplify())
159
+ return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
158
160
 
159
161
  return ModelRequestParameters(
160
162
  function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
161
- allow_text_result=model_request_parameters.allow_text_result,
162
- result_tools=[_customize_tool_def(tool) for tool in model_request_parameters.result_tools],
163
+ allow_text_output=model_request_parameters.allow_text_output,
164
+ output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
163
165
  )
164
166
 
165
167
  @property
@@ -174,14 +176,14 @@ class GeminiModel(Model):
174
176
 
175
177
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
176
178
  tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
177
- if model_request_parameters.result_tools:
178
- tools += [_function_from_abstract_tool(t) for t in model_request_parameters.result_tools]
179
+ if model_request_parameters.output_tools:
180
+ tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools]
179
181
  return _GeminiTools(function_declarations=tools) if tools else None
180
182
 
181
183
  def _get_tool_config(
182
184
  self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
183
185
  ) -> _GeminiToolConfig | None:
184
- if model_request_parameters.allow_text_result:
186
+ if model_request_parameters.allow_text_output:
185
187
  return None
186
188
  elif tools:
187
189
  return _tool_config([t['name'] for t in tools['function_declarations']])
@@ -202,11 +204,11 @@ class GeminiModel(Model):
202
204
 
203
205
  request_data = _GeminiRequest(contents=contents)
204
206
  if sys_prompt_parts:
205
- request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
207
+ request_data['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
206
208
  if tools is not None:
207
209
  request_data['tools'] = tools
208
210
  if tool_config is not None:
209
- request_data['tool_config'] = tool_config
211
+ request_data['toolConfig'] = tool_config
210
212
 
211
213
  generation_config: _GeminiGenerationConfig = {}
212
214
  if model_settings:
@@ -221,9 +223,9 @@ class GeminiModel(Model):
221
223
  if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
222
224
  generation_config['frequency_penalty'] = frequency_penalty
223
225
  if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []:
224
- request_data['safety_settings'] = gemini_safety_settings
226
+ request_data['safetySettings'] = gemini_safety_settings
225
227
  if generation_config:
226
- request_data['generation_config'] = generation_config
228
+ request_data['generationConfig'] = generation_config
227
229
 
228
230
  headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
229
231
  url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
@@ -277,9 +279,8 @@ class GeminiModel(Model):
277
279
 
278
280
  return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
279
281
 
280
- @classmethod
281
282
  async def _message_to_gemini_content(
282
- cls, messages: list[ModelMessage]
283
+ self, messages: list[ModelMessage]
283
284
  ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
284
285
  sys_prompt_parts: list[_GeminiTextPart] = []
285
286
  contents: list[_GeminiContent] = []
@@ -291,7 +292,7 @@ class GeminiModel(Model):
291
292
  if isinstance(part, SystemPromptPart):
292
293
  sys_prompt_parts.append(_GeminiTextPart(text=part.content))
293
294
  elif isinstance(part, UserPromptPart):
294
- message_parts.extend(await cls._map_user_prompt(part))
295
+ message_parts.extend(await self._map_user_prompt(part))
295
296
  elif isinstance(part, ToolReturnPart):
296
297
  message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
297
298
  elif isinstance(part, RetryPromptPart):
@@ -309,11 +310,11 @@ class GeminiModel(Model):
309
310
  contents.append(_content_model_response(m))
310
311
  else:
311
312
  assert_never(m)
312
-
313
+ if instructions := self._get_instructions(messages):
314
+ sys_prompt_parts.insert(0, _GeminiTextPart(text=instructions))
313
315
  return sys_prompt_parts, contents
314
316
 
315
- @staticmethod
316
- async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
317
+ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion]:
317
318
  if isinstance(part.content, str):
318
319
  return [{'text': part.content}]
319
320
  else:
@@ -335,6 +336,8 @@ class GeminiModel(Model):
335
336
  inline_data={'data': base64.b64encode(response.content).decode('utf-8'), 'mime_type': mime_type}
336
337
  )
337
338
  content.append(inline_data)
339
+ elif isinstance(item, VideoUrl): # pragma: no cover
340
+ raise NotImplementedError('VideoUrl is not supported for Gemini.')
338
341
  else:
339
342
  assert_never(item)
340
343
  return content
@@ -448,17 +451,19 @@ class _GeminiRequest(TypedDict):
448
451
  See <https://ai.google.dev/api/generate-content#request-body> for API docs.
449
452
  """
450
453
 
454
+ # Note: Even though Google supposedly supports camelCase and snake_case, we've had user report misbehavior
455
+ # when using snake_case, which is why this typeddict now uses camelCase. And anyway, the plan is to replace this
456
+ # with an official google SDK in the near future anyway.
451
457
  contents: list[_GeminiContent]
452
458
  tools: NotRequired[_GeminiTools]
453
- tool_config: NotRequired[_GeminiToolConfig]
454
- safety_settings: NotRequired[list[GeminiSafetySettings]]
455
- # we don't implement `generationConfig`, instead we use a named tool for the response
456
- system_instruction: NotRequired[_GeminiTextContent]
459
+ toolConfig: NotRequired[_GeminiToolConfig]
460
+ safetySettings: NotRequired[list[GeminiSafetySettings]]
461
+ systemInstruction: NotRequired[_GeminiTextContent]
457
462
  """
458
463
  Developer generated system instructions, see
459
464
  <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
460
465
  """
461
- generation_config: NotRequired[_GeminiGenerationConfig]
466
+ generationConfig: NotRequired[_GeminiGenerationConfig]
462
467
 
463
468
 
464
469
  class GeminiSafetySettings(TypedDict):
@@ -757,7 +762,7 @@ _gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
757
762
  _gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
758
763
 
759
764
 
760
- class _GeminiJsonSchema:
765
+ class _GeminiJsonSchema(WalkJsonSchema):
761
766
  """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
762
767
 
763
768
  Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
@@ -768,72 +773,74 @@ class _GeminiJsonSchema:
768
773
  * gemini doesn't allow `$defs` — we need to inline the definitions where possible
769
774
  """
770
775
 
771
- def __init__(self, schema: _utils.ObjectJsonSchema):
772
- self.schema = deepcopy(schema)
773
- self.defs = self.schema.pop('$defs', {})
774
-
775
- def simplify(self) -> dict[str, Any]:
776
- self._simplify(self.schema, refs_stack=())
777
- return self.schema
776
+ def __init__(self, schema: JsonSchema):
777
+ super().__init__(schema, prefer_inlined_defs=True, simplify_nullable_unions=True)
778
+
779
+ def transform(self, schema: JsonSchema) -> JsonSchema:
780
+ # Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
781
+ additional_properties = schema.pop(
782
+ 'additionalProperties', None
783
+ ) # don't pop yet so it's included in the warning
784
+ if additional_properties: # pragma: no cover
785
+ original_schema = {**schema, 'additionalProperties': additional_properties}
786
+ warnings.warn(
787
+ '`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
788
+ f' Full schema: {self.schema}\n\n'
789
+ f'Source of additionalProperties within the full schema: {original_schema}\n\n'
790
+ 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
791
+ "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
792
+ ' and we will fix this behavior.',
793
+ UserWarning,
794
+ )
778
795
 
779
- def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
780
796
  schema.pop('title', None)
781
797
  schema.pop('default', None)
782
798
  schema.pop('$schema', None)
799
+ if (const := schema.pop('const', None)) is not None: # pragma: no cover
800
+ # Gemini doesn't support const, but it does support enum with a single value
801
+ schema['enum'] = [const]
802
+ schema.pop('discriminator', None)
803
+ schema.pop('examples', None)
804
+
805
+ # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
806
+ # where we add notes about these properties to the field description?
783
807
  schema.pop('exclusiveMaximum', None)
784
808
  schema.pop('exclusiveMinimum', None)
785
- if ref := schema.pop('$ref', None):
786
- # noinspection PyTypeChecker
787
- key = re.sub(r'^#/\$defs/', '', ref)
788
- if key in refs_stack:
789
- raise UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
790
- refs_stack += (key,)
791
- schema_def = self.defs[key]
792
- self._simplify(schema_def, refs_stack)
793
- schema.update(schema_def)
794
- return
795
-
796
- if any_of := schema.get('anyOf'):
797
- for item_schema in any_of:
798
- self._simplify(item_schema, refs_stack)
799
- if len(any_of) == 2 and {'type': 'null'} in any_of:
800
- for item_schema in any_of:
801
- if item_schema != {'type': 'null'}:
802
- schema.clear()
803
- schema.update(item_schema)
804
- schema['nullable'] = True
805
- return
806
809
 
807
810
  type_ = schema.get('type')
811
+ if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
812
+ # This gets hit when we have a discriminated union
813
+ # Gemini returns an API error in this case even though it says in its error message it shouldn't...
814
+ # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
815
+ schema['anyOf'] = schema.pop('oneOf')
808
816
 
809
- if type_ == 'object':
810
- self._object(schema, refs_stack)
811
- elif type_ == 'array':
812
- return self._array(schema, refs_stack)
813
- elif type_ == 'string' and (fmt := schema.pop('format', None)):
817
+ if type_ == 'string' and (fmt := schema.pop('format', None)):
814
818
  description = schema.get('description')
815
819
  if description:
816
820
  schema['description'] = f'{description} (format: {fmt})'
817
821
  else:
818
822
  schema['description'] = f'Format: {fmt}'
819
823
 
820
- def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
821
- ad_props = schema.pop('additionalProperties', None)
822
- if ad_props:
823
- raise UserError('Additional properties in JSON Schema are not supported by Gemini')
824
-
825
- if properties := schema.get('properties'): # pragma: no branch
826
- for value in properties.values():
827
- self._simplify(value, refs_stack)
828
-
829
- def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
830
- if prefix_items := schema.get('prefixItems'):
831
- # TODO I think this not is supported by Gemini, maybe we should raise an error?
832
- for prefix_item in prefix_items:
833
- self._simplify(prefix_item, refs_stack)
834
-
835
- if items_schema := schema.get('items'): # pragma: no branch
836
- self._simplify(items_schema, refs_stack)
824
+ if '$ref' in schema:
825
+ raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}')
826
+
827
+ if 'prefixItems' in schema:
828
+ # prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
829
+ prefix_items = schema.pop('prefixItems')
830
+ items = schema.get('items')
831
+ unique_items = [items] if items is not None else []
832
+ for item in prefix_items:
833
+ if item not in unique_items:
834
+ unique_items.append(item)
835
+ if len(unique_items) > 1: # pragma: no cover
836
+ schema['items'] = {'anyOf': unique_items}
837
+ elif len(unique_items) == 1:
838
+ schema['items'] = unique_items[0]
839
+ schema.setdefault('minItems', len(prefix_items))
840
+ if items is None:
841
+ schema.setdefault('maxItems', len(prefix_items))
842
+
843
+ return schema
837
844
 
838
845
 
839
846
  def _ensure_decodeable(content: bytearray) -> bytearray:
@@ -5,7 +5,6 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
- from itertools import chain
9
8
  from typing import Literal, Union, cast, overload
10
9
 
11
10
  from typing_extensions import assert_never
@@ -193,12 +192,12 @@ class GroqModel(Model):
193
192
  # standalone function to make it easier to override
194
193
  if not tools:
195
194
  tool_choice: Literal['none', 'required', 'auto'] | None = None
196
- elif not model_request_parameters.allow_text_result:
195
+ elif not model_request_parameters.allow_text_output:
197
196
  tool_choice = 'required'
198
197
  else:
199
198
  tool_choice = 'auto'
200
199
 
201
- groq_messages = list(chain(*(self._map_message(m) for m in messages)))
200
+ groq_messages = self._map_messages(messages)
202
201
 
203
202
  try:
204
203
  return await self.client.chat.completions.create(
@@ -252,34 +251,39 @@ class GroqModel(Model):
252
251
 
253
252
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
254
253
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
255
- if model_request_parameters.result_tools:
256
- tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
254
+ if model_request_parameters.output_tools:
255
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
257
256
  return tools
258
257
 
259
- def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
258
+ def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
260
259
  """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
261
- if isinstance(message, ModelRequest):
262
- yield from self._map_user_message(message)
263
- elif isinstance(message, ModelResponse):
264
- texts: list[str] = []
265
- tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
266
- for item in message.parts:
267
- if isinstance(item, TextPart):
268
- texts.append(item.content)
269
- elif isinstance(item, ToolCallPart):
270
- tool_calls.append(self._map_tool_call(item))
271
- else:
272
- assert_never(item)
273
- message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
274
- if texts:
275
- # Note: model responses from this model should only have one text item, so the following
276
- # shouldn't merge multiple texts into one unless you switch models between runs:
277
- message_param['content'] = '\n\n'.join(texts)
278
- if tool_calls:
279
- message_param['tool_calls'] = tool_calls
280
- yield message_param
281
- else:
282
- assert_never(message)
260
+ groq_messages: list[chat.ChatCompletionMessageParam] = []
261
+ for message in messages:
262
+ if isinstance(message, ModelRequest):
263
+ groq_messages.extend(self._map_user_message(message))
264
+ elif isinstance(message, ModelResponse):
265
+ texts: list[str] = []
266
+ tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
267
+ for item in message.parts:
268
+ if isinstance(item, TextPart):
269
+ texts.append(item.content)
270
+ elif isinstance(item, ToolCallPart):
271
+ tool_calls.append(self._map_tool_call(item))
272
+ else:
273
+ assert_never(item)
274
+ message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
275
+ if texts:
276
+ # Note: model responses from this model should only have one text item, so the following
277
+ # shouldn't merge multiple texts into one unless you switch models between runs:
278
+ message_param['content'] = '\n\n'.join(texts)
279
+ if tool_calls:
280
+ message_param['tool_calls'] = tool_calls
281
+ groq_messages.append(message_param)
282
+ else:
283
+ assert_never(message)
284
+ if instructions := self._get_instructions(messages):
285
+ groq_messages.insert(0, chat.ChatCompletionSystemMessageParam(role='system', content=instructions))
286
+ return groq_messages
283
287
 
284
288
  @staticmethod
285
289
  def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
@@ -260,7 +260,7 @@ class InstrumentedModel(WrapperModel):
260
260
 
261
261
  @staticmethod
262
262
  def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
263
- result: list[Event] = []
263
+ events: list[Event] = []
264
264
  for message_index, message in enumerate(messages):
265
265
  message_events: list[Event] = []
266
266
  if isinstance(message, ModelRequest):
@@ -274,10 +274,10 @@ class InstrumentedModel(WrapperModel):
274
274
  'gen_ai.message.index': message_index,
275
275
  **(event.attributes or {}),
276
276
  }
277
- result.extend(message_events)
278
- for event in result:
277
+ events.extend(message_events)
278
+ for event in events:
279
279
  event.body = InstrumentedModel.serialize_any(event.body)
280
- return result
280
+ return events
281
281
 
282
282
  @staticmethod
283
283
  def serialize_any(value: Any) -> str: