pydantic-ai-slim 0.0.54__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,6 +30,7 @@ from ..messages import (
30
30
  ToolCallPart,
31
31
  ToolReturnPart,
32
32
  UserPromptPart,
33
+ VideoUrl,
33
34
  )
34
35
  from ..settings import ModelSettings
35
36
  from ..tools import ToolDefinition
@@ -39,7 +40,9 @@ from . import (
39
40
  StreamedResponse,
40
41
  cached_async_http_client,
41
42
  check_allow_model_requests,
43
+ get_user_agent,
42
44
  )
45
+ from ._json_schema import JsonSchema, WalkJsonSchema
43
46
 
44
47
  try:
45
48
  from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
@@ -251,15 +254,12 @@ class OpenAIModel(Model):
251
254
  # standalone function to make it easier to override
252
255
  if not tools:
253
256
  tool_choice: Literal['none', 'required', 'auto'] | None = None
254
- elif not model_request_parameters.allow_text_result:
257
+ elif not model_request_parameters.allow_text_output:
255
258
  tool_choice = 'required'
256
259
  else:
257
260
  tool_choice = 'auto'
258
261
 
259
- openai_messages: list[chat.ChatCompletionMessageParam] = []
260
- for m in messages:
261
- async for msg in self._map_message(m):
262
- openai_messages.append(msg)
262
+ openai_messages = await self._map_messages(messages)
263
263
 
264
264
  try:
265
265
  return await self.client.chat.completions.create(
@@ -282,6 +282,7 @@ class OpenAIModel(Model):
282
282
  logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
283
283
  reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
284
284
  user=model_settings.get('openai_user', NOT_GIVEN),
285
+ extra_headers={'User-Agent': get_user_agent()},
285
286
  )
286
287
  except APIStatusError as e:
287
288
  if (status_code := e.status_code) >= 400:
@@ -315,35 +316,40 @@ class OpenAIModel(Model):
315
316
 
316
317
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
317
318
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
318
- if model_request_parameters.result_tools:
319
- tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
319
+ if model_request_parameters.output_tools:
320
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
320
321
  return tools
321
322
 
322
- async def _map_message(self, message: ModelMessage) -> AsyncIterable[chat.ChatCompletionMessageParam]:
323
+ async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
323
324
  """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
324
- if isinstance(message, ModelRequest):
325
- async for item in self._map_user_message(message):
326
- yield item
327
- elif isinstance(message, ModelResponse):
328
- texts: list[str] = []
329
- tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
330
- for item in message.parts:
331
- if isinstance(item, TextPart):
332
- texts.append(item.content)
333
- elif isinstance(item, ToolCallPart):
334
- tool_calls.append(self._map_tool_call(item))
335
- else:
336
- assert_never(item)
337
- message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
338
- if texts:
339
- # Note: model responses from this model should only have one text item, so the following
340
- # shouldn't merge multiple texts into one unless you switch models between runs:
341
- message_param['content'] = '\n\n'.join(texts)
342
- if tool_calls:
343
- message_param['tool_calls'] = tool_calls
344
- yield message_param
345
- else:
346
- assert_never(message)
325
+ openai_messages: list[chat.ChatCompletionMessageParam] = []
326
+ for message in messages:
327
+ if isinstance(message, ModelRequest):
328
+ async for item in self._map_user_message(message):
329
+ openai_messages.append(item)
330
+ elif isinstance(message, ModelResponse):
331
+ texts: list[str] = []
332
+ tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
333
+ for item in message.parts:
334
+ if isinstance(item, TextPart):
335
+ texts.append(item.content)
336
+ elif isinstance(item, ToolCallPart):
337
+ tool_calls.append(self._map_tool_call(item))
338
+ else:
339
+ assert_never(item)
340
+ message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
341
+ if texts:
342
+ # Note: model responses from this model should only have one text item, so the following
343
+ # shouldn't merge multiple texts into one unless you switch models between runs:
344
+ message_param['content'] = '\n\n'.join(texts)
345
+ if tool_calls:
346
+ message_param['tool_calls'] = tool_calls
347
+ openai_messages.append(message_param)
348
+ else:
349
+ assert_never(message)
350
+ if instructions := self._get_instructions(messages):
351
+ openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system'))
352
+ return openai_messages
347
353
 
348
354
  @staticmethod
349
355
  def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
@@ -446,6 +452,8 @@ class OpenAIModel(Model):
446
452
  # file_data = f'data:{media_type};base64,{base64_encoded}'
447
453
  # file = File(file={'file_data': file_data, 'file_name': item.url, 'file_id': item.url}, type='file')
448
454
  # content.append(file)
455
+ elif isinstance(item, VideoUrl): # pragma: no cover
456
+ raise NotImplementedError('VideoUrl is not supported for OpenAI')
449
457
  else:
450
458
  assert_never(item)
451
459
  return chat.ChatCompletionUserMessageParam(role='user', content=content)
@@ -589,19 +597,19 @@ class OpenAIResponsesModel(Model):
589
597
  # standalone function to make it easier to override
590
598
  if not tools:
591
599
  tool_choice: Literal['none', 'required', 'auto'] | None = None
592
- elif not model_request_parameters.allow_text_result:
600
+ elif not model_request_parameters.allow_text_output:
593
601
  tool_choice = 'required'
594
602
  else:
595
603
  tool_choice = 'auto'
596
604
 
597
- system_prompt, openai_messages = await self._map_message(messages)
605
+ instructions, openai_messages = await self._map_messages(messages)
598
606
  reasoning = self._get_reasoning(model_settings)
599
607
 
600
608
  try:
601
609
  return await self.client.responses.create(
602
610
  input=openai_messages,
603
611
  model=self._model_name,
604
- instructions=system_prompt,
612
+ instructions=instructions,
605
613
  parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
606
614
  tools=tools or NOT_GIVEN,
607
615
  tool_choice=tool_choice or NOT_GIVEN,
@@ -613,6 +621,7 @@ class OpenAIResponsesModel(Model):
613
621
  timeout=model_settings.get('timeout', NOT_GIVEN),
614
622
  reasoning=reasoning,
615
623
  user=model_settings.get('openai_user', NOT_GIVEN),
624
+ extra_headers={'User-Agent': get_user_agent()},
616
625
  )
617
626
  except APIStatusError as e:
618
627
  if (status_code := e.status_code) >= 400:
@@ -629,8 +638,8 @@ class OpenAIResponsesModel(Model):
629
638
 
630
639
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
631
640
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
632
- if model_request_parameters.result_tools:
633
- tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
641
+ if model_request_parameters.output_tools:
642
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
634
643
  return tools
635
644
 
636
645
  @staticmethod
@@ -644,15 +653,16 @@ class OpenAIResponsesModel(Model):
644
653
  'strict': f.strict or False,
645
654
  }
646
655
 
647
- async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[responses.ResponseInputItemParam]]:
656
+ async def _map_messages(
657
+ self, messages: list[ModelMessage]
658
+ ) -> tuple[str | NotGiven, list[responses.ResponseInputItemParam]]:
648
659
  """Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`."""
649
- system_prompt: str = ''
650
660
  openai_messages: list[responses.ResponseInputItemParam] = []
651
661
  for message in messages:
652
662
  if isinstance(message, ModelRequest):
653
663
  for part in message.parts:
654
664
  if isinstance(part, SystemPromptPart):
655
- system_prompt += part.content
665
+ openai_messages.append(responses.EasyInputMessageParam(role='system', content=part.content))
656
666
  elif isinstance(part, UserPromptPart):
657
667
  openai_messages.append(await self._map_user_prompt(part))
658
668
  elif isinstance(part, ToolReturnPart):
@@ -689,7 +699,8 @@ class OpenAIResponsesModel(Model):
689
699
  assert_never(item)
690
700
  else:
691
701
  assert_never(message)
692
- return system_prompt, openai_messages
702
+ instructions = self._get_instructions(messages) or NOT_GIVEN
703
+ return instructions, openai_messages
693
704
 
694
705
  @staticmethod
695
706
  def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam:
@@ -762,6 +773,8 @@ class OpenAIResponsesModel(Model):
762
773
  filename=f'filename.{item.format}',
763
774
  )
764
775
  )
776
+ elif isinstance(item, VideoUrl): # pragma: no cover
777
+ raise NotImplementedError('VideoUrl is not supported for OpenAI.')
765
778
  else:
766
779
  assert_never(item)
767
780
  return responses.EasyInputMessageParam(role='user', content=content)
@@ -919,137 +932,79 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
919
932
  )
920
933
 
921
934
 
922
- class _StrictSchemaHelper:
923
- def make_schema_strict(self, schema: dict[str, Any]) -> dict[str, Any]:
924
- """Recursively handle the schema to make it compatible with OpenAI strict mode.
925
-
926
- See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
927
- but this basically just requires:
928
- * `additionalProperties` must be set to false for each object in the parameters
929
- * all fields in properties must be marked as required
930
- """
931
- assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
935
+ @dataclass
936
+ class _OpenAIJsonSchema(WalkJsonSchema):
937
+ """Recursively handle the schema to make it compatible with OpenAI strict mode.
932
938
 
933
- # Create a copy to avoid modifying the original schema
934
- schema = schema.copy()
939
+ See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
940
+ but this basically just requires:
941
+ * `additionalProperties` must be set to false for each object in the parameters
942
+ * all fields in properties must be marked as required
943
+ """
935
944
 
936
- # Handle $defs
937
- if defs := schema.get('$defs'):
938
- schema['$defs'] = {k: self.make_schema_strict(v) for k, v in defs.items()}
945
+ def __init__(self, schema: JsonSchema, strict: bool | None):
946
+ super().__init__(schema)
947
+ self.strict = strict
948
+ self.is_strict_compatible = True
949
+
950
+ def transform(self, schema: JsonSchema) -> JsonSchema:
951
+ # Remove unnecessary keys
952
+ schema.pop('title', None)
953
+ schema.pop('default', None)
954
+ schema.pop('$schema', None)
955
+ schema.pop('discriminator', None)
956
+
957
+ # Remove incompatible keys, but note their impact in the description provided to the LLM
958
+ description = schema.get('description')
959
+ min_length = schema.pop('minLength', None)
960
+ max_length = schema.pop('minLength', None)
961
+ if description is not None:
962
+ notes = list[str]()
963
+ if min_length is not None: # pragma: no cover
964
+ notes.append(f'min_length={min_length}')
965
+ if max_length is not None: # pragma: no cover
966
+ notes.append(f'max_length={max_length}')
967
+ if notes: # pragma: no cover
968
+ schema['description'] = f'{description} ({", ".join(notes)})'
939
969
 
940
- # Process schema based on its type
941
970
  schema_type = schema.get('type')
942
971
  if schema_type == 'object':
943
- # Handle object type by setting additionalProperties to false
944
- # and adding all properties to required list
945
- self._make_object_schema_strict(schema)
946
- elif schema_type == 'array':
947
- # Handle array types by processing their items
948
- if 'items' in schema:
949
- items: Any = schema['items']
950
- schema['items'] = self.make_schema_strict(items)
951
- if 'prefixItems' in schema:
952
- prefix_items: list[Any] = schema['prefixItems']
953
- schema['prefixItems'] = [self.make_schema_strict(item) for item in prefix_items]
954
-
955
- elif schema_type in {'string', 'number', 'integer', 'boolean', 'null'}:
956
- pass # Primitive types need no special handling
957
- elif 'oneOf' in schema:
958
- schema['oneOf'] = [self.make_schema_strict(item) for item in schema['oneOf']]
959
- elif 'anyOf' in schema:
960
- schema['anyOf'] = [self.make_schema_strict(item) for item in schema['anyOf']]
961
-
972
+ if self.strict is True:
973
+ # additional properties are disallowed
974
+ schema['additionalProperties'] = False
975
+
976
+ # all properties are required
977
+ if 'properties' not in schema:
978
+ schema['properties'] = dict[str, Any]()
979
+ schema['required'] = list(schema['properties'].keys())
980
+
981
+ elif self.strict is None:
982
+ if (
983
+ schema.get('additionalProperties') is not False
984
+ or 'properties' not in schema
985
+ or 'required' not in schema
986
+ ):
987
+ self.is_strict_compatible = False
988
+ else:
989
+ required = schema['required']
990
+ for k in schema['properties'].keys():
991
+ if k not in required:
992
+ self.is_strict_compatible = False
962
993
  return schema
963
994
 
964
- def _make_object_schema_strict(self, schema: dict[str, Any]) -> None:
965
- schema['additionalProperties'] = False
966
-
967
- # Handle patternProperties; note this may not be compatible with strict mode but is included for completeness
968
- if 'patternProperties' in schema and isinstance(schema['patternProperties'], dict):
969
- pattern_props: dict[str, Any] = schema['patternProperties']
970
- schema['patternProperties'] = {str(k): self.make_schema_strict(v) for k, v in pattern_props.items()}
971
-
972
- # Handle properties — update their schemas recursively, and make all properties required
973
- if 'properties' in schema and isinstance(schema['properties'], dict):
974
- properties: dict[str, Any] = schema['properties']
975
- schema['properties'] = {k: self.make_schema_strict(v) for k, v in properties.items()}
976
- schema['required'] = list(properties.keys())
977
-
978
- def is_schema_strict(self, schema: dict[str, Any]) -> bool:
979
- """Check if the schema is strict-mode-compatible.
980
-
981
- A schema is compatible if:
982
- * `additionalProperties` is set to false for each object in the parameters
983
- * all fields in properties are marked as required
984
-
985
- See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details.
986
- """
987
- assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
988
-
989
- # Note that checking the defs first is usually the fastest way to proceed, but
990
- # it makes it hard/impossible to hit coverage below, hence all the pragma no covers.
991
- # I still included the handling below because I'm not _confident_ those code paths can't be hit.
992
- if defs := schema.get('$defs'):
993
- if not all(self.is_schema_strict(v) for v in defs.values()): # pragma: no branch
994
- return False
995
-
996
- schema_type = schema.get('type')
997
- if schema_type == 'object':
998
- if not self._is_object_schema_strict(schema):
999
- return False
1000
- elif schema_type == 'array':
1001
- if 'items' in schema:
1002
- items: Any = schema['items']
1003
- if not self.is_schema_strict(items): # pragma: no cover
1004
- return False
1005
- if 'prefixItems' in schema:
1006
- prefix_items: list[Any] = schema['prefixItems']
1007
- if not all(self.is_schema_strict(item) for item in prefix_items): # pragma: no cover
1008
- return False
1009
- elif schema_type in {'string', 'number', 'integer', 'boolean', 'null'}:
1010
- pass
1011
- elif 'oneOf' in schema: # pragma: no cover
1012
- if not all(self.is_schema_strict(item) for item in schema['oneOf']):
1013
- return False
1014
-
1015
- elif 'anyOf' in schema: # pragma: no cover
1016
- if not all(self.is_schema_strict(item) for item in schema['anyOf']):
1017
- return False
1018
-
1019
- return True
1020
-
1021
- def _is_object_schema_strict(self, schema: dict[str, Any]) -> bool:
1022
- """Check if the schema is an object and has additionalProperties set to false."""
1023
- if schema.get('additionalProperties') is not False:
1024
- return False
1025
- if 'properties' not in schema: # pragma: no cover
1026
- return False
1027
- if 'required' not in schema: # pragma: no cover
1028
- return False
1029
-
1030
- for k, v in schema['properties'].items():
1031
- if k not in schema['required']:
1032
- return False
1033
- if not self.is_schema_strict(v): # pragma: no cover
1034
- return False
1035
-
1036
- return True
1037
-
1038
995
 
1039
996
  def _customize_request_parameters(model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
1040
997
  """Customize the request parameters for OpenAI models."""
1041
998
 
1042
999
  def _customize_tool_def(t: ToolDefinition):
1043
- if t.strict is True:
1044
- parameters_json_schema = _StrictSchemaHelper().make_schema_strict(t.parameters_json_schema)
1045
- return replace(t, parameters_json_schema=parameters_json_schema)
1046
- elif t.strict is None:
1047
- strict = _StrictSchemaHelper().is_schema_strict(t.parameters_json_schema)
1048
- return replace(t, strict=strict)
1049
- return t
1000
+ schema_transformer = _OpenAIJsonSchema(t.parameters_json_schema, strict=t.strict)
1001
+ parameters_json_schema = schema_transformer.walk()
1002
+ if t.strict is None:
1003
+ t = replace(t, strict=schema_transformer.is_strict_compatible)
1004
+ return replace(t, parameters_json_schema=parameters_json_schema)
1050
1005
 
1051
1006
  return ModelRequestParameters(
1052
1007
  function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
1053
- allow_text_result=model_request_parameters.allow_text_result,
1054
- result_tools=[_customize_tool_def(tool) for tool in model_request_parameters.result_tools],
1008
+ allow_text_output=model_request_parameters.allow_text_output,
1009
+ output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
1055
1010
  )
@@ -22,9 +22,9 @@ from ..messages import (
22
22
  ToolCallPart,
23
23
  ToolReturnPart,
24
24
  )
25
- from ..result import Usage
26
25
  from ..settings import ModelSettings
27
26
  from ..tools import ToolDefinition
27
+ from ..usage import Usage
28
28
  from . import (
29
29
  Model,
30
30
  ModelRequestParameters,
@@ -34,15 +34,15 @@ from .function import _estimate_string_tokens, _estimate_usage # pyright: ignor
34
34
 
35
35
 
36
36
  @dataclass
37
- class _TextResult:
38
- """A private wrapper class to tag a result that came from the custom_result_text field."""
37
+ class _WrappedTextOutput:
38
+ """A private wrapper class to tag an output that came from the custom_output_text field."""
39
39
 
40
40
  value: str | None
41
41
 
42
42
 
43
43
  @dataclass
44
- class _FunctionToolResult:
45
- """A wrapper class to tag a result that came from the custom_result_args field."""
44
+ class _WrappedToolOutput:
45
+ """A wrapper class to tag an output that came from the custom_output_args field."""
46
46
 
47
47
  value: Any | None
48
48
 
@@ -65,16 +65,16 @@ class TestModel(Model):
65
65
 
66
66
  call_tools: list[str] | Literal['all'] = 'all'
67
67
  """List of tools to call. If `'all'`, all tools will be called."""
68
- custom_result_text: str | None = None
69
- """If set, this text is returned as the final result."""
70
- custom_result_args: Any | None = None
71
- """If set, these args will be passed to the result tool."""
68
+ custom_output_text: str | None = None
69
+ """If set, this text is returned as the final output."""
70
+ custom_output_args: Any | None = None
71
+ """If set, these args will be passed to the output tool."""
72
72
  seed: int = 0
73
73
  """Seed for generating random data."""
74
74
  last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False)
75
75
  """The last ModelRequestParameters passed to the model in a request.
76
76
 
77
- The ModelRequestParameters contains information about the function and result tools available during request handling.
77
+ The ModelRequestParameters contains information about the function and output tools available during request handling.
78
78
 
79
79
  This is set when a request is made, so will reflect the function tools from the last step of the last run.
80
80
  """
@@ -88,7 +88,6 @@ class TestModel(Model):
88
88
  model_request_parameters: ModelRequestParameters,
89
89
  ) -> tuple[ModelResponse, Usage]:
90
90
  self.last_model_request_parameters = model_request_parameters
91
-
92
91
  model_response = self._request(messages, model_settings, model_request_parameters)
93
92
  usage = _estimate_usage([*messages, model_response])
94
93
  return model_response, usage
@@ -128,29 +127,29 @@ class TestModel(Model):
128
127
  tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
129
128
  return [(r.name, r) for r in tools_to_call]
130
129
 
131
- def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
132
- if self.custom_result_text is not None:
133
- assert model_request_parameters.allow_text_result, (
134
- 'Plain response not allowed, but `custom_result_text` is set.'
130
+ def _get_output(self, model_request_parameters: ModelRequestParameters) -> _WrappedTextOutput | _WrappedToolOutput:
131
+ if self.custom_output_text is not None:
132
+ assert model_request_parameters.allow_text_output, (
133
+ 'Plain response not allowed, but `custom_output_text` is set.'
135
134
  )
136
- assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
137
- return _TextResult(self.custom_result_text)
138
- elif self.custom_result_args is not None:
139
- assert model_request_parameters.result_tools is not None, (
140
- 'No result tools provided, but `custom_result_args` is set.'
135
+ assert self.custom_output_args is None, 'Cannot set both `custom_output_text` and `custom_output_args`.'
136
+ return _WrappedTextOutput(self.custom_output_text)
137
+ elif self.custom_output_args is not None:
138
+ assert model_request_parameters.output_tools is not None, (
139
+ 'No output tools provided, but `custom_output_args` is set.'
141
140
  )
142
- result_tool = model_request_parameters.result_tools[0]
141
+ output_tool = model_request_parameters.output_tools[0]
143
142
 
144
- if k := result_tool.outer_typed_dict_key:
145
- return _FunctionToolResult({k: self.custom_result_args})
143
+ if k := output_tool.outer_typed_dict_key:
144
+ return _WrappedToolOutput({k: self.custom_output_args})
146
145
  else:
147
- return _FunctionToolResult(self.custom_result_args)
148
- elif model_request_parameters.allow_text_result:
149
- return _TextResult(None)
150
- elif model_request_parameters.result_tools:
151
- return _FunctionToolResult(None)
146
+ return _WrappedToolOutput(self.custom_output_args)
147
+ elif model_request_parameters.allow_text_output:
148
+ return _WrappedTextOutput(None)
149
+ elif model_request_parameters.output_tools:
150
+ return _WrappedToolOutput(None)
152
151
  else:
153
- return _TextResult(None)
152
+ return _WrappedTextOutput(None) # pragma: no cover
154
153
 
155
154
  def _request(
156
155
  self,
@@ -159,8 +158,8 @@ class TestModel(Model):
159
158
  model_request_parameters: ModelRequestParameters,
160
159
  ) -> ModelResponse:
161
160
  tool_calls = self._get_tool_calls(model_request_parameters)
162
- result = self._get_result(model_request_parameters)
163
- result_tools = model_request_parameters.result_tools
161
+ output_wrapper = self._get_output(model_request_parameters)
162
+ output_tools = model_request_parameters.output_tools
164
163
 
165
164
  # if there are tools, the first thing we want to do is call all of them
166
165
  if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
@@ -176,29 +175,29 @@ class TestModel(Model):
176
175
  # check if there are any retry prompts, if so retry them
177
176
  new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
178
177
  if new_retry_names:
179
- # Handle retries for both function tools and result tools
178
+ # Handle retries for both function tools and output tools
180
179
  # Check function tools first
181
180
  retry_parts: list[ModelResponsePart] = [
182
181
  ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names
183
182
  ]
184
- # Check result tools
185
- if result_tools:
183
+ # Check output tools
184
+ if output_tools:
186
185
  retry_parts.extend(
187
186
  [
188
187
  ToolCallPart(
189
188
  tool.name,
190
- result.value
191
- if isinstance(result, _FunctionToolResult) and result.value is not None
189
+ output_wrapper.value
190
+ if isinstance(output_wrapper, _WrappedToolOutput) and output_wrapper.value is not None
192
191
  else self.gen_tool_args(tool),
193
192
  )
194
- for tool in result_tools
193
+ for tool in output_tools
195
194
  if tool.name in new_retry_names
196
195
  ]
197
196
  )
198
197
  return ModelResponse(parts=retry_parts, model_name=self._model_name)
199
198
 
200
- if isinstance(result, _TextResult):
201
- if (response_text := result.value) is None:
199
+ if isinstance(output_wrapper, _WrappedTextOutput):
200
+ if (response_text := output_wrapper.value) is None:
202
201
  # build up details of tool responses
203
202
  output: dict[str, Any] = {}
204
203
  for message in messages:
@@ -215,16 +214,16 @@ class TestModel(Model):
215
214
  else:
216
215
  return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name)
217
216
  else:
218
- assert result_tools, 'No result tools provided'
219
- custom_result_args = result.value
220
- result_tool = result_tools[self.seed % len(result_tools)]
221
- if custom_result_args is not None:
217
+ assert output_tools, 'No output tools provided'
218
+ custom_output_args = output_wrapper.value
219
+ output_tool = output_tools[self.seed % len(output_tools)]
220
+ if custom_output_args is not None:
222
221
  return ModelResponse(
223
- parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self._model_name
222
+ parts=[ToolCallPart(output_tool.name, custom_output_args)], model_name=self._model_name
224
223
  )
225
224
  else:
226
- response_args = self.gen_tool_args(result_tool)
227
- return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self._model_name)
225
+ response_args = self.gen_tool_args(output_tool)
226
+ return ModelResponse(parts=[ToolCallPart(output_tool.name, response_args)], model_name=self._model_name)
228
227
 
229
228
 
230
229
  @dataclass
@@ -37,6 +37,9 @@ class WrapperModel(Model):
37
37
  async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
38
38
  yield response_stream
39
39
 
40
+ def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
41
+ return self.wrapped.customize_request_parameters(model_request_parameters)
42
+
40
43
  @property
41
44
  def model_name(self) -> str:
42
45
  return self.wrapped.model_name
@@ -52,6 +52,10 @@ def infer_provider(provider: str) -> Provider[Any]:
52
52
  from .deepseek import DeepSeekProvider
53
53
 
54
54
  return DeepSeekProvider()
55
+ elif provider == 'azure':
56
+ from .azure import AzureProvider
57
+
58
+ return AzureProvider()
55
59
  elif provider == 'google-vertex':
56
60
  from .google_vertex import GoogleVertexProvider
57
61
 
@@ -87,9 +87,9 @@ class AzureProvider(Provider[AsyncOpenAI]):
87
87
  'Must provide one of the `azure_endpoint` argument or the `AZURE_OPENAI_ENDPOINT` environment variable'
88
88
  )
89
89
 
90
- if not api_key and 'OPENAI_API_KEY' not in os.environ: # pragma: no cover
90
+ if not api_key and 'AZURE_OPENAI_API_KEY' not in os.environ: # pragma: no cover
91
91
  raise UserError(
92
- 'Must provide one of the `api_key` argument or the `OPENAI_API_KEY` environment variable'
92
+ 'Must provide one of the `api_key` argument or the `AZURE_OPENAI_API_KEY` environment variable'
93
93
  )
94
94
 
95
95
  if not api_version and 'OPENAI_API_VERSION' not in os.environ: # pragma: no cover