pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -52,11 +52,15 @@ KnownModelName = Literal[
52
52
  'google-gla:gemini-1.5-flash-8b',
53
53
  'google-gla:gemini-1.5-pro',
54
54
  'google-gla:gemini-2.0-flash-exp',
55
+ 'google-gla:gemini-2.0-flash-thinking-exp-01-21',
56
+ 'google-gla:gemini-exp-1206',
55
57
  'google-vertex:gemini-1.0-pro',
56
58
  'google-vertex:gemini-1.5-flash',
57
59
  'google-vertex:gemini-1.5-flash-8b',
58
60
  'google-vertex:gemini-1.5-pro',
59
61
  'google-vertex:gemini-2.0-flash-exp',
62
+ 'google-vertex:gemini-2.0-flash-thinking-exp-01-21',
63
+ 'google-vertex:gemini-exp-1206',
60
64
  'gpt-3.5-turbo',
61
65
  'gpt-3.5-turbo-0125',
62
66
  'gpt-3.5-turbo-0301',
@@ -108,6 +112,8 @@ KnownModelName = Literal[
108
112
  'o1-mini-2024-09-12',
109
113
  'o1-preview',
110
114
  'o1-preview-2024-09-12',
115
+ 'o3-mini',
116
+ 'o3-mini-2025-01-31',
111
117
  'openai:chatgpt-4o-latest',
112
118
  'openai:gpt-3.5-turbo',
113
119
  'openai:gpt-3.5-turbo-0125',
@@ -145,6 +151,8 @@ KnownModelName = Literal[
145
151
  'openai:o1-mini-2024-09-12',
146
152
  'openai:o1-preview',
147
153
  'openai:o1-preview-2024-09-12',
154
+ 'openai:o3-mini',
155
+ 'openai:o3-mini-2025-01-31',
148
156
  'test',
149
157
  ]
150
158
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -153,49 +161,37 @@ KnownModelName = Literal[
153
161
  """
154
162
 
155
163
 
156
- class Model(ABC):
157
- """Abstract class for a model."""
164
+ @dataclass
165
+ class ModelRequestParameters:
166
+ """Configuration for an agent's request to a model, specifically related to tools and result handling."""
158
167
 
159
- @abstractmethod
160
- async def agent_model(
161
- self,
162
- *,
163
- function_tools: list[ToolDefinition],
164
- allow_text_result: bool,
165
- result_tools: list[ToolDefinition],
166
- ) -> AgentModel:
167
- """Create an agent model, this is called for each step of an agent run.
168
-
169
- This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
170
-
171
- Args:
172
- function_tools: The tools available to the agent.
173
- allow_text_result: Whether a plain text final response/result is permitted.
174
- result_tools: Tool definitions for the final result tool(s), if any.
175
-
176
- Returns:
177
- An agent model.
178
- """
179
- raise NotImplementedError()
168
+ function_tools: list[ToolDefinition]
169
+ allow_text_result: bool
170
+ result_tools: list[ToolDefinition]
180
171
 
181
- @abstractmethod
182
- def name(self) -> str:
183
- raise NotImplementedError()
184
172
 
173
+ class Model(ABC):
174
+ """Abstract class for a model."""
185
175
 
186
- class AgentModel(ABC):
187
- """Model configured for each step of an Agent run."""
176
+ _model_name: str
177
+ _system: str | None
188
178
 
189
179
  @abstractmethod
190
180
  async def request(
191
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
181
+ self,
182
+ messages: list[ModelMessage],
183
+ model_settings: ModelSettings | None,
184
+ model_request_parameters: ModelRequestParameters,
192
185
  ) -> tuple[ModelResponse, Usage]:
193
186
  """Make a request to the model."""
194
187
  raise NotImplementedError()
195
188
 
196
189
  @asynccontextmanager
197
190
  async def request_stream(
198
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
191
+ self,
192
+ messages: list[ModelMessage],
193
+ model_settings: ModelSettings | None,
194
+ model_request_parameters: ModelRequestParameters,
199
195
  ) -> AsyncIterator[StreamedResponse]:
200
196
  """Make a request to the model and return a streaming response."""
201
197
  # This method is not required, but you need to implement it if you want to support streamed responses
@@ -204,6 +200,16 @@ class AgentModel(ABC):
204
200
  # noinspection PyUnreachableCode
205
201
  yield # pragma: no cover
206
202
 
203
+ @property
204
+ def model_name(self) -> str:
205
+ """The model name."""
206
+ return self._model_name
207
+
208
+ @property
209
+ def system(self) -> str | None:
210
+ """The system / model provider, ex: openai."""
211
+ return self._system
212
+
207
213
 
208
214
  @dataclass
209
215
  class StreamedResponse(ABC):
@@ -266,7 +272,7 @@ def check_allow_model_requests() -> None:
266
272
  """Check if model requests are allowed.
267
273
 
268
274
  If you're defining your own models that have costs or latency associated with their use, you should call this in
269
- [`Model.agent_model`][pydantic_ai.models.Model.agent_model].
275
+ [`Model.request`][pydantic_ai.models.Model.request] and [`Model.request_stream`][pydantic_ai.models.Model.request_stream].
270
276
 
271
277
  Raises:
272
278
  RuntimeError: If model requests are not allowed.
@@ -307,33 +313,33 @@ def infer_model(model: Model | KnownModelName) -> Model:
307
313
  from .openai import OpenAIModel
308
314
 
309
315
  return OpenAIModel(model[7:])
310
- elif model.startswith(('gpt', 'o1')):
316
+ elif model.startswith(('gpt', 'o1', 'o3')):
311
317
  from .openai import OpenAIModel
312
318
 
313
319
  return OpenAIModel(model)
314
320
  elif model.startswith('google-gla'):
315
321
  from .gemini import GeminiModel
316
322
 
317
- return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
323
+ return GeminiModel(model[11:])
318
324
  # backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
319
325
  elif model.startswith('gemini'):
320
326
  from .gemini import GeminiModel
321
327
 
322
328
  # noinspection PyTypeChecker
323
- return GeminiModel(model) # pyright: ignore[reportArgumentType]
329
+ return GeminiModel(model)
324
330
  elif model.startswith('groq:'):
325
331
  from .groq import GroqModel
326
332
 
327
- return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
333
+ return GroqModel(model[5:])
328
334
  elif model.startswith('google-vertex'):
329
335
  from .vertexai import VertexAIModel
330
336
 
331
- return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
337
+ return VertexAIModel(model[14:])
332
338
  # backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
333
339
  elif model.startswith('vertexai:'):
334
340
  from .vertexai import VertexAIModel
335
341
 
336
- return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
342
+ return VertexAIModel(model[9:])
337
343
  elif model.startswith('mistral:'):
338
344
  from .mistral import MistralModel
339
345
 
@@ -28,8 +28,8 @@ from ..messages import (
28
28
  from ..settings import ModelSettings
29
29
  from ..tools import ToolDefinition
30
30
  from . import (
31
- AgentModel,
32
31
  Model,
32
+ ModelRequestParameters,
33
33
  StreamedResponse,
34
34
  cached_async_http_client,
35
35
  check_allow_model_requests,
@@ -68,14 +68,14 @@ LatestAnthropicModelNames = Literal[
68
68
  'claude-3-5-sonnet-latest',
69
69
  'claude-3-opus-latest',
70
70
  ]
71
- """Latest named Anthropic models."""
71
+ """Latest Anthropic models."""
72
72
 
73
73
  AnthropicModelName = Union[str, LatestAnthropicModelNames]
74
74
  """Possible Anthropic model names.
75
75
 
76
76
  Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
77
77
  allow any name in the type hints.
78
- Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
78
+ See [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
79
79
  """
80
80
 
81
81
 
@@ -101,9 +101,11 @@ class AnthropicModel(Model):
101
101
  We anticipate adding support for streaming responses in a near-term future release.
102
102
  """
103
103
 
104
- model_name: AnthropicModelName
105
104
  client: AsyncAnthropic = field(repr=False)
106
105
 
106
+ _model_name: AnthropicModelName = field(repr=False)
107
+ _system: str | None = field(default='anthropic', repr=False)
108
+
107
109
  def __init__(
108
110
  self,
109
111
  model_name: AnthropicModelName,
@@ -124,7 +126,7 @@ class AnthropicModel(Model):
124
126
  client to use, if provided, `api_key` and `http_client` must be `None`.
125
127
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
126
128
  """
127
- self.model_name = model_name
129
+ self._model_name = model_name
128
130
  if anthropic_client is not None:
129
131
  assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
130
132
  assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
@@ -134,81 +136,67 @@ class AnthropicModel(Model):
134
136
  else:
135
137
  self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
136
138
 
137
- async def agent_model(
139
+ async def request(
138
140
  self,
139
- *,
140
- function_tools: list[ToolDefinition],
141
- allow_text_result: bool,
142
- result_tools: list[ToolDefinition],
143
- ) -> AgentModel:
141
+ messages: list[ModelMessage],
142
+ model_settings: ModelSettings | None,
143
+ model_request_parameters: ModelRequestParameters,
144
+ ) -> tuple[ModelResponse, usage.Usage]:
144
145
  check_allow_model_requests()
145
- tools = [self._map_tool_definition(r) for r in function_tools]
146
- if result_tools:
147
- tools += [self._map_tool_definition(r) for r in result_tools]
148
- return AnthropicAgentModel(
149
- self.client,
150
- self.model_name,
151
- allow_text_result,
152
- tools,
146
+ response = await self._messages_create(
147
+ messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
153
148
  )
154
-
155
- def name(self) -> str:
156
- return f'anthropic:{self.model_name}'
157
-
158
- @staticmethod
159
- def _map_tool_definition(f: ToolDefinition) -> ToolParam:
160
- return {
161
- 'name': f.name,
162
- 'description': f.description,
163
- 'input_schema': f.parameters_json_schema,
164
- }
165
-
166
-
167
- @dataclass
168
- class AnthropicAgentModel(AgentModel):
169
- """Implementation of `AgentModel` for Anthropic models."""
170
-
171
- client: AsyncAnthropic
172
- model_name: AnthropicModelName
173
- allow_text_result: bool
174
- tools: list[ToolParam]
175
-
176
- async def request(
177
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
178
- ) -> tuple[ModelResponse, usage.Usage]:
179
- response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
180
149
  return self._process_response(response), _map_usage(response)
181
150
 
182
151
  @asynccontextmanager
183
152
  async def request_stream(
184
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
153
+ self,
154
+ messages: list[ModelMessage],
155
+ model_settings: ModelSettings | None,
156
+ model_request_parameters: ModelRequestParameters,
185
157
  ) -> AsyncIterator[StreamedResponse]:
186
- response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
158
+ check_allow_model_requests()
159
+ response = await self._messages_create(
160
+ messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
161
+ )
187
162
  async with response:
188
163
  yield await self._process_streamed_response(response)
189
164
 
190
165
  @overload
191
166
  async def _messages_create(
192
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
167
+ self,
168
+ messages: list[ModelMessage],
169
+ stream: Literal[True],
170
+ model_settings: AnthropicModelSettings,
171
+ model_request_parameters: ModelRequestParameters,
193
172
  ) -> AsyncStream[RawMessageStreamEvent]:
194
173
  pass
195
174
 
196
175
  @overload
197
176
  async def _messages_create(
198
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
177
+ self,
178
+ messages: list[ModelMessage],
179
+ stream: Literal[False],
180
+ model_settings: AnthropicModelSettings,
181
+ model_request_parameters: ModelRequestParameters,
199
182
  ) -> AnthropicMessage:
200
183
  pass
201
184
 
202
185
  async def _messages_create(
203
- self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
186
+ self,
187
+ messages: list[ModelMessage],
188
+ stream: bool,
189
+ model_settings: AnthropicModelSettings,
190
+ model_request_parameters: ModelRequestParameters,
204
191
  ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
205
192
  # standalone function to make it easier to override
193
+ tools = self._get_tools(model_request_parameters)
206
194
  tool_choice: ToolChoiceParam | None
207
195
 
208
- if not self.tools:
196
+ if not tools:
209
197
  tool_choice = None
210
198
  else:
211
- if not self.allow_text_result:
199
+ if not model_request_parameters.allow_text_result:
212
200
  tool_choice = {'type': 'any'}
213
201
  else:
214
202
  tool_choice = {'type': 'auto'}
@@ -222,8 +210,8 @@ class AnthropicAgentModel(AgentModel):
222
210
  max_tokens=model_settings.get('max_tokens', 1024),
223
211
  system=system_prompt or NOT_GIVEN,
224
212
  messages=anthropic_messages,
225
- model=self.model_name,
226
- tools=self.tools or NOT_GIVEN,
213
+ model=self._model_name,
214
+ tools=tools or NOT_GIVEN,
227
215
  tool_choice=tool_choice or NOT_GIVEN,
228
216
  stream=stream,
229
217
  temperature=model_settings.get('temperature', NOT_GIVEN),
@@ -248,7 +236,7 @@ class AnthropicAgentModel(AgentModel):
248
236
  )
249
237
  )
250
238
 
251
- return ModelResponse(items, model_name=self.model_name)
239
+ return ModelResponse(items, model_name=self._model_name)
252
240
 
253
241
  async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
254
242
  peekable_response = _utils.PeekableAsyncStream(response)
@@ -258,10 +246,17 @@ class AnthropicAgentModel(AgentModel):
258
246
 
259
247
  # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
260
248
  timestamp = datetime.now(tz=timezone.utc)
261
- return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
249
+ return AnthropicStreamedResponse(
250
+ _model_name=self._model_name, _response=peekable_response, _timestamp=timestamp
251
+ )
262
252
 
263
- @staticmethod
264
- def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
253
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolParam]:
254
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
255
+ if model_request_parameters.result_tools:
256
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
257
+ return tools
258
+
259
+ def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
265
260
  """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
266
261
  system_prompt: str = ''
267
262
  anthropic_messages: list[MessageParam] = []
@@ -310,20 +305,28 @@ class AnthropicAgentModel(AgentModel):
310
305
  content.append(TextBlockParam(text=item.content, type='text'))
311
306
  else:
312
307
  assert isinstance(item, ToolCallPart)
313
- content.append(_map_tool_call(item))
308
+ content.append(self._map_tool_call(item))
314
309
  anthropic_messages.append(MessageParam(role='assistant', content=content))
315
310
  else:
316
311
  assert_never(m)
317
312
  return system_prompt, anthropic_messages
318
313
 
314
+ @staticmethod
315
+ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
316
+ return ToolUseBlockParam(
317
+ id=_guard_tool_call_id(t=t, model_source='Anthropic'),
318
+ type='tool_use',
319
+ name=t.tool_name,
320
+ input=t.args_as_dict(),
321
+ )
319
322
 
320
- def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
321
- return ToolUseBlockParam(
322
- id=_guard_tool_call_id(t=t, model_source='Anthropic'),
323
- type='tool_use',
324
- name=t.tool_name,
325
- input=t.args_as_dict(),
326
- )
323
+ @staticmethod
324
+ def _map_tool_definition(f: ToolDefinition) -> ToolParam:
325
+ return {
326
+ 'name': f.name,
327
+ 'description': f.description,
328
+ 'input_schema': f.parameters_json_schema,
329
+ }
327
330
 
328
331
 
329
332
  def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
@@ -26,8 +26,8 @@ from ..messages import (
26
26
  from ..settings import ModelSettings
27
27
  from ..tools import ToolDefinition
28
28
  from . import (
29
- AgentModel,
30
29
  Model,
30
+ ModelRequestParameters,
31
31
  check_allow_model_requests,
32
32
  )
33
33
 
@@ -52,7 +52,7 @@ except ImportError as _import_error:
52
52
  "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
53
53
  ) from _import_error
54
54
 
55
- NamedCohereModels = Literal[
55
+ LatestCohereModelNames = Literal[
56
56
  'c4ai-aya-expanse-32b',
57
57
  'c4ai-aya-expanse-8b',
58
58
  'command',
@@ -67,9 +67,15 @@ NamedCohereModels = Literal[
67
67
  'command-r-plus-08-2024',
68
68
  'command-r7b-12-2024',
69
69
  ]
70
- """Latest / most popular named Cohere models."""
70
+ """Latest Cohere models."""
71
71
 
72
- CohereModelName = Union[NamedCohereModels, str]
72
+ CohereModelName = Union[str, LatestCohereModelNames]
73
+ """Possible Cohere model names.
74
+
75
+ Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
76
+ allow any name in the type hints.
77
+ See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
78
+ """
73
79
 
74
80
 
75
81
  class CohereModelSettings(ModelSettings):
@@ -88,9 +94,11 @@ class CohereModel(Model):
88
94
  Apart from `__init__`, all methods are private or match those of the base class.
89
95
  """
90
96
 
91
- model_name: CohereModelName
92
97
  client: AsyncClientV2 = field(repr=False)
93
98
 
99
+ _model_name: CohereModelName = field(repr=False)
100
+ _system: str | None = field(default='cohere', repr=False)
101
+
94
102
  def __init__(
95
103
  self,
96
104
  model_name: CohereModelName,
@@ -110,7 +118,7 @@ class CohereModel(Model):
110
118
  `api_key` and `http_client` must be `None`.
111
119
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
112
120
  """
113
- self.model_name: CohereModelName = model_name
121
+ self._model_name: CohereModelName = model_name
114
122
  if cohere_client is not None:
115
123
  assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
116
124
  assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
@@ -118,64 +126,28 @@ class CohereModel(Model):
118
126
  else:
119
127
  self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
120
128
 
121
- async def agent_model(
122
- self,
123
- *,
124
- function_tools: list[ToolDefinition],
125
- allow_text_result: bool,
126
- result_tools: list[ToolDefinition],
127
- ) -> AgentModel:
128
- check_allow_model_requests()
129
- tools = [self._map_tool_definition(r) for r in function_tools]
130
- if result_tools:
131
- tools += [self._map_tool_definition(r) for r in result_tools]
132
- return CohereAgentModel(
133
- self.client,
134
- self.model_name,
135
- allow_text_result,
136
- tools,
137
- )
138
-
139
- def name(self) -> str:
140
- return f'cohere:{self.model_name}'
141
-
142
- @staticmethod
143
- def _map_tool_definition(f: ToolDefinition) -> ToolV2:
144
- return ToolV2(
145
- type='function',
146
- function=ToolV2Function(
147
- name=f.name,
148
- description=f.description,
149
- parameters=f.parameters_json_schema,
150
- ),
151
- )
152
-
153
-
154
- @dataclass
155
- class CohereAgentModel(AgentModel):
156
- """Implementation of `AgentModel` for Cohere models."""
157
-
158
- client: AsyncClientV2
159
- model_name: CohereModelName
160
- allow_text_result: bool
161
- tools: list[ToolV2]
162
-
163
129
  async def request(
164
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
130
+ self,
131
+ messages: list[ModelMessage],
132
+ model_settings: ModelSettings | None,
133
+ model_request_parameters: ModelRequestParameters,
165
134
  ) -> tuple[ModelResponse, result.Usage]:
166
- response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
135
+ check_allow_model_requests()
136
+ response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
167
137
  return self._process_response(response), _map_usage(response)
168
138
 
169
139
  async def _chat(
170
140
  self,
171
141
  messages: list[ModelMessage],
172
142
  model_settings: CohereModelSettings,
143
+ model_request_parameters: ModelRequestParameters,
173
144
  ) -> ChatResponse:
145
+ tools = self._get_tools(model_request_parameters)
174
146
  cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
175
147
  return await self.client.chat(
176
- model=self.model_name,
148
+ model=self._model_name,
177
149
  messages=cohere_messages,
178
- tools=self.tools or OMIT,
150
+ tools=tools or OMIT,
179
151
  max_tokens=model_settings.get('max_tokens', OMIT),
180
152
  temperature=model_settings.get('temperature', OMIT),
181
153
  p=model_settings.get('top_p', OMIT),
@@ -201,13 +173,12 @@ class CohereAgentModel(AgentModel):
201
173
  tool_call_id=c.id,
202
174
  )
203
175
  )
204
- return ModelResponse(parts=parts, model_name=self.model_name)
176
+ return ModelResponse(parts=parts, model_name=self._model_name)
205
177
 
206
- @classmethod
207
- def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
178
+ def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]:
208
179
  """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
209
180
  if isinstance(message, ModelRequest):
210
- yield from cls._map_user_message(message)
181
+ yield from self._map_user_message(message)
211
182
  elif isinstance(message, ModelResponse):
212
183
  texts: list[str] = []
213
184
  tool_calls: list[ToolCallV2] = []
@@ -215,7 +186,7 @@ class CohereAgentModel(AgentModel):
215
186
  if isinstance(item, TextPart):
216
187
  texts.append(item.content)
217
188
  elif isinstance(item, ToolCallPart):
218
- tool_calls.append(_map_tool_call(item))
189
+ tool_calls.append(self._map_tool_call(item))
219
190
  else:
220
191
  assert_never(item)
221
192
  message_param = AssistantChatMessageV2(role='assistant')
@@ -227,6 +198,34 @@ class CohereAgentModel(AgentModel):
227
198
  else:
228
199
  assert_never(message)
229
200
 
201
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
202
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
203
+ if model_request_parameters.result_tools:
204
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
205
+ return tools
206
+
207
+ @staticmethod
208
+ def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
209
+ return ToolCallV2(
210
+ id=_guard_tool_call_id(t=t, model_source='Cohere'),
211
+ type='function',
212
+ function=ToolCallV2Function(
213
+ name=t.tool_name,
214
+ arguments=t.args_as_json_str(),
215
+ ),
216
+ )
217
+
218
+ @staticmethod
219
+ def _map_tool_definition(f: ToolDefinition) -> ToolV2:
220
+ return ToolV2(
221
+ type='function',
222
+ function=ToolV2Function(
223
+ name=f.name,
224
+ description=f.description,
225
+ parameters=f.parameters_json_schema,
226
+ ),
227
+ )
228
+
230
229
  @classmethod
231
230
  def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
232
231
  for part in message.parts:
@@ -253,17 +252,6 @@ class CohereAgentModel(AgentModel):
253
252
  assert_never(part)
254
253
 
255
254
 
256
- def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
257
- return ToolCallV2(
258
- id=_guard_tool_call_id(t=t, model_source='Cohere'),
259
- type='function',
260
- function=ToolCallV2Function(
261
- name=t.tool_name,
262
- arguments=t.args_as_json_str(),
263
- ),
264
- )
265
-
266
-
267
255
  def _map_usage(response: ChatResponse) -> result.Usage:
268
256
  usage = response.usage
269
257
  if usage is None: