pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +17 -3
- pydantic_ai/_result.py +26 -21
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +187 -159
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +217 -15
- pydantic_ai/models/__init__.py +58 -71
- pydantic_ai/models/anthropic.py +112 -48
- pydantic_ai/models/cohere.py +278 -0
- pydantic_ai/models/function.py +57 -85
- pydantic_ai/models/gemini.py +83 -129
- pydantic_ai/models/groq.py +60 -130
- pydantic_ai/models/mistral.py +86 -142
- pydantic_ai/models/ollama.py +4 -0
- pydantic_ai/models/openai.py +75 -136
- pydantic_ai/models/test.py +55 -80
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +132 -114
- pydantic_ai/settings.py +18 -1
- pydantic_ai/tools.py +42 -23
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/METADATA +7 -3
- pydantic_ai_slim-0.0.20.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/WHEEL +0 -0
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator
|
|
3
|
+
from collections.abc import AsyncIterable, AsyncIterator
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from json import JSONDecodeError, loads as json_loads
|
|
6
8
|
from typing import Any, Literal, Union, cast, overload
|
|
7
9
|
|
|
8
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
9
11
|
from typing_extensions import assert_never
|
|
10
12
|
|
|
11
|
-
from .. import
|
|
13
|
+
from .. import UnexpectedModelBehavior, _utils, usage
|
|
12
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
15
|
from ..messages import (
|
|
14
16
|
ArgsDict,
|
|
@@ -16,6 +18,7 @@ from ..messages import (
|
|
|
16
18
|
ModelRequest,
|
|
17
19
|
ModelResponse,
|
|
18
20
|
ModelResponsePart,
|
|
21
|
+
ModelResponseStreamEvent,
|
|
19
22
|
RetryPromptPart,
|
|
20
23
|
SystemPromptPart,
|
|
21
24
|
TextPart,
|
|
@@ -27,8 +30,8 @@ from ..settings import ModelSettings
|
|
|
27
30
|
from ..tools import ToolDefinition
|
|
28
31
|
from . import (
|
|
29
32
|
AgentModel,
|
|
30
|
-
EitherStreamedResponse,
|
|
31
33
|
Model,
|
|
34
|
+
StreamedResponse,
|
|
32
35
|
cached_async_http_client,
|
|
33
36
|
check_allow_model_requests,
|
|
34
37
|
)
|
|
@@ -38,11 +41,16 @@ try:
|
|
|
38
41
|
from anthropic.types import (
|
|
39
42
|
Message as AnthropicMessage,
|
|
40
43
|
MessageParam,
|
|
44
|
+
RawContentBlockDeltaEvent,
|
|
45
|
+
RawContentBlockStartEvent,
|
|
46
|
+
RawContentBlockStopEvent,
|
|
41
47
|
RawMessageDeltaEvent,
|
|
42
48
|
RawMessageStartEvent,
|
|
49
|
+
RawMessageStopEvent,
|
|
43
50
|
RawMessageStreamEvent,
|
|
44
51
|
TextBlock,
|
|
45
52
|
TextBlockParam,
|
|
53
|
+
TextDelta,
|
|
46
54
|
ToolChoiceParam,
|
|
47
55
|
ToolParam,
|
|
48
56
|
ToolResultBlockParam,
|
|
@@ -152,20 +160,20 @@ class AnthropicAgentModel(AgentModel):
|
|
|
152
160
|
"""Implementation of `AgentModel` for Anthropic models."""
|
|
153
161
|
|
|
154
162
|
client: AsyncAnthropic
|
|
155
|
-
model_name:
|
|
163
|
+
model_name: AnthropicModelName
|
|
156
164
|
allow_text_result: bool
|
|
157
165
|
tools: list[ToolParam]
|
|
158
166
|
|
|
159
167
|
async def request(
|
|
160
168
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
-
) -> tuple[ModelResponse,
|
|
169
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
162
170
|
response = await self._messages_create(messages, False, model_settings)
|
|
163
171
|
return self._process_response(response), _map_usage(response)
|
|
164
172
|
|
|
165
173
|
@asynccontextmanager
|
|
166
174
|
async def request_stream(
|
|
167
175
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
|
-
) -> AsyncIterator[
|
|
176
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
169
177
|
response = await self._messages_create(messages, True, model_settings)
|
|
170
178
|
async with response:
|
|
171
179
|
yield await self._process_streamed_response(response)
|
|
@@ -186,16 +194,22 @@ class AnthropicAgentModel(AgentModel):
|
|
|
186
194
|
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
187
195
|
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
|
|
188
196
|
# standalone function to make it easier to override
|
|
197
|
+
model_settings = model_settings or {}
|
|
198
|
+
|
|
199
|
+
tool_choice: ToolChoiceParam | None
|
|
200
|
+
|
|
189
201
|
if not self.tools:
|
|
190
|
-
tool_choice
|
|
191
|
-
elif not self.allow_text_result:
|
|
192
|
-
tool_choice = {'type': 'any'}
|
|
202
|
+
tool_choice = None
|
|
193
203
|
else:
|
|
194
|
-
|
|
204
|
+
if not self.allow_text_result:
|
|
205
|
+
tool_choice = {'type': 'any'}
|
|
206
|
+
else:
|
|
207
|
+
tool_choice = {'type': 'auto'}
|
|
195
208
|
|
|
196
|
-
|
|
209
|
+
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
|
|
210
|
+
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
|
|
197
211
|
|
|
198
|
-
|
|
212
|
+
system_prompt, anthropic_messages = self._map_message(messages)
|
|
199
213
|
|
|
200
214
|
return await self.client.messages.create(
|
|
201
215
|
max_tokens=model_settings.get('max_tokens', 1024),
|
|
@@ -210,43 +224,33 @@ class AnthropicAgentModel(AgentModel):
|
|
|
210
224
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
211
225
|
)
|
|
212
226
|
|
|
213
|
-
|
|
214
|
-
def _process_response(response: AnthropicMessage) -> ModelResponse:
|
|
227
|
+
def _process_response(self, response: AnthropicMessage) -> ModelResponse:
|
|
215
228
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
216
229
|
items: list[ModelResponsePart] = []
|
|
217
230
|
for item in response.content:
|
|
218
231
|
if isinstance(item, TextBlock):
|
|
219
|
-
items.append(TextPart(item.text))
|
|
232
|
+
items.append(TextPart(content=item.text))
|
|
220
233
|
else:
|
|
221
234
|
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
222
235
|
items.append(
|
|
223
236
|
ToolCallPart.from_raw_args(
|
|
224
|
-
item.name,
|
|
225
|
-
cast(dict[str, Any], item.input),
|
|
226
|
-
item.id,
|
|
237
|
+
tool_name=item.name,
|
|
238
|
+
args=cast(dict[str, Any], item.input),
|
|
239
|
+
tool_call_id=item.id,
|
|
227
240
|
)
|
|
228
241
|
)
|
|
229
242
|
|
|
230
|
-
return ModelResponse(items)
|
|
243
|
+
return ModelResponse(items, model_name=self.model_name)
|
|
231
244
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
# depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
|
|
242
|
-
# RawMessageStartEvent
|
|
243
|
-
# RawMessageDeltaEvent
|
|
244
|
-
# RawMessageStopEvent
|
|
245
|
-
# RawContentBlockStartEvent
|
|
246
|
-
# RawContentBlockDeltaEvent
|
|
247
|
-
# RawContentBlockDeltaEvent
|
|
248
|
-
#
|
|
249
|
-
# We might refactor streaming internally before we implement this...
|
|
245
|
+
async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
|
|
246
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
247
|
+
first_chunk = await peekable_response.peek()
|
|
248
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
249
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
250
|
+
|
|
251
|
+
# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
|
|
252
|
+
timestamp = datetime.now(tz=timezone.utc)
|
|
253
|
+
return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
|
|
250
254
|
|
|
251
255
|
@staticmethod
|
|
252
256
|
def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
@@ -315,30 +319,90 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
|
315
319
|
)
|
|
316
320
|
|
|
317
321
|
|
|
318
|
-
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) ->
|
|
322
|
+
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
|
|
319
323
|
if isinstance(message, AnthropicMessage):
|
|
320
|
-
|
|
324
|
+
response_usage = message.usage
|
|
321
325
|
else:
|
|
322
326
|
if isinstance(message, RawMessageStartEvent):
|
|
323
|
-
|
|
327
|
+
response_usage = message.message.usage
|
|
324
328
|
elif isinstance(message, RawMessageDeltaEvent):
|
|
325
|
-
|
|
329
|
+
response_usage = message.usage
|
|
326
330
|
else:
|
|
327
331
|
# No usage information provided in:
|
|
328
332
|
# - RawMessageStopEvent
|
|
329
333
|
# - RawContentBlockStartEvent
|
|
330
334
|
# - RawContentBlockDeltaEvent
|
|
331
335
|
# - RawContentBlockStopEvent
|
|
332
|
-
|
|
336
|
+
response_usage = None
|
|
333
337
|
|
|
334
|
-
if
|
|
335
|
-
return
|
|
338
|
+
if response_usage is None:
|
|
339
|
+
return usage.Usage()
|
|
336
340
|
|
|
337
|
-
request_tokens = getattr(
|
|
341
|
+
request_tokens = getattr(response_usage, 'input_tokens', None)
|
|
338
342
|
|
|
339
|
-
return
|
|
343
|
+
return usage.Usage(
|
|
340
344
|
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
|
|
341
345
|
request_tokens=request_tokens,
|
|
342
|
-
response_tokens=
|
|
343
|
-
total_tokens=(request_tokens or 0) +
|
|
346
|
+
response_tokens=response_usage.output_tokens,
|
|
347
|
+
total_tokens=(request_tokens or 0) + response_usage.output_tokens,
|
|
344
348
|
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@dataclass
|
|
352
|
+
class AnthropicStreamedResponse(StreamedResponse):
|
|
353
|
+
"""Implementation of `StreamedResponse` for Anthropic models."""
|
|
354
|
+
|
|
355
|
+
_response: AsyncIterable[RawMessageStreamEvent]
|
|
356
|
+
_timestamp: datetime
|
|
357
|
+
|
|
358
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
359
|
+
current_block: TextBlock | ToolUseBlock | None = None
|
|
360
|
+
current_json: str = ''
|
|
361
|
+
|
|
362
|
+
async for event in self._response:
|
|
363
|
+
self._usage += _map_usage(event)
|
|
364
|
+
|
|
365
|
+
if isinstance(event, RawContentBlockStartEvent):
|
|
366
|
+
current_block = event.content_block
|
|
367
|
+
if isinstance(current_block, TextBlock) and current_block.text:
|
|
368
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
|
|
369
|
+
elif isinstance(current_block, ToolUseBlock):
|
|
370
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
371
|
+
vendor_part_id=current_block.id,
|
|
372
|
+
tool_name=current_block.name,
|
|
373
|
+
args=cast(dict[str, Any], current_block.input),
|
|
374
|
+
tool_call_id=current_block.id,
|
|
375
|
+
)
|
|
376
|
+
if maybe_event is not None:
|
|
377
|
+
yield maybe_event
|
|
378
|
+
|
|
379
|
+
elif isinstance(event, RawContentBlockDeltaEvent):
|
|
380
|
+
if isinstance(event.delta, TextDelta):
|
|
381
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
|
|
382
|
+
elif (
|
|
383
|
+
current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
|
|
384
|
+
):
|
|
385
|
+
# Try to parse the JSON immediately, otherwise cache the value for later. This handles
|
|
386
|
+
# cases where the JSON is not currently valid but will be valid once we stream more tokens.
|
|
387
|
+
try:
|
|
388
|
+
parsed_args = json_loads(current_json + event.delta.partial_json)
|
|
389
|
+
current_json = ''
|
|
390
|
+
except JSONDecodeError:
|
|
391
|
+
current_json += event.delta.partial_json
|
|
392
|
+
continue
|
|
393
|
+
|
|
394
|
+
# For tool calls, we need to handle partial JSON updates
|
|
395
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
396
|
+
vendor_part_id=current_block.id,
|
|
397
|
+
tool_name='',
|
|
398
|
+
args=parsed_args,
|
|
399
|
+
tool_call_id=current_block.id,
|
|
400
|
+
)
|
|
401
|
+
if maybe_event is not None:
|
|
402
|
+
yield maybe_event
|
|
403
|
+
|
|
404
|
+
elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
|
|
405
|
+
current_block = None
|
|
406
|
+
|
|
407
|
+
def timestamp(self) -> datetime:
|
|
408
|
+
return self._timestamp
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from itertools import chain
|
|
6
|
+
from typing import Literal, TypeAlias, Union
|
|
7
|
+
|
|
8
|
+
from cohere import TextAssistantMessageContentItem
|
|
9
|
+
from typing_extensions import assert_never
|
|
10
|
+
|
|
11
|
+
from .. import result
|
|
12
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
|
+
from ..messages import (
|
|
14
|
+
ModelMessage,
|
|
15
|
+
ModelRequest,
|
|
16
|
+
ModelResponse,
|
|
17
|
+
ModelResponsePart,
|
|
18
|
+
RetryPromptPart,
|
|
19
|
+
SystemPromptPart,
|
|
20
|
+
TextPart,
|
|
21
|
+
ToolCallPart,
|
|
22
|
+
ToolReturnPart,
|
|
23
|
+
UserPromptPart,
|
|
24
|
+
)
|
|
25
|
+
from ..settings import ModelSettings
|
|
26
|
+
from ..tools import ToolDefinition
|
|
27
|
+
from . import (
|
|
28
|
+
AgentModel,
|
|
29
|
+
Model,
|
|
30
|
+
check_allow_model_requests,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from cohere import (
|
|
35
|
+
AssistantChatMessageV2,
|
|
36
|
+
AsyncClientV2,
|
|
37
|
+
ChatMessageV2,
|
|
38
|
+
ChatResponse,
|
|
39
|
+
SystemChatMessageV2,
|
|
40
|
+
ToolCallV2,
|
|
41
|
+
ToolCallV2Function,
|
|
42
|
+
ToolChatMessageV2,
|
|
43
|
+
ToolV2,
|
|
44
|
+
ToolV2Function,
|
|
45
|
+
UserChatMessageV2,
|
|
46
|
+
)
|
|
47
|
+
from cohere.v2.client import OMIT
|
|
48
|
+
except ImportError as _import_error:
|
|
49
|
+
raise ImportError(
|
|
50
|
+
'Please install `cohere` to use the Cohere model, '
|
|
51
|
+
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
|
|
52
|
+
) from _import_error
|
|
53
|
+
|
|
54
|
+
CohereModelName: TypeAlias = Union[
|
|
55
|
+
str,
|
|
56
|
+
Literal[
|
|
57
|
+
'c4ai-aya-expanse-32b',
|
|
58
|
+
'c4ai-aya-expanse-8b',
|
|
59
|
+
'command',
|
|
60
|
+
'command-light',
|
|
61
|
+
'command-light-nightly',
|
|
62
|
+
'command-nightly',
|
|
63
|
+
'command-r',
|
|
64
|
+
'command-r-03-2024',
|
|
65
|
+
'command-r-08-2024',
|
|
66
|
+
'command-r-plus',
|
|
67
|
+
'command-r-plus-04-2024',
|
|
68
|
+
'command-r-plus-08-2024',
|
|
69
|
+
'command-r7b-12-2024',
|
|
70
|
+
],
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(init=False)
|
|
75
|
+
class CohereModel(Model):
|
|
76
|
+
"""A model that uses the Cohere API.
|
|
77
|
+
|
|
78
|
+
Internally, this uses the [Cohere Python client](
|
|
79
|
+
https://github.com/cohere-ai/cohere-python) to interact with the API.
|
|
80
|
+
|
|
81
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
model_name: CohereModelName
|
|
85
|
+
client: AsyncClientV2 = field(repr=False)
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
model_name: CohereModelName,
|
|
90
|
+
*,
|
|
91
|
+
api_key: str | None = None,
|
|
92
|
+
cohere_client: AsyncClientV2 | None = None,
|
|
93
|
+
):
|
|
94
|
+
"""Initialize an Cohere model.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
model_name: The name of the Cohere model to use. List of model names
|
|
98
|
+
available [here](https://docs.cohere.com/docs/models#command).
|
|
99
|
+
api_key: The API key to use for authentication, if not provided, the
|
|
100
|
+
`COHERE_API_KEY` environment variable will be used if available.
|
|
101
|
+
cohere_client: An existing Cohere async client to use. If provided,
|
|
102
|
+
`api_key` must be `None`.
|
|
103
|
+
"""
|
|
104
|
+
self.model_name: CohereModelName = model_name
|
|
105
|
+
if cohere_client is not None:
|
|
106
|
+
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
107
|
+
self.client = cohere_client
|
|
108
|
+
else:
|
|
109
|
+
self.client = AsyncClientV2(api_key=api_key) # type: ignore
|
|
110
|
+
|
|
111
|
+
async def agent_model(
|
|
112
|
+
self,
|
|
113
|
+
*,
|
|
114
|
+
function_tools: list[ToolDefinition],
|
|
115
|
+
allow_text_result: bool,
|
|
116
|
+
result_tools: list[ToolDefinition],
|
|
117
|
+
) -> AgentModel:
|
|
118
|
+
check_allow_model_requests()
|
|
119
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
120
|
+
if result_tools:
|
|
121
|
+
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
122
|
+
return CohereAgentModel(
|
|
123
|
+
self.client,
|
|
124
|
+
self.model_name,
|
|
125
|
+
allow_text_result,
|
|
126
|
+
tools,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def name(self) -> str:
|
|
130
|
+
return f'cohere:{self.model_name}'
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def _map_tool_definition(f: ToolDefinition) -> ToolV2:
|
|
134
|
+
return ToolV2(
|
|
135
|
+
type='function',
|
|
136
|
+
function=ToolV2Function(
|
|
137
|
+
name=f.name,
|
|
138
|
+
description=f.description,
|
|
139
|
+
parameters=f.parameters_json_schema,
|
|
140
|
+
),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@dataclass
|
|
145
|
+
class CohereAgentModel(AgentModel):
|
|
146
|
+
"""Implementation of `AgentModel` for Cohere models."""
|
|
147
|
+
|
|
148
|
+
client: AsyncClientV2
|
|
149
|
+
model_name: CohereModelName
|
|
150
|
+
allow_text_result: bool
|
|
151
|
+
tools: list[ToolV2]
|
|
152
|
+
|
|
153
|
+
async def request(
|
|
154
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
155
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
156
|
+
response = await self._chat(messages, model_settings)
|
|
157
|
+
return self._process_response(response), _map_usage(response)
|
|
158
|
+
|
|
159
|
+
async def _chat(
|
|
160
|
+
self,
|
|
161
|
+
messages: list[ModelMessage],
|
|
162
|
+
model_settings: ModelSettings | None,
|
|
163
|
+
) -> ChatResponse:
|
|
164
|
+
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
165
|
+
model_settings = model_settings or {}
|
|
166
|
+
return await self.client.chat(
|
|
167
|
+
model=self.model_name,
|
|
168
|
+
messages=cohere_messages,
|
|
169
|
+
tools=self.tools or OMIT,
|
|
170
|
+
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
171
|
+
temperature=model_settings.get('temperature', OMIT),
|
|
172
|
+
p=model_settings.get('top_p', OMIT),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _process_response(self, response: ChatResponse) -> ModelResponse:
|
|
176
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
177
|
+
parts: list[ModelResponsePart] = []
|
|
178
|
+
if response.message.content is not None and len(response.message.content) > 0:
|
|
179
|
+
# While Cohere's API returns a list, it only does that for future proofing
|
|
180
|
+
# and currently only one item is being returned.
|
|
181
|
+
choice = response.message.content[0]
|
|
182
|
+
parts.append(TextPart(choice.text))
|
|
183
|
+
for c in response.message.tool_calls or []:
|
|
184
|
+
if c.function and c.function.name and c.function.arguments:
|
|
185
|
+
parts.append(
|
|
186
|
+
ToolCallPart.from_raw_args(
|
|
187
|
+
tool_name=c.function.name,
|
|
188
|
+
args=c.function.arguments,
|
|
189
|
+
tool_call_id=c.id,
|
|
190
|
+
)
|
|
191
|
+
)
|
|
192
|
+
return ModelResponse(parts=parts, model_name=self.model_name)
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
|
|
196
|
+
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
|
|
197
|
+
if isinstance(message, ModelRequest):
|
|
198
|
+
yield from cls._map_user_message(message)
|
|
199
|
+
elif isinstance(message, ModelResponse):
|
|
200
|
+
texts: list[str] = []
|
|
201
|
+
tool_calls: list[ToolCallV2] = []
|
|
202
|
+
for item in message.parts:
|
|
203
|
+
if isinstance(item, TextPart):
|
|
204
|
+
texts.append(item.content)
|
|
205
|
+
elif isinstance(item, ToolCallPart):
|
|
206
|
+
tool_calls.append(_map_tool_call(item))
|
|
207
|
+
else:
|
|
208
|
+
assert_never(item)
|
|
209
|
+
message_param = AssistantChatMessageV2(role='assistant')
|
|
210
|
+
if texts:
|
|
211
|
+
message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
|
|
212
|
+
if tool_calls:
|
|
213
|
+
message_param.tool_calls = tool_calls
|
|
214
|
+
yield message_param
|
|
215
|
+
else:
|
|
216
|
+
assert_never(message)
|
|
217
|
+
|
|
218
|
+
@classmethod
|
|
219
|
+
def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
|
|
220
|
+
for part in message.parts:
|
|
221
|
+
if isinstance(part, SystemPromptPart):
|
|
222
|
+
yield SystemChatMessageV2(role='system', content=part.content)
|
|
223
|
+
elif isinstance(part, UserPromptPart):
|
|
224
|
+
yield UserChatMessageV2(role='user', content=part.content)
|
|
225
|
+
elif isinstance(part, ToolReturnPart):
|
|
226
|
+
yield ToolChatMessageV2(
|
|
227
|
+
role='tool',
|
|
228
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
|
|
229
|
+
content=part.model_response_str(),
|
|
230
|
+
)
|
|
231
|
+
elif isinstance(part, RetryPromptPart):
|
|
232
|
+
if part.tool_name is None:
|
|
233
|
+
yield UserChatMessageV2(role='user', content=part.model_response())
|
|
234
|
+
else:
|
|
235
|
+
yield ToolChatMessageV2(
|
|
236
|
+
role='tool',
|
|
237
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
|
|
238
|
+
content=part.model_response(),
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
assert_never(part)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
|
245
|
+
return ToolCallV2(
|
|
246
|
+
id=_guard_tool_call_id(t=t, model_source='Cohere'),
|
|
247
|
+
type='function',
|
|
248
|
+
function=ToolCallV2Function(
|
|
249
|
+
name=t.tool_name,
|
|
250
|
+
arguments=t.args_as_json_str(),
|
|
251
|
+
),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _map_usage(response: ChatResponse) -> result.Usage:
|
|
256
|
+
usage = response.usage
|
|
257
|
+
if usage is None:
|
|
258
|
+
return result.Usage()
|
|
259
|
+
else:
|
|
260
|
+
details: dict[str, int] = {}
|
|
261
|
+
if usage.billed_units is not None:
|
|
262
|
+
if usage.billed_units.input_tokens:
|
|
263
|
+
details['input_tokens'] = int(usage.billed_units.input_tokens)
|
|
264
|
+
if usage.billed_units.output_tokens:
|
|
265
|
+
details['output_tokens'] = int(usage.billed_units.output_tokens)
|
|
266
|
+
if usage.billed_units.search_units:
|
|
267
|
+
details['search_units'] = int(usage.billed_units.search_units)
|
|
268
|
+
if usage.billed_units.classifications:
|
|
269
|
+
details['classifications'] = int(usage.billed_units.classifications)
|
|
270
|
+
|
|
271
|
+
request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
|
|
272
|
+
response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
|
|
273
|
+
return result.Usage(
|
|
274
|
+
request_tokens=request_tokens,
|
|
275
|
+
response_tokens=response_tokens,
|
|
276
|
+
total_tokens=(request_tokens or 0) + (response_tokens or 0),
|
|
277
|
+
details=details,
|
|
278
|
+
)
|