pydantic-ai-slim 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -26,8 +26,8 @@ from ..messages import (
26
26
  from ..settings import ModelSettings
27
27
  from ..tools import ToolDefinition
28
28
  from . import (
29
- AgentModel,
30
29
  Model,
30
+ ModelRequestParameters,
31
31
  check_allow_model_requests,
32
32
  )
33
33
 
@@ -52,7 +52,7 @@ except ImportError as _import_error:
52
52
  "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
53
53
  ) from _import_error
54
54
 
55
- NamedCohereModels = Literal[
55
+ LatestCohereModelNames = Literal[
56
56
  'c4ai-aya-expanse-32b',
57
57
  'c4ai-aya-expanse-8b',
58
58
  'command',
@@ -67,9 +67,15 @@ NamedCohereModels = Literal[
67
67
  'command-r-plus-08-2024',
68
68
  'command-r7b-12-2024',
69
69
  ]
70
- """Latest / most popular named Cohere models."""
70
+ """Latest Cohere models."""
71
71
 
72
- CohereModelName = Union[NamedCohereModels, str]
72
+ CohereModelName = Union[str, LatestCohereModelNames]
73
+ """Possible Cohere model names.
74
+
75
+ Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
76
+ allow any name in the type hints.
77
+ See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
78
+ """
73
79
 
74
80
 
75
81
  class CohereModelSettings(ModelSettings):
@@ -88,9 +94,11 @@ class CohereModel(Model):
88
94
  Apart from `__init__`, all methods are private or match those of the base class.
89
95
  """
90
96
 
91
- model_name: CohereModelName
92
97
  client: AsyncClientV2 = field(repr=False)
93
98
 
99
+ _model_name: CohereModelName = field(repr=False)
100
+ _system: str | None = field(default='cohere', repr=False)
101
+
94
102
  def __init__(
95
103
  self,
96
104
  model_name: CohereModelName,
@@ -110,7 +118,7 @@ class CohereModel(Model):
110
118
  `api_key` and `http_client` must be `None`.
111
119
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
112
120
  """
113
- self.model_name: CohereModelName = model_name
121
+ self._model_name: CohereModelName = model_name
114
122
  if cohere_client is not None:
115
123
  assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
116
124
  assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
@@ -118,64 +126,38 @@ class CohereModel(Model):
118
126
  else:
119
127
  self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
120
128
 
121
- async def agent_model(
129
+ async def request(
122
130
  self,
123
- *,
124
- function_tools: list[ToolDefinition],
125
- allow_text_result: bool,
126
- result_tools: list[ToolDefinition],
127
- ) -> AgentModel:
131
+ messages: list[ModelMessage],
132
+ model_settings: ModelSettings | None,
133
+ model_request_parameters: ModelRequestParameters,
134
+ ) -> tuple[ModelResponse, result.Usage]:
128
135
  check_allow_model_requests()
129
- tools = [self._map_tool_definition(r) for r in function_tools]
130
- if result_tools:
131
- tools += [self._map_tool_definition(r) for r in result_tools]
132
- return CohereAgentModel(
133
- self.client,
134
- self.model_name,
135
- allow_text_result,
136
- tools,
137
- )
138
-
139
- def name(self) -> str:
140
- return f'cohere:{self.model_name}'
141
-
142
- @staticmethod
143
- def _map_tool_definition(f: ToolDefinition) -> ToolV2:
144
- return ToolV2(
145
- type='function',
146
- function=ToolV2Function(
147
- name=f.name,
148
- description=f.description,
149
- parameters=f.parameters_json_schema,
150
- ),
151
- )
152
-
153
-
154
- @dataclass
155
- class CohereAgentModel(AgentModel):
156
- """Implementation of `AgentModel` for Cohere models."""
136
+ response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
137
+ return self._process_response(response), _map_usage(response)
157
138
 
158
- client: AsyncClientV2
159
- model_name: CohereModelName
160
- allow_text_result: bool
161
- tools: list[ToolV2]
139
+ @property
140
+ def model_name(self) -> CohereModelName:
141
+ """The model name."""
142
+ return self._model_name
162
143
 
163
- async def request(
164
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
165
- ) -> tuple[ModelResponse, result.Usage]:
166
- response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
167
- return self._process_response(response), _map_usage(response)
144
+ @property
145
+ def system(self) -> str | None:
146
+ """The system / model provider."""
147
+ return self._system
168
148
 
169
149
  async def _chat(
170
150
  self,
171
151
  messages: list[ModelMessage],
172
152
  model_settings: CohereModelSettings,
153
+ model_request_parameters: ModelRequestParameters,
173
154
  ) -> ChatResponse:
155
+ tools = self._get_tools(model_request_parameters)
174
156
  cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
175
157
  return await self.client.chat(
176
- model=self.model_name,
158
+ model=self._model_name,
177
159
  messages=cohere_messages,
178
- tools=self.tools or OMIT,
160
+ tools=tools or OMIT,
179
161
  max_tokens=model_settings.get('max_tokens', OMIT),
180
162
  temperature=model_settings.get('temperature', OMIT),
181
163
  p=model_settings.get('top_p', OMIT),
@@ -201,13 +183,12 @@ class CohereAgentModel(AgentModel):
201
183
  tool_call_id=c.id,
202
184
  )
203
185
  )
204
- return ModelResponse(parts=parts, model_name=self.model_name)
186
+ return ModelResponse(parts=parts, model_name=self._model_name)
205
187
 
206
- @classmethod
207
- def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
188
+ def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]:
208
189
  """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
209
190
  if isinstance(message, ModelRequest):
210
- yield from cls._map_user_message(message)
191
+ yield from self._map_user_message(message)
211
192
  elif isinstance(message, ModelResponse):
212
193
  texts: list[str] = []
213
194
  tool_calls: list[ToolCallV2] = []
@@ -215,7 +196,7 @@ class CohereAgentModel(AgentModel):
215
196
  if isinstance(item, TextPart):
216
197
  texts.append(item.content)
217
198
  elif isinstance(item, ToolCallPart):
218
- tool_calls.append(_map_tool_call(item))
199
+ tool_calls.append(self._map_tool_call(item))
219
200
  else:
220
201
  assert_never(item)
221
202
  message_param = AssistantChatMessageV2(role='assistant')
@@ -227,6 +208,34 @@ class CohereAgentModel(AgentModel):
227
208
  else:
228
209
  assert_never(message)
229
210
 
211
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
212
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
213
+ if model_request_parameters.result_tools:
214
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
215
+ return tools
216
+
217
+ @staticmethod
218
+ def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
219
+ return ToolCallV2(
220
+ id=_guard_tool_call_id(t=t, model_source='Cohere'),
221
+ type='function',
222
+ function=ToolCallV2Function(
223
+ name=t.tool_name,
224
+ arguments=t.args_as_json_str(),
225
+ ),
226
+ )
227
+
228
+ @staticmethod
229
+ def _map_tool_definition(f: ToolDefinition) -> ToolV2:
230
+ return ToolV2(
231
+ type='function',
232
+ function=ToolV2Function(
233
+ name=f.name,
234
+ description=f.description,
235
+ parameters=f.parameters_json_schema,
236
+ ),
237
+ )
238
+
230
239
  @classmethod
231
240
  def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
232
241
  for part in message.parts:
@@ -253,17 +262,6 @@ class CohereAgentModel(AgentModel):
253
262
  assert_never(part)
254
263
 
255
264
 
256
- def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
257
- return ToolCallV2(
258
- id=_guard_tool_call_id(t=t, model_source='Cohere'),
259
- type='function',
260
- function=ToolCallV2Function(
261
- name=t.tool_name,
262
- arguments=t.args_as_json_str(),
263
- ),
264
- )
265
-
266
-
267
265
  def _map_usage(response: ChatResponse) -> result.Usage:
268
266
  usage = response.usage
269
267
  if usage is None:
@@ -4,7 +4,7 @@ import inspect
4
4
  import re
5
5
  from collections.abc import AsyncIterator, Awaitable, Iterable
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field, replace
7
+ from dataclasses import dataclass, field
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
10
  from typing import Callable, Union
@@ -27,7 +27,7 @@ from ..messages import (
27
27
  )
28
28
  from ..settings import ModelSettings
29
29
  from ..tools import ToolDefinition
30
- from . import AgentModel, Model, StreamedResponse
30
+ from . import Model, ModelRequestParameters, StreamedResponse
31
31
 
32
32
 
33
33
  @dataclass(init=False)
@@ -40,6 +40,9 @@ class FunctionModel(Model):
40
40
  function: FunctionDef | None = None
41
41
  stream_function: StreamFunctionDef | None = None
42
42
 
43
+ _model_name: str = field(repr=False)
44
+ _system: str | None = field(default=None, repr=False)
45
+
43
46
  @overload
44
47
  def __init__(self, function: FunctionDef) -> None: ...
45
48
 
@@ -63,23 +66,70 @@ class FunctionModel(Model):
63
66
  self.function = function
64
67
  self.stream_function = stream_function
65
68
 
66
- async def agent_model(
69
+ function_name = self.function.__name__ if self.function is not None else ''
70
+ stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
71
+ self._model_name = f'function:{function_name}:{stream_function_name}'
72
+
73
+ async def request(
67
74
  self,
68
- *,
69
- function_tools: list[ToolDefinition],
70
- allow_text_result: bool,
71
- result_tools: list[ToolDefinition],
72
- ) -> AgentModel:
73
- return FunctionAgentModel(
74
- self.function,
75
- self.stream_function,
76
- AgentInfo(function_tools, allow_text_result, result_tools, None),
75
+ messages: list[ModelMessage],
76
+ model_settings: ModelSettings | None,
77
+ model_request_parameters: ModelRequestParameters,
78
+ ) -> tuple[ModelResponse, usage.Usage]:
79
+ agent_info = AgentInfo(
80
+ model_request_parameters.function_tools,
81
+ model_request_parameters.allow_text_result,
82
+ model_request_parameters.result_tools,
83
+ model_settings,
77
84
  )
78
85
 
79
- def name(self) -> str:
80
- function_name = self.function.__name__ if self.function is not None else ''
81
- stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
82
- return f'function:{function_name}:{stream_function_name}'
86
+ assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
87
+
88
+ if inspect.iscoroutinefunction(self.function):
89
+ response = await self.function(messages, agent_info)
90
+ else:
91
+ response_ = await _utils.run_in_executor(self.function, messages, agent_info)
92
+ assert isinstance(response_, ModelResponse), response_
93
+ response = response_
94
+ response.model_name = f'function:{self.function.__name__}'
95
+ # TODO is `messages` right here? Should it just be new messages?
96
+ return response, _estimate_usage(chain(messages, [response]))
97
+
98
+ @asynccontextmanager
99
+ async def request_stream(
100
+ self,
101
+ messages: list[ModelMessage],
102
+ model_settings: ModelSettings | None,
103
+ model_request_parameters: ModelRequestParameters,
104
+ ) -> AsyncIterator[StreamedResponse]:
105
+ agent_info = AgentInfo(
106
+ model_request_parameters.function_tools,
107
+ model_request_parameters.allow_text_result,
108
+ model_request_parameters.result_tools,
109
+ model_settings,
110
+ )
111
+
112
+ assert (
113
+ self.stream_function is not None
114
+ ), 'FunctionModel must receive a `stream_function` to support streamed requests'
115
+
116
+ response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
117
+
118
+ first = await response_stream.peek()
119
+ if isinstance(first, _utils.Unset):
120
+ raise ValueError('Stream function must return at least one item')
121
+
122
+ yield FunctionStreamedResponse(_model_name=f'function:{self.stream_function.__name__}', _iter=response_stream)
123
+
124
+ @property
125
+ def model_name(self) -> str:
126
+ """The model name."""
127
+ return self._model_name
128
+
129
+ @property
130
+ def system(self) -> str | None:
131
+ """The system / model provider."""
132
+ return self._system
83
133
 
84
134
 
85
135
  @dataclass(frozen=True)
@@ -119,9 +169,11 @@ class DeltaToolCall:
119
169
  DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
120
170
  """A mapping of tool call IDs to incremental changes."""
121
171
 
172
+ # TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
122
173
  FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
123
174
  """A function used to generate a non-streamed response."""
124
175
 
176
+ # TODO: Change signature as indicated above
125
177
  StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
126
178
  """A function used to generate a streamed response.
127
179
 
@@ -132,54 +184,11 @@ E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
132
184
  """
133
185
 
134
186
 
135
- @dataclass
136
- class FunctionAgentModel(AgentModel):
137
- """Implementation of `AgentModel` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
138
-
139
- function: FunctionDef | None
140
- stream_function: StreamFunctionDef | None
141
- agent_info: AgentInfo
142
-
143
- async def request(
144
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
145
- ) -> tuple[ModelResponse, usage.Usage]:
146
- agent_info = replace(self.agent_info, model_settings=model_settings)
147
-
148
- assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
149
- model_name = f'function:{self.function.__name__}'
150
-
151
- if inspect.iscoroutinefunction(self.function):
152
- response = await self.function(messages, agent_info)
153
- else:
154
- response_ = await _utils.run_in_executor(self.function, messages, agent_info)
155
- assert isinstance(response_, ModelResponse), response_
156
- response = response_
157
- response.model_name = model_name
158
- # TODO is `messages` right here? Should it just be new messages?
159
- return response, _estimate_usage(chain(messages, [response]))
160
-
161
- @asynccontextmanager
162
- async def request_stream(
163
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
164
- ) -> AsyncIterator[StreamedResponse]:
165
- assert (
166
- self.stream_function is not None
167
- ), 'FunctionModel must receive a `stream_function` to support streamed requests'
168
- model_name = f'function:{self.stream_function.__name__}'
169
-
170
- response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
171
-
172
- first = await response_stream.peek()
173
- if isinstance(first, _utils.Unset):
174
- raise ValueError('Stream function must return at least one item')
175
-
176
- yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
177
-
178
-
179
187
  @dataclass
180
188
  class FunctionStreamedResponse(StreamedResponse):
181
189
  """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
182
190
 
191
+ _model_name: str
183
192
  _iter: AsyncIterator[str | DeltaToolCalls]
184
193
  _timestamp: datetime = field(default_factory=_utils.now_utc)
185
194
 
@@ -207,7 +216,14 @@ class FunctionStreamedResponse(StreamedResponse):
207
216
  if maybe_event is not None:
208
217
  yield maybe_event
209
218
 
219
+ @property
220
+ def model_name(self) -> str:
221
+ """Get the model name of the response."""
222
+ return self._model_name
223
+
224
+ @property
210
225
  def timestamp(self) -> datetime:
226
+ """Get the timestamp of the response."""
211
227
  return self._timestamp
212
228
 
213
229