pydantic-ai-slim 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/messages.py CHANGED
@@ -6,7 +6,6 @@ from typing import Annotated, Any, Literal, Union, cast, overload
6
6
 
7
7
  import pydantic
8
8
  import pydantic_core
9
- from typing_extensions import Self, assert_never
10
9
 
11
10
  from ._utils import now_utc as _now_utc
12
11
  from .exceptions import UnexpectedModelBehavior
@@ -168,22 +167,6 @@ class TextPart:
168
167
  return bool(self.content)
169
168
 
170
169
 
171
- @dataclass
172
- class ArgsJson:
173
- """Tool arguments as a JSON string."""
174
-
175
- args_json: str
176
- """A JSON string of arguments."""
177
-
178
-
179
- @dataclass
180
- class ArgsDict:
181
- """Tool arguments as a Python dictionary."""
182
-
183
- args_dict: dict[str, Any]
184
- """A python dictionary of arguments."""
185
-
186
-
187
170
  @dataclass
188
171
  class ToolCallPart:
189
172
  """A tool call from a model."""
@@ -191,10 +174,10 @@ class ToolCallPart:
191
174
  tool_name: str
192
175
  """The name of the tool to call."""
193
176
 
194
- args: ArgsJson | ArgsDict
177
+ args: str | dict[str, Any]
195
178
  """The arguments to pass to the tool.
196
179
 
197
- Either as JSON or a Python dictionary depending on how data was returned.
180
+ This is stored either as a JSON string or a Python dictionary depending on how data was received.
198
181
  """
199
182
 
200
183
  tool_call_id: str | None = None
@@ -203,24 +186,14 @@ class ToolCallPart:
203
186
  part_kind: Literal['tool-call'] = 'tool-call'
204
187
  """Part type identifier, this is available on all parts as a discriminator."""
205
188
 
206
- @classmethod
207
- def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
208
- """Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
209
- if isinstance(args, str):
210
- return cls(tool_name, ArgsJson(args), tool_call_id)
211
- elif isinstance(args, dict):
212
- return cls(tool_name, ArgsDict(args), tool_call_id)
213
- else:
214
- assert_never(args)
215
-
216
189
  def args_as_dict(self) -> dict[str, Any]:
217
190
  """Return the arguments as a Python dictionary.
218
191
 
219
192
  This is just for convenience with models that require dicts as input.
220
193
  """
221
- if isinstance(self.args, ArgsDict):
222
- return self.args.args_dict
223
- args = pydantic_core.from_json(self.args.args_json)
194
+ if isinstance(self.args, dict):
195
+ return self.args
196
+ args = pydantic_core.from_json(self.args)
224
197
  assert isinstance(args, dict), 'args should be a dict'
225
198
  return cast(dict[str, Any], args)
226
199
 
@@ -229,16 +202,18 @@ class ToolCallPart:
229
202
 
230
203
  This is just for convenience with models that require JSON strings as input.
231
204
  """
232
- if isinstance(self.args, ArgsJson):
233
- return self.args.args_json
234
- return pydantic_core.to_json(self.args.args_dict).decode()
205
+ if isinstance(self.args, str):
206
+ return self.args
207
+ return pydantic_core.to_json(self.args).decode()
235
208
 
236
209
  def has_content(self) -> bool:
237
210
  """Return `True` if the arguments contain any data."""
238
- if isinstance(self.args, ArgsDict):
239
- return any(self.args.args_dict.values())
211
+ if isinstance(self.args, dict):
212
+ # TODO: This should probably return True if you have the value False, or 0, etc.
213
+ # It makes sense to me to ignore empty strings, but not sure about empty lists or dicts
214
+ return any(self.args.values())
240
215
  else:
241
- return bool(self.args.args_json)
216
+ return bool(self.args)
242
217
 
243
218
 
244
219
  ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
@@ -331,7 +306,7 @@ class ToolCallPartDelta:
331
306
  if self.tool_name_delta is None or self.args_delta is None:
332
307
  return None
333
308
 
334
- return ToolCallPart.from_raw_args(
309
+ return ToolCallPart(
335
310
  self.tool_name_delta,
336
311
  self.args_delta,
337
312
  self.tool_call_id,
@@ -396,7 +371,7 @@ class ToolCallPartDelta:
396
371
 
397
372
  # If we now have enough data to create a full ToolCallPart, do so
398
373
  if delta.tool_name_delta is not None and delta.args_delta is not None:
399
- return ToolCallPart.from_raw_args(
374
+ return ToolCallPart(
400
375
  delta.tool_name_delta,
401
376
  delta.args_delta,
402
377
  delta.tool_call_id,
@@ -412,15 +387,15 @@ class ToolCallPartDelta:
412
387
  part = replace(part, tool_name=tool_name)
413
388
 
414
389
  if isinstance(self.args_delta, str):
415
- if not isinstance(part.args, ArgsJson):
390
+ if not isinstance(part.args, str):
416
391
  raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
417
- updated_json = part.args.args_json + self.args_delta
418
- part = replace(part, args=ArgsJson(updated_json))
392
+ updated_json = part.args + self.args_delta
393
+ part = replace(part, args=updated_json)
419
394
  elif isinstance(self.args_delta, dict):
420
- if not isinstance(part.args, ArgsDict):
395
+ if not isinstance(part.args, dict):
421
396
  raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
422
- updated_dict = {**(part.args.args_dict or {}), **self.args_delta}
423
- part = replace(part, args=ArgsDict(updated_dict))
397
+ updated_dict = {**(part.args or {}), **self.args_delta}
398
+ part = replace(part, args=updated_dict)
424
399
 
425
400
  if self.tool_call_id:
426
401
  # Replace the tool_call_id entirely if given
@@ -12,9 +12,10 @@ from contextlib import asynccontextmanager, contextmanager
12
12
  from dataclasses import dataclass, field
13
13
  from datetime import datetime
14
14
  from functools import cache
15
- from typing import TYPE_CHECKING, Literal
15
+ from typing import TYPE_CHECKING
16
16
 
17
17
  import httpx
18
+ from typing_extensions import Literal
18
19
 
19
20
  from .._parts_manager import ModelResponsePartsManager
20
21
  from ..exceptions import UserError
@@ -27,58 +28,6 @@ if TYPE_CHECKING:
27
28
 
28
29
 
29
30
  KnownModelName = Literal[
30
- 'openai:gpt-4o',
31
- 'openai:gpt-4o-mini',
32
- 'openai:gpt-4-turbo',
33
- 'openai:gpt-4',
34
- 'openai:o1-preview',
35
- 'openai:o1-mini',
36
- 'openai:o1',
37
- 'openai:gpt-3.5-turbo',
38
- 'groq:llama-3.3-70b-versatile',
39
- 'groq:llama-3.1-70b-versatile',
40
- 'groq:llama3-groq-70b-8192-tool-use-preview',
41
- 'groq:llama3-groq-8b-8192-tool-use-preview',
42
- 'groq:llama-3.1-70b-specdec',
43
- 'groq:llama-3.1-8b-instant',
44
- 'groq:llama-3.2-1b-preview',
45
- 'groq:llama-3.2-3b-preview',
46
- 'groq:llama-3.2-11b-vision-preview',
47
- 'groq:llama-3.2-90b-vision-preview',
48
- 'groq:llama3-70b-8192',
49
- 'groq:llama3-8b-8192',
50
- 'groq:mixtral-8x7b-32768',
51
- 'groq:gemma2-9b-it',
52
- 'groq:gemma-7b-it',
53
- 'google-gla:gemini-1.5-flash',
54
- 'google-gla:gemini-1.5-pro',
55
- 'google-gla:gemini-2.0-flash-exp',
56
- 'google-vertex:gemini-1.5-flash',
57
- 'google-vertex:gemini-1.5-pro',
58
- 'google-vertex:gemini-2.0-flash-exp',
59
- 'mistral:mistral-small-latest',
60
- 'mistral:mistral-large-latest',
61
- 'mistral:codestral-latest',
62
- 'mistral:mistral-moderation-latest',
63
- 'ollama:codellama',
64
- 'ollama:deepseek-r1',
65
- 'ollama:gemma',
66
- 'ollama:gemma2',
67
- 'ollama:llama3',
68
- 'ollama:llama3.1',
69
- 'ollama:llama3.2',
70
- 'ollama:llama3.2-vision',
71
- 'ollama:llama3.3',
72
- 'ollama:mistral',
73
- 'ollama:mistral-nemo',
74
- 'ollama:mixtral',
75
- 'ollama:phi3',
76
- 'ollama:phi4',
77
- 'ollama:qwq',
78
- 'ollama:qwen',
79
- 'ollama:qwen2',
80
- 'ollama:qwen2.5',
81
- 'ollama:starcoder2',
82
31
  'anthropic:claude-3-5-haiku-latest',
83
32
  'anthropic:claude-3-5-sonnet-latest',
84
33
  'anthropic:claude-3-opus-latest',
@@ -98,6 +47,108 @@ KnownModelName = Literal[
98
47
  'cohere:command-r-plus-04-2024',
99
48
  'cohere:command-r-plus-08-2024',
100
49
  'cohere:command-r7b-12-2024',
50
+ 'google-gla:gemini-1.0-pro',
51
+ 'google-gla:gemini-1.5-flash',
52
+ 'google-gla:gemini-1.5-flash-8b',
53
+ 'google-gla:gemini-1.5-pro',
54
+ 'google-gla:gemini-2.0-flash-exp',
55
+ 'google-gla:gemini-2.0-flash-thinking-exp-01-21',
56
+ 'google-gla:gemini-exp-1206',
57
+ 'google-vertex:gemini-1.0-pro',
58
+ 'google-vertex:gemini-1.5-flash',
59
+ 'google-vertex:gemini-1.5-flash-8b',
60
+ 'google-vertex:gemini-1.5-pro',
61
+ 'google-vertex:gemini-2.0-flash-exp',
62
+ 'google-vertex:gemini-2.0-flash-thinking-exp-01-21',
63
+ 'google-vertex:gemini-exp-1206',
64
+ 'gpt-3.5-turbo',
65
+ 'gpt-3.5-turbo-0125',
66
+ 'gpt-3.5-turbo-0301',
67
+ 'gpt-3.5-turbo-0613',
68
+ 'gpt-3.5-turbo-1106',
69
+ 'gpt-3.5-turbo-16k',
70
+ 'gpt-3.5-turbo-16k-0613',
71
+ 'gpt-4',
72
+ 'gpt-4-0125-preview',
73
+ 'gpt-4-0314',
74
+ 'gpt-4-0613',
75
+ 'gpt-4-1106-preview',
76
+ 'gpt-4-32k',
77
+ 'gpt-4-32k-0314',
78
+ 'gpt-4-32k-0613',
79
+ 'gpt-4-turbo',
80
+ 'gpt-4-turbo-2024-04-09',
81
+ 'gpt-4-turbo-preview',
82
+ 'gpt-4-vision-preview',
83
+ 'gpt-4o',
84
+ 'gpt-4o-2024-05-13',
85
+ 'gpt-4o-2024-08-06',
86
+ 'gpt-4o-2024-11-20',
87
+ 'gpt-4o-audio-preview',
88
+ 'gpt-4o-audio-preview-2024-10-01',
89
+ 'gpt-4o-audio-preview-2024-12-17',
90
+ 'gpt-4o-mini',
91
+ 'gpt-4o-mini-2024-07-18',
92
+ 'gpt-4o-mini-audio-preview',
93
+ 'gpt-4o-mini-audio-preview-2024-12-17',
94
+ 'groq:gemma2-9b-it',
95
+ 'groq:llama-3.1-8b-instant',
96
+ 'groq:llama-3.2-11b-vision-preview',
97
+ 'groq:llama-3.2-1b-preview',
98
+ 'groq:llama-3.2-3b-preview',
99
+ 'groq:llama-3.2-90b-vision-preview',
100
+ 'groq:llama-3.3-70b-specdec',
101
+ 'groq:llama-3.3-70b-versatile',
102
+ 'groq:llama3-70b-8192',
103
+ 'groq:llama3-8b-8192',
104
+ 'groq:mixtral-8x7b-32768',
105
+ 'mistral:codestral-latest',
106
+ 'mistral:mistral-large-latest',
107
+ 'mistral:mistral-moderation-latest',
108
+ 'mistral:mistral-small-latest',
109
+ 'o1',
110
+ 'o1-2024-12-17',
111
+ 'o1-mini',
112
+ 'o1-mini-2024-09-12',
113
+ 'o1-preview',
114
+ 'o1-preview-2024-09-12',
115
+ 'openai:chatgpt-4o-latest',
116
+ 'openai:gpt-3.5-turbo',
117
+ 'openai:gpt-3.5-turbo-0125',
118
+ 'openai:gpt-3.5-turbo-0301',
119
+ 'openai:gpt-3.5-turbo-0613',
120
+ 'openai:gpt-3.5-turbo-1106',
121
+ 'openai:gpt-3.5-turbo-16k',
122
+ 'openai:gpt-3.5-turbo-16k-0613',
123
+ 'openai:gpt-4',
124
+ 'openai:gpt-4-0125-preview',
125
+ 'openai:gpt-4-0314',
126
+ 'openai:gpt-4-0613',
127
+ 'openai:gpt-4-1106-preview',
128
+ 'openai:gpt-4-32k',
129
+ 'openai:gpt-4-32k-0314',
130
+ 'openai:gpt-4-32k-0613',
131
+ 'openai:gpt-4-turbo',
132
+ 'openai:gpt-4-turbo-2024-04-09',
133
+ 'openai:gpt-4-turbo-preview',
134
+ 'openai:gpt-4-vision-preview',
135
+ 'openai:gpt-4o',
136
+ 'openai:gpt-4o-2024-05-13',
137
+ 'openai:gpt-4o-2024-08-06',
138
+ 'openai:gpt-4o-2024-11-20',
139
+ 'openai:gpt-4o-audio-preview',
140
+ 'openai:gpt-4o-audio-preview-2024-10-01',
141
+ 'openai:gpt-4o-audio-preview-2024-12-17',
142
+ 'openai:gpt-4o-mini',
143
+ 'openai:gpt-4o-mini-2024-07-18',
144
+ 'openai:gpt-4o-mini-audio-preview',
145
+ 'openai:gpt-4o-mini-audio-preview-2024-12-17',
146
+ 'openai:o1',
147
+ 'openai:o1-2024-12-17',
148
+ 'openai:o1-mini',
149
+ 'openai:o1-mini-2024-09-12',
150
+ 'openai:o1-preview',
151
+ 'openai:o1-preview-2024-09-12',
101
152
  'test',
102
153
  ]
103
154
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -291,10 +342,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
291
342
  from .mistral import MistralModel
292
343
 
293
344
  return MistralModel(model[8:])
294
- elif model.startswith('ollama:'):
295
- from .ollama import OllamaModel
296
-
297
- return OllamaModel(model[7:])
298
345
  elif model.startswith('anthropic'):
299
346
  from .anthropic import AnthropicModel
300
347
 
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
13
13
  from .. import UnexpectedModelBehavior, _utils, usage
14
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
15
  from ..messages import (
16
- ArgsDict,
17
16
  ModelMessage,
18
17
  ModelRequest,
19
18
  ModelResponse,
@@ -41,6 +40,7 @@ try:
41
40
  from anthropic.types import (
42
41
  Message as AnthropicMessage,
43
42
  MessageParam,
43
+ MetadataParam,
44
44
  RawContentBlockDeltaEvent,
45
45
  RawContentBlockStartEvent,
46
46
  RawContentBlockStopEvent,
@@ -79,6 +79,15 @@ Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/model
79
79
  """
80
80
 
81
81
 
82
+ class AnthropicModelSettings(ModelSettings):
83
+ """Settings used for an Anthropic model request."""
84
+
85
+ anthropic_metadata: MetadataParam
86
+ """An object describing metadata about the request.
87
+
88
+ Contains `user_id`, an external identifier for the user who is associated with the request."""
89
+
90
+
82
91
  @dataclass(init=False)
83
92
  class AnthropicModel(Model):
84
93
  """A model that uses the Anthropic API.
@@ -167,35 +176,33 @@ class AnthropicAgentModel(AgentModel):
167
176
  async def request(
168
177
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
169
178
  ) -> tuple[ModelResponse, usage.Usage]:
170
- response = await self._messages_create(messages, False, model_settings)
179
+ response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
171
180
  return self._process_response(response), _map_usage(response)
172
181
 
173
182
  @asynccontextmanager
174
183
  async def request_stream(
175
184
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
176
185
  ) -> AsyncIterator[StreamedResponse]:
177
- response = await self._messages_create(messages, True, model_settings)
186
+ response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
178
187
  async with response:
179
188
  yield await self._process_streamed_response(response)
180
189
 
181
190
  @overload
182
191
  async def _messages_create(
183
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
192
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
184
193
  ) -> AsyncStream[RawMessageStreamEvent]:
185
194
  pass
186
195
 
187
196
  @overload
188
197
  async def _messages_create(
189
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
198
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
190
199
  ) -> AnthropicMessage:
191
200
  pass
192
201
 
193
202
  async def _messages_create(
194
- self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
203
+ self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
195
204
  ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
196
205
  # standalone function to make it easier to override
197
- model_settings = model_settings or {}
198
-
199
206
  tool_choice: ToolChoiceParam | None
200
207
 
201
208
  if not self.tools:
@@ -222,6 +229,7 @@ class AnthropicAgentModel(AgentModel):
222
229
  temperature=model_settings.get('temperature', NOT_GIVEN),
223
230
  top_p=model_settings.get('top_p', NOT_GIVEN),
224
231
  timeout=model_settings.get('timeout', NOT_GIVEN),
232
+ metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
225
233
  )
226
234
 
227
235
  def _process_response(self, response: AnthropicMessage) -> ModelResponse:
@@ -233,7 +241,7 @@ class AnthropicAgentModel(AgentModel):
233
241
  else:
234
242
  assert isinstance(item, ToolUseBlock), 'unexpected item type'
235
243
  items.append(
236
- ToolCallPart.from_raw_args(
244
+ ToolCallPart(
237
245
  tool_name=item.name,
238
246
  args=cast(dict[str, Any], item.input),
239
247
  tool_call_id=item.id,
@@ -310,7 +318,6 @@ class AnthropicAgentModel(AgentModel):
310
318
 
311
319
 
312
320
  def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
313
- assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
314
321
  return ToolUseBlockParam(
315
322
  id=_guard_tool_call_id(t=t, model_source='Anthropic'),
316
323
  type='tool_use',
@@ -3,9 +3,10 @@ from __future__ import annotations as _annotations
3
3
  from collections.abc import Iterable
4
4
  from dataclasses import dataclass, field
5
5
  from itertools import chain
6
- from typing import Literal, TypeAlias, Union
6
+ from typing import Literal, Union, cast
7
7
 
8
8
  from cohere import TextAssistantMessageContentItem
9
+ from httpx import AsyncClient as AsyncHTTPClient
9
10
  from typing_extensions import assert_never
10
11
 
11
12
  from .. import result
@@ -51,24 +52,30 @@ except ImportError as _import_error:
51
52
  "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
52
53
  ) from _import_error
53
54
 
54
- CohereModelName: TypeAlias = Union[
55
- str,
56
- Literal[
57
- 'c4ai-aya-expanse-32b',
58
- 'c4ai-aya-expanse-8b',
59
- 'command',
60
- 'command-light',
61
- 'command-light-nightly',
62
- 'command-nightly',
63
- 'command-r',
64
- 'command-r-03-2024',
65
- 'command-r-08-2024',
66
- 'command-r-plus',
67
- 'command-r-plus-04-2024',
68
- 'command-r-plus-08-2024',
69
- 'command-r7b-12-2024',
70
- ],
55
+ NamedCohereModels = Literal[
56
+ 'c4ai-aya-expanse-32b',
57
+ 'c4ai-aya-expanse-8b',
58
+ 'command',
59
+ 'command-light',
60
+ 'command-light-nightly',
61
+ 'command-nightly',
62
+ 'command-r',
63
+ 'command-r-03-2024',
64
+ 'command-r-08-2024',
65
+ 'command-r-plus',
66
+ 'command-r-plus-04-2024',
67
+ 'command-r-plus-08-2024',
68
+ 'command-r7b-12-2024',
71
69
  ]
70
+ """Latest / most popular named Cohere models."""
71
+
72
+ CohereModelName = Union[NamedCohereModels, str]
73
+
74
+
75
+ class CohereModelSettings(ModelSettings):
76
+ """Settings used for a Cohere model request."""
77
+
78
+ # This class is a placeholder for any future cohere-specific settings
72
79
 
73
80
 
74
81
  @dataclass(init=False)
@@ -90,6 +97,7 @@ class CohereModel(Model):
90
97
  *,
91
98
  api_key: str | None = None,
92
99
  cohere_client: AsyncClientV2 | None = None,
100
+ http_client: AsyncHTTPClient | None = None,
93
101
  ):
94
102
  """Initialize an Cohere model.
95
103
 
@@ -97,16 +105,18 @@ class CohereModel(Model):
97
105
  model_name: The name of the Cohere model to use. List of model names
98
106
  available [here](https://docs.cohere.com/docs/models#command).
99
107
  api_key: The API key to use for authentication, if not provided, the
100
- `COHERE_API_KEY` environment variable will be used if available.
108
+ `CO_API_KEY` environment variable will be used if available.
101
109
  cohere_client: An existing Cohere async client to use. If provided,
102
- `api_key` must be `None`.
110
+ `api_key` and `http_client` must be `None`.
111
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
103
112
  """
104
113
  self.model_name: CohereModelName = model_name
105
114
  if cohere_client is not None:
115
+ assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
106
116
  assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
107
117
  self.client = cohere_client
108
118
  else:
109
- self.client = AsyncClientV2(api_key=api_key) # type: ignore
119
+ self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
110
120
 
111
121
  async def agent_model(
112
122
  self,
@@ -153,16 +163,15 @@ class CohereAgentModel(AgentModel):
153
163
  async def request(
154
164
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
155
165
  ) -> tuple[ModelResponse, result.Usage]:
156
- response = await self._chat(messages, model_settings)
166
+ response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
157
167
  return self._process_response(response), _map_usage(response)
158
168
 
159
169
  async def _chat(
160
170
  self,
161
171
  messages: list[ModelMessage],
162
- model_settings: ModelSettings | None,
172
+ model_settings: CohereModelSettings,
163
173
  ) -> ChatResponse:
164
174
  cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
165
- model_settings = model_settings or {}
166
175
  return await self.client.chat(
167
176
  model=self.model_name,
168
177
  messages=cohere_messages,
@@ -170,6 +179,9 @@ class CohereAgentModel(AgentModel):
170
179
  max_tokens=model_settings.get('max_tokens', OMIT),
171
180
  temperature=model_settings.get('temperature', OMIT),
172
181
  p=model_settings.get('top_p', OMIT),
182
+ seed=model_settings.get('seed', OMIT),
183
+ presence_penalty=model_settings.get('presence_penalty', OMIT),
184
+ frequency_penalty=model_settings.get('frequency_penalty', OMIT),
173
185
  )
174
186
 
175
187
  def _process_response(self, response: ChatResponse) -> ModelResponse:
@@ -183,7 +195,7 @@ class CohereAgentModel(AgentModel):
183
195
  for c in response.message.tool_calls or []:
184
196
  if c.function and c.function.name and c.function.arguments:
185
197
  parts.append(
186
- ToolCallPart.from_raw_args(
198
+ ToolCallPart(
187
199
  tool_name=c.function.name,
188
200
  args=c.function.arguments,
189
201
  tool_call_id=c.id,
@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
7
7
  from copy import deepcopy
8
8
  from dataclasses import dataclass, field
9
9
  from datetime import datetime
10
- from typing import Annotated, Any, Literal, Protocol, Union
10
+ from typing import Annotated, Any, Literal, Protocol, Union, cast
11
11
  from uuid import uuid4
12
12
 
13
13
  import pydantic
@@ -40,7 +40,13 @@ from . import (
40
40
  )
41
41
 
42
42
  GeminiModelName = Literal[
43
- 'gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro', 'gemini-2.0-flash-exp'
43
+ 'gemini-1.5-flash',
44
+ 'gemini-1.5-flash-8b',
45
+ 'gemini-1.5-pro',
46
+ 'gemini-1.0-pro',
47
+ 'gemini-2.0-flash-exp',
48
+ 'gemini-2.0-flash-thinking-exp-01-21',
49
+ 'gemini-exp-1206',
44
50
  ]
45
51
  """Named Gemini models.
46
52
 
@@ -48,6 +54,12 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#mo
48
54
  """
49
55
 
50
56
 
57
+ class GeminiModelSettings(ModelSettings):
58
+ """Settings used for a Gemini model request."""
59
+
60
+ # This class is a placeholder for any future gemini-specific settings
61
+
62
+
51
63
  @dataclass(init=False)
52
64
  class GeminiModel(Model):
53
65
  """A model that uses Gemini via `generativelanguage.googleapis.com` API.
@@ -171,7 +183,9 @@ class GeminiAgentModel(AgentModel):
171
183
  async def request(
172
184
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
173
185
  ) -> tuple[ModelResponse, usage.Usage]:
174
- async with self._make_request(messages, False, model_settings) as http_response:
186
+ async with self._make_request(
187
+ messages, False, cast(GeminiModelSettings, model_settings or {})
188
+ ) as http_response:
175
189
  response = _gemini_response_ta.validate_json(await http_response.aread())
176
190
  return self._process_response(response), _metadata_as_usage(response)
177
191
 
@@ -179,12 +193,12 @@ class GeminiAgentModel(AgentModel):
179
193
  async def request_stream(
180
194
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
181
195
  ) -> AsyncIterator[StreamedResponse]:
182
- async with self._make_request(messages, True, model_settings) as http_response:
196
+ async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
183
197
  yield await self._process_streamed_response(http_response)
184
198
 
185
199
  @asynccontextmanager
186
200
  async def _make_request(
187
- self, messages: list[ModelMessage], streamed: bool, model_settings: ModelSettings | None
201
+ self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
188
202
  ) -> AsyncIterator[HTTPResponse]:
189
203
  sys_prompt_parts, contents = self._message_to_gemini_content(messages)
190
204
 
@@ -204,6 +218,10 @@ class GeminiAgentModel(AgentModel):
204
218
  generation_config['temperature'] = temperature
205
219
  if (top_p := model_settings.get('top_p')) is not None:
206
220
  generation_config['top_p'] = top_p
221
+ if (presence_penalty := model_settings.get('presence_penalty')) is not None:
222
+ generation_config['presence_penalty'] = presence_penalty
223
+ if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
224
+ generation_config['frequency_penalty'] = frequency_penalty
207
225
  if generation_config:
208
226
  request_data['generation_config'] = generation_config
209
227
 
@@ -222,7 +240,7 @@ class GeminiAgentModel(AgentModel):
222
240
  url,
223
241
  content=request_json,
224
242
  headers=headers,
225
- timeout=(model_settings or {}).get('timeout', USE_CLIENT_DEFAULT),
243
+ timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
226
244
  ) as r:
227
245
  if r.status_code != 200:
228
246
  await r.aread()
@@ -398,6 +416,8 @@ class _GeminiGenerationConfig(TypedDict, total=False):
398
416
  max_output_tokens: int
399
417
  temperature: float
400
418
  top_p: float
419
+ presence_penalty: float
420
+ frequency_penalty: float
401
421
 
402
422
 
403
423
  class _GeminiContent(TypedDict):
@@ -439,7 +459,7 @@ def _process_response_from_parts(
439
459
  items.append(TextPart(content=part['text']))
440
460
  elif 'function_call' in part:
441
461
  items.append(
442
- ToolCallPart.from_raw_args(
462
+ ToolCallPart(
443
463
  tool_name=part['function_call']['name'],
444
464
  args=part['function_call']['args'],
445
465
  )