pydantic-ai-slim 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -13,7 +13,6 @@ from typing_extensions import assert_never
13
13
  from .. import UnexpectedModelBehavior, _utils, usage
14
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
15
  from ..messages import (
16
- ArgsDict,
17
16
  ModelMessage,
18
17
  ModelRequest,
19
18
  ModelResponse,
@@ -41,6 +40,7 @@ try:
41
40
  from anthropic.types import (
42
41
  Message as AnthropicMessage,
43
42
  MessageParam,
43
+ MetadataParam,
44
44
  RawContentBlockDeltaEvent,
45
45
  RawContentBlockStartEvent,
46
46
  RawContentBlockStopEvent,
@@ -79,6 +79,15 @@ Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/model
79
79
  """
80
80
 
81
81
 
82
+ class AnthropicModelSettings(ModelSettings):
83
+ """Settings used for an Anthropic model request."""
84
+
85
+ anthropic_metadata: MetadataParam
86
+ """An object describing metadata about the request.
87
+
88
+ Contains `user_id`, an external identifier for the user who is associated with the request."""
89
+
90
+
82
91
  @dataclass(init=False)
83
92
  class AnthropicModel(Model):
84
93
  """A model that uses the Anthropic API.
@@ -167,35 +176,33 @@ class AnthropicAgentModel(AgentModel):
167
176
  async def request(
168
177
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
169
178
  ) -> tuple[ModelResponse, usage.Usage]:
170
- response = await self._messages_create(messages, False, model_settings)
179
+ response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
171
180
  return self._process_response(response), _map_usage(response)
172
181
 
173
182
  @asynccontextmanager
174
183
  async def request_stream(
175
184
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
176
185
  ) -> AsyncIterator[StreamedResponse]:
177
- response = await self._messages_create(messages, True, model_settings)
186
+ response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
178
187
  async with response:
179
188
  yield await self._process_streamed_response(response)
180
189
 
181
190
  @overload
182
191
  async def _messages_create(
183
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
192
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
184
193
  ) -> AsyncStream[RawMessageStreamEvent]:
185
194
  pass
186
195
 
187
196
  @overload
188
197
  async def _messages_create(
189
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
198
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
190
199
  ) -> AnthropicMessage:
191
200
  pass
192
201
 
193
202
  async def _messages_create(
194
- self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
203
+ self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
195
204
  ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
196
205
  # standalone function to make it easier to override
197
- model_settings = model_settings or {}
198
-
199
206
  tool_choice: ToolChoiceParam | None
200
207
 
201
208
  if not self.tools:
@@ -222,6 +229,7 @@ class AnthropicAgentModel(AgentModel):
222
229
  temperature=model_settings.get('temperature', NOT_GIVEN),
223
230
  top_p=model_settings.get('top_p', NOT_GIVEN),
224
231
  timeout=model_settings.get('timeout', NOT_GIVEN),
232
+ metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
225
233
  )
226
234
 
227
235
  def _process_response(self, response: AnthropicMessage) -> ModelResponse:
@@ -233,7 +241,7 @@ class AnthropicAgentModel(AgentModel):
233
241
  else:
234
242
  assert isinstance(item, ToolUseBlock), 'unexpected item type'
235
243
  items.append(
236
- ToolCallPart.from_raw_args(
244
+ ToolCallPart(
237
245
  tool_name=item.name,
238
246
  args=cast(dict[str, Any], item.input),
239
247
  tool_call_id=item.id,
@@ -310,7 +318,6 @@ class AnthropicAgentModel(AgentModel):
310
318
 
311
319
 
312
320
  def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
313
- assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
314
321
  return ToolUseBlockParam(
315
322
  id=_guard_tool_call_id(t=t, model_source='Anthropic'),
316
323
  type='tool_use',
@@ -3,9 +3,10 @@ from __future__ import annotations as _annotations
3
3
  from collections.abc import Iterable
4
4
  from dataclasses import dataclass, field
5
5
  from itertools import chain
6
- from typing import Literal, TypeAlias, Union
6
+ from typing import Literal, Union, cast
7
7
 
8
8
  from cohere import TextAssistantMessageContentItem
9
+ from httpx import AsyncClient as AsyncHTTPClient
9
10
  from typing_extensions import assert_never
10
11
 
11
12
  from .. import result
@@ -51,24 +52,30 @@ except ImportError as _import_error:
51
52
  "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
52
53
  ) from _import_error
53
54
 
54
- CohereModelName: TypeAlias = Union[
55
- str,
56
- Literal[
57
- 'c4ai-aya-expanse-32b',
58
- 'c4ai-aya-expanse-8b',
59
- 'command',
60
- 'command-light',
61
- 'command-light-nightly',
62
- 'command-nightly',
63
- 'command-r',
64
- 'command-r-03-2024',
65
- 'command-r-08-2024',
66
- 'command-r-plus',
67
- 'command-r-plus-04-2024',
68
- 'command-r-plus-08-2024',
69
- 'command-r7b-12-2024',
70
- ],
55
+ NamedCohereModels = Literal[
56
+ 'c4ai-aya-expanse-32b',
57
+ 'c4ai-aya-expanse-8b',
58
+ 'command',
59
+ 'command-light',
60
+ 'command-light-nightly',
61
+ 'command-nightly',
62
+ 'command-r',
63
+ 'command-r-03-2024',
64
+ 'command-r-08-2024',
65
+ 'command-r-plus',
66
+ 'command-r-plus-04-2024',
67
+ 'command-r-plus-08-2024',
68
+ 'command-r7b-12-2024',
71
69
  ]
70
+ """Latest / most popular named Cohere models."""
71
+
72
+ CohereModelName = Union[NamedCohereModels, str]
73
+
74
+
75
+ class CohereModelSettings(ModelSettings):
76
+ """Settings used for a Cohere model request."""
77
+
78
+ # This class is a placeholder for any future cohere-specific settings
72
79
 
73
80
 
74
81
  @dataclass(init=False)
@@ -90,6 +97,7 @@ class CohereModel(Model):
90
97
  *,
91
98
  api_key: str | None = None,
92
99
  cohere_client: AsyncClientV2 | None = None,
100
+ http_client: AsyncHTTPClient | None = None,
93
101
  ):
94
102
  """Initialize an Cohere model.
95
103
 
@@ -97,16 +105,18 @@ class CohereModel(Model):
97
105
  model_name: The name of the Cohere model to use. List of model names
98
106
  available [here](https://docs.cohere.com/docs/models#command).
99
107
  api_key: The API key to use for authentication, if not provided, the
100
- `COHERE_API_KEY` environment variable will be used if available.
108
+ `CO_API_KEY` environment variable will be used if available.
101
109
  cohere_client: An existing Cohere async client to use. If provided,
102
- `api_key` must be `None`.
110
+ `api_key` and `http_client` must be `None`.
111
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
103
112
  """
104
113
  self.model_name: CohereModelName = model_name
105
114
  if cohere_client is not None:
115
+ assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
106
116
  assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
107
117
  self.client = cohere_client
108
118
  else:
109
- self.client = AsyncClientV2(api_key=api_key) # type: ignore
119
+ self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
110
120
 
111
121
  async def agent_model(
112
122
  self,
@@ -153,16 +163,15 @@ class CohereAgentModel(AgentModel):
153
163
  async def request(
154
164
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
155
165
  ) -> tuple[ModelResponse, result.Usage]:
156
- response = await self._chat(messages, model_settings)
166
+ response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
157
167
  return self._process_response(response), _map_usage(response)
158
168
 
159
169
  async def _chat(
160
170
  self,
161
171
  messages: list[ModelMessage],
162
- model_settings: ModelSettings | None,
172
+ model_settings: CohereModelSettings,
163
173
  ) -> ChatResponse:
164
174
  cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
165
- model_settings = model_settings or {}
166
175
  return await self.client.chat(
167
176
  model=self.model_name,
168
177
  messages=cohere_messages,
@@ -170,6 +179,9 @@ class CohereAgentModel(AgentModel):
170
179
  max_tokens=model_settings.get('max_tokens', OMIT),
171
180
  temperature=model_settings.get('temperature', OMIT),
172
181
  p=model_settings.get('top_p', OMIT),
182
+ seed=model_settings.get('seed', OMIT),
183
+ presence_penalty=model_settings.get('presence_penalty', OMIT),
184
+ frequency_penalty=model_settings.get('frequency_penalty', OMIT),
173
185
  )
174
186
 
175
187
  def _process_response(self, response: ChatResponse) -> ModelResponse:
@@ -183,7 +195,7 @@ class CohereAgentModel(AgentModel):
183
195
  for c in response.message.tool_calls or []:
184
196
  if c.function and c.function.name and c.function.arguments:
185
197
  parts.append(
186
- ToolCallPart.from_raw_args(
198
+ ToolCallPart(
187
199
  tool_name=c.function.name,
188
200
  args=c.function.arguments,
189
201
  tool_call_id=c.id,
@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
7
7
  from copy import deepcopy
8
8
  from dataclasses import dataclass, field
9
9
  from datetime import datetime
10
- from typing import Annotated, Any, Literal, Protocol, Union
10
+ from typing import Annotated, Any, Literal, Protocol, Union, cast
11
11
  from uuid import uuid4
12
12
 
13
13
  import pydantic
@@ -48,6 +48,12 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#mo
48
48
  """
49
49
 
50
50
 
51
+ class GeminiModelSettings(ModelSettings):
52
+ """Settings used for a Gemini model request."""
53
+
54
+ # This class is a placeholder for any future gemini-specific settings
55
+
56
+
51
57
  @dataclass(init=False)
52
58
  class GeminiModel(Model):
53
59
  """A model that uses Gemini via `generativelanguage.googleapis.com` API.
@@ -171,7 +177,9 @@ class GeminiAgentModel(AgentModel):
171
177
  async def request(
172
178
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
173
179
  ) -> tuple[ModelResponse, usage.Usage]:
174
- async with self._make_request(messages, False, model_settings) as http_response:
180
+ async with self._make_request(
181
+ messages, False, cast(GeminiModelSettings, model_settings or {})
182
+ ) as http_response:
175
183
  response = _gemini_response_ta.validate_json(await http_response.aread())
176
184
  return self._process_response(response), _metadata_as_usage(response)
177
185
 
@@ -179,12 +187,12 @@ class GeminiAgentModel(AgentModel):
179
187
  async def request_stream(
180
188
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
181
189
  ) -> AsyncIterator[StreamedResponse]:
182
- async with self._make_request(messages, True, model_settings) as http_response:
190
+ async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
183
191
  yield await self._process_streamed_response(http_response)
184
192
 
185
193
  @asynccontextmanager
186
194
  async def _make_request(
187
- self, messages: list[ModelMessage], streamed: bool, model_settings: ModelSettings | None
195
+ self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
188
196
  ) -> AsyncIterator[HTTPResponse]:
189
197
  sys_prompt_parts, contents = self._message_to_gemini_content(messages)
190
198
 
@@ -204,6 +212,10 @@ class GeminiAgentModel(AgentModel):
204
212
  generation_config['temperature'] = temperature
205
213
  if (top_p := model_settings.get('top_p')) is not None:
206
214
  generation_config['top_p'] = top_p
215
+ if (presence_penalty := model_settings.get('presence_penalty')) is not None:
216
+ generation_config['presence_penalty'] = presence_penalty
217
+ if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
218
+ generation_config['frequency_penalty'] = frequency_penalty
207
219
  if generation_config:
208
220
  request_data['generation_config'] = generation_config
209
221
 
@@ -222,7 +234,7 @@ class GeminiAgentModel(AgentModel):
222
234
  url,
223
235
  content=request_json,
224
236
  headers=headers,
225
- timeout=(model_settings or {}).get('timeout', USE_CLIENT_DEFAULT),
237
+ timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
226
238
  ) as r:
227
239
  if r.status_code != 200:
228
240
  await r.aread()
@@ -398,6 +410,8 @@ class _GeminiGenerationConfig(TypedDict, total=False):
398
410
  max_output_tokens: int
399
411
  temperature: float
400
412
  top_p: float
413
+ presence_penalty: float
414
+ frequency_penalty: float
401
415
 
402
416
 
403
417
  class _GeminiContent(TypedDict):
@@ -439,7 +453,7 @@ def _process_response_from_parts(
439
453
  items.append(TextPart(content=part['text']))
440
454
  elif 'function_call' in part:
441
455
  items.append(
442
- ToolCallPart.from_raw_args(
456
+ ToolCallPart(
443
457
  tool_name=part['function_call']['name'],
444
458
  args=part['function_call']['args'],
445
459
  )
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
7
  from itertools import chain
8
- from typing import Literal, overload
8
+ from typing import Literal, cast, overload
9
9
 
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
@@ -47,10 +47,7 @@ except ImportError as _import_error:
47
47
 
48
48
  GroqModelName = Literal[
49
49
  'llama-3.3-70b-versatile',
50
- 'llama-3.1-70b-versatile',
51
- 'llama3-groq-70b-8192-tool-use-preview',
52
- 'llama3-groq-8b-8192-tool-use-preview',
53
- 'llama-3.1-70b-specdec',
50
+ 'llama-3.3-70b-specdec',
54
51
  'llama-3.1-8b-instant',
55
52
  'llama-3.2-1b-preview',
56
53
  'llama-3.2-3b-preview',
@@ -60,7 +57,6 @@ GroqModelName = Literal[
60
57
  'llama3-8b-8192',
61
58
  'mixtral-8x7b-32768',
62
59
  'gemma2-9b-it',
63
- 'gemma-7b-it',
64
60
  ]
65
61
  """Named Groq models.
66
62
 
@@ -68,6 +64,12 @@ See [the Groq docs](https://console.groq.com/docs/models) for a full list.
68
64
  """
69
65
 
70
66
 
67
+ class GroqModelSettings(ModelSettings):
68
+ """Settings used for a Groq model request."""
69
+
70
+ # This class is a placeholder for any future groq-specific settings
71
+
72
+
71
73
  @dataclass(init=False)
72
74
  class GroqModel(Model):
73
75
  """A model that uses the Groq API.
@@ -155,31 +157,31 @@ class GroqAgentModel(AgentModel):
155
157
  async def request(
156
158
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
157
159
  ) -> tuple[ModelResponse, usage.Usage]:
158
- response = await self._completions_create(messages, False, model_settings)
160
+ response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
159
161
  return self._process_response(response), _map_usage(response)
160
162
 
161
163
  @asynccontextmanager
162
164
  async def request_stream(
163
165
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
164
166
  ) -> AsyncIterator[StreamedResponse]:
165
- response = await self._completions_create(messages, True, model_settings)
167
+ response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
166
168
  async with response:
167
169
  yield await self._process_streamed_response(response)
168
170
 
169
171
  @overload
170
172
  async def _completions_create(
171
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
173
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
172
174
  ) -> AsyncStream[ChatCompletionChunk]:
173
175
  pass
174
176
 
175
177
  @overload
176
178
  async def _completions_create(
177
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
179
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
178
180
  ) -> chat.ChatCompletion:
179
181
  pass
180
182
 
181
183
  async def _completions_create(
182
- self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
184
+ self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
183
185
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
184
186
  # standalone function to make it easier to override
185
187
  if not self.tools:
@@ -191,13 +193,11 @@ class GroqAgentModel(AgentModel):
191
193
 
192
194
  groq_messages = list(chain(*(self._map_message(m) for m in messages)))
193
195
 
194
- model_settings = model_settings or {}
195
-
196
196
  return await self.client.chat.completions.create(
197
197
  model=str(self.model_name),
198
198
  messages=groq_messages,
199
199
  n=1,
200
- parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
200
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
201
201
  tools=self.tools or NOT_GIVEN,
202
202
  tool_choice=tool_choice or NOT_GIVEN,
203
203
  stream=stream,
@@ -205,6 +205,10 @@ class GroqAgentModel(AgentModel):
205
205
  temperature=model_settings.get('temperature', NOT_GIVEN),
206
206
  top_p=model_settings.get('top_p', NOT_GIVEN),
207
207
  timeout=model_settings.get('timeout', NOT_GIVEN),
208
+ seed=model_settings.get('seed', NOT_GIVEN),
209
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
210
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
211
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
208
212
  )
209
213
 
210
214
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
@@ -216,9 +220,7 @@ class GroqAgentModel(AgentModel):
216
220
  items.append(TextPart(content=choice.message.content))
217
221
  if choice.message.tool_calls is not None:
218
222
  for c in choice.message.tool_calls:
219
- items.append(
220
- ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
221
- )
223
+ items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
222
224
  return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
223
225
 
224
226
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
@@ -6,7 +6,7 @@ from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
8
  from itertools import chain
9
- from typing import Any, Callable, Literal, Union
9
+ from typing import Any, Callable, Literal, Union, cast
10
10
 
11
11
  import pydantic_core
12
12
  from httpx import AsyncClient as AsyncHTTPClient, Timeout
@@ -15,7 +15,6 @@ from typing_extensions import assert_never
15
15
  from .. import UnexpectedModelBehavior, _utils
16
16
  from .._utils import now_utc as _now_utc
17
17
  from ..messages import (
18
- ArgsJson,
19
18
  ModelMessage,
20
19
  ModelRequest,
21
20
  ModelResponse,
@@ -85,6 +84,12 @@ Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_o
85
84
  """
86
85
 
87
86
 
87
+ class MistralModelSettings(ModelSettings):
88
+ """Settings used for a Mistral model request."""
89
+
90
+ # This class is a placeholder for any future mistral-specific settings
91
+
92
+
88
93
  @dataclass(init=False)
89
94
  class MistralModel(Model):
90
95
  """A model that uses Mistral.
@@ -159,7 +164,7 @@ class MistralAgentModel(AgentModel):
159
164
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
160
165
  ) -> tuple[ModelResponse, Usage]:
161
166
  """Make a non-streaming request to the model from Pydantic AI call."""
162
- response = await self._completions_create(messages, model_settings)
167
+ response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {}))
163
168
  return self._process_response(response), _map_usage(response)
164
169
 
165
170
  @asynccontextmanager
@@ -167,15 +172,14 @@ class MistralAgentModel(AgentModel):
167
172
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
173
  ) -> AsyncIterator[StreamedResponse]:
169
174
  """Make a streaming request to the model from Pydantic AI call."""
170
- response = await self._stream_completions_create(messages, model_settings)
175
+ response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {}))
171
176
  async with response:
172
177
  yield await self._process_streamed_response(self.result_tools, response)
173
178
 
174
179
  async def _completions_create(
175
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
180
+ self, messages: list[ModelMessage], model_settings: MistralModelSettings
176
181
  ) -> MistralChatCompletionResponse:
177
182
  """Make a non-streaming request to the model."""
178
- model_settings = model_settings or {}
179
183
  response = await self.client.chat.complete_async(
180
184
  model=str(self.model_name),
181
185
  messages=list(chain(*(self._map_message(m) for m in messages))),
@@ -187,6 +191,7 @@ class MistralAgentModel(AgentModel):
187
191
  temperature=model_settings.get('temperature', UNSET),
188
192
  top_p=model_settings.get('top_p', 1),
189
193
  timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
194
+ random_seed=model_settings.get('seed', UNSET),
190
195
  )
191
196
  assert response, 'A unexpected empty response from Mistral.'
192
197
  return response
@@ -194,12 +199,11 @@ class MistralAgentModel(AgentModel):
194
199
  async def _stream_completions_create(
195
200
  self,
196
201
  messages: list[ModelMessage],
197
- model_settings: ModelSettings | None,
202
+ model_settings: MistralModelSettings,
198
203
  ) -> MistralEventStreamAsync[MistralCompletionEvent]:
199
204
  """Create a streaming completion request to the Mistral model."""
200
205
  response: MistralEventStreamAsync[MistralCompletionEvent] | None
201
206
  mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
202
- model_settings = model_settings or {}
203
207
 
204
208
  if self.result_tools and self.function_tools or self.function_tools:
205
209
  # Function Calling
@@ -213,6 +217,8 @@ class MistralAgentModel(AgentModel):
213
217
  top_p=model_settings.get('top_p', 1),
214
218
  max_tokens=model_settings.get('max_tokens', UNSET),
215
219
  timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
220
+ presence_penalty=model_settings.get('presence_penalty'),
221
+ frequency_penalty=model_settings.get('frequency_penalty'),
216
222
  )
217
223
 
218
224
  elif self.result_tools:
@@ -317,18 +323,11 @@ class MistralAgentModel(AgentModel):
317
323
  @staticmethod
318
324
  def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
319
325
  """Maps a pydantic-ai ToolCall to a MistralToolCall."""
320
- if isinstance(t.args, ArgsJson):
321
- return MistralToolCall(
322
- id=t.tool_call_id,
323
- type='function',
324
- function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_json),
325
- )
326
- else:
327
- return MistralToolCall(
328
- id=t.tool_call_id,
329
- type='function',
330
- function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_dict),
331
- )
326
+ return MistralToolCall(
327
+ id=t.tool_call_id,
328
+ type='function',
329
+ function=MistralFunctionCall(name=t.tool_name, arguments=t.args),
330
+ )
332
331
 
333
332
  def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
334
333
  """Get a message with an example of the expected output format."""
@@ -511,7 +510,7 @@ class MistralStreamedResponse(StreamedResponse):
511
510
  continue
512
511
 
513
512
  # The following part_id will be thrown away
514
- return ToolCallPart.from_raw_args(tool_name=result_tool.name, args=output_json)
513
+ return ToolCallPart(tool_name=result_tool.name, args=output_json)
515
514
 
516
515
  @staticmethod
517
516
  def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
@@ -569,7 +568,7 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa
569
568
  tool_call_id = tool_call.id or None
570
569
  func_call = tool_call.function
571
570
 
572
- return ToolCallPart.from_raw_args(func_call.name, func_call.arguments, tool_call_id)
571
+ return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
573
572
 
574
573
 
575
574
  def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
@@ -600,7 +599,7 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None
600
599
  elif isinstance(content, str):
601
600
  result = content
602
601
 
603
- # Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and reponses`)
602
+ # Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
604
603
  if result and len(result) == 0:
605
604
  result = None
606
605
 
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
7
7
  from itertools import chain
8
- from typing import Literal, Union, overload
8
+ from typing import Literal, Union, cast, overload
9
9
 
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
@@ -48,12 +48,18 @@ except ImportError as _import_error:
48
48
  OpenAIModelName = Union[ChatModel, str]
49
49
  """
50
50
  Using this more broad type for the model name instead of the ChatModel definition
51
- allows this model to be used more easily with other model types (ie, Ollama)
51
+ allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
52
52
  """
53
53
 
54
54
  OpenAISystemPromptRole = Literal['system', 'developer', 'user']
55
55
 
56
56
 
57
+ class OpenAIModelSettings(ModelSettings):
58
+ """Settings used for an OpenAI model request."""
59
+
60
+ # This class is a placeholder for any future openai-specific settings
61
+
62
+
57
63
  @dataclass(init=False)
58
64
  class OpenAIModel(Model):
59
65
  """A model that uses the OpenAI API.
@@ -153,31 +159,31 @@ class OpenAIAgentModel(AgentModel):
153
159
  async def request(
154
160
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
155
161
  ) -> tuple[ModelResponse, usage.Usage]:
156
- response = await self._completions_create(messages, False, model_settings)
162
+ response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {}))
157
163
  return self._process_response(response), _map_usage(response)
158
164
 
159
165
  @asynccontextmanager
160
166
  async def request_stream(
161
167
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
162
168
  ) -> AsyncIterator[StreamedResponse]:
163
- response = await self._completions_create(messages, True, model_settings)
169
+ response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {}))
164
170
  async with response:
165
171
  yield await self._process_streamed_response(response)
166
172
 
167
173
  @overload
168
174
  async def _completions_create(
169
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
175
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings
170
176
  ) -> AsyncStream[ChatCompletionChunk]:
171
177
  pass
172
178
 
173
179
  @overload
174
180
  async def _completions_create(
175
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
181
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings
176
182
  ) -> chat.ChatCompletion:
177
183
  pass
178
184
 
179
185
  async def _completions_create(
180
- self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
186
+ self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
181
187
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
182
188
  # standalone function to make it easier to override
183
189
  if not self.tools:
@@ -189,13 +195,11 @@ class OpenAIAgentModel(AgentModel):
189
195
 
190
196
  openai_messages = list(chain(*(self._map_message(m) for m in messages)))
191
197
 
192
- model_settings = model_settings or {}
193
-
194
198
  return await self.client.chat.completions.create(
195
199
  model=self.model_name,
196
200
  messages=openai_messages,
197
201
  n=1,
198
- parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
202
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
199
203
  tools=self.tools or NOT_GIVEN,
200
204
  tool_choice=tool_choice or NOT_GIVEN,
201
205
  stream=stream,
@@ -204,6 +208,10 @@ class OpenAIAgentModel(AgentModel):
204
208
  temperature=model_settings.get('temperature', NOT_GIVEN),
205
209
  top_p=model_settings.get('top_p', NOT_GIVEN),
206
210
  timeout=model_settings.get('timeout', NOT_GIVEN),
211
+ seed=model_settings.get('seed', NOT_GIVEN),
212
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
213
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
214
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
207
215
  )
208
216
 
209
217
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
@@ -215,7 +223,7 @@ class OpenAIAgentModel(AgentModel):
215
223
  items.append(TextPart(choice.message.content))
216
224
  if choice.message.tool_calls is not None:
217
225
  for c in choice.message.tool_calls:
218
- items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
226
+ items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
219
227
  return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
220
228
 
221
229
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse: