pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +1 -0
- pydantic_ai/_result.py +26 -21
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/agent.py +107 -87
- pydantic_ai/messages.py +3 -10
- pydantic_ai/models/__init__.py +29 -1
- pydantic_ai/models/anthropic.py +94 -30
- pydantic_ai/models/cohere.py +278 -0
- pydantic_ai/models/function.py +12 -8
- pydantic_ai/models/gemini.py +9 -9
- pydantic_ai/models/groq.py +9 -7
- pydantic_ai/models/mistral.py +12 -6
- pydantic_ai/models/ollama.py +3 -0
- pydantic_ai/models/openai.py +27 -13
- pydantic_ai/models/test.py +16 -8
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +45 -26
- pydantic_ai/settings.py +18 -1
- pydantic_ai/tools.py +18 -18
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.20.dist-info}/METADATA +6 -4
- pydantic_ai_slim-0.0.20.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.19.dist-info/RECORD +0 -29
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.20.dist-info}/WHEEL +0 -0
pydantic_ai/models/__init__.py
CHANGED
|
@@ -61,6 +61,7 @@ KnownModelName = Literal[
|
|
|
61
61
|
'mistral:codestral-latest',
|
|
62
62
|
'mistral:mistral-moderation-latest',
|
|
63
63
|
'ollama:codellama',
|
|
64
|
+
'ollama:deepseek-r1',
|
|
64
65
|
'ollama:gemma',
|
|
65
66
|
'ollama:gemma2',
|
|
66
67
|
'ollama:llama3',
|
|
@@ -81,6 +82,22 @@ KnownModelName = Literal[
|
|
|
81
82
|
'anthropic:claude-3-5-haiku-latest',
|
|
82
83
|
'anthropic:claude-3-5-sonnet-latest',
|
|
83
84
|
'anthropic:claude-3-opus-latest',
|
|
85
|
+
'claude-3-5-haiku-latest',
|
|
86
|
+
'claude-3-5-sonnet-latest',
|
|
87
|
+
'claude-3-opus-latest',
|
|
88
|
+
'cohere:c4ai-aya-expanse-32b',
|
|
89
|
+
'cohere:c4ai-aya-expanse-8b',
|
|
90
|
+
'cohere:command',
|
|
91
|
+
'cohere:command-light',
|
|
92
|
+
'cohere:command-light-nightly',
|
|
93
|
+
'cohere:command-nightly',
|
|
94
|
+
'cohere:command-r',
|
|
95
|
+
'cohere:command-r-03-2024',
|
|
96
|
+
'cohere:command-r-08-2024',
|
|
97
|
+
'cohere:command-r-plus',
|
|
98
|
+
'cohere:command-r-plus-04-2024',
|
|
99
|
+
'cohere:command-r-plus-08-2024',
|
|
100
|
+
'cohere:command-r7b-12-2024',
|
|
84
101
|
'test',
|
|
85
102
|
]
|
|
86
103
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -145,6 +162,7 @@ class AgentModel(ABC):
|
|
|
145
162
|
class StreamedResponse(ABC):
|
|
146
163
|
"""Streamed response from an LLM when calling a tool."""
|
|
147
164
|
|
|
165
|
+
_model_name: str
|
|
148
166
|
_usage: Usage = field(default_factory=Usage, init=False)
|
|
149
167
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
150
168
|
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
|
|
@@ -168,7 +186,13 @@ class StreamedResponse(ABC):
|
|
|
168
186
|
|
|
169
187
|
def get(self) -> ModelResponse:
|
|
170
188
|
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
|
|
171
|
-
return ModelResponse(
|
|
189
|
+
return ModelResponse(
|
|
190
|
+
parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp()
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def model_name(self) -> str:
|
|
194
|
+
"""Get the model name of the response."""
|
|
195
|
+
return self._model_name
|
|
172
196
|
|
|
173
197
|
def usage(self) -> Usage:
|
|
174
198
|
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
|
|
@@ -228,6 +252,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
228
252
|
from .test import TestModel
|
|
229
253
|
|
|
230
254
|
return TestModel()
|
|
255
|
+
elif model.startswith('cohere:'):
|
|
256
|
+
from .cohere import CohereModel
|
|
257
|
+
|
|
258
|
+
return CohereModel(model[7:])
|
|
231
259
|
elif model.startswith('openai:'):
|
|
232
260
|
from .openai import OpenAIModel
|
|
233
261
|
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator
|
|
3
|
+
from collections.abc import AsyncIterable, AsyncIterator
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from json import JSONDecodeError, loads as json_loads
|
|
6
8
|
from typing import Any, Literal, Union, cast, overload
|
|
7
9
|
|
|
8
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
9
11
|
from typing_extensions import assert_never
|
|
10
12
|
|
|
11
|
-
from .. import usage
|
|
13
|
+
from .. import UnexpectedModelBehavior, _utils, usage
|
|
12
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
15
|
from ..messages import (
|
|
14
16
|
ArgsDict,
|
|
@@ -16,6 +18,7 @@ from ..messages import (
|
|
|
16
18
|
ModelRequest,
|
|
17
19
|
ModelResponse,
|
|
18
20
|
ModelResponsePart,
|
|
21
|
+
ModelResponseStreamEvent,
|
|
19
22
|
RetryPromptPart,
|
|
20
23
|
SystemPromptPart,
|
|
21
24
|
TextPart,
|
|
@@ -38,11 +41,16 @@ try:
|
|
|
38
41
|
from anthropic.types import (
|
|
39
42
|
Message as AnthropicMessage,
|
|
40
43
|
MessageParam,
|
|
44
|
+
RawContentBlockDeltaEvent,
|
|
45
|
+
RawContentBlockStartEvent,
|
|
46
|
+
RawContentBlockStopEvent,
|
|
41
47
|
RawMessageDeltaEvent,
|
|
42
48
|
RawMessageStartEvent,
|
|
49
|
+
RawMessageStopEvent,
|
|
43
50
|
RawMessageStreamEvent,
|
|
44
51
|
TextBlock,
|
|
45
52
|
TextBlockParam,
|
|
53
|
+
TextDelta,
|
|
46
54
|
ToolChoiceParam,
|
|
47
55
|
ToolParam,
|
|
48
56
|
ToolResultBlockParam,
|
|
@@ -152,7 +160,7 @@ class AnthropicAgentModel(AgentModel):
|
|
|
152
160
|
"""Implementation of `AgentModel` for Anthropic models."""
|
|
153
161
|
|
|
154
162
|
client: AsyncAnthropic
|
|
155
|
-
model_name:
|
|
163
|
+
model_name: AnthropicModelName
|
|
156
164
|
allow_text_result: bool
|
|
157
165
|
tools: list[ToolParam]
|
|
158
166
|
|
|
@@ -186,16 +194,22 @@ class AnthropicAgentModel(AgentModel):
|
|
|
186
194
|
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
187
195
|
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
|
|
188
196
|
# standalone function to make it easier to override
|
|
197
|
+
model_settings = model_settings or {}
|
|
198
|
+
|
|
199
|
+
tool_choice: ToolChoiceParam | None
|
|
200
|
+
|
|
189
201
|
if not self.tools:
|
|
190
|
-
tool_choice
|
|
191
|
-
elif not self.allow_text_result:
|
|
192
|
-
tool_choice = {'type': 'any'}
|
|
202
|
+
tool_choice = None
|
|
193
203
|
else:
|
|
194
|
-
|
|
204
|
+
if not self.allow_text_result:
|
|
205
|
+
tool_choice = {'type': 'any'}
|
|
206
|
+
else:
|
|
207
|
+
tool_choice = {'type': 'auto'}
|
|
195
208
|
|
|
196
|
-
|
|
209
|
+
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
|
|
210
|
+
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
|
|
197
211
|
|
|
198
|
-
|
|
212
|
+
system_prompt, anthropic_messages = self._map_message(messages)
|
|
199
213
|
|
|
200
214
|
return await self.client.messages.create(
|
|
201
215
|
max_tokens=model_settings.get('max_tokens', 1024),
|
|
@@ -210,8 +224,7 @@ class AnthropicAgentModel(AgentModel):
|
|
|
210
224
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
211
225
|
)
|
|
212
226
|
|
|
213
|
-
|
|
214
|
-
def _process_response(response: AnthropicMessage) -> ModelResponse:
|
|
227
|
+
def _process_response(self, response: AnthropicMessage) -> ModelResponse:
|
|
215
228
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
216
229
|
items: list[ModelResponsePart] = []
|
|
217
230
|
for item in response.content:
|
|
@@ -227,26 +240,17 @@ class AnthropicAgentModel(AgentModel):
|
|
|
227
240
|
)
|
|
228
241
|
)
|
|
229
242
|
|
|
230
|
-
return ModelResponse(items)
|
|
243
|
+
return ModelResponse(items, model_name=self.model_name)
|
|
231
244
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
# depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
|
|
242
|
-
# RawMessageStartEvent
|
|
243
|
-
# RawMessageDeltaEvent
|
|
244
|
-
# RawMessageStopEvent
|
|
245
|
-
# RawContentBlockStartEvent
|
|
246
|
-
# RawContentBlockDeltaEvent
|
|
247
|
-
# RawContentBlockDeltaEvent
|
|
248
|
-
#
|
|
249
|
-
# We might refactor streaming internally before we implement this...
|
|
245
|
+
async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
|
|
246
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
247
|
+
first_chunk = await peekable_response.peek()
|
|
248
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
249
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
250
|
+
|
|
251
|
+
# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
|
|
252
|
+
timestamp = datetime.now(tz=timezone.utc)
|
|
253
|
+
return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
|
|
250
254
|
|
|
251
255
|
@staticmethod
|
|
252
256
|
def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
@@ -342,3 +346,63 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
|
|
|
342
346
|
response_tokens=response_usage.output_tokens,
|
|
343
347
|
total_tokens=(request_tokens or 0) + response_usage.output_tokens,
|
|
344
348
|
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@dataclass
|
|
352
|
+
class AnthropicStreamedResponse(StreamedResponse):
|
|
353
|
+
"""Implementation of `StreamedResponse` for Anthropic models."""
|
|
354
|
+
|
|
355
|
+
_response: AsyncIterable[RawMessageStreamEvent]
|
|
356
|
+
_timestamp: datetime
|
|
357
|
+
|
|
358
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
359
|
+
current_block: TextBlock | ToolUseBlock | None = None
|
|
360
|
+
current_json: str = ''
|
|
361
|
+
|
|
362
|
+
async for event in self._response:
|
|
363
|
+
self._usage += _map_usage(event)
|
|
364
|
+
|
|
365
|
+
if isinstance(event, RawContentBlockStartEvent):
|
|
366
|
+
current_block = event.content_block
|
|
367
|
+
if isinstance(current_block, TextBlock) and current_block.text:
|
|
368
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
|
|
369
|
+
elif isinstance(current_block, ToolUseBlock):
|
|
370
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
371
|
+
vendor_part_id=current_block.id,
|
|
372
|
+
tool_name=current_block.name,
|
|
373
|
+
args=cast(dict[str, Any], current_block.input),
|
|
374
|
+
tool_call_id=current_block.id,
|
|
375
|
+
)
|
|
376
|
+
if maybe_event is not None:
|
|
377
|
+
yield maybe_event
|
|
378
|
+
|
|
379
|
+
elif isinstance(event, RawContentBlockDeltaEvent):
|
|
380
|
+
if isinstance(event.delta, TextDelta):
|
|
381
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
|
|
382
|
+
elif (
|
|
383
|
+
current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
|
|
384
|
+
):
|
|
385
|
+
# Try to parse the JSON immediately, otherwise cache the value for later. This handles
|
|
386
|
+
# cases where the JSON is not currently valid but will be valid once we stream more tokens.
|
|
387
|
+
try:
|
|
388
|
+
parsed_args = json_loads(current_json + event.delta.partial_json)
|
|
389
|
+
current_json = ''
|
|
390
|
+
except JSONDecodeError:
|
|
391
|
+
current_json += event.delta.partial_json
|
|
392
|
+
continue
|
|
393
|
+
|
|
394
|
+
# For tool calls, we need to handle partial JSON updates
|
|
395
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
396
|
+
vendor_part_id=current_block.id,
|
|
397
|
+
tool_name='',
|
|
398
|
+
args=parsed_args,
|
|
399
|
+
tool_call_id=current_block.id,
|
|
400
|
+
)
|
|
401
|
+
if maybe_event is not None:
|
|
402
|
+
yield maybe_event
|
|
403
|
+
|
|
404
|
+
elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
|
|
405
|
+
current_block = None
|
|
406
|
+
|
|
407
|
+
def timestamp(self) -> datetime:
|
|
408
|
+
return self._timestamp
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from itertools import chain
|
|
6
|
+
from typing import Literal, TypeAlias, Union
|
|
7
|
+
|
|
8
|
+
from cohere import TextAssistantMessageContentItem
|
|
9
|
+
from typing_extensions import assert_never
|
|
10
|
+
|
|
11
|
+
from .. import result
|
|
12
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
|
+
from ..messages import (
|
|
14
|
+
ModelMessage,
|
|
15
|
+
ModelRequest,
|
|
16
|
+
ModelResponse,
|
|
17
|
+
ModelResponsePart,
|
|
18
|
+
RetryPromptPart,
|
|
19
|
+
SystemPromptPart,
|
|
20
|
+
TextPart,
|
|
21
|
+
ToolCallPart,
|
|
22
|
+
ToolReturnPart,
|
|
23
|
+
UserPromptPart,
|
|
24
|
+
)
|
|
25
|
+
from ..settings import ModelSettings
|
|
26
|
+
from ..tools import ToolDefinition
|
|
27
|
+
from . import (
|
|
28
|
+
AgentModel,
|
|
29
|
+
Model,
|
|
30
|
+
check_allow_model_requests,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from cohere import (
|
|
35
|
+
AssistantChatMessageV2,
|
|
36
|
+
AsyncClientV2,
|
|
37
|
+
ChatMessageV2,
|
|
38
|
+
ChatResponse,
|
|
39
|
+
SystemChatMessageV2,
|
|
40
|
+
ToolCallV2,
|
|
41
|
+
ToolCallV2Function,
|
|
42
|
+
ToolChatMessageV2,
|
|
43
|
+
ToolV2,
|
|
44
|
+
ToolV2Function,
|
|
45
|
+
UserChatMessageV2,
|
|
46
|
+
)
|
|
47
|
+
from cohere.v2.client import OMIT
|
|
48
|
+
except ImportError as _import_error:
|
|
49
|
+
raise ImportError(
|
|
50
|
+
'Please install `cohere` to use the Cohere model, '
|
|
51
|
+
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
|
|
52
|
+
) from _import_error
|
|
53
|
+
|
|
54
|
+
CohereModelName: TypeAlias = Union[
|
|
55
|
+
str,
|
|
56
|
+
Literal[
|
|
57
|
+
'c4ai-aya-expanse-32b',
|
|
58
|
+
'c4ai-aya-expanse-8b',
|
|
59
|
+
'command',
|
|
60
|
+
'command-light',
|
|
61
|
+
'command-light-nightly',
|
|
62
|
+
'command-nightly',
|
|
63
|
+
'command-r',
|
|
64
|
+
'command-r-03-2024',
|
|
65
|
+
'command-r-08-2024',
|
|
66
|
+
'command-r-plus',
|
|
67
|
+
'command-r-plus-04-2024',
|
|
68
|
+
'command-r-plus-08-2024',
|
|
69
|
+
'command-r7b-12-2024',
|
|
70
|
+
],
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(init=False)
|
|
75
|
+
class CohereModel(Model):
|
|
76
|
+
"""A model that uses the Cohere API.
|
|
77
|
+
|
|
78
|
+
Internally, this uses the [Cohere Python client](
|
|
79
|
+
https://github.com/cohere-ai/cohere-python) to interact with the API.
|
|
80
|
+
|
|
81
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
model_name: CohereModelName
|
|
85
|
+
client: AsyncClientV2 = field(repr=False)
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
model_name: CohereModelName,
|
|
90
|
+
*,
|
|
91
|
+
api_key: str | None = None,
|
|
92
|
+
cohere_client: AsyncClientV2 | None = None,
|
|
93
|
+
):
|
|
94
|
+
"""Initialize an Cohere model.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
model_name: The name of the Cohere model to use. List of model names
|
|
98
|
+
available [here](https://docs.cohere.com/docs/models#command).
|
|
99
|
+
api_key: The API key to use for authentication, if not provided, the
|
|
100
|
+
`COHERE_API_KEY` environment variable will be used if available.
|
|
101
|
+
cohere_client: An existing Cohere async client to use. If provided,
|
|
102
|
+
`api_key` must be `None`.
|
|
103
|
+
"""
|
|
104
|
+
self.model_name: CohereModelName = model_name
|
|
105
|
+
if cohere_client is not None:
|
|
106
|
+
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
107
|
+
self.client = cohere_client
|
|
108
|
+
else:
|
|
109
|
+
self.client = AsyncClientV2(api_key=api_key) # type: ignore
|
|
110
|
+
|
|
111
|
+
async def agent_model(
|
|
112
|
+
self,
|
|
113
|
+
*,
|
|
114
|
+
function_tools: list[ToolDefinition],
|
|
115
|
+
allow_text_result: bool,
|
|
116
|
+
result_tools: list[ToolDefinition],
|
|
117
|
+
) -> AgentModel:
|
|
118
|
+
check_allow_model_requests()
|
|
119
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
120
|
+
if result_tools:
|
|
121
|
+
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
122
|
+
return CohereAgentModel(
|
|
123
|
+
self.client,
|
|
124
|
+
self.model_name,
|
|
125
|
+
allow_text_result,
|
|
126
|
+
tools,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def name(self) -> str:
|
|
130
|
+
return f'cohere:{self.model_name}'
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def _map_tool_definition(f: ToolDefinition) -> ToolV2:
|
|
134
|
+
return ToolV2(
|
|
135
|
+
type='function',
|
|
136
|
+
function=ToolV2Function(
|
|
137
|
+
name=f.name,
|
|
138
|
+
description=f.description,
|
|
139
|
+
parameters=f.parameters_json_schema,
|
|
140
|
+
),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@dataclass
|
|
145
|
+
class CohereAgentModel(AgentModel):
|
|
146
|
+
"""Implementation of `AgentModel` for Cohere models."""
|
|
147
|
+
|
|
148
|
+
client: AsyncClientV2
|
|
149
|
+
model_name: CohereModelName
|
|
150
|
+
allow_text_result: bool
|
|
151
|
+
tools: list[ToolV2]
|
|
152
|
+
|
|
153
|
+
async def request(
|
|
154
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
155
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
156
|
+
response = await self._chat(messages, model_settings)
|
|
157
|
+
return self._process_response(response), _map_usage(response)
|
|
158
|
+
|
|
159
|
+
async def _chat(
|
|
160
|
+
self,
|
|
161
|
+
messages: list[ModelMessage],
|
|
162
|
+
model_settings: ModelSettings | None,
|
|
163
|
+
) -> ChatResponse:
|
|
164
|
+
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
165
|
+
model_settings = model_settings or {}
|
|
166
|
+
return await self.client.chat(
|
|
167
|
+
model=self.model_name,
|
|
168
|
+
messages=cohere_messages,
|
|
169
|
+
tools=self.tools or OMIT,
|
|
170
|
+
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
171
|
+
temperature=model_settings.get('temperature', OMIT),
|
|
172
|
+
p=model_settings.get('top_p', OMIT),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _process_response(self, response: ChatResponse) -> ModelResponse:
|
|
176
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
177
|
+
parts: list[ModelResponsePart] = []
|
|
178
|
+
if response.message.content is not None and len(response.message.content) > 0:
|
|
179
|
+
# While Cohere's API returns a list, it only does that for future proofing
|
|
180
|
+
# and currently only one item is being returned.
|
|
181
|
+
choice = response.message.content[0]
|
|
182
|
+
parts.append(TextPart(choice.text))
|
|
183
|
+
for c in response.message.tool_calls or []:
|
|
184
|
+
if c.function and c.function.name and c.function.arguments:
|
|
185
|
+
parts.append(
|
|
186
|
+
ToolCallPart.from_raw_args(
|
|
187
|
+
tool_name=c.function.name,
|
|
188
|
+
args=c.function.arguments,
|
|
189
|
+
tool_call_id=c.id,
|
|
190
|
+
)
|
|
191
|
+
)
|
|
192
|
+
return ModelResponse(parts=parts, model_name=self.model_name)
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
|
|
196
|
+
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
|
|
197
|
+
if isinstance(message, ModelRequest):
|
|
198
|
+
yield from cls._map_user_message(message)
|
|
199
|
+
elif isinstance(message, ModelResponse):
|
|
200
|
+
texts: list[str] = []
|
|
201
|
+
tool_calls: list[ToolCallV2] = []
|
|
202
|
+
for item in message.parts:
|
|
203
|
+
if isinstance(item, TextPart):
|
|
204
|
+
texts.append(item.content)
|
|
205
|
+
elif isinstance(item, ToolCallPart):
|
|
206
|
+
tool_calls.append(_map_tool_call(item))
|
|
207
|
+
else:
|
|
208
|
+
assert_never(item)
|
|
209
|
+
message_param = AssistantChatMessageV2(role='assistant')
|
|
210
|
+
if texts:
|
|
211
|
+
message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
|
|
212
|
+
if tool_calls:
|
|
213
|
+
message_param.tool_calls = tool_calls
|
|
214
|
+
yield message_param
|
|
215
|
+
else:
|
|
216
|
+
assert_never(message)
|
|
217
|
+
|
|
218
|
+
@classmethod
|
|
219
|
+
def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
|
|
220
|
+
for part in message.parts:
|
|
221
|
+
if isinstance(part, SystemPromptPart):
|
|
222
|
+
yield SystemChatMessageV2(role='system', content=part.content)
|
|
223
|
+
elif isinstance(part, UserPromptPart):
|
|
224
|
+
yield UserChatMessageV2(role='user', content=part.content)
|
|
225
|
+
elif isinstance(part, ToolReturnPart):
|
|
226
|
+
yield ToolChatMessageV2(
|
|
227
|
+
role='tool',
|
|
228
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
|
|
229
|
+
content=part.model_response_str(),
|
|
230
|
+
)
|
|
231
|
+
elif isinstance(part, RetryPromptPart):
|
|
232
|
+
if part.tool_name is None:
|
|
233
|
+
yield UserChatMessageV2(role='user', content=part.model_response())
|
|
234
|
+
else:
|
|
235
|
+
yield ToolChatMessageV2(
|
|
236
|
+
role='tool',
|
|
237
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
|
|
238
|
+
content=part.model_response(),
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
assert_never(part)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
|
245
|
+
return ToolCallV2(
|
|
246
|
+
id=_guard_tool_call_id(t=t, model_source='Cohere'),
|
|
247
|
+
type='function',
|
|
248
|
+
function=ToolCallV2Function(
|
|
249
|
+
name=t.tool_name,
|
|
250
|
+
arguments=t.args_as_json_str(),
|
|
251
|
+
),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _map_usage(response: ChatResponse) -> result.Usage:
|
|
256
|
+
usage = response.usage
|
|
257
|
+
if usage is None:
|
|
258
|
+
return result.Usage()
|
|
259
|
+
else:
|
|
260
|
+
details: dict[str, int] = {}
|
|
261
|
+
if usage.billed_units is not None:
|
|
262
|
+
if usage.billed_units.input_tokens:
|
|
263
|
+
details['input_tokens'] = int(usage.billed_units.input_tokens)
|
|
264
|
+
if usage.billed_units.output_tokens:
|
|
265
|
+
details['output_tokens'] = int(usage.billed_units.output_tokens)
|
|
266
|
+
if usage.billed_units.search_units:
|
|
267
|
+
details['search_units'] = int(usage.billed_units.search_units)
|
|
268
|
+
if usage.billed_units.classifications:
|
|
269
|
+
details['classifications'] = int(usage.billed_units.classifications)
|
|
270
|
+
|
|
271
|
+
request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
|
|
272
|
+
response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
|
|
273
|
+
return result.Usage(
|
|
274
|
+
request_tokens=request_tokens,
|
|
275
|
+
response_tokens=response_tokens,
|
|
276
|
+
total_tokens=(request_tokens or 0) + (response_tokens or 0),
|
|
277
|
+
details=details,
|
|
278
|
+
)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -71,16 +71,15 @@ class FunctionModel(Model):
|
|
|
71
71
|
result_tools: list[ToolDefinition],
|
|
72
72
|
) -> AgentModel:
|
|
73
73
|
return FunctionAgentModel(
|
|
74
|
-
self.function,
|
|
74
|
+
self.function,
|
|
75
|
+
self.stream_function,
|
|
76
|
+
AgentInfo(function_tools, allow_text_result, result_tools, None),
|
|
75
77
|
)
|
|
76
78
|
|
|
77
79
|
def name(self) -> str:
|
|
78
|
-
|
|
79
|
-
if self.
|
|
80
|
-
|
|
81
|
-
if self.stream_function is not None:
|
|
82
|
-
labels.append(f'stream-{self.stream_function.__name__}')
|
|
83
|
-
return f'function:{",".join(labels)}'
|
|
80
|
+
function_name = self.function.__name__ if self.function is not None else ''
|
|
81
|
+
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
82
|
+
return f'function:{function_name}:{stream_function_name}'
|
|
84
83
|
|
|
85
84
|
|
|
86
85
|
@dataclass(frozen=True)
|
|
@@ -147,12 +146,15 @@ class FunctionAgentModel(AgentModel):
|
|
|
147
146
|
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
148
147
|
|
|
149
148
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
149
|
+
model_name = f'function:{self.function.__name__}'
|
|
150
|
+
|
|
150
151
|
if inspect.iscoroutinefunction(self.function):
|
|
151
152
|
response = await self.function(messages, agent_info)
|
|
152
153
|
else:
|
|
153
154
|
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
154
155
|
assert isinstance(response_, ModelResponse), response_
|
|
155
156
|
response = response_
|
|
157
|
+
response.model_name = model_name
|
|
156
158
|
# TODO is `messages` right here? Should it just be new messages?
|
|
157
159
|
return response, _estimate_usage(chain(messages, [response]))
|
|
158
160
|
|
|
@@ -163,13 +165,15 @@ class FunctionAgentModel(AgentModel):
|
|
|
163
165
|
assert (
|
|
164
166
|
self.stream_function is not None
|
|
165
167
|
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
168
|
+
model_name = f'function:{self.stream_function.__name__}'
|
|
169
|
+
|
|
166
170
|
response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
|
|
167
171
|
|
|
168
172
|
first = await response_stream.peek()
|
|
169
173
|
if isinstance(first, _utils.Unset):
|
|
170
174
|
raise ValueError('Stream function must return at least one item')
|
|
171
175
|
|
|
172
|
-
yield FunctionStreamedResponse(response_stream)
|
|
176
|
+
yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
|
|
173
177
|
|
|
174
178
|
|
|
175
179
|
@dataclass
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -99,6 +99,7 @@ class GeminiModel(Model):
|
|
|
99
99
|
allow_text_result: bool,
|
|
100
100
|
result_tools: list[ToolDefinition],
|
|
101
101
|
) -> GeminiAgentModel:
|
|
102
|
+
check_allow_model_requests()
|
|
102
103
|
return GeminiAgentModel(
|
|
103
104
|
http_client=self.http_client,
|
|
104
105
|
model_name=self.model_name,
|
|
@@ -151,7 +152,6 @@ class GeminiAgentModel(AgentModel):
|
|
|
151
152
|
allow_text_result: bool,
|
|
152
153
|
result_tools: list[ToolDefinition],
|
|
153
154
|
):
|
|
154
|
-
check_allow_model_requests()
|
|
155
155
|
tools = [_function_from_abstract_tool(t) for t in function_tools]
|
|
156
156
|
if result_tools:
|
|
157
157
|
tools += [_function_from_abstract_tool(t) for t in result_tools]
|
|
@@ -229,15 +229,13 @@ class GeminiAgentModel(AgentModel):
|
|
|
229
229
|
raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
|
|
230
230
|
yield r
|
|
231
231
|
|
|
232
|
-
|
|
233
|
-
def _process_response(response: _GeminiResponse) -> ModelResponse:
|
|
232
|
+
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
|
|
234
233
|
if len(response['candidates']) != 1:
|
|
235
234
|
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
236
235
|
parts = response['candidates'][0]['content']['parts']
|
|
237
|
-
return _process_response_from_parts(parts)
|
|
236
|
+
return _process_response_from_parts(parts, model_name=self.model_name)
|
|
238
237
|
|
|
239
|
-
|
|
240
|
-
async def _process_streamed_response(http_response: HTTPResponse) -> StreamedResponse:
|
|
238
|
+
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
|
|
241
239
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
242
240
|
aiter_bytes = http_response.aiter_bytes()
|
|
243
241
|
start_response: _GeminiResponse | None = None
|
|
@@ -258,7 +256,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
258
256
|
if start_response is None:
|
|
259
257
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
260
258
|
|
|
261
|
-
return GeminiStreamedResponse(_content=content, _stream=aiter_bytes)
|
|
259
|
+
return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
|
|
262
260
|
|
|
263
261
|
@classmethod
|
|
264
262
|
def _message_to_gemini_content(
|
|
@@ -432,7 +430,9 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
|
|
|
432
430
|
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
|
|
433
431
|
|
|
434
432
|
|
|
435
|
-
def _process_response_from_parts(
|
|
433
|
+
def _process_response_from_parts(
|
|
434
|
+
parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
|
|
435
|
+
) -> ModelResponse:
|
|
436
436
|
items: list[ModelResponsePart] = []
|
|
437
437
|
for part in parts:
|
|
438
438
|
if 'text' in part:
|
|
@@ -448,7 +448,7 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
|
|
|
448
448
|
raise exceptions.UnexpectedModelBehavior(
|
|
449
449
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
450
450
|
)
|
|
451
|
-
return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
|
|
451
|
+
return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
|
|
452
452
|
|
|
453
453
|
|
|
454
454
|
class _GeminiFunctionCall(TypedDict):
|