pydantic-ai-slim 0.0.54__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +10 -3
- pydantic_ai/_agent_graph.py +67 -55
- pydantic_ai/_cli.py +1 -2
- pydantic_ai/{_result.py → _output.py} +69 -47
- pydantic_ai/_utils.py +20 -0
- pydantic_ai/agent.py +503 -163
- pydantic_ai/format_as_xml.py +6 -113
- pydantic_ai/format_prompt.py +116 -0
- pydantic_ai/messages.py +104 -21
- pydantic_ai/models/__init__.py +25 -5
- pydantic_ai/models/_json_schema.py +156 -0
- pydantic_ai/models/anthropic.py +14 -4
- pydantic_ai/models/bedrock.py +100 -22
- pydantic_ai/models/cohere.py +48 -44
- pydantic_ai/models/fallback.py +2 -1
- pydantic_ai/models/function.py +8 -8
- pydantic_ai/models/gemini.py +65 -75
- pydantic_ai/models/groq.py +34 -29
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mistral.py +67 -58
- pydantic_ai/models/openai.py +113 -158
- pydantic_ai/models/test.py +45 -46
- pydantic_ai/models/wrapper.py +3 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/azure.py +2 -2
- pydantic_ai/result.py +203 -90
- pydantic_ai/tools.py +3 -3
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/METADATA +5 -5
- pydantic_ai_slim-0.1.0.dist-info/RECORD +53 -0
- pydantic_ai_slim-0.0.54.dist-info/RECORD +0 -51
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/entry_points.txt +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -30,6 +30,7 @@ from ..messages import (
|
|
|
30
30
|
ToolCallPart,
|
|
31
31
|
ToolReturnPart,
|
|
32
32
|
UserPromptPart,
|
|
33
|
+
VideoUrl,
|
|
33
34
|
)
|
|
34
35
|
from ..settings import ModelSettings
|
|
35
36
|
from ..tools import ToolDefinition
|
|
@@ -39,7 +40,9 @@ from . import (
|
|
|
39
40
|
StreamedResponse,
|
|
40
41
|
cached_async_http_client,
|
|
41
42
|
check_allow_model_requests,
|
|
43
|
+
get_user_agent,
|
|
42
44
|
)
|
|
45
|
+
from ._json_schema import JsonSchema, WalkJsonSchema
|
|
43
46
|
|
|
44
47
|
try:
|
|
45
48
|
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
|
|
@@ -251,15 +254,12 @@ class OpenAIModel(Model):
|
|
|
251
254
|
# standalone function to make it easier to override
|
|
252
255
|
if not tools:
|
|
253
256
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
254
|
-
elif not model_request_parameters.
|
|
257
|
+
elif not model_request_parameters.allow_text_output:
|
|
255
258
|
tool_choice = 'required'
|
|
256
259
|
else:
|
|
257
260
|
tool_choice = 'auto'
|
|
258
261
|
|
|
259
|
-
openai_messages
|
|
260
|
-
for m in messages:
|
|
261
|
-
async for msg in self._map_message(m):
|
|
262
|
-
openai_messages.append(msg)
|
|
262
|
+
openai_messages = await self._map_messages(messages)
|
|
263
263
|
|
|
264
264
|
try:
|
|
265
265
|
return await self.client.chat.completions.create(
|
|
@@ -282,6 +282,7 @@ class OpenAIModel(Model):
|
|
|
282
282
|
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
283
283
|
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
284
284
|
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
285
|
+
extra_headers={'User-Agent': get_user_agent()},
|
|
285
286
|
)
|
|
286
287
|
except APIStatusError as e:
|
|
287
288
|
if (status_code := e.status_code) >= 400:
|
|
@@ -315,35 +316,40 @@ class OpenAIModel(Model):
|
|
|
315
316
|
|
|
316
317
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
317
318
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
318
|
-
if model_request_parameters.
|
|
319
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.
|
|
319
|
+
if model_request_parameters.output_tools:
|
|
320
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
320
321
|
return tools
|
|
321
322
|
|
|
322
|
-
async def
|
|
323
|
+
async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
|
|
323
324
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
325
|
+
openai_messages: list[chat.ChatCompletionMessageParam] = []
|
|
326
|
+
for message in messages:
|
|
327
|
+
if isinstance(message, ModelRequest):
|
|
328
|
+
async for item in self._map_user_message(message):
|
|
329
|
+
openai_messages.append(item)
|
|
330
|
+
elif isinstance(message, ModelResponse):
|
|
331
|
+
texts: list[str] = []
|
|
332
|
+
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
333
|
+
for item in message.parts:
|
|
334
|
+
if isinstance(item, TextPart):
|
|
335
|
+
texts.append(item.content)
|
|
336
|
+
elif isinstance(item, ToolCallPart):
|
|
337
|
+
tool_calls.append(self._map_tool_call(item))
|
|
338
|
+
else:
|
|
339
|
+
assert_never(item)
|
|
340
|
+
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
341
|
+
if texts:
|
|
342
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
343
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
344
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
345
|
+
if tool_calls:
|
|
346
|
+
message_param['tool_calls'] = tool_calls
|
|
347
|
+
openai_messages.append(message_param)
|
|
348
|
+
else:
|
|
349
|
+
assert_never(message)
|
|
350
|
+
if instructions := self._get_instructions(messages):
|
|
351
|
+
openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system'))
|
|
352
|
+
return openai_messages
|
|
347
353
|
|
|
348
354
|
@staticmethod
|
|
349
355
|
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
@@ -446,6 +452,8 @@ class OpenAIModel(Model):
|
|
|
446
452
|
# file_data = f'data:{media_type};base64,{base64_encoded}'
|
|
447
453
|
# file = File(file={'file_data': file_data, 'file_name': item.url, 'file_id': item.url}, type='file')
|
|
448
454
|
# content.append(file)
|
|
455
|
+
elif isinstance(item, VideoUrl): # pragma: no cover
|
|
456
|
+
raise NotImplementedError('VideoUrl is not supported for OpenAI')
|
|
449
457
|
else:
|
|
450
458
|
assert_never(item)
|
|
451
459
|
return chat.ChatCompletionUserMessageParam(role='user', content=content)
|
|
@@ -589,19 +597,19 @@ class OpenAIResponsesModel(Model):
|
|
|
589
597
|
# standalone function to make it easier to override
|
|
590
598
|
if not tools:
|
|
591
599
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
592
|
-
elif not model_request_parameters.
|
|
600
|
+
elif not model_request_parameters.allow_text_output:
|
|
593
601
|
tool_choice = 'required'
|
|
594
602
|
else:
|
|
595
603
|
tool_choice = 'auto'
|
|
596
604
|
|
|
597
|
-
|
|
605
|
+
instructions, openai_messages = await self._map_messages(messages)
|
|
598
606
|
reasoning = self._get_reasoning(model_settings)
|
|
599
607
|
|
|
600
608
|
try:
|
|
601
609
|
return await self.client.responses.create(
|
|
602
610
|
input=openai_messages,
|
|
603
611
|
model=self._model_name,
|
|
604
|
-
instructions=
|
|
612
|
+
instructions=instructions,
|
|
605
613
|
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
606
614
|
tools=tools or NOT_GIVEN,
|
|
607
615
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
@@ -613,6 +621,7 @@ class OpenAIResponsesModel(Model):
|
|
|
613
621
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
614
622
|
reasoning=reasoning,
|
|
615
623
|
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
624
|
+
extra_headers={'User-Agent': get_user_agent()},
|
|
616
625
|
)
|
|
617
626
|
except APIStatusError as e:
|
|
618
627
|
if (status_code := e.status_code) >= 400:
|
|
@@ -629,8 +638,8 @@ class OpenAIResponsesModel(Model):
|
|
|
629
638
|
|
|
630
639
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
|
|
631
640
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
632
|
-
if model_request_parameters.
|
|
633
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.
|
|
641
|
+
if model_request_parameters.output_tools:
|
|
642
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
634
643
|
return tools
|
|
635
644
|
|
|
636
645
|
@staticmethod
|
|
@@ -644,15 +653,16 @@ class OpenAIResponsesModel(Model):
|
|
|
644
653
|
'strict': f.strict or False,
|
|
645
654
|
}
|
|
646
655
|
|
|
647
|
-
async def
|
|
656
|
+
async def _map_messages(
|
|
657
|
+
self, messages: list[ModelMessage]
|
|
658
|
+
) -> tuple[str | NotGiven, list[responses.ResponseInputItemParam]]:
|
|
648
659
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`."""
|
|
649
|
-
system_prompt: str = ''
|
|
650
660
|
openai_messages: list[responses.ResponseInputItemParam] = []
|
|
651
661
|
for message in messages:
|
|
652
662
|
if isinstance(message, ModelRequest):
|
|
653
663
|
for part in message.parts:
|
|
654
664
|
if isinstance(part, SystemPromptPart):
|
|
655
|
-
|
|
665
|
+
openai_messages.append(responses.EasyInputMessageParam(role='system', content=part.content))
|
|
656
666
|
elif isinstance(part, UserPromptPart):
|
|
657
667
|
openai_messages.append(await self._map_user_prompt(part))
|
|
658
668
|
elif isinstance(part, ToolReturnPart):
|
|
@@ -689,7 +699,8 @@ class OpenAIResponsesModel(Model):
|
|
|
689
699
|
assert_never(item)
|
|
690
700
|
else:
|
|
691
701
|
assert_never(message)
|
|
692
|
-
|
|
702
|
+
instructions = self._get_instructions(messages) or NOT_GIVEN
|
|
703
|
+
return instructions, openai_messages
|
|
693
704
|
|
|
694
705
|
@staticmethod
|
|
695
706
|
def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam:
|
|
@@ -762,6 +773,8 @@ class OpenAIResponsesModel(Model):
|
|
|
762
773
|
filename=f'filename.{item.format}',
|
|
763
774
|
)
|
|
764
775
|
)
|
|
776
|
+
elif isinstance(item, VideoUrl): # pragma: no cover
|
|
777
|
+
raise NotImplementedError('VideoUrl is not supported for OpenAI.')
|
|
765
778
|
else:
|
|
766
779
|
assert_never(item)
|
|
767
780
|
return responses.EasyInputMessageParam(role='user', content=content)
|
|
@@ -919,137 +932,79 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
|
|
|
919
932
|
)
|
|
920
933
|
|
|
921
934
|
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
|
|
927
|
-
but this basically just requires:
|
|
928
|
-
* `additionalProperties` must be set to false for each object in the parameters
|
|
929
|
-
* all fields in properties must be marked as required
|
|
930
|
-
"""
|
|
931
|
-
assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
|
|
935
|
+
@dataclass
|
|
936
|
+
class _OpenAIJsonSchema(WalkJsonSchema):
|
|
937
|
+
"""Recursively handle the schema to make it compatible with OpenAI strict mode.
|
|
932
938
|
|
|
933
|
-
|
|
934
|
-
|
|
939
|
+
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
|
|
940
|
+
but this basically just requires:
|
|
941
|
+
* `additionalProperties` must be set to false for each object in the parameters
|
|
942
|
+
* all fields in properties must be marked as required
|
|
943
|
+
"""
|
|
935
944
|
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
945
|
+
def __init__(self, schema: JsonSchema, strict: bool | None):
|
|
946
|
+
super().__init__(schema)
|
|
947
|
+
self.strict = strict
|
|
948
|
+
self.is_strict_compatible = True
|
|
949
|
+
|
|
950
|
+
def transform(self, schema: JsonSchema) -> JsonSchema:
|
|
951
|
+
# Remove unnecessary keys
|
|
952
|
+
schema.pop('title', None)
|
|
953
|
+
schema.pop('default', None)
|
|
954
|
+
schema.pop('$schema', None)
|
|
955
|
+
schema.pop('discriminator', None)
|
|
956
|
+
|
|
957
|
+
# Remove incompatible keys, but note their impact in the description provided to the LLM
|
|
958
|
+
description = schema.get('description')
|
|
959
|
+
min_length = schema.pop('minLength', None)
|
|
960
|
+
max_length = schema.pop('minLength', None)
|
|
961
|
+
if description is not None:
|
|
962
|
+
notes = list[str]()
|
|
963
|
+
if min_length is not None: # pragma: no cover
|
|
964
|
+
notes.append(f'min_length={min_length}')
|
|
965
|
+
if max_length is not None: # pragma: no cover
|
|
966
|
+
notes.append(f'max_length={max_length}')
|
|
967
|
+
if notes: # pragma: no cover
|
|
968
|
+
schema['description'] = f'{description} ({", ".join(notes)})'
|
|
939
969
|
|
|
940
|
-
# Process schema based on its type
|
|
941
970
|
schema_type = schema.get('type')
|
|
942
971
|
if schema_type == 'object':
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
schema['
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
972
|
+
if self.strict is True:
|
|
973
|
+
# additional properties are disallowed
|
|
974
|
+
schema['additionalProperties'] = False
|
|
975
|
+
|
|
976
|
+
# all properties are required
|
|
977
|
+
if 'properties' not in schema:
|
|
978
|
+
schema['properties'] = dict[str, Any]()
|
|
979
|
+
schema['required'] = list(schema['properties'].keys())
|
|
980
|
+
|
|
981
|
+
elif self.strict is None:
|
|
982
|
+
if (
|
|
983
|
+
schema.get('additionalProperties') is not False
|
|
984
|
+
or 'properties' not in schema
|
|
985
|
+
or 'required' not in schema
|
|
986
|
+
):
|
|
987
|
+
self.is_strict_compatible = False
|
|
988
|
+
else:
|
|
989
|
+
required = schema['required']
|
|
990
|
+
for k in schema['properties'].keys():
|
|
991
|
+
if k not in required:
|
|
992
|
+
self.is_strict_compatible = False
|
|
962
993
|
return schema
|
|
963
994
|
|
|
964
|
-
def _make_object_schema_strict(self, schema: dict[str, Any]) -> None:
|
|
965
|
-
schema['additionalProperties'] = False
|
|
966
|
-
|
|
967
|
-
# Handle patternProperties; note this may not be compatible with strict mode but is included for completeness
|
|
968
|
-
if 'patternProperties' in schema and isinstance(schema['patternProperties'], dict):
|
|
969
|
-
pattern_props: dict[str, Any] = schema['patternProperties']
|
|
970
|
-
schema['patternProperties'] = {str(k): self.make_schema_strict(v) for k, v in pattern_props.items()}
|
|
971
|
-
|
|
972
|
-
# Handle properties — update their schemas recursively, and make all properties required
|
|
973
|
-
if 'properties' in schema and isinstance(schema['properties'], dict):
|
|
974
|
-
properties: dict[str, Any] = schema['properties']
|
|
975
|
-
schema['properties'] = {k: self.make_schema_strict(v) for k, v in properties.items()}
|
|
976
|
-
schema['required'] = list(properties.keys())
|
|
977
|
-
|
|
978
|
-
def is_schema_strict(self, schema: dict[str, Any]) -> bool:
|
|
979
|
-
"""Check if the schema is strict-mode-compatible.
|
|
980
|
-
|
|
981
|
-
A schema is compatible if:
|
|
982
|
-
* `additionalProperties` is set to false for each object in the parameters
|
|
983
|
-
* all fields in properties are marked as required
|
|
984
|
-
|
|
985
|
-
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details.
|
|
986
|
-
"""
|
|
987
|
-
assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
|
|
988
|
-
|
|
989
|
-
# Note that checking the defs first is usually the fastest way to proceed, but
|
|
990
|
-
# it makes it hard/impossible to hit coverage below, hence all the pragma no covers.
|
|
991
|
-
# I still included the handling below because I'm not _confident_ those code paths can't be hit.
|
|
992
|
-
if defs := schema.get('$defs'):
|
|
993
|
-
if not all(self.is_schema_strict(v) for v in defs.values()): # pragma: no branch
|
|
994
|
-
return False
|
|
995
|
-
|
|
996
|
-
schema_type = schema.get('type')
|
|
997
|
-
if schema_type == 'object':
|
|
998
|
-
if not self._is_object_schema_strict(schema):
|
|
999
|
-
return False
|
|
1000
|
-
elif schema_type == 'array':
|
|
1001
|
-
if 'items' in schema:
|
|
1002
|
-
items: Any = schema['items']
|
|
1003
|
-
if not self.is_schema_strict(items): # pragma: no cover
|
|
1004
|
-
return False
|
|
1005
|
-
if 'prefixItems' in schema:
|
|
1006
|
-
prefix_items: list[Any] = schema['prefixItems']
|
|
1007
|
-
if not all(self.is_schema_strict(item) for item in prefix_items): # pragma: no cover
|
|
1008
|
-
return False
|
|
1009
|
-
elif schema_type in {'string', 'number', 'integer', 'boolean', 'null'}:
|
|
1010
|
-
pass
|
|
1011
|
-
elif 'oneOf' in schema: # pragma: no cover
|
|
1012
|
-
if not all(self.is_schema_strict(item) for item in schema['oneOf']):
|
|
1013
|
-
return False
|
|
1014
|
-
|
|
1015
|
-
elif 'anyOf' in schema: # pragma: no cover
|
|
1016
|
-
if not all(self.is_schema_strict(item) for item in schema['anyOf']):
|
|
1017
|
-
return False
|
|
1018
|
-
|
|
1019
|
-
return True
|
|
1020
|
-
|
|
1021
|
-
def _is_object_schema_strict(self, schema: dict[str, Any]) -> bool:
|
|
1022
|
-
"""Check if the schema is an object and has additionalProperties set to false."""
|
|
1023
|
-
if schema.get('additionalProperties') is not False:
|
|
1024
|
-
return False
|
|
1025
|
-
if 'properties' not in schema: # pragma: no cover
|
|
1026
|
-
return False
|
|
1027
|
-
if 'required' not in schema: # pragma: no cover
|
|
1028
|
-
return False
|
|
1029
|
-
|
|
1030
|
-
for k, v in schema['properties'].items():
|
|
1031
|
-
if k not in schema['required']:
|
|
1032
|
-
return False
|
|
1033
|
-
if not self.is_schema_strict(v): # pragma: no cover
|
|
1034
|
-
return False
|
|
1035
|
-
|
|
1036
|
-
return True
|
|
1037
|
-
|
|
1038
995
|
|
|
1039
996
|
def _customize_request_parameters(model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
1040
997
|
"""Customize the request parameters for OpenAI models."""
|
|
1041
998
|
|
|
1042
999
|
def _customize_tool_def(t: ToolDefinition):
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
return replace(t, strict=strict)
|
|
1049
|
-
return t
|
|
1000
|
+
schema_transformer = _OpenAIJsonSchema(t.parameters_json_schema, strict=t.strict)
|
|
1001
|
+
parameters_json_schema = schema_transformer.walk()
|
|
1002
|
+
if t.strict is None:
|
|
1003
|
+
t = replace(t, strict=schema_transformer.is_strict_compatible)
|
|
1004
|
+
return replace(t, parameters_json_schema=parameters_json_schema)
|
|
1050
1005
|
|
|
1051
1006
|
return ModelRequestParameters(
|
|
1052
1007
|
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
|
|
1053
|
-
|
|
1054
|
-
|
|
1008
|
+
allow_text_output=model_request_parameters.allow_text_output,
|
|
1009
|
+
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
|
|
1055
1010
|
)
|
pydantic_ai/models/test.py
CHANGED
|
@@ -22,9 +22,9 @@ from ..messages import (
|
|
|
22
22
|
ToolCallPart,
|
|
23
23
|
ToolReturnPart,
|
|
24
24
|
)
|
|
25
|
-
from ..result import Usage
|
|
26
25
|
from ..settings import ModelSettings
|
|
27
26
|
from ..tools import ToolDefinition
|
|
27
|
+
from ..usage import Usage
|
|
28
28
|
from . import (
|
|
29
29
|
Model,
|
|
30
30
|
ModelRequestParameters,
|
|
@@ -34,15 +34,15 @@ from .function import _estimate_string_tokens, _estimate_usage # pyright: ignor
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
@dataclass
|
|
37
|
-
class
|
|
38
|
-
"""A private wrapper class to tag
|
|
37
|
+
class _WrappedTextOutput:
|
|
38
|
+
"""A private wrapper class to tag an output that came from the custom_output_text field."""
|
|
39
39
|
|
|
40
40
|
value: str | None
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
@dataclass
|
|
44
|
-
class
|
|
45
|
-
"""A wrapper class to tag
|
|
44
|
+
class _WrappedToolOutput:
|
|
45
|
+
"""A wrapper class to tag an output that came from the custom_output_args field."""
|
|
46
46
|
|
|
47
47
|
value: Any | None
|
|
48
48
|
|
|
@@ -65,16 +65,16 @@ class TestModel(Model):
|
|
|
65
65
|
|
|
66
66
|
call_tools: list[str] | Literal['all'] = 'all'
|
|
67
67
|
"""List of tools to call. If `'all'`, all tools will be called."""
|
|
68
|
-
|
|
69
|
-
"""If set, this text is returned as the final
|
|
70
|
-
|
|
71
|
-
"""If set, these args will be passed to the
|
|
68
|
+
custom_output_text: str | None = None
|
|
69
|
+
"""If set, this text is returned as the final output."""
|
|
70
|
+
custom_output_args: Any | None = None
|
|
71
|
+
"""If set, these args will be passed to the output tool."""
|
|
72
72
|
seed: int = 0
|
|
73
73
|
"""Seed for generating random data."""
|
|
74
74
|
last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False)
|
|
75
75
|
"""The last ModelRequestParameters passed to the model in a request.
|
|
76
76
|
|
|
77
|
-
The ModelRequestParameters contains information about the function and
|
|
77
|
+
The ModelRequestParameters contains information about the function and output tools available during request handling.
|
|
78
78
|
|
|
79
79
|
This is set when a request is made, so will reflect the function tools from the last step of the last run.
|
|
80
80
|
"""
|
|
@@ -88,7 +88,6 @@ class TestModel(Model):
|
|
|
88
88
|
model_request_parameters: ModelRequestParameters,
|
|
89
89
|
) -> tuple[ModelResponse, Usage]:
|
|
90
90
|
self.last_model_request_parameters = model_request_parameters
|
|
91
|
-
|
|
92
91
|
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
93
92
|
usage = _estimate_usage([*messages, model_response])
|
|
94
93
|
return model_response, usage
|
|
@@ -128,29 +127,29 @@ class TestModel(Model):
|
|
|
128
127
|
tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
|
|
129
128
|
return [(r.name, r) for r in tools_to_call]
|
|
130
129
|
|
|
131
|
-
def
|
|
132
|
-
if self.
|
|
133
|
-
assert model_request_parameters.
|
|
134
|
-
'Plain response not allowed, but `
|
|
130
|
+
def _get_output(self, model_request_parameters: ModelRequestParameters) -> _WrappedTextOutput | _WrappedToolOutput:
|
|
131
|
+
if self.custom_output_text is not None:
|
|
132
|
+
assert model_request_parameters.allow_text_output, (
|
|
133
|
+
'Plain response not allowed, but `custom_output_text` is set.'
|
|
135
134
|
)
|
|
136
|
-
assert self.
|
|
137
|
-
return
|
|
138
|
-
elif self.
|
|
139
|
-
assert model_request_parameters.
|
|
140
|
-
'No
|
|
135
|
+
assert self.custom_output_args is None, 'Cannot set both `custom_output_text` and `custom_output_args`.'
|
|
136
|
+
return _WrappedTextOutput(self.custom_output_text)
|
|
137
|
+
elif self.custom_output_args is not None:
|
|
138
|
+
assert model_request_parameters.output_tools is not None, (
|
|
139
|
+
'No output tools provided, but `custom_output_args` is set.'
|
|
141
140
|
)
|
|
142
|
-
|
|
141
|
+
output_tool = model_request_parameters.output_tools[0]
|
|
143
142
|
|
|
144
|
-
if k :=
|
|
145
|
-
return
|
|
143
|
+
if k := output_tool.outer_typed_dict_key:
|
|
144
|
+
return _WrappedToolOutput({k: self.custom_output_args})
|
|
146
145
|
else:
|
|
147
|
-
return
|
|
148
|
-
elif model_request_parameters.
|
|
149
|
-
return
|
|
150
|
-
elif model_request_parameters.
|
|
151
|
-
return
|
|
146
|
+
return _WrappedToolOutput(self.custom_output_args)
|
|
147
|
+
elif model_request_parameters.allow_text_output:
|
|
148
|
+
return _WrappedTextOutput(None)
|
|
149
|
+
elif model_request_parameters.output_tools:
|
|
150
|
+
return _WrappedToolOutput(None)
|
|
152
151
|
else:
|
|
153
|
-
return
|
|
152
|
+
return _WrappedTextOutput(None) # pragma: no cover
|
|
154
153
|
|
|
155
154
|
def _request(
|
|
156
155
|
self,
|
|
@@ -159,8 +158,8 @@ class TestModel(Model):
|
|
|
159
158
|
model_request_parameters: ModelRequestParameters,
|
|
160
159
|
) -> ModelResponse:
|
|
161
160
|
tool_calls = self._get_tool_calls(model_request_parameters)
|
|
162
|
-
|
|
163
|
-
|
|
161
|
+
output_wrapper = self._get_output(model_request_parameters)
|
|
162
|
+
output_tools = model_request_parameters.output_tools
|
|
164
163
|
|
|
165
164
|
# if there are tools, the first thing we want to do is call all of them
|
|
166
165
|
if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
@@ -176,29 +175,29 @@ class TestModel(Model):
|
|
|
176
175
|
# check if there are any retry prompts, if so retry them
|
|
177
176
|
new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
|
|
178
177
|
if new_retry_names:
|
|
179
|
-
# Handle retries for both function tools and
|
|
178
|
+
# Handle retries for both function tools and output tools
|
|
180
179
|
# Check function tools first
|
|
181
180
|
retry_parts: list[ModelResponsePart] = [
|
|
182
181
|
ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names
|
|
183
182
|
]
|
|
184
|
-
# Check
|
|
185
|
-
if
|
|
183
|
+
# Check output tools
|
|
184
|
+
if output_tools:
|
|
186
185
|
retry_parts.extend(
|
|
187
186
|
[
|
|
188
187
|
ToolCallPart(
|
|
189
188
|
tool.name,
|
|
190
|
-
|
|
191
|
-
if isinstance(
|
|
189
|
+
output_wrapper.value
|
|
190
|
+
if isinstance(output_wrapper, _WrappedToolOutput) and output_wrapper.value is not None
|
|
192
191
|
else self.gen_tool_args(tool),
|
|
193
192
|
)
|
|
194
|
-
for tool in
|
|
193
|
+
for tool in output_tools
|
|
195
194
|
if tool.name in new_retry_names
|
|
196
195
|
]
|
|
197
196
|
)
|
|
198
197
|
return ModelResponse(parts=retry_parts, model_name=self._model_name)
|
|
199
198
|
|
|
200
|
-
if isinstance(
|
|
201
|
-
if (response_text :=
|
|
199
|
+
if isinstance(output_wrapper, _WrappedTextOutput):
|
|
200
|
+
if (response_text := output_wrapper.value) is None:
|
|
202
201
|
# build up details of tool responses
|
|
203
202
|
output: dict[str, Any] = {}
|
|
204
203
|
for message in messages:
|
|
@@ -215,16 +214,16 @@ class TestModel(Model):
|
|
|
215
214
|
else:
|
|
216
215
|
return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name)
|
|
217
216
|
else:
|
|
218
|
-
assert
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
if
|
|
217
|
+
assert output_tools, 'No output tools provided'
|
|
218
|
+
custom_output_args = output_wrapper.value
|
|
219
|
+
output_tool = output_tools[self.seed % len(output_tools)]
|
|
220
|
+
if custom_output_args is not None:
|
|
222
221
|
return ModelResponse(
|
|
223
|
-
parts=[ToolCallPart(
|
|
222
|
+
parts=[ToolCallPart(output_tool.name, custom_output_args)], model_name=self._model_name
|
|
224
223
|
)
|
|
225
224
|
else:
|
|
226
|
-
response_args = self.gen_tool_args(
|
|
227
|
-
return ModelResponse(parts=[ToolCallPart(
|
|
225
|
+
response_args = self.gen_tool_args(output_tool)
|
|
226
|
+
return ModelResponse(parts=[ToolCallPart(output_tool.name, response_args)], model_name=self._model_name)
|
|
228
227
|
|
|
229
228
|
|
|
230
229
|
@dataclass
|
pydantic_ai/models/wrapper.py
CHANGED
|
@@ -37,6 +37,9 @@ class WrapperModel(Model):
|
|
|
37
37
|
async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
|
|
38
38
|
yield response_stream
|
|
39
39
|
|
|
40
|
+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
41
|
+
return self.wrapped.customize_request_parameters(model_request_parameters)
|
|
42
|
+
|
|
40
43
|
@property
|
|
41
44
|
def model_name(self) -> str:
|
|
42
45
|
return self.wrapped.model_name
|
|
@@ -52,6 +52,10 @@ def infer_provider(provider: str) -> Provider[Any]:
|
|
|
52
52
|
from .deepseek import DeepSeekProvider
|
|
53
53
|
|
|
54
54
|
return DeepSeekProvider()
|
|
55
|
+
elif provider == 'azure':
|
|
56
|
+
from .azure import AzureProvider
|
|
57
|
+
|
|
58
|
+
return AzureProvider()
|
|
55
59
|
elif provider == 'google-vertex':
|
|
56
60
|
from .google_vertex import GoogleVertexProvider
|
|
57
61
|
|
pydantic_ai/providers/azure.py
CHANGED
|
@@ -87,9 +87,9 @@ class AzureProvider(Provider[AsyncOpenAI]):
|
|
|
87
87
|
'Must provide one of the `azure_endpoint` argument or the `AZURE_OPENAI_ENDPOINT` environment variable'
|
|
88
88
|
)
|
|
89
89
|
|
|
90
|
-
if not api_key and '
|
|
90
|
+
if not api_key and 'AZURE_OPENAI_API_KEY' not in os.environ: # pragma: no cover
|
|
91
91
|
raise UserError(
|
|
92
|
-
'Must provide one of the `api_key` argument or the `
|
|
92
|
+
'Must provide one of the `api_key` argument or the `AZURE_OPENAI_API_KEY` environment variable'
|
|
93
93
|
)
|
|
94
94
|
|
|
95
95
|
if not api_version and 'OPENAI_API_VERSION' not in os.environ: # pragma: no cover
|