pydantic-ai-slim 0.0.54__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +10 -3
- pydantic_ai/_agent_graph.py +67 -55
- pydantic_ai/_cli.py +1 -2
- pydantic_ai/{_result.py → _output.py} +69 -47
- pydantic_ai/_utils.py +20 -0
- pydantic_ai/agent.py +503 -163
- pydantic_ai/format_as_xml.py +6 -113
- pydantic_ai/format_prompt.py +116 -0
- pydantic_ai/messages.py +104 -21
- pydantic_ai/models/__init__.py +25 -5
- pydantic_ai/models/_json_schema.py +156 -0
- pydantic_ai/models/anthropic.py +14 -4
- pydantic_ai/models/bedrock.py +100 -22
- pydantic_ai/models/cohere.py +48 -44
- pydantic_ai/models/fallback.py +2 -1
- pydantic_ai/models/function.py +8 -8
- pydantic_ai/models/gemini.py +65 -75
- pydantic_ai/models/groq.py +34 -29
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mistral.py +67 -58
- pydantic_ai/models/openai.py +113 -158
- pydantic_ai/models/test.py +45 -46
- pydantic_ai/models/wrapper.py +3 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/azure.py +2 -2
- pydantic_ai/result.py +203 -90
- pydantic_ai/tools.py +3 -3
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/METADATA +5 -5
- pydantic_ai_slim-0.1.0.dist-info/RECORD +53 -0
- pydantic_ai_slim-0.0.54.dist-info/RECORD +0 -51
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/entry_points.txt +0 -0
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -31,7 +31,14 @@ from ..messages import (
|
|
|
31
31
|
from ..providers import Provider, infer_provider
|
|
32
32
|
from ..settings import ModelSettings
|
|
33
33
|
from ..tools import ToolDefinition
|
|
34
|
-
from . import
|
|
34
|
+
from . import (
|
|
35
|
+
Model,
|
|
36
|
+
ModelRequestParameters,
|
|
37
|
+
StreamedResponse,
|
|
38
|
+
cached_async_http_client,
|
|
39
|
+
check_allow_model_requests,
|
|
40
|
+
get_user_agent,
|
|
41
|
+
)
|
|
35
42
|
|
|
36
43
|
try:
|
|
37
44
|
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
|
|
@@ -207,7 +214,7 @@ class AnthropicModel(Model):
|
|
|
207
214
|
if not tools:
|
|
208
215
|
tool_choice = None
|
|
209
216
|
else:
|
|
210
|
-
if not model_request_parameters.
|
|
217
|
+
if not model_request_parameters.allow_text_output:
|
|
211
218
|
tool_choice = {'type': 'any'}
|
|
212
219
|
else:
|
|
213
220
|
tool_choice = {'type': 'auto'}
|
|
@@ -231,6 +238,7 @@ class AnthropicModel(Model):
|
|
|
231
238
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
232
239
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
233
240
|
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
|
|
241
|
+
extra_headers={'User-Agent': get_user_agent()},
|
|
234
242
|
)
|
|
235
243
|
except APIStatusError as e:
|
|
236
244
|
if (status_code := e.status_code) >= 400:
|
|
@@ -269,8 +277,8 @@ class AnthropicModel(Model):
|
|
|
269
277
|
|
|
270
278
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolParam]:
|
|
271
279
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
272
|
-
if model_request_parameters.
|
|
273
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.
|
|
280
|
+
if model_request_parameters.output_tools:
|
|
281
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
274
282
|
return tools
|
|
275
283
|
|
|
276
284
|
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
@@ -324,6 +332,8 @@ class AnthropicModel(Model):
|
|
|
324
332
|
anthropic_messages.append(MessageParam(role='assistant', content=assistant_content_params))
|
|
325
333
|
else:
|
|
326
334
|
assert_never(m)
|
|
335
|
+
if instructions := self._get_instructions(messages):
|
|
336
|
+
system_prompt = f'{instructions}\n\n{system_prompt}'
|
|
327
337
|
return system_prompt, anthropic_messages
|
|
328
338
|
|
|
329
339
|
@staticmethod
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -2,17 +2,17 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
import typing
|
|
5
|
-
from collections.abc import AsyncIterator, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable, Mapping
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
9
|
-
from typing import TYPE_CHECKING, Generic, Literal, Union, cast, overload
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast, overload
|
|
10
10
|
|
|
11
11
|
import anyio
|
|
12
12
|
import anyio.to_thread
|
|
13
13
|
from typing_extensions import ParamSpec, assert_never
|
|
14
14
|
|
|
15
|
-
from pydantic_ai import _utils,
|
|
15
|
+
from pydantic_ai import _utils, usage
|
|
16
16
|
from pydantic_ai.messages import (
|
|
17
17
|
AudioUrl,
|
|
18
18
|
BinaryContent,
|
|
@@ -29,6 +29,7 @@ from pydantic_ai.messages import (
|
|
|
29
29
|
ToolCallPart,
|
|
30
30
|
ToolReturnPart,
|
|
31
31
|
UserPromptPart,
|
|
32
|
+
VideoUrl,
|
|
32
33
|
)
|
|
33
34
|
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client
|
|
34
35
|
from pydantic_ai.providers import Provider, infer_provider
|
|
@@ -46,12 +47,16 @@ if TYPE_CHECKING:
|
|
|
46
47
|
ConverseResponseTypeDef,
|
|
47
48
|
ConverseStreamMetadataEventTypeDef,
|
|
48
49
|
ConverseStreamOutputTypeDef,
|
|
50
|
+
GuardrailConfigurationTypeDef,
|
|
49
51
|
ImageBlockTypeDef,
|
|
50
52
|
InferenceConfigurationTypeDef,
|
|
51
53
|
MessageUnionTypeDef,
|
|
54
|
+
PerformanceConfigurationTypeDef,
|
|
55
|
+
PromptVariableValuesTypeDef,
|
|
52
56
|
SystemContentBlockTypeDef,
|
|
53
57
|
ToolChoiceTypeDef,
|
|
54
58
|
ToolTypeDef,
|
|
59
|
+
VideoBlockTypeDef,
|
|
55
60
|
)
|
|
56
61
|
|
|
57
62
|
|
|
@@ -114,10 +119,49 @@ P = ParamSpec('P')
|
|
|
114
119
|
T = typing.TypeVar('T')
|
|
115
120
|
|
|
116
121
|
|
|
117
|
-
class BedrockModelSettings(ModelSettings):
|
|
122
|
+
class BedrockModelSettings(ModelSettings, total=False):
|
|
118
123
|
"""Settings for Bedrock models.
|
|
119
124
|
|
|
120
125
|
ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
126
|
+
|
|
127
|
+
See [the Bedrock Converse API docs](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax) for a full list.
|
|
128
|
+
See [the boto3 implementation](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html) of the Bedrock Converse API.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
bedrock_guardrail_config: GuardrailConfigurationTypeDef
|
|
132
|
+
"""Content moderation and safety settings for Bedrock API requests.
|
|
133
|
+
|
|
134
|
+
See more about it on <https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_GuardrailConfiguration.html>.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
bedrock_performance_configuration: PerformanceConfigurationTypeDef
|
|
138
|
+
"""Performance optimization settings for model inference.
|
|
139
|
+
|
|
140
|
+
See more about it on <https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_PerformanceConfiguration.html>.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
bedrock_request_metadata: dict[str, str]
|
|
144
|
+
"""Additional metadata to attach to Bedrock API requests.
|
|
145
|
+
|
|
146
|
+
See more about it on <https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax>.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
bedrock_additional_model_response_fields_paths: list[str]
|
|
150
|
+
"""JSON paths to extract additional fields from model responses.
|
|
151
|
+
|
|
152
|
+
See more about it on <https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
bedrock_prompt_variables: Mapping[str, PromptVariableValuesTypeDef]
|
|
156
|
+
"""Variables for substitution into prompt templates.
|
|
157
|
+
|
|
158
|
+
See more about it on <https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_PromptVariableValues.html>.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
bedrock_additional_model_requests_fields: Mapping[str, Any]
|
|
162
|
+
"""Additional model-specific parameters to include in requests.
|
|
163
|
+
|
|
164
|
+
See more about it on <https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>.
|
|
121
165
|
"""
|
|
122
166
|
|
|
123
167
|
|
|
@@ -164,8 +208,8 @@ class BedrockConverseModel(Model):
|
|
|
164
208
|
|
|
165
209
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
|
|
166
210
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
167
|
-
if model_request_parameters.
|
|
168
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.
|
|
211
|
+
if model_request_parameters.output_tools:
|
|
212
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
169
213
|
return tools
|
|
170
214
|
|
|
171
215
|
@staticmethod
|
|
@@ -187,8 +231,9 @@ class BedrockConverseModel(Model):
|
|
|
187
231
|
messages: list[ModelMessage],
|
|
188
232
|
model_settings: ModelSettings | None,
|
|
189
233
|
model_request_parameters: ModelRequestParameters,
|
|
190
|
-
) -> tuple[ModelResponse,
|
|
191
|
-
|
|
234
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
235
|
+
settings = cast(BedrockModelSettings, model_settings or {})
|
|
236
|
+
response = await self._messages_create(messages, False, settings, model_request_parameters)
|
|
192
237
|
return await self._process_response(response)
|
|
193
238
|
|
|
194
239
|
@asynccontextmanager
|
|
@@ -198,10 +243,11 @@ class BedrockConverseModel(Model):
|
|
|
198
243
|
model_settings: ModelSettings | None,
|
|
199
244
|
model_request_parameters: ModelRequestParameters,
|
|
200
245
|
) -> AsyncIterator[StreamedResponse]:
|
|
201
|
-
|
|
246
|
+
settings = cast(BedrockModelSettings, model_settings or {})
|
|
247
|
+
response = await self._messages_create(messages, True, settings, model_request_parameters)
|
|
202
248
|
yield BedrockStreamedResponse(_model_name=self.model_name, _event_stream=response)
|
|
203
249
|
|
|
204
|
-
async def _process_response(self, response: ConverseResponseTypeDef) -> tuple[ModelResponse,
|
|
250
|
+
async def _process_response(self, response: ConverseResponseTypeDef) -> tuple[ModelResponse, usage.Usage]:
|
|
205
251
|
items: list[ModelResponsePart] = []
|
|
206
252
|
if message := response['output'].get('message'):
|
|
207
253
|
for item in message['content']:
|
|
@@ -217,19 +263,19 @@ class BedrockConverseModel(Model):
|
|
|
217
263
|
tool_call_id=tool_use['toolUseId'],
|
|
218
264
|
),
|
|
219
265
|
)
|
|
220
|
-
|
|
266
|
+
u = usage.Usage(
|
|
221
267
|
request_tokens=response['usage']['inputTokens'],
|
|
222
268
|
response_tokens=response['usage']['outputTokens'],
|
|
223
269
|
total_tokens=response['usage']['totalTokens'],
|
|
224
270
|
)
|
|
225
|
-
return ModelResponse(items, model_name=self.model_name),
|
|
271
|
+
return ModelResponse(items, model_name=self.model_name), u
|
|
226
272
|
|
|
227
273
|
@overload
|
|
228
274
|
async def _messages_create(
|
|
229
275
|
self,
|
|
230
276
|
messages: list[ModelMessage],
|
|
231
277
|
stream: Literal[True],
|
|
232
|
-
model_settings:
|
|
278
|
+
model_settings: BedrockModelSettings | None,
|
|
233
279
|
model_request_parameters: ModelRequestParameters,
|
|
234
280
|
) -> EventStream[ConverseStreamOutputTypeDef]:
|
|
235
281
|
pass
|
|
@@ -239,7 +285,7 @@ class BedrockConverseModel(Model):
|
|
|
239
285
|
self,
|
|
240
286
|
messages: list[ModelMessage],
|
|
241
287
|
stream: Literal[False],
|
|
242
|
-
model_settings:
|
|
288
|
+
model_settings: BedrockModelSettings | None,
|
|
243
289
|
model_request_parameters: ModelRequestParameters,
|
|
244
290
|
) -> ConverseResponseTypeDef:
|
|
245
291
|
pass
|
|
@@ -248,14 +294,14 @@ class BedrockConverseModel(Model):
|
|
|
248
294
|
self,
|
|
249
295
|
messages: list[ModelMessage],
|
|
250
296
|
stream: bool,
|
|
251
|
-
model_settings:
|
|
297
|
+
model_settings: BedrockModelSettings | None,
|
|
252
298
|
model_request_parameters: ModelRequestParameters,
|
|
253
299
|
) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
|
|
254
300
|
tools = self._get_tools(model_request_parameters)
|
|
255
301
|
support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
|
|
256
302
|
if not tools or not support_tools_choice:
|
|
257
303
|
tool_choice: ToolChoiceTypeDef = {}
|
|
258
|
-
elif not model_request_parameters.
|
|
304
|
+
elif not model_request_parameters.allow_text_output:
|
|
259
305
|
tool_choice = {'any': {}}
|
|
260
306
|
else:
|
|
261
307
|
tool_choice = {'auto': {}}
|
|
@@ -269,6 +315,24 @@ class BedrockConverseModel(Model):
|
|
|
269
315
|
'system': system_prompt,
|
|
270
316
|
'inferenceConfig': inference_config,
|
|
271
317
|
}
|
|
318
|
+
|
|
319
|
+
# Bedrock supports a set of specific extra parameters
|
|
320
|
+
if model_settings:
|
|
321
|
+
if guardrail_config := model_settings.get('bedrock_guardrail_config', None):
|
|
322
|
+
params['guardrailConfig'] = guardrail_config
|
|
323
|
+
if performance_configuration := model_settings.get('bedrock_performance_configuration', None):
|
|
324
|
+
params['performanceConfig'] = performance_configuration
|
|
325
|
+
if request_metadata := model_settings.get('bedrock_request_metadata', None):
|
|
326
|
+
params['requestMetadata'] = request_metadata
|
|
327
|
+
if additional_model_response_fields_paths := model_settings.get(
|
|
328
|
+
'bedrock_additional_model_response_fields_paths', None
|
|
329
|
+
):
|
|
330
|
+
params['additionalModelResponseFieldPaths'] = additional_model_response_fields_paths
|
|
331
|
+
if additional_model_requests_fields := model_settings.get('bedrock_additional_model_requests_fields', None):
|
|
332
|
+
params['additionalModelRequestFields'] = additional_model_requests_fields
|
|
333
|
+
if prompt_variables := model_settings.get('bedrock_prompt_variables', None):
|
|
334
|
+
params['promptVariables'] = prompt_variables
|
|
335
|
+
|
|
272
336
|
if tools:
|
|
273
337
|
params['toolConfig'] = {'tools': tools}
|
|
274
338
|
if tool_choice:
|
|
@@ -359,6 +423,10 @@ class BedrockConverseModel(Model):
|
|
|
359
423
|
bedrock_messages.append({'role': 'assistant', 'content': content})
|
|
360
424
|
else:
|
|
361
425
|
assert_never(m)
|
|
426
|
+
|
|
427
|
+
if instructions := self._get_instructions(messages):
|
|
428
|
+
system_prompt.insert(0, {'text': instructions})
|
|
429
|
+
|
|
362
430
|
return system_prompt, bedrock_messages
|
|
363
431
|
|
|
364
432
|
@staticmethod
|
|
@@ -381,9 +449,12 @@ class BedrockConverseModel(Model):
|
|
|
381
449
|
elif item.is_image:
|
|
382
450
|
assert format in ('jpeg', 'png', 'gif', 'webp')
|
|
383
451
|
content.append({'image': {'format': format, 'source': {'bytes': item.data}}})
|
|
452
|
+
elif item.is_video:
|
|
453
|
+
assert format in ('mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp')
|
|
454
|
+
content.append({'video': {'format': format, 'source': {'bytes': item.data}}})
|
|
384
455
|
else:
|
|
385
456
|
raise NotImplementedError('Binary content is not supported yet.')
|
|
386
|
-
elif isinstance(item, (ImageUrl, DocumentUrl)):
|
|
457
|
+
elif isinstance(item, (ImageUrl, DocumentUrl, VideoUrl)):
|
|
387
458
|
response = await cached_async_http_client().get(item.url)
|
|
388
459
|
response.raise_for_status()
|
|
389
460
|
if item.kind == 'image-url':
|
|
@@ -391,11 +462,20 @@ class BedrockConverseModel(Model):
|
|
|
391
462
|
assert format in ('jpeg', 'png', 'gif', 'webp'), f'Unsupported image format: {format}'
|
|
392
463
|
image: ImageBlockTypeDef = {'format': format, 'source': {'bytes': response.content}}
|
|
393
464
|
content.append({'image': image})
|
|
465
|
+
|
|
394
466
|
elif item.kind == 'document-url':
|
|
395
467
|
document_count += 1
|
|
396
468
|
name = f'Document {document_count}'
|
|
397
469
|
data = response.content
|
|
398
470
|
content.append({'document': {'name': name, 'format': item.format, 'source': {'bytes': data}}})
|
|
471
|
+
|
|
472
|
+
elif item.kind == 'video-url':
|
|
473
|
+
format = item.media_type.split('/')[1]
|
|
474
|
+
assert format in ('mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp'), (
|
|
475
|
+
f'Unsupported video format: {format}'
|
|
476
|
+
)
|
|
477
|
+
video: VideoBlockTypeDef = {'format': format, 'source': {'bytes': response.content}}
|
|
478
|
+
content.append({'video': video})
|
|
399
479
|
elif isinstance(item, AudioUrl): # pragma: no cover
|
|
400
480
|
raise NotImplementedError('Audio is not supported yet.')
|
|
401
481
|
else:
|
|
@@ -475,8 +555,8 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
475
555
|
"""Get the model name of the response."""
|
|
476
556
|
return self._model_name
|
|
477
557
|
|
|
478
|
-
def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) ->
|
|
479
|
-
return
|
|
558
|
+
def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.Usage:
|
|
559
|
+
return usage.Usage(
|
|
480
560
|
request_tokens=metadata['usage']['inputTokens'],
|
|
481
561
|
response_tokens=metadata['usage']['outputTokens'],
|
|
482
562
|
total_tokens=metadata['usage']['totalTokens'],
|
|
@@ -494,9 +574,7 @@ class _AsyncIteratorWrapper(Generic[T]):
|
|
|
494
574
|
|
|
495
575
|
async def __anext__(self) -> T:
|
|
496
576
|
try:
|
|
497
|
-
|
|
498
|
-
item = await anyio.to_thread.run_sync(next, self.sync_iterator)
|
|
499
|
-
return item
|
|
577
|
+
return await anyio.to_thread.run_sync(next, self.sync_iterator)
|
|
500
578
|
except RuntimeError as e:
|
|
501
579
|
if type(e.__cause__) is StopIteration:
|
|
502
580
|
raise StopAsyncIteration
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -2,12 +2,11 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from itertools import chain
|
|
6
5
|
from typing import Literal, Union, cast
|
|
7
6
|
|
|
8
7
|
from typing_extensions import assert_never
|
|
9
8
|
|
|
10
|
-
from .. import ModelHTTPError,
|
|
9
|
+
from .. import ModelHTTPError, usage
|
|
11
10
|
from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
|
|
12
11
|
from ..messages import (
|
|
13
12
|
ModelMessage,
|
|
@@ -134,7 +133,7 @@ class CohereModel(Model):
|
|
|
134
133
|
messages: list[ModelMessage],
|
|
135
134
|
model_settings: ModelSettings | None,
|
|
136
135
|
model_request_parameters: ModelRequestParameters,
|
|
137
|
-
) -> tuple[ModelResponse,
|
|
136
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
138
137
|
check_allow_model_requests()
|
|
139
138
|
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
|
|
140
139
|
return self._process_response(response), _map_usage(response)
|
|
@@ -156,7 +155,7 @@ class CohereModel(Model):
|
|
|
156
155
|
model_request_parameters: ModelRequestParameters,
|
|
157
156
|
) -> ChatResponse:
|
|
158
157
|
tools = self._get_tools(model_request_parameters)
|
|
159
|
-
cohere_messages =
|
|
158
|
+
cohere_messages = self._map_messages(messages)
|
|
160
159
|
try:
|
|
161
160
|
return await self.client.chat(
|
|
162
161
|
model=self._model_name,
|
|
@@ -194,33 +193,38 @@ class CohereModel(Model):
|
|
|
194
193
|
)
|
|
195
194
|
return ModelResponse(parts=parts, model_name=self._model_name)
|
|
196
195
|
|
|
197
|
-
def
|
|
196
|
+
def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
|
|
198
197
|
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
message_param
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
198
|
+
cohere_messages: list[ChatMessageV2] = []
|
|
199
|
+
for message in messages:
|
|
200
|
+
if isinstance(message, ModelRequest):
|
|
201
|
+
cohere_messages.extend(self._map_user_message(message))
|
|
202
|
+
elif isinstance(message, ModelResponse):
|
|
203
|
+
texts: list[str] = []
|
|
204
|
+
tool_calls: list[ToolCallV2] = []
|
|
205
|
+
for item in message.parts:
|
|
206
|
+
if isinstance(item, TextPart):
|
|
207
|
+
texts.append(item.content)
|
|
208
|
+
elif isinstance(item, ToolCallPart):
|
|
209
|
+
tool_calls.append(self._map_tool_call(item))
|
|
210
|
+
else:
|
|
211
|
+
assert_never(item)
|
|
212
|
+
message_param = AssistantChatMessageV2(role='assistant')
|
|
213
|
+
if texts:
|
|
214
|
+
message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
|
|
215
|
+
if tool_calls:
|
|
216
|
+
message_param.tool_calls = tool_calls
|
|
217
|
+
cohere_messages.append(message_param)
|
|
218
|
+
else:
|
|
219
|
+
assert_never(message)
|
|
220
|
+
if instructions := self._get_instructions(messages):
|
|
221
|
+
cohere_messages.insert(0, SystemChatMessageV2(role='system', content=instructions))
|
|
222
|
+
return cohere_messages
|
|
219
223
|
|
|
220
224
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
|
|
221
225
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
222
|
-
if model_request_parameters.
|
|
223
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.
|
|
226
|
+
if model_request_parameters.output_tools:
|
|
227
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
224
228
|
return tools
|
|
225
229
|
|
|
226
230
|
@staticmethod
|
|
@@ -274,25 +278,25 @@ class CohereModel(Model):
|
|
|
274
278
|
assert_never(part)
|
|
275
279
|
|
|
276
280
|
|
|
277
|
-
def _map_usage(response: ChatResponse) ->
|
|
278
|
-
|
|
279
|
-
if
|
|
280
|
-
return
|
|
281
|
+
def _map_usage(response: ChatResponse) -> usage.Usage:
|
|
282
|
+
u = response.usage
|
|
283
|
+
if u is None:
|
|
284
|
+
return usage.Usage()
|
|
281
285
|
else:
|
|
282
286
|
details: dict[str, int] = {}
|
|
283
|
-
if
|
|
284
|
-
if
|
|
285
|
-
details['input_tokens'] = int(
|
|
286
|
-
if
|
|
287
|
-
details['output_tokens'] = int(
|
|
288
|
-
if
|
|
289
|
-
details['search_units'] = int(
|
|
290
|
-
if
|
|
291
|
-
details['classifications'] = int(
|
|
292
|
-
|
|
293
|
-
request_tokens = int(
|
|
294
|
-
response_tokens = int(
|
|
295
|
-
return
|
|
287
|
+
if u.billed_units is not None:
|
|
288
|
+
if u.billed_units.input_tokens:
|
|
289
|
+
details['input_tokens'] = int(u.billed_units.input_tokens)
|
|
290
|
+
if u.billed_units.output_tokens:
|
|
291
|
+
details['output_tokens'] = int(u.billed_units.output_tokens)
|
|
292
|
+
if u.billed_units.search_units: # pragma: no cover
|
|
293
|
+
details['search_units'] = int(u.billed_units.search_units)
|
|
294
|
+
if u.billed_units.classifications: # pragma: no cover
|
|
295
|
+
details['classifications'] = int(u.billed_units.classifications)
|
|
296
|
+
|
|
297
|
+
request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else None
|
|
298
|
+
response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else None
|
|
299
|
+
return usage.Usage(
|
|
296
300
|
request_tokens=request_tokens,
|
|
297
301
|
response_tokens=response_tokens,
|
|
298
302
|
total_tokens=(request_tokens or 0) + (response_tokens or 0),
|
pydantic_ai/models/fallback.py
CHANGED
|
@@ -63,8 +63,9 @@ class FallbackModel(Model):
|
|
|
63
63
|
exceptions: list[Exception] = []
|
|
64
64
|
|
|
65
65
|
for model in self.models:
|
|
66
|
+
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
|
|
66
67
|
try:
|
|
67
|
-
response, usage = await model.request(messages, model_settings,
|
|
68
|
+
response, usage = await model.request(messages, model_settings, customized_model_request_parameters)
|
|
68
69
|
except Exception as exc:
|
|
69
70
|
if self._fallback_on(exc):
|
|
70
71
|
exceptions.append(exc)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -91,8 +91,8 @@ class FunctionModel(Model):
|
|
|
91
91
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
92
92
|
agent_info = AgentInfo(
|
|
93
93
|
model_request_parameters.function_tools,
|
|
94
|
-
model_request_parameters.
|
|
95
|
-
model_request_parameters.
|
|
94
|
+
model_request_parameters.allow_text_output,
|
|
95
|
+
model_request_parameters.output_tools,
|
|
96
96
|
model_settings,
|
|
97
97
|
)
|
|
98
98
|
|
|
@@ -117,8 +117,8 @@ class FunctionModel(Model):
|
|
|
117
117
|
) -> AsyncIterator[StreamedResponse]:
|
|
118
118
|
agent_info = AgentInfo(
|
|
119
119
|
model_request_parameters.function_tools,
|
|
120
|
-
model_request_parameters.
|
|
121
|
-
model_request_parameters.
|
|
120
|
+
model_request_parameters.allow_text_output,
|
|
121
|
+
model_request_parameters.output_tools,
|
|
122
122
|
model_settings,
|
|
123
123
|
)
|
|
124
124
|
|
|
@@ -158,10 +158,10 @@ class AgentInfo:
|
|
|
158
158
|
These are the tools registered via the [`tool`][pydantic_ai.Agent.tool] and
|
|
159
159
|
[`tool_plain`][pydantic_ai.Agent.tool_plain] decorators.
|
|
160
160
|
"""
|
|
161
|
-
|
|
162
|
-
"""Whether a plain text
|
|
163
|
-
|
|
164
|
-
"""The tools that can called
|
|
161
|
+
allow_text_output: bool
|
|
162
|
+
"""Whether a plain text output is allowed."""
|
|
163
|
+
output_tools: list[ToolDefinition]
|
|
164
|
+
"""The tools that can called to produce the final output of the run."""
|
|
165
165
|
model_settings: ModelSettings | None
|
|
166
166
|
"""The model settings passed to the run call."""
|
|
167
167
|
|