pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -197,7 +197,7 @@ class GroqAgentModel(AgentModel):
197
197
  model=str(self.model_name),
198
198
  messages=groq_messages,
199
199
  n=1,
200
- parallel_tool_calls=True if self.tools else NOT_GIVEN,
200
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
201
201
  tools=self.tools or NOT_GIVEN,
202
202
  tool_choice=tool_choice or NOT_GIVEN,
203
203
  stream=stream,
@@ -207,8 +207,7 @@ class GroqAgentModel(AgentModel):
207
207
  timeout=model_settings.get('timeout', NOT_GIVEN),
208
208
  )
209
209
 
210
- @staticmethod
211
- def _process_response(response: chat.ChatCompletion) -> ModelResponse:
210
+ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
212
211
  """Process a non-streamed response, and prepare a message to return."""
213
212
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
214
213
  choice = response.choices[0]
@@ -220,17 +219,20 @@ class GroqAgentModel(AgentModel):
220
219
  items.append(
221
220
  ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
222
221
  )
223
- return ModelResponse(items, timestamp=timestamp)
222
+ return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
224
223
 
225
- @staticmethod
226
- async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
224
+ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
227
225
  """Process a streamed response, and prepare a streaming response to return."""
228
226
  peekable_response = _utils.PeekableAsyncStream(response)
229
227
  first_chunk = await peekable_response.peek()
230
228
  if isinstance(first_chunk, _utils.Unset):
231
229
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
232
230
 
233
- return GroqStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc))
231
+ return GroqStreamedResponse(
232
+ _response=peekable_response,
233
+ _model_name=self.model_name,
234
+ _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
235
+ )
234
236
 
235
237
  @classmethod
236
238
  def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
@@ -36,6 +36,7 @@ from . import (
36
36
  Model,
37
37
  StreamedResponse,
38
38
  cached_async_http_client,
39
+ check_allow_model_requests,
39
40
  )
40
41
 
41
42
  try:
@@ -130,6 +131,7 @@ class MistralModel(Model):
130
131
  result_tools: list[ToolDefinition],
131
132
  ) -> AgentModel:
132
133
  """Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
134
+ check_allow_model_requests()
133
135
  return MistralAgentModel(
134
136
  self.client,
135
137
  self.model_name,
@@ -147,7 +149,7 @@ class MistralAgentModel(AgentModel):
147
149
  """Implementation of `AgentModel` for Mistral models."""
148
150
 
149
151
  client: Mistral
150
- model_name: str
152
+ model_name: MistralModelName
151
153
  allow_text_result: bool
152
154
  function_tools: list[ToolDefinition]
153
155
  result_tools: list[ToolDefinition]
@@ -265,8 +267,7 @@ class MistralAgentModel(AgentModel):
265
267
  ]
266
268
  return tools if tools else None
267
269
 
268
- @staticmethod
269
- def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
270
+ def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse:
270
271
  """Process a non-streamed response, and prepare a message to return."""
271
272
  assert response.choices, 'Unexpected empty response choice.'
272
273
 
@@ -288,10 +289,10 @@ class MistralAgentModel(AgentModel):
288
289
  tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
289
290
  parts.append(tool)
290
291
 
291
- return ModelResponse(parts, timestamp=timestamp)
292
+ return ModelResponse(parts, model_name=self.model_name, timestamp=timestamp)
292
293
 
293
- @staticmethod
294
294
  async def _process_streamed_response(
295
+ self,
295
296
  result_tools: list[ToolDefinition],
296
297
  response: MistralEventStreamAsync[MistralCompletionEvent],
297
298
  ) -> StreamedResponse:
@@ -306,7 +307,12 @@ class MistralAgentModel(AgentModel):
306
307
  else:
307
308
  timestamp = datetime.now(tz=timezone.utc)
308
309
 
309
- return MistralStreamedResponse(peekable_response, timestamp, {c.name: c for c in result_tools})
310
+ return MistralStreamedResponse(
311
+ _response=peekable_response,
312
+ _model_name=self.model_name,
313
+ _timestamp=timestamp,
314
+ _result_tools={c.name: c for c in result_tools},
315
+ )
310
316
 
311
317
  @staticmethod
312
318
  def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
@@ -10,6 +10,7 @@ from . import (
10
10
  AgentModel,
11
11
  Model,
12
12
  cached_async_http_client,
13
+ check_allow_model_requests,
13
14
  )
14
15
 
15
16
  try:
@@ -25,6 +26,7 @@ from .openai import OpenAIModel
25
26
 
26
27
  CommonOllamaModelNames = Literal[
27
28
  'codellama',
29
+ 'deepseek-r1',
28
30
  'gemma',
29
31
  'gemma2',
30
32
  'llama3',
@@ -110,6 +112,7 @@ class OllamaModel(Model):
110
112
  allow_text_result: bool,
111
113
  result_tools: list[ToolDefinition],
112
114
  ) -> AgentModel:
115
+ check_allow_model_requests()
113
116
  return await self.openai_model.agent_model(
114
117
  function_tools=function_tools,
115
118
  allow_text_result=allow_text_result,
@@ -51,6 +51,8 @@ Using this more broad type for the model name instead of the ChatModel definitio
51
51
  allows this model to be used more easily with other model types (ie, Ollama)
52
52
  """
53
53
 
54
+ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
55
+
54
56
 
55
57
  @dataclass(init=False)
56
58
  class OpenAIModel(Model):
@@ -63,6 +65,7 @@ class OpenAIModel(Model):
63
65
 
64
66
  model_name: OpenAIModelName
65
67
  client: AsyncOpenAI = field(repr=False)
68
+ system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
66
69
 
67
70
  def __init__(
68
71
  self,
@@ -72,6 +75,7 @@ class OpenAIModel(Model):
72
75
  api_key: str | None = None,
73
76
  openai_client: AsyncOpenAI | None = None,
74
77
  http_client: AsyncHTTPClient | None = None,
78
+ system_prompt_role: OpenAISystemPromptRole | None = None,
75
79
  ):
76
80
  """Initialize an OpenAI model.
77
81
 
@@ -87,6 +91,8 @@ class OpenAIModel(Model):
87
91
  [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
88
92
  client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
89
93
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
94
+ system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
95
+ In the future, this may be inferred from the model name.
90
96
  """
91
97
  self.model_name: OpenAIModelName = model_name
92
98
  if openai_client is not None:
@@ -98,6 +104,7 @@ class OpenAIModel(Model):
98
104
  self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
99
105
  else:
100
106
  self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
107
+ self.system_prompt_role = system_prompt_role
101
108
 
102
109
  async def agent_model(
103
110
  self,
@@ -115,6 +122,7 @@ class OpenAIModel(Model):
115
122
  self.model_name,
116
123
  allow_text_result,
117
124
  tools,
125
+ self.system_prompt_role,
118
126
  )
119
127
 
120
128
  def name(self) -> str:
@@ -140,6 +148,7 @@ class OpenAIAgentModel(AgentModel):
140
148
  model_name: OpenAIModelName
141
149
  allow_text_result: bool
142
150
  tools: list[chat.ChatCompletionToolParam]
151
+ system_prompt_role: OpenAISystemPromptRole | None
143
152
 
144
153
  async def request(
145
154
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
@@ -186,7 +195,7 @@ class OpenAIAgentModel(AgentModel):
186
195
  model=self.model_name,
187
196
  messages=openai_messages,
188
197
  n=1,
189
- parallel_tool_calls=True if self.tools else NOT_GIVEN,
198
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
190
199
  tools=self.tools or NOT_GIVEN,
191
200
  tool_choice=tool_choice or NOT_GIVEN,
192
201
  stream=stream,
@@ -197,8 +206,7 @@ class OpenAIAgentModel(AgentModel):
197
206
  timeout=model_settings.get('timeout', NOT_GIVEN),
198
207
  )
199
208
 
200
- @staticmethod
201
- def _process_response(response: chat.ChatCompletion) -> ModelResponse:
209
+ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
202
210
  """Process a non-streamed response, and prepare a message to return."""
203
211
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
204
212
  choice = response.choices[0]
@@ -208,23 +216,25 @@ class OpenAIAgentModel(AgentModel):
208
216
  if choice.message.tool_calls is not None:
209
217
  for c in choice.message.tool_calls:
210
218
  items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
211
- return ModelResponse(items, timestamp=timestamp)
219
+ return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
212
220
 
213
- @staticmethod
214
- async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
221
+ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
215
222
  """Process a streamed response, and prepare a streaming response to return."""
216
223
  peekable_response = _utils.PeekableAsyncStream(response)
217
224
  first_chunk = await peekable_response.peek()
218
225
  if isinstance(first_chunk, _utils.Unset):
219
226
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
220
227
 
221
- return OpenAIStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc))
228
+ return OpenAIStreamedResponse(
229
+ _model_name=self.model_name,
230
+ _response=peekable_response,
231
+ _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
232
+ )
222
233
 
223
- @classmethod
224
- def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
234
+ def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
225
235
  """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
226
236
  if isinstance(message, ModelRequest):
227
- yield from cls._map_user_message(message)
237
+ yield from self._map_user_message(message)
228
238
  elif isinstance(message, ModelResponse):
229
239
  texts: list[str] = []
230
240
  tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
@@ -246,11 +256,15 @@ class OpenAIAgentModel(AgentModel):
246
256
  else:
247
257
  assert_never(message)
248
258
 
249
- @classmethod
250
- def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
259
+ def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
251
260
  for part in message.parts:
252
261
  if isinstance(part, SystemPromptPart):
253
- yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
262
+ if self.system_prompt_role == 'developer':
263
+ yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
264
+ elif self.system_prompt_role == 'user':
265
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
266
+ else:
267
+ yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
254
268
  elif isinstance(part, UserPromptPart):
255
269
  yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
256
270
  elif isinstance(part, ToolReturnPart):
@@ -129,6 +129,7 @@ class TestAgentModel(AgentModel):
129
129
  result: _utils.Either[str | None, Any | None]
130
130
  result_tools: list[ToolDefinition]
131
131
  seed: int
132
+ model_name: str = 'test'
132
133
 
133
134
  async def request(
134
135
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
@@ -142,7 +143,7 @@ class TestAgentModel(AgentModel):
142
143
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
143
144
  ) -> AsyncIterator[StreamedResponse]:
144
145
  model_response = self._request(messages, model_settings)
145
- yield TestStreamedResponse(model_response, messages)
146
+ yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
146
147
 
147
148
  def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
148
149
  return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
@@ -151,7 +152,8 @@ class TestAgentModel(AgentModel):
151
152
  # if there are tools, the first thing we want to do is call all of them
152
153
  if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
153
154
  return ModelResponse(
154
- parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
155
+ parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
156
+ model_name=self.model_name,
155
157
  )
156
158
 
157
159
  if messages:
@@ -177,7 +179,7 @@ class TestAgentModel(AgentModel):
177
179
  if tool.name in new_retry_names
178
180
  ]
179
181
  )
180
- return ModelResponse(parts=retry_parts)
182
+ return ModelResponse(parts=retry_parts, model_name=self.model_name)
181
183
 
182
184
  if response_text := self.result.left:
183
185
  if response_text.value is None:
@@ -189,20 +191,26 @@ class TestAgentModel(AgentModel):
189
191
  if isinstance(part, ToolReturnPart):
190
192
  output[part.tool_name] = part.content
191
193
  if output:
192
- return ModelResponse.from_text(pydantic_core.to_json(output).decode())
194
+ return ModelResponse(
195
+ parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name
196
+ )
193
197
  else:
194
- return ModelResponse.from_text('success (no tool calls)')
198
+ return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
195
199
  else:
196
- return ModelResponse.from_text(response_text.value)
200
+ return ModelResponse(parts=[TextPart(response_text.value)], model_name=self.model_name)
197
201
  else:
198
202
  assert self.result_tools, 'No result tools provided'
199
203
  custom_result_args = self.result.right
200
204
  result_tool = self.result_tools[self.seed % len(self.result_tools)]
201
205
  if custom_result_args is not None:
202
- return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)])
206
+ return ModelResponse(
207
+ parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)], model_name=self.model_name
208
+ )
203
209
  else:
204
210
  response_args = self.gen_tool_args(result_tool)
205
- return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)])
211
+ return ModelResponse(
212
+ parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)], model_name=self.model_name
213
+ )
206
214
 
207
215
 
208
216
  @dataclass
@@ -10,7 +10,7 @@ from httpx import AsyncClient as AsyncHTTPClient
10
10
  from .._utils import run_in_executor
11
11
  from ..exceptions import UserError
12
12
  from ..tools import ToolDefinition
13
- from . import Model, cached_async_http_client
13
+ from . import Model, cached_async_http_client, check_allow_model_requests
14
14
  from .gemini import GeminiAgentModel, GeminiModelName
15
15
 
16
16
  try:
@@ -114,6 +114,7 @@ class VertexAIModel(Model):
114
114
  allow_text_result: bool,
115
115
  result_tools: list[ToolDefinition],
116
116
  ) -> GeminiAgentModel:
117
+ check_allow_model_requests()
117
118
  url, auth = await self.ainit()
118
119
  return GeminiAgentModel(
119
120
  http_client=self.http_client,
pydantic_ai/result.py CHANGED
@@ -11,35 +11,49 @@ import logfire_api
11
11
  from typing_extensions import TypeVar
12
12
 
13
13
  from . import _result, _utils, exceptions, messages as _messages, models
14
- from .tools import AgentDeps, RunContext
14
+ from .tools import AgentDepsT, RunContext
15
15
  from .usage import Usage, UsageLimits
16
16
 
17
- __all__ = 'ResultData', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
17
+ __all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
18
18
 
19
19
 
20
- ResultData = TypeVar('ResultData', default=str)
21
- """Type variable for the result data of a run."""
20
+ T = TypeVar('T')
21
+ """An invariant TypeVar."""
22
+ ResultDataT_inv = TypeVar('ResultDataT_inv', default=str)
23
+ """
24
+ An invariant type variable for the result data of a model.
25
+
26
+ We need to use an invariant typevar for `ResultValidator` and `ResultValidatorFunc` because the result data type is used
27
+ in both the input and output of a `ResultValidatorFunc`. This can theoretically lead to some issues assuming that types
28
+ possessing ResultValidator's are covariant in the result data type, but in practice this is rarely an issue, and
29
+ changing it would have negative consequences for the ergonomics of the library.
30
+
31
+ At some point, it may make sense to change the input to ResultValidatorFunc to be `Any` or `object` as doing that would
32
+ resolve these potential variance issues.
33
+ """
34
+ ResultDataT = TypeVar('ResultDataT', default=str, covariant=True)
35
+ """Covariant type variable for the result data type of a run."""
22
36
 
23
37
  ResultValidatorFunc = Union[
24
- Callable[[RunContext[AgentDeps], ResultData], ResultData],
25
- Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
26
- Callable[[ResultData], ResultData],
27
- Callable[[ResultData], Awaitable[ResultData]],
38
+ Callable[[RunContext[AgentDepsT], ResultDataT_inv], ResultDataT_inv],
39
+ Callable[[RunContext[AgentDepsT], ResultDataT_inv], Awaitable[ResultDataT_inv]],
40
+ Callable[[ResultDataT_inv], ResultDataT_inv],
41
+ Callable[[ResultDataT_inv], Awaitable[ResultDataT_inv]],
28
42
  ]
29
43
  """
30
- A function that always takes `ResultData` and returns `ResultData` and:
44
+ A function that always takes and returns the same type of data (which is the result type of an agent run), and:
31
45
 
32
46
  * may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
33
47
  * may or may not be async
34
48
 
35
- Usage `ResultValidatorFunc[AgentDeps, ResultData]`.
49
+ Usage `ResultValidatorFunc[AgentDeps, T]`.
36
50
  """
37
51
 
38
52
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
39
53
 
40
54
 
41
55
  @dataclass
42
- class _BaseRunResult(ABC, Generic[ResultData]):
56
+ class _BaseRunResult(ABC, Generic[ResultDataT]):
43
57
  """Base type for results.
44
58
 
45
59
  You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`.
@@ -119,10 +133,10 @@ class _BaseRunResult(ABC, Generic[ResultData]):
119
133
 
120
134
 
121
135
  @dataclass
122
- class RunResult(_BaseRunResult[ResultData]):
136
+ class RunResult(_BaseRunResult[ResultDataT]):
123
137
  """Result of a non-streamed run."""
124
138
 
125
- data: ResultData
139
+ data: ResultDataT
126
140
  """Data from the final response in the run."""
127
141
  _result_tool_name: str | None
128
142
  _usage: Usage
@@ -165,14 +179,14 @@ class RunResult(_BaseRunResult[ResultData]):
165
179
 
166
180
 
167
181
  @dataclass
168
- class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
182
+ class StreamedRunResult(_BaseRunResult[ResultDataT], Generic[AgentDepsT, ResultDataT]):
169
183
  """Result of a streamed run that returns structured data via a tool call."""
170
184
 
171
185
  _usage_limits: UsageLimits | None
172
186
  _stream_response: models.StreamedResponse
173
- _result_schema: _result.ResultSchema[ResultData] | None
174
- _run_ctx: RunContext[AgentDeps]
175
- _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
187
+ _result_schema: _result.ResultSchema[ResultDataT] | None
188
+ _run_ctx: RunContext[AgentDepsT]
189
+ _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
176
190
  _result_tool_name: str | None
177
191
  _on_complete: Callable[[], Awaitable[None]]
178
192
  is_complete: bool = field(default=False, init=False)
@@ -185,7 +199,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
185
199
  [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes.
186
200
  """
187
201
 
188
- async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultData]:
202
+ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
189
203
  """Stream the response as an async iterable.
190
204
 
191
205
  The pydantic validator for structured data will be called in
@@ -269,7 +283,12 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
269
283
  yield combined_validated_text
270
284
 
271
285
  lf_span.set_attribute('combined_text', combined_validated_text)
272
- await self._marked_completed(_messages.ModelResponse.from_text(combined_validated_text))
286
+ await self._marked_completed(
287
+ _messages.ModelResponse(
288
+ parts=[_messages.TextPart(combined_validated_text)],
289
+ model_name=self._stream_response.model_name(),
290
+ )
291
+ )
273
292
 
274
293
  async def stream_structured(
275
294
  self, *, debounce_by: float | None = 0.1
@@ -306,7 +325,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
306
325
  lf_span.set_attribute('structured_response', msg)
307
326
  await self._marked_completed(msg)
308
327
 
309
- async def get_data(self) -> ResultData:
328
+ async def get_data(self) -> ResultDataT:
310
329
  """Stream the whole response, validate and return it."""
311
330
  usage_checking_stream = _get_usage_checking_stream_response(
312
331
  self._stream_response, self._usage_limits, self.usage
@@ -332,7 +351,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
332
351
 
333
352
  async def validate_structured_result(
334
353
  self, message: _messages.ModelResponse, *, allow_partial: bool = False
335
- ) -> ResultData:
354
+ ) -> ResultDataT:
336
355
  """Validate a structured result message."""
337
356
  if self._result_schema is not None and self._result_tool_name is not None:
338
357
  match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
@@ -351,17 +370,17 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
351
370
  text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
352
371
  for validator in self._result_validators:
353
372
  text = await validator.validate(
354
- text, # pyright: ignore[reportArgumentType]
373
+ text,
355
374
  None,
356
375
  self._run_ctx,
357
376
  )
358
- # Since there is no result tool, we can assume that str is compatible with ResultData
359
- return cast(ResultData, text)
377
+ # Since there is no result tool, we can assume that str is compatible with ResultDataT
378
+ return cast(ResultDataT, text)
360
379
 
361
380
  async def _validate_text_result(self, text: str) -> str:
362
381
  for validator in self._result_validators:
363
- text = await validator.validate( # pyright: ignore[reportAssignmentType]
364
- text, # pyright: ignore[reportArgumentType]
382
+ text = await validator.validate(
383
+ text,
365
384
  None,
366
385
  self._run_ctx,
367
386
  )
pydantic_ai/settings.py CHANGED
@@ -12,7 +12,8 @@ if TYPE_CHECKING:
12
12
  class ModelSettings(TypedDict, total=False):
13
13
  """Settings to configure an LLM.
14
14
 
15
- Here we include only settings which apply to multiple models / model providers.
15
+ Here we include only settings which apply to multiple models / model providers,
16
+ though not all of these settings are supported by all models.
16
17
  """
17
18
 
18
19
  max_tokens: int
@@ -24,6 +25,8 @@ class ModelSettings(TypedDict, total=False):
24
25
  * Anthropic
25
26
  * OpenAI
26
27
  * Groq
28
+ * Cohere
29
+ * Mistral
27
30
  """
28
31
 
29
32
  temperature: float
@@ -40,6 +43,8 @@ class ModelSettings(TypedDict, total=False):
40
43
  * Anthropic
41
44
  * OpenAI
42
45
  * Groq
46
+ * Cohere
47
+ * Mistral
43
48
  """
44
49
 
45
50
  top_p: float
@@ -55,6 +60,8 @@ class ModelSettings(TypedDict, total=False):
55
60
  * Anthropic
56
61
  * OpenAI
57
62
  * Groq
63
+ * Cohere
64
+ * Mistral
58
65
  """
59
66
 
60
67
  timeout: float | Timeout
@@ -66,6 +73,16 @@ class ModelSettings(TypedDict, total=False):
66
73
  * Anthropic
67
74
  * OpenAI
68
75
  * Groq
76
+ * Mistral
77
+ """
78
+
79
+ parallel_tool_calls: bool
80
+ """Whether to allow parallel tool calls.
81
+
82
+ Supported by:
83
+ * OpenAI
84
+ * Groq
85
+ * Anthropic
69
86
  """
70
87
 
71
88