pydantic-ai-slim 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +60 -57
- pydantic_ai/_cli.py +18 -3
- pydantic_ai/_parts_manager.py +5 -4
- pydantic_ai/_run_context.py +2 -2
- pydantic_ai/_tool_manager.py +50 -29
- pydantic_ai/ag_ui.py +4 -4
- pydantic_ai/agent/__init__.py +69 -84
- pydantic_ai/agent/abstract.py +16 -18
- pydantic_ai/agent/wrapper.py +4 -6
- pydantic_ai/direct.py +4 -4
- pydantic_ai/durable_exec/temporal/_agent.py +13 -15
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/messages.py +16 -6
- pydantic_ai/models/__init__.py +5 -5
- pydantic_ai/models/anthropic.py +47 -46
- pydantic_ai/models/bedrock.py +25 -27
- pydantic_ai/models/cohere.py +20 -25
- pydantic_ai/models/fallback.py +15 -15
- pydantic_ai/models/function.py +7 -9
- pydantic_ai/models/gemini.py +43 -39
- pydantic_ai/models/google.py +59 -40
- pydantic_ai/models/groq.py +23 -19
- pydantic_ai/models/huggingface.py +27 -23
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mcp_sampling.py +1 -2
- pydantic_ai/models/mistral.py +24 -22
- pydantic_ai/models/openai.py +101 -45
- pydantic_ai/models/test.py +4 -5
- pydantic_ai/profiles/__init__.py +10 -1
- pydantic_ai/profiles/deepseek.py +1 -1
- pydantic_ai/profiles/moonshotai.py +1 -1
- pydantic_ai/profiles/openai.py +13 -3
- pydantic_ai/profiles/qwen.py +4 -1
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/huggingface.py +27 -0
- pydantic_ai/providers/ollama.py +105 -0
- pydantic_ai/providers/openai.py +1 -1
- pydantic_ai/providers/openrouter.py +2 -0
- pydantic_ai/result.py +6 -6
- pydantic_ai/run.py +4 -11
- pydantic_ai/tools.py +9 -9
- pydantic_ai/usage.py +229 -67
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/METADATA +10 -4
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/RECORD +47 -46
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/__init__.py
CHANGED
|
@@ -42,7 +42,7 @@ from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
|
|
42
42
|
from ..profiles._json_schema import JsonSchemaTransformer
|
|
43
43
|
from ..settings import ModelSettings
|
|
44
44
|
from ..tools import ToolDefinition
|
|
45
|
-
from ..usage import
|
|
45
|
+
from ..usage import RequestUsage
|
|
46
46
|
|
|
47
47
|
KnownModelName = TypeAliasType(
|
|
48
48
|
'KnownModelName',
|
|
@@ -418,7 +418,7 @@ class Model(ABC):
|
|
|
418
418
|
messages: list[ModelMessage],
|
|
419
419
|
model_settings: ModelSettings | None,
|
|
420
420
|
model_request_parameters: ModelRequestParameters,
|
|
421
|
-
) ->
|
|
421
|
+
) -> RequestUsage:
|
|
422
422
|
"""Make a request to the model for counting tokens."""
|
|
423
423
|
# This method is not required, but you need to implement it if you want to support `UsageLimits.count_tokens_before_request`.
|
|
424
424
|
raise NotImplementedError(f'Token counting ahead of the request is not supported by {self.__class__.__name__}')
|
|
@@ -480,7 +480,7 @@ class Model(ABC):
|
|
|
480
480
|
@property
|
|
481
481
|
@abstractmethod
|
|
482
482
|
def system(self) -> str:
|
|
483
|
-
"""The
|
|
483
|
+
"""The model provider, ex: openai.
|
|
484
484
|
|
|
485
485
|
Use to populate the `gen_ai.system` OpenTelemetry semantic convention attribute,
|
|
486
486
|
so should use well-known values listed in
|
|
@@ -547,7 +547,7 @@ class StreamedResponse(ABC):
|
|
|
547
547
|
|
|
548
548
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
549
549
|
_event_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
|
|
550
|
-
_usage:
|
|
550
|
+
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)
|
|
551
551
|
|
|
552
552
|
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
|
|
553
553
|
"""Stream the response as an async iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
|
|
@@ -600,7 +600,7 @@ class StreamedResponse(ABC):
|
|
|
600
600
|
usage=self.usage(),
|
|
601
601
|
)
|
|
602
602
|
|
|
603
|
-
def usage(self) ->
|
|
603
|
+
def usage(self) -> RequestUsage:
|
|
604
604
|
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
|
|
605
605
|
return self._usage
|
|
606
606
|
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -8,14 +8,6 @@ from dataclasses import dataclass, field
|
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
9
|
from typing import Any, Literal, Union, cast, overload
|
|
10
10
|
|
|
11
|
-
from anthropic.types.beta import (
|
|
12
|
-
BetaCitationsDelta,
|
|
13
|
-
BetaCodeExecutionToolResultBlock,
|
|
14
|
-
BetaCodeExecutionToolResultBlockParam,
|
|
15
|
-
BetaInputJSONDelta,
|
|
16
|
-
BetaServerToolUseBlockParam,
|
|
17
|
-
BetaWebSearchToolResultBlockParam,
|
|
18
|
-
)
|
|
19
11
|
from typing_extensions import assert_never
|
|
20
12
|
|
|
21
13
|
from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
|
|
@@ -47,24 +39,21 @@ from ..profiles import ModelProfileSpec
|
|
|
47
39
|
from ..providers import Provider, infer_provider
|
|
48
40
|
from ..settings import ModelSettings
|
|
49
41
|
from ..tools import ToolDefinition
|
|
50
|
-
from . import
|
|
51
|
-
Model,
|
|
52
|
-
ModelRequestParameters,
|
|
53
|
-
StreamedResponse,
|
|
54
|
-
check_allow_model_requests,
|
|
55
|
-
download_item,
|
|
56
|
-
get_user_agent,
|
|
57
|
-
)
|
|
42
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
|
|
58
43
|
|
|
59
44
|
try:
|
|
60
45
|
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
|
|
61
46
|
from anthropic.types.beta import (
|
|
62
47
|
BetaBase64PDFBlockParam,
|
|
63
48
|
BetaBase64PDFSourceParam,
|
|
49
|
+
BetaCitationsDelta,
|
|
64
50
|
BetaCodeExecutionTool20250522Param,
|
|
51
|
+
BetaCodeExecutionToolResultBlock,
|
|
52
|
+
BetaCodeExecutionToolResultBlockParam,
|
|
65
53
|
BetaContentBlock,
|
|
66
54
|
BetaContentBlockParam,
|
|
67
55
|
BetaImageBlockParam,
|
|
56
|
+
BetaInputJSONDelta,
|
|
68
57
|
BetaMessage,
|
|
69
58
|
BetaMessageParam,
|
|
70
59
|
BetaMetadataParam,
|
|
@@ -78,6 +67,7 @@ try:
|
|
|
78
67
|
BetaRawMessageStreamEvent,
|
|
79
68
|
BetaRedactedThinkingBlock,
|
|
80
69
|
BetaServerToolUseBlock,
|
|
70
|
+
BetaServerToolUseBlockParam,
|
|
81
71
|
BetaSignatureDelta,
|
|
82
72
|
BetaTextBlock,
|
|
83
73
|
BetaTextBlockParam,
|
|
@@ -94,6 +84,7 @@ try:
|
|
|
94
84
|
BetaToolUseBlockParam,
|
|
95
85
|
BetaWebSearchTool20250305Param,
|
|
96
86
|
BetaWebSearchToolResultBlock,
|
|
87
|
+
BetaWebSearchToolResultBlockParam,
|
|
97
88
|
)
|
|
98
89
|
from anthropic.types.beta.beta_web_search_tool_20250305_param import UserLocation
|
|
99
90
|
from anthropic.types.model_param import ModelParam
|
|
@@ -146,7 +137,7 @@ class AnthropicModel(Model):
|
|
|
146
137
|
client: AsyncAnthropic = field(repr=False)
|
|
147
138
|
|
|
148
139
|
_model_name: AnthropicModelName = field(repr=False)
|
|
149
|
-
|
|
140
|
+
_provider: Provider[AsyncAnthropic] = field(repr=False)
|
|
150
141
|
|
|
151
142
|
def __init__(
|
|
152
143
|
self,
|
|
@@ -170,6 +161,7 @@ class AnthropicModel(Model):
|
|
|
170
161
|
|
|
171
162
|
if isinstance(provider, str):
|
|
172
163
|
provider = infer_provider(provider)
|
|
164
|
+
self._provider = provider
|
|
173
165
|
self.client = provider.client
|
|
174
166
|
|
|
175
167
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -178,6 +170,16 @@ class AnthropicModel(Model):
|
|
|
178
170
|
def base_url(self) -> str:
|
|
179
171
|
return str(self.client.base_url)
|
|
180
172
|
|
|
173
|
+
@property
|
|
174
|
+
def model_name(self) -> AnthropicModelName:
|
|
175
|
+
"""The model name."""
|
|
176
|
+
return self._model_name
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def system(self) -> str:
|
|
180
|
+
"""The model provider."""
|
|
181
|
+
return self._provider.name
|
|
182
|
+
|
|
181
183
|
async def request(
|
|
182
184
|
self,
|
|
183
185
|
messages: list[ModelMessage],
|
|
@@ -189,7 +191,6 @@ class AnthropicModel(Model):
|
|
|
189
191
|
messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
|
|
190
192
|
)
|
|
191
193
|
model_response = self._process_response(response)
|
|
192
|
-
model_response.usage.requests = 1
|
|
193
194
|
return model_response
|
|
194
195
|
|
|
195
196
|
@asynccontextmanager
|
|
@@ -207,16 +208,6 @@ class AnthropicModel(Model):
|
|
|
207
208
|
async with response:
|
|
208
209
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
209
210
|
|
|
210
|
-
@property
|
|
211
|
-
def model_name(self) -> AnthropicModelName:
|
|
212
|
-
"""The model name."""
|
|
213
|
-
return self._model_name
|
|
214
|
-
|
|
215
|
-
@property
|
|
216
|
-
def system(self) -> str:
|
|
217
|
-
"""The system / model provider."""
|
|
218
|
-
return self._system
|
|
219
|
-
|
|
220
211
|
@overload
|
|
221
212
|
async def _messages_create(
|
|
222
213
|
self,
|
|
@@ -246,7 +237,9 @@ class AnthropicModel(Model):
|
|
|
246
237
|
) -> BetaMessage | AsyncStream[BetaRawMessageStreamEvent]:
|
|
247
238
|
# standalone function to make it easier to override
|
|
248
239
|
tools = self._get_tools(model_request_parameters)
|
|
249
|
-
|
|
240
|
+
builtin_tools, tool_headers = self._get_builtin_tools(model_request_parameters)
|
|
241
|
+
tools += builtin_tools
|
|
242
|
+
|
|
250
243
|
tool_choice: BetaToolChoiceParam | None
|
|
251
244
|
|
|
252
245
|
if not tools:
|
|
@@ -264,8 +257,10 @@ class AnthropicModel(Model):
|
|
|
264
257
|
|
|
265
258
|
try:
|
|
266
259
|
extra_headers = model_settings.get('extra_headers', {})
|
|
260
|
+
for k, v in tool_headers.items():
|
|
261
|
+
extra_headers.setdefault(k, v)
|
|
267
262
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
268
|
-
|
|
263
|
+
|
|
269
264
|
return await self.client.beta.messages.create(
|
|
270
265
|
max_tokens=model_settings.get('max_tokens', 4096),
|
|
271
266
|
system=system_prompt or NOT_GIVEN,
|
|
@@ -330,7 +325,9 @@ class AnthropicModel(Model):
|
|
|
330
325
|
)
|
|
331
326
|
)
|
|
332
327
|
|
|
333
|
-
return ModelResponse(
|
|
328
|
+
return ModelResponse(
|
|
329
|
+
items, usage=_map_usage(response), model_name=response.model, provider_request_id=response.id
|
|
330
|
+
)
|
|
334
331
|
|
|
335
332
|
async def _process_streamed_response(
|
|
336
333
|
self, response: AsyncStream[BetaRawMessageStreamEvent], model_request_parameters: ModelRequestParameters
|
|
@@ -352,8 +349,11 @@ class AnthropicModel(Model):
|
|
|
352
349
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]:
|
|
353
350
|
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
354
351
|
|
|
355
|
-
def _get_builtin_tools(
|
|
352
|
+
def _get_builtin_tools(
|
|
353
|
+
self, model_request_parameters: ModelRequestParameters
|
|
354
|
+
) -> tuple[list[BetaToolUnionParam], dict[str, str]]:
|
|
356
355
|
tools: list[BetaToolUnionParam] = []
|
|
356
|
+
extra_headers: dict[str, str] = {}
|
|
357
357
|
for tool in model_request_parameters.builtin_tools:
|
|
358
358
|
if isinstance(tool, WebSearchTool):
|
|
359
359
|
user_location = UserLocation(type='approximate', **tool.user_location) if tool.user_location else None
|
|
@@ -361,18 +361,20 @@ class AnthropicModel(Model):
|
|
|
361
361
|
BetaWebSearchTool20250305Param(
|
|
362
362
|
name='web_search',
|
|
363
363
|
type='web_search_20250305',
|
|
364
|
+
max_uses=tool.max_uses,
|
|
364
365
|
allowed_domains=tool.allowed_domains,
|
|
365
366
|
blocked_domains=tool.blocked_domains,
|
|
366
367
|
user_location=user_location,
|
|
367
368
|
)
|
|
368
369
|
)
|
|
369
370
|
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
371
|
+
extra_headers['anthropic-beta'] = 'code-execution-2025-05-22'
|
|
370
372
|
tools.append(BetaCodeExecutionTool20250522Param(name='code_execution', type='code_execution_20250522'))
|
|
371
373
|
else: # pragma: no cover
|
|
372
374
|
raise UserError(
|
|
373
375
|
f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
|
|
374
376
|
)
|
|
375
|
-
return tools
|
|
377
|
+
return tools, extra_headers
|
|
376
378
|
|
|
377
379
|
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
|
|
378
380
|
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
@@ -528,7 +530,7 @@ class AnthropicModel(Model):
|
|
|
528
530
|
}
|
|
529
531
|
|
|
530
532
|
|
|
531
|
-
def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.
|
|
533
|
+
def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage:
|
|
532
534
|
if isinstance(message, BetaMessage):
|
|
533
535
|
response_usage = message.usage
|
|
534
536
|
elif isinstance(message, BetaRawMessageStartEvent):
|
|
@@ -541,7 +543,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage:
|
|
|
541
543
|
# - RawContentBlockStartEvent
|
|
542
544
|
# - RawContentBlockDeltaEvent
|
|
543
545
|
# - RawContentBlockStopEvent
|
|
544
|
-
return usage.
|
|
546
|
+
return usage.RequestUsage()
|
|
545
547
|
|
|
546
548
|
# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
|
|
547
549
|
# `response_tokens`
|
|
@@ -552,17 +554,16 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage:
|
|
|
552
554
|
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get`
|
|
553
555
|
# Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
|
|
554
556
|
# This approach maintains request_tokens as the count of all input tokens, with cached counts as details
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
details=details or None,
|
|
557
|
+
cache_write_tokens = details.get('cache_creation_input_tokens', 0)
|
|
558
|
+
cache_read_tokens = details.get('cache_read_input_tokens', 0)
|
|
559
|
+
request_tokens = details.get('input_tokens', 0) + cache_write_tokens + cache_read_tokens
|
|
560
|
+
|
|
561
|
+
return usage.RequestUsage(
|
|
562
|
+
input_tokens=request_tokens,
|
|
563
|
+
cache_read_tokens=cache_read_tokens,
|
|
564
|
+
cache_write_tokens=cache_write_tokens,
|
|
565
|
+
output_tokens=response_usage.output_tokens,
|
|
566
|
+
details=details,
|
|
566
567
|
)
|
|
567
568
|
|
|
568
569
|
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -190,17 +190,7 @@ class BedrockConverseModel(Model):
|
|
|
190
190
|
client: BedrockRuntimeClient
|
|
191
191
|
|
|
192
192
|
_model_name: BedrockModelName = field(repr=False)
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
@property
|
|
196
|
-
def model_name(self) -> str:
|
|
197
|
-
"""The model name."""
|
|
198
|
-
return self._model_name
|
|
199
|
-
|
|
200
|
-
@property
|
|
201
|
-
def system(self) -> str:
|
|
202
|
-
"""The system / model provider, ex: openai."""
|
|
203
|
-
return self._system
|
|
193
|
+
_provider: Provider[BaseClient] = field(repr=False)
|
|
204
194
|
|
|
205
195
|
def __init__(
|
|
206
196
|
self,
|
|
@@ -226,10 +216,25 @@ class BedrockConverseModel(Model):
|
|
|
226
216
|
|
|
227
217
|
if isinstance(provider, str):
|
|
228
218
|
provider = infer_provider(provider)
|
|
219
|
+
self._provider = provider
|
|
229
220
|
self.client = cast('BedrockRuntimeClient', provider.client)
|
|
230
221
|
|
|
231
222
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
232
223
|
|
|
224
|
+
@property
|
|
225
|
+
def base_url(self) -> str:
|
|
226
|
+
return str(self.client.meta.endpoint_url)
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def model_name(self) -> str:
|
|
230
|
+
"""The model name."""
|
|
231
|
+
return self._model_name
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def system(self) -> str:
|
|
235
|
+
"""The model provider."""
|
|
236
|
+
return self._provider.name
|
|
237
|
+
|
|
233
238
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
|
|
234
239
|
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
235
240
|
|
|
@@ -245,10 +250,6 @@ class BedrockConverseModel(Model):
|
|
|
245
250
|
|
|
246
251
|
return {'toolSpec': tool_spec}
|
|
247
252
|
|
|
248
|
-
@property
|
|
249
|
-
def base_url(self) -> str:
|
|
250
|
-
return str(self.client.meta.endpoint_url)
|
|
251
|
-
|
|
252
253
|
async def request(
|
|
253
254
|
self,
|
|
254
255
|
messages: list[ModelMessage],
|
|
@@ -258,7 +259,6 @@ class BedrockConverseModel(Model):
|
|
|
258
259
|
settings = cast(BedrockModelSettings, model_settings or {})
|
|
259
260
|
response = await self._messages_create(messages, False, settings, model_request_parameters)
|
|
260
261
|
model_response = await self._process_response(response)
|
|
261
|
-
model_response.usage.requests = 1
|
|
262
262
|
return model_response
|
|
263
263
|
|
|
264
264
|
@asynccontextmanager
|
|
@@ -299,13 +299,12 @@ class BedrockConverseModel(Model):
|
|
|
299
299
|
tool_call_id=tool_use['toolUseId'],
|
|
300
300
|
),
|
|
301
301
|
)
|
|
302
|
-
u = usage.
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
total_tokens=response['usage']['totalTokens'],
|
|
302
|
+
u = usage.RequestUsage(
|
|
303
|
+
input_tokens=response['usage']['inputTokens'],
|
|
304
|
+
output_tokens=response['usage']['outputTokens'],
|
|
306
305
|
)
|
|
307
306
|
vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None)
|
|
308
|
-
return ModelResponse(items, usage=u, model_name=self.model_name,
|
|
307
|
+
return ModelResponse(items, usage=u, model_name=self.model_name, provider_request_id=vendor_id)
|
|
309
308
|
|
|
310
309
|
@overload
|
|
311
310
|
async def _messages_create(
|
|
@@ -648,7 +647,7 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
648
647
|
)
|
|
649
648
|
if 'text' in delta:
|
|
650
649
|
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
|
|
651
|
-
if maybe_event is not None:
|
|
650
|
+
if maybe_event is not None: # pragma: no branch
|
|
652
651
|
yield maybe_event
|
|
653
652
|
if 'toolUse' in delta:
|
|
654
653
|
tool_use = delta['toolUse']
|
|
@@ -670,11 +669,10 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
670
669
|
"""Get the model name of the response."""
|
|
671
670
|
return self._model_name
|
|
672
671
|
|
|
673
|
-
def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.
|
|
674
|
-
return usage.
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
total_tokens=metadata['usage']['totalTokens'],
|
|
672
|
+
def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage:
|
|
673
|
+
return usage.RequestUsage(
|
|
674
|
+
input_tokens=metadata['usage']['inputTokens'],
|
|
675
|
+
output_tokens=metadata['usage']['outputTokens'],
|
|
678
676
|
)
|
|
679
677
|
|
|
680
678
|
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -30,11 +30,7 @@ from ..profiles import ModelProfileSpec
|
|
|
30
30
|
from ..providers import Provider, infer_provider
|
|
31
31
|
from ..settings import ModelSettings
|
|
32
32
|
from ..tools import ToolDefinition
|
|
33
|
-
from . import
|
|
34
|
-
Model,
|
|
35
|
-
ModelRequestParameters,
|
|
36
|
-
check_allow_model_requests,
|
|
37
|
-
)
|
|
33
|
+
from . import Model, ModelRequestParameters, check_allow_model_requests
|
|
38
34
|
|
|
39
35
|
try:
|
|
40
36
|
from cohere import (
|
|
@@ -106,7 +102,7 @@ class CohereModel(Model):
|
|
|
106
102
|
client: AsyncClientV2 = field(repr=False)
|
|
107
103
|
|
|
108
104
|
_model_name: CohereModelName = field(repr=False)
|
|
109
|
-
|
|
105
|
+
_provider: Provider[AsyncClientV2] = field(repr=False)
|
|
110
106
|
|
|
111
107
|
def __init__(
|
|
112
108
|
self,
|
|
@@ -131,6 +127,7 @@ class CohereModel(Model):
|
|
|
131
127
|
|
|
132
128
|
if isinstance(provider, str):
|
|
133
129
|
provider = infer_provider(provider)
|
|
130
|
+
self._provider = provider
|
|
134
131
|
self.client = provider.client
|
|
135
132
|
|
|
136
133
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -140,6 +137,16 @@ class CohereModel(Model):
|
|
|
140
137
|
client_wrapper = self.client._client_wrapper # type: ignore
|
|
141
138
|
return str(client_wrapper.get_base_url())
|
|
142
139
|
|
|
140
|
+
@property
|
|
141
|
+
def model_name(self) -> CohereModelName:
|
|
142
|
+
"""The model name."""
|
|
143
|
+
return self._model_name
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def system(self) -> str:
|
|
147
|
+
"""The model provider."""
|
|
148
|
+
return self._provider.name
|
|
149
|
+
|
|
143
150
|
async def request(
|
|
144
151
|
self,
|
|
145
152
|
messages: list[ModelMessage],
|
|
@@ -149,19 +156,8 @@ class CohereModel(Model):
|
|
|
149
156
|
check_allow_model_requests()
|
|
150
157
|
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
|
|
151
158
|
model_response = self._process_response(response)
|
|
152
|
-
model_response.usage.requests = 1
|
|
153
159
|
return model_response
|
|
154
160
|
|
|
155
|
-
@property
|
|
156
|
-
def model_name(self) -> CohereModelName:
|
|
157
|
-
"""The model name."""
|
|
158
|
-
return self._model_name
|
|
159
|
-
|
|
160
|
-
@property
|
|
161
|
-
def system(self) -> str:
|
|
162
|
-
"""The system / model provider."""
|
|
163
|
-
return self._system
|
|
164
|
-
|
|
165
161
|
async def _chat(
|
|
166
162
|
self,
|
|
167
163
|
messages: list[ModelMessage],
|
|
@@ -301,10 +297,10 @@ class CohereModel(Model):
|
|
|
301
297
|
assert_never(part)
|
|
302
298
|
|
|
303
299
|
|
|
304
|
-
def _map_usage(response: V2ChatResponse) -> usage.
|
|
300
|
+
def _map_usage(response: V2ChatResponse) -> usage.RequestUsage:
|
|
305
301
|
u = response.usage
|
|
306
302
|
if u is None:
|
|
307
|
-
return usage.
|
|
303
|
+
return usage.RequestUsage()
|
|
308
304
|
else:
|
|
309
305
|
details: dict[str, int] = {}
|
|
310
306
|
if u.billed_units is not None:
|
|
@@ -317,11 +313,10 @@ def _map_usage(response: V2ChatResponse) -> usage.Usage:
|
|
|
317
313
|
if u.billed_units.classifications: # pragma: no cover
|
|
318
314
|
details['classifications'] = int(u.billed_units.classifications)
|
|
319
315
|
|
|
320
|
-
request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else
|
|
321
|
-
response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else
|
|
322
|
-
return usage.
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
total_tokens=(request_tokens or 0) + (response_tokens or 0),
|
|
316
|
+
request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else 0
|
|
317
|
+
response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else 0
|
|
318
|
+
return usage.RequestUsage(
|
|
319
|
+
input_tokens=request_tokens,
|
|
320
|
+
output_tokens=response_tokens,
|
|
326
321
|
details=details,
|
|
327
322
|
)
|
pydantic_ai/models/fallback.py
CHANGED
|
@@ -33,8 +33,8 @@ class FallbackModel(Model):
|
|
|
33
33
|
|
|
34
34
|
def __init__(
|
|
35
35
|
self,
|
|
36
|
-
default_model: Model | KnownModelName,
|
|
37
|
-
*fallback_models: Model | KnownModelName,
|
|
36
|
+
default_model: Model | KnownModelName | str,
|
|
37
|
+
*fallback_models: Model | KnownModelName | str,
|
|
38
38
|
fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
|
|
39
39
|
):
|
|
40
40
|
"""Initialize a fallback model instance.
|
|
@@ -52,6 +52,19 @@ class FallbackModel(Model):
|
|
|
52
52
|
else:
|
|
53
53
|
self._fallback_on = fallback_on
|
|
54
54
|
|
|
55
|
+
@property
|
|
56
|
+
def model_name(self) -> str:
|
|
57
|
+
"""The model name."""
|
|
58
|
+
return f'fallback:{",".join(model.model_name for model in self.models)}'
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def system(self) -> str:
|
|
62
|
+
return f'fallback:{",".join(model.system for model in self.models)}'
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def base_url(self) -> str | None:
|
|
66
|
+
return self.models[0].base_url
|
|
67
|
+
|
|
55
68
|
async def request(
|
|
56
69
|
self,
|
|
57
70
|
messages: list[ModelMessage],
|
|
@@ -121,19 +134,6 @@ class FallbackModel(Model):
|
|
|
121
134
|
if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch
|
|
122
135
|
span.set_attributes(InstrumentedModel.model_attributes(model))
|
|
123
136
|
|
|
124
|
-
@property
|
|
125
|
-
def model_name(self) -> str:
|
|
126
|
-
"""The model name."""
|
|
127
|
-
return f'fallback:{",".join(model.model_name for model in self.models)}'
|
|
128
|
-
|
|
129
|
-
@property
|
|
130
|
-
def system(self) -> str:
|
|
131
|
-
return f'fallback:{",".join(model.system for model in self.models)}'
|
|
132
|
-
|
|
133
|
-
@property
|
|
134
|
-
def base_url(self) -> str | None:
|
|
135
|
-
return self.models[0].base_url
|
|
136
|
-
|
|
137
137
|
|
|
138
138
|
def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
|
|
139
139
|
"""Create a default fallback condition for the given exceptions."""
|
pydantic_ai/models/function.py
CHANGED
|
@@ -138,7 +138,6 @@ class FunctionModel(Model):
|
|
|
138
138
|
# Add usage data if not already present
|
|
139
139
|
if not response.usage.has_values(): # pragma: no branch
|
|
140
140
|
response.usage = _estimate_usage(chain(messages, [response]))
|
|
141
|
-
response.usage.requests = 1
|
|
142
141
|
return response
|
|
143
142
|
|
|
144
143
|
@asynccontextmanager
|
|
@@ -270,7 +269,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
270
269
|
async for item in self._iter:
|
|
271
270
|
if isinstance(item, str):
|
|
272
271
|
response_tokens = _estimate_string_tokens(item)
|
|
273
|
-
self._usage += usage.
|
|
272
|
+
self._usage += usage.RequestUsage(output_tokens=response_tokens)
|
|
274
273
|
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
|
|
275
274
|
if maybe_event is not None: # pragma: no branch
|
|
276
275
|
yield maybe_event
|
|
@@ -279,7 +278,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
279
278
|
if isinstance(delta, DeltaThinkingPart):
|
|
280
279
|
if delta.content: # pragma: no branch
|
|
281
280
|
response_tokens = _estimate_string_tokens(delta.content)
|
|
282
|
-
self._usage += usage.
|
|
281
|
+
self._usage += usage.RequestUsage(output_tokens=response_tokens)
|
|
283
282
|
yield self._parts_manager.handle_thinking_delta(
|
|
284
283
|
vendor_part_id=dtc_index,
|
|
285
284
|
content=delta.content,
|
|
@@ -288,7 +287,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
288
287
|
elif isinstance(delta, DeltaToolCall):
|
|
289
288
|
if delta.json_args:
|
|
290
289
|
response_tokens = _estimate_string_tokens(delta.json_args)
|
|
291
|
-
self._usage += usage.
|
|
290
|
+
self._usage += usage.RequestUsage(output_tokens=response_tokens)
|
|
292
291
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
293
292
|
vendor_part_id=dtc_index,
|
|
294
293
|
tool_name=delta.name,
|
|
@@ -311,7 +310,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
311
310
|
return self._timestamp
|
|
312
311
|
|
|
313
312
|
|
|
314
|
-
def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.
|
|
313
|
+
def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
|
|
315
314
|
"""Very rough guesstimate of the token usage associated with a series of messages.
|
|
316
315
|
|
|
317
316
|
This is designed to be used solely to give plausible numbers for testing!
|
|
@@ -349,10 +348,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
|
349
348
|
assert_never(part)
|
|
350
349
|
else:
|
|
351
350
|
assert_never(message)
|
|
352
|
-
return usage.
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
total_tokens=request_tokens + response_tokens,
|
|
351
|
+
return usage.RequestUsage(
|
|
352
|
+
input_tokens=request_tokens,
|
|
353
|
+
output_tokens=response_tokens,
|
|
356
354
|
)
|
|
357
355
|
|
|
358
356
|
|