pydantic-ai-slim 0.0.54__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,8 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
- import re
5
4
  from collections.abc import AsyncIterator, Sequence
6
5
  from contextlib import asynccontextmanager
7
- from copy import deepcopy
8
6
  from dataclasses import dataclass, field, replace
9
7
  from datetime import datetime
10
8
  from typing import Annotated, Any, Literal, Protocol, Union, cast
@@ -34,6 +32,7 @@ from ..messages import (
34
32
  ToolCallPart,
35
33
  ToolReturnPart,
36
34
  UserPromptPart,
35
+ VideoUrl,
37
36
  )
38
37
  from ..settings import ModelSettings
39
38
  from ..tools import ToolDefinition
@@ -45,6 +44,7 @@ from . import (
45
44
  check_allow_model_requests,
46
45
  get_user_agent,
47
46
  )
47
+ from ._json_schema import JsonSchema, WalkJsonSchema
48
48
 
49
49
  LatestGeminiModelNames = Literal[
50
50
  'gemini-1.5-flash',
@@ -58,6 +58,7 @@ LatestGeminiModelNames = Literal[
58
58
  'gemini-2.0-flash-lite-preview-02-05',
59
59
  'gemini-2.0-pro-exp-02-05',
60
60
  'gemini-2.5-pro-exp-03-25',
61
+ 'gemini-2.5-pro-preview-03-25',
61
62
  ]
62
63
  """Latest Gemini models."""
63
64
 
@@ -154,12 +155,12 @@ class GeminiModel(Model):
154
155
 
155
156
  def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
156
157
  def _customize_tool_def(t: ToolDefinition):
157
- return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).simplify())
158
+ return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
158
159
 
159
160
  return ModelRequestParameters(
160
161
  function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
161
- allow_text_result=model_request_parameters.allow_text_result,
162
- result_tools=[_customize_tool_def(tool) for tool in model_request_parameters.result_tools],
162
+ allow_text_output=model_request_parameters.allow_text_output,
163
+ output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
163
164
  )
164
165
 
165
166
  @property
@@ -174,14 +175,14 @@ class GeminiModel(Model):
174
175
 
175
176
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
176
177
  tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
177
- if model_request_parameters.result_tools:
178
- tools += [_function_from_abstract_tool(t) for t in model_request_parameters.result_tools]
178
+ if model_request_parameters.output_tools:
179
+ tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools]
179
180
  return _GeminiTools(function_declarations=tools) if tools else None
180
181
 
181
182
  def _get_tool_config(
182
183
  self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
183
184
  ) -> _GeminiToolConfig | None:
184
- if model_request_parameters.allow_text_result:
185
+ if model_request_parameters.allow_text_output:
185
186
  return None
186
187
  elif tools:
187
188
  return _tool_config([t['name'] for t in tools['function_declarations']])
@@ -202,11 +203,11 @@ class GeminiModel(Model):
202
203
 
203
204
  request_data = _GeminiRequest(contents=contents)
204
205
  if sys_prompt_parts:
205
- request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
206
+ request_data['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
206
207
  if tools is not None:
207
208
  request_data['tools'] = tools
208
209
  if tool_config is not None:
209
- request_data['tool_config'] = tool_config
210
+ request_data['toolConfig'] = tool_config
210
211
 
211
212
  generation_config: _GeminiGenerationConfig = {}
212
213
  if model_settings:
@@ -221,9 +222,9 @@ class GeminiModel(Model):
221
222
  if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
222
223
  generation_config['frequency_penalty'] = frequency_penalty
223
224
  if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []:
224
- request_data['safety_settings'] = gemini_safety_settings
225
+ request_data['safetySettings'] = gemini_safety_settings
225
226
  if generation_config:
226
- request_data['generation_config'] = generation_config
227
+ request_data['generationConfig'] = generation_config
227
228
 
228
229
  headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
229
230
  url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
@@ -277,9 +278,8 @@ class GeminiModel(Model):
277
278
 
278
279
  return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
279
280
 
280
- @classmethod
281
281
  async def _message_to_gemini_content(
282
- cls, messages: list[ModelMessage]
282
+ self, messages: list[ModelMessage]
283
283
  ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
284
284
  sys_prompt_parts: list[_GeminiTextPart] = []
285
285
  contents: list[_GeminiContent] = []
@@ -291,7 +291,7 @@ class GeminiModel(Model):
291
291
  if isinstance(part, SystemPromptPart):
292
292
  sys_prompt_parts.append(_GeminiTextPart(text=part.content))
293
293
  elif isinstance(part, UserPromptPart):
294
- message_parts.extend(await cls._map_user_prompt(part))
294
+ message_parts.extend(await self._map_user_prompt(part))
295
295
  elif isinstance(part, ToolReturnPart):
296
296
  message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
297
297
  elif isinstance(part, RetryPromptPart):
@@ -309,11 +309,11 @@ class GeminiModel(Model):
309
309
  contents.append(_content_model_response(m))
310
310
  else:
311
311
  assert_never(m)
312
-
312
+ if instructions := self._get_instructions(messages):
313
+ sys_prompt_parts.insert(0, _GeminiTextPart(text=instructions))
313
314
  return sys_prompt_parts, contents
314
315
 
315
- @staticmethod
316
- async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
316
+ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion]:
317
317
  if isinstance(part.content, str):
318
318
  return [{'text': part.content}]
319
319
  else:
@@ -335,6 +335,8 @@ class GeminiModel(Model):
335
335
  inline_data={'data': base64.b64encode(response.content).decode('utf-8'), 'mime_type': mime_type}
336
336
  )
337
337
  content.append(inline_data)
338
+ elif isinstance(item, VideoUrl): # pragma: no cover
339
+ raise NotImplementedError('VideoUrl is not supported for Gemini.')
338
340
  else:
339
341
  assert_never(item)
340
342
  return content
@@ -448,17 +450,19 @@ class _GeminiRequest(TypedDict):
448
450
  See <https://ai.google.dev/api/generate-content#request-body> for API docs.
449
451
  """
450
452
 
453
+ # Note: Even though Google supposedly supports camelCase and snake_case, we've had user report misbehavior
454
+ # when using snake_case, which is why this typeddict now uses camelCase. And anyway, the plan is to replace this
455
+ # with an official google SDK in the near future anyway.
451
456
  contents: list[_GeminiContent]
452
457
  tools: NotRequired[_GeminiTools]
453
- tool_config: NotRequired[_GeminiToolConfig]
454
- safety_settings: NotRequired[list[GeminiSafetySettings]]
455
- # we don't implement `generationConfig`, instead we use a named tool for the response
456
- system_instruction: NotRequired[_GeminiTextContent]
458
+ toolConfig: NotRequired[_GeminiToolConfig]
459
+ safetySettings: NotRequired[list[GeminiSafetySettings]]
460
+ systemInstruction: NotRequired[_GeminiTextContent]
457
461
  """
458
462
  Developer generated system instructions, see
459
463
  <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
460
464
  """
461
- generation_config: NotRequired[_GeminiGenerationConfig]
465
+ generationConfig: NotRequired[_GeminiGenerationConfig]
462
466
 
463
467
 
464
468
  class GeminiSafetySettings(TypedDict):
@@ -757,7 +761,7 @@ _gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
757
761
  _gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
758
762
 
759
763
 
760
- class _GeminiJsonSchema:
764
+ class _GeminiJsonSchema(WalkJsonSchema):
761
765
  """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
762
766
 
763
767
  Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
@@ -768,72 +772,58 @@ class _GeminiJsonSchema:
768
772
  * gemini doesn't allow `$defs` — we need to inline the definitions where possible
769
773
  """
770
774
 
771
- def __init__(self, schema: _utils.ObjectJsonSchema):
772
- self.schema = deepcopy(schema)
773
- self.defs = self.schema.pop('$defs', {})
774
-
775
- def simplify(self) -> dict[str, Any]:
776
- self._simplify(self.schema, refs_stack=())
777
- return self.schema
775
+ def __init__(self, schema: JsonSchema):
776
+ super().__init__(schema, prefer_inlined_defs=True, simplify_nullable_unions=True)
778
777
 
779
- def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
778
+ def transform(self, schema: JsonSchema) -> JsonSchema:
780
779
  schema.pop('title', None)
781
780
  schema.pop('default', None)
782
781
  schema.pop('$schema', None)
782
+ if (const := schema.pop('const', None)) is not None: # pragma: no cover
783
+ # Gemini doesn't support const, but it does support enum with a single value
784
+ schema['enum'] = [const]
785
+ schema.pop('discriminator', None)
786
+ schema.pop('examples', None)
787
+
788
+ # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
789
+ # where we add notes about these properties to the field description?
783
790
  schema.pop('exclusiveMaximum', None)
784
791
  schema.pop('exclusiveMinimum', None)
785
- if ref := schema.pop('$ref', None):
786
- # noinspection PyTypeChecker
787
- key = re.sub(r'^#/\$defs/', '', ref)
788
- if key in refs_stack:
789
- raise UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
790
- refs_stack += (key,)
791
- schema_def = self.defs[key]
792
- self._simplify(schema_def, refs_stack)
793
- schema.update(schema_def)
794
- return
795
-
796
- if any_of := schema.get('anyOf'):
797
- for item_schema in any_of:
798
- self._simplify(item_schema, refs_stack)
799
- if len(any_of) == 2 and {'type': 'null'} in any_of:
800
- for item_schema in any_of:
801
- if item_schema != {'type': 'null'}:
802
- schema.clear()
803
- schema.update(item_schema)
804
- schema['nullable'] = True
805
- return
806
792
 
807
793
  type_ = schema.get('type')
794
+ if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
795
+ # This gets hit when we have a discriminated union
796
+ # Gemini returns an API error in this case even though it says in its error message it shouldn't...
797
+ # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
798
+ schema['anyOf'] = schema.pop('oneOf')
808
799
 
809
- if type_ == 'object':
810
- self._object(schema, refs_stack)
811
- elif type_ == 'array':
812
- return self._array(schema, refs_stack)
813
- elif type_ == 'string' and (fmt := schema.pop('format', None)):
800
+ if type_ == 'string' and (fmt := schema.pop('format', None)):
814
801
  description = schema.get('description')
815
802
  if description:
816
803
  schema['description'] = f'{description} (format: {fmt})'
817
804
  else:
818
805
  schema['description'] = f'Format: {fmt}'
819
806
 
820
- def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
821
- ad_props = schema.pop('additionalProperties', None)
822
- if ad_props:
823
- raise UserError('Additional properties in JSON Schema are not supported by Gemini')
824
-
825
- if properties := schema.get('properties'): # pragma: no branch
826
- for value in properties.values():
827
- self._simplify(value, refs_stack)
828
-
829
- def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
830
- if prefix_items := schema.get('prefixItems'):
831
- # TODO I think this not is supported by Gemini, maybe we should raise an error?
832
- for prefix_item in prefix_items:
833
- self._simplify(prefix_item, refs_stack)
834
-
835
- if items_schema := schema.get('items'): # pragma: no branch
836
- self._simplify(items_schema, refs_stack)
807
+ if '$ref' in schema:
808
+ raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}')
809
+
810
+ if 'prefixItems' in schema:
811
+ # prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
812
+ prefix_items = schema.pop('prefixItems')
813
+ items = schema.get('items')
814
+ unique_items = [items] if items is not None else []
815
+ for item in prefix_items:
816
+ if item not in unique_items:
817
+ unique_items.append(item)
818
+ if len(unique_items) > 1: # pragma: no cover
819
+ schema['items'] = {'anyOf': unique_items}
820
+ elif len(unique_items) == 1:
821
+ schema['items'] = unique_items[0]
822
+ schema.setdefault('minItems', len(prefix_items))
823
+ if items is None:
824
+ schema.setdefault('maxItems', len(prefix_items))
825
+
826
+ return schema
837
827
 
838
828
 
839
829
  def _ensure_decodeable(content: bytearray) -> bytearray:
@@ -5,7 +5,6 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
- from itertools import chain
9
8
  from typing import Literal, Union, cast, overload
10
9
 
11
10
  from typing_extensions import assert_never
@@ -31,7 +30,7 @@ from ..messages import (
31
30
  from ..providers import Provider, infer_provider
32
31
  from ..settings import ModelSettings
33
32
  from ..tools import ToolDefinition
34
- from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
33
+ from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, get_user_agent
35
34
 
36
35
  try:
37
36
  from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
@@ -193,12 +192,12 @@ class GroqModel(Model):
193
192
  # standalone function to make it easier to override
194
193
  if not tools:
195
194
  tool_choice: Literal['none', 'required', 'auto'] | None = None
196
- elif not model_request_parameters.allow_text_result:
195
+ elif not model_request_parameters.allow_text_output:
197
196
  tool_choice = 'required'
198
197
  else:
199
198
  tool_choice = 'auto'
200
199
 
201
- groq_messages = list(chain(*(self._map_message(m) for m in messages)))
200
+ groq_messages = self._map_messages(messages)
202
201
 
203
202
  try:
204
203
  return await self.client.chat.completions.create(
@@ -218,6 +217,7 @@ class GroqModel(Model):
218
217
  presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
219
218
  frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
220
219
  logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
220
+ extra_headers={'User-Agent': get_user_agent()},
221
221
  )
222
222
  except APIStatusError as e:
223
223
  if (status_code := e.status_code) >= 400:
@@ -251,34 +251,39 @@ class GroqModel(Model):
251
251
 
252
252
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
253
253
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
254
- if model_request_parameters.result_tools:
255
- tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
254
+ if model_request_parameters.output_tools:
255
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
256
256
  return tools
257
257
 
258
- def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
258
+ def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
259
259
  """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
260
- if isinstance(message, ModelRequest):
261
- yield from self._map_user_message(message)
262
- elif isinstance(message, ModelResponse):
263
- texts: list[str] = []
264
- tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
265
- for item in message.parts:
266
- if isinstance(item, TextPart):
267
- texts.append(item.content)
268
- elif isinstance(item, ToolCallPart):
269
- tool_calls.append(self._map_tool_call(item))
270
- else:
271
- assert_never(item)
272
- message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
273
- if texts:
274
- # Note: model responses from this model should only have one text item, so the following
275
- # shouldn't merge multiple texts into one unless you switch models between runs:
276
- message_param['content'] = '\n\n'.join(texts)
277
- if tool_calls:
278
- message_param['tool_calls'] = tool_calls
279
- yield message_param
280
- else:
281
- assert_never(message)
260
+ groq_messages: list[chat.ChatCompletionMessageParam] = []
261
+ for message in messages:
262
+ if isinstance(message, ModelRequest):
263
+ groq_messages.extend(self._map_user_message(message))
264
+ elif isinstance(message, ModelResponse):
265
+ texts: list[str] = []
266
+ tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
267
+ for item in message.parts:
268
+ if isinstance(item, TextPart):
269
+ texts.append(item.content)
270
+ elif isinstance(item, ToolCallPart):
271
+ tool_calls.append(self._map_tool_call(item))
272
+ else:
273
+ assert_never(item)
274
+ message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
275
+ if texts:
276
+ # Note: model responses from this model should only have one text item, so the following
277
+ # shouldn't merge multiple texts into one unless you switch models between runs:
278
+ message_param['content'] = '\n\n'.join(texts)
279
+ if tool_calls:
280
+ message_param['tool_calls'] = tool_calls
281
+ groq_messages.append(message_param)
282
+ else:
283
+ assert_never(message)
284
+ if instructions := self._get_instructions(messages):
285
+ groq_messages.insert(0, chat.ChatCompletionSystemMessageParam(role='system', content=instructions))
286
+ return groq_messages
282
287
 
283
288
  @staticmethod
284
289
  def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
@@ -260,7 +260,7 @@ class InstrumentedModel(WrapperModel):
260
260
 
261
261
  @staticmethod
262
262
  def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
263
- result: list[Event] = []
263
+ events: list[Event] = []
264
264
  for message_index, message in enumerate(messages):
265
265
  message_events: list[Event] = []
266
266
  if isinstance(message, ModelRequest):
@@ -274,10 +274,10 @@ class InstrumentedModel(WrapperModel):
274
274
  'gen_ai.message.index': message_index,
275
275
  **(event.attributes or {}),
276
276
  }
277
- result.extend(message_events)
278
- for event in result:
277
+ events.extend(message_events)
278
+ for event in events:
279
279
  event.body = InstrumentedModel.serialize_any(event.body)
280
- return result
280
+ return events
281
281
 
282
282
  @staticmethod
283
283
  def serialize_any(value: Any) -> str: