pydantic-ai-slim 0.7.2__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -1
- pydantic_ai/_agent_graph.py +2 -2
- pydantic_ai/_cli.py +18 -3
- pydantic_ai/_run_context.py +2 -2
- pydantic_ai/ag_ui.py +4 -4
- pydantic_ai/agent/__init__.py +7 -9
- pydantic_ai/agent/abstract.py +16 -18
- pydantic_ai/agent/wrapper.py +4 -6
- pydantic_ai/builtin_tools.py +9 -1
- pydantic_ai/direct.py +4 -4
- pydantic_ai/durable_exec/temporal/_agent.py +13 -15
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/messages.py +16 -6
- pydantic_ai/models/__init__.py +5 -5
- pydantic_ai/models/anthropic.py +27 -26
- pydantic_ai/models/bedrock.py +24 -26
- pydantic_ai/models/cohere.py +20 -25
- pydantic_ai/models/fallback.py +15 -15
- pydantic_ai/models/function.py +7 -9
- pydantic_ai/models/gemini.py +43 -39
- pydantic_ai/models/google.py +76 -50
- pydantic_ai/models/groq.py +22 -19
- pydantic_ai/models/huggingface.py +18 -21
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mcp_sampling.py +1 -2
- pydantic_ai/models/mistral.py +24 -22
- pydantic_ai/models/openai.py +98 -44
- pydantic_ai/models/test.py +4 -5
- pydantic_ai/profiles/openai.py +13 -3
- pydantic_ai/providers/openai.py +1 -1
- pydantic_ai/result.py +5 -5
- pydantic_ai/run.py +4 -11
- pydantic_ai/tools.py +5 -2
- pydantic_ai/usage.py +230 -68
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/METADATA +10 -4
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/RECORD +39 -39
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/__init__.py
CHANGED
|
@@ -42,7 +42,7 @@ from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
|
|
42
42
|
from ..profiles._json_schema import JsonSchemaTransformer
|
|
43
43
|
from ..settings import ModelSettings
|
|
44
44
|
from ..tools import ToolDefinition
|
|
45
|
-
from ..usage import
|
|
45
|
+
from ..usage import RequestUsage
|
|
46
46
|
|
|
47
47
|
KnownModelName = TypeAliasType(
|
|
48
48
|
'KnownModelName',
|
|
@@ -418,7 +418,7 @@ class Model(ABC):
|
|
|
418
418
|
messages: list[ModelMessage],
|
|
419
419
|
model_settings: ModelSettings | None,
|
|
420
420
|
model_request_parameters: ModelRequestParameters,
|
|
421
|
-
) ->
|
|
421
|
+
) -> RequestUsage:
|
|
422
422
|
"""Make a request to the model for counting tokens."""
|
|
423
423
|
# This method is not required, but you need to implement it if you want to support `UsageLimits.count_tokens_before_request`.
|
|
424
424
|
raise NotImplementedError(f'Token counting ahead of the request is not supported by {self.__class__.__name__}')
|
|
@@ -480,7 +480,7 @@ class Model(ABC):
|
|
|
480
480
|
@property
|
|
481
481
|
@abstractmethod
|
|
482
482
|
def system(self) -> str:
|
|
483
|
-
"""The
|
|
483
|
+
"""The model provider, ex: openai.
|
|
484
484
|
|
|
485
485
|
Use to populate the `gen_ai.system` OpenTelemetry semantic convention attribute,
|
|
486
486
|
so should use well-known values listed in
|
|
@@ -547,7 +547,7 @@ class StreamedResponse(ABC):
|
|
|
547
547
|
|
|
548
548
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
549
549
|
_event_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
|
|
550
|
-
_usage:
|
|
550
|
+
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)
|
|
551
551
|
|
|
552
552
|
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
|
|
553
553
|
"""Stream the response as an async iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
|
|
@@ -600,7 +600,7 @@ class StreamedResponse(ABC):
|
|
|
600
600
|
usage=self.usage(),
|
|
601
601
|
)
|
|
602
602
|
|
|
603
|
-
def usage(self) ->
|
|
603
|
+
def usage(self) -> RequestUsage:
|
|
604
604
|
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
|
|
605
605
|
return self._usage
|
|
606
606
|
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -137,7 +137,7 @@ class AnthropicModel(Model):
|
|
|
137
137
|
client: AsyncAnthropic = field(repr=False)
|
|
138
138
|
|
|
139
139
|
_model_name: AnthropicModelName = field(repr=False)
|
|
140
|
-
|
|
140
|
+
_provider: Provider[AsyncAnthropic] = field(repr=False)
|
|
141
141
|
|
|
142
142
|
def __init__(
|
|
143
143
|
self,
|
|
@@ -161,6 +161,7 @@ class AnthropicModel(Model):
|
|
|
161
161
|
|
|
162
162
|
if isinstance(provider, str):
|
|
163
163
|
provider = infer_provider(provider)
|
|
164
|
+
self._provider = provider
|
|
164
165
|
self.client = provider.client
|
|
165
166
|
|
|
166
167
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -169,6 +170,16 @@ class AnthropicModel(Model):
|
|
|
169
170
|
def base_url(self) -> str:
|
|
170
171
|
return str(self.client.base_url)
|
|
171
172
|
|
|
173
|
+
@property
|
|
174
|
+
def model_name(self) -> AnthropicModelName:
|
|
175
|
+
"""The model name."""
|
|
176
|
+
return self._model_name
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def system(self) -> str:
|
|
180
|
+
"""The model provider."""
|
|
181
|
+
return self._provider.name
|
|
182
|
+
|
|
172
183
|
async def request(
|
|
173
184
|
self,
|
|
174
185
|
messages: list[ModelMessage],
|
|
@@ -180,7 +191,6 @@ class AnthropicModel(Model):
|
|
|
180
191
|
messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
|
|
181
192
|
)
|
|
182
193
|
model_response = self._process_response(response)
|
|
183
|
-
model_response.usage.requests = 1
|
|
184
194
|
return model_response
|
|
185
195
|
|
|
186
196
|
@asynccontextmanager
|
|
@@ -198,16 +208,6 @@ class AnthropicModel(Model):
|
|
|
198
208
|
async with response:
|
|
199
209
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
200
210
|
|
|
201
|
-
@property
|
|
202
|
-
def model_name(self) -> AnthropicModelName:
|
|
203
|
-
"""The model name."""
|
|
204
|
-
return self._model_name
|
|
205
|
-
|
|
206
|
-
@property
|
|
207
|
-
def system(self) -> str:
|
|
208
|
-
"""The system / model provider."""
|
|
209
|
-
return self._system
|
|
210
|
-
|
|
211
211
|
@overload
|
|
212
212
|
async def _messages_create(
|
|
213
213
|
self,
|
|
@@ -325,7 +325,9 @@ class AnthropicModel(Model):
|
|
|
325
325
|
)
|
|
326
326
|
)
|
|
327
327
|
|
|
328
|
-
return ModelResponse(
|
|
328
|
+
return ModelResponse(
|
|
329
|
+
items, usage=_map_usage(response), model_name=response.model, provider_request_id=response.id
|
|
330
|
+
)
|
|
329
331
|
|
|
330
332
|
async def _process_streamed_response(
|
|
331
333
|
self, response: AsyncStream[BetaRawMessageStreamEvent], model_request_parameters: ModelRequestParameters
|
|
@@ -528,7 +530,7 @@ class AnthropicModel(Model):
|
|
|
528
530
|
}
|
|
529
531
|
|
|
530
532
|
|
|
531
|
-
def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.
|
|
533
|
+
def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage:
|
|
532
534
|
if isinstance(message, BetaMessage):
|
|
533
535
|
response_usage = message.usage
|
|
534
536
|
elif isinstance(message, BetaRawMessageStartEvent):
|
|
@@ -541,7 +543,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage:
|
|
|
541
543
|
# - RawContentBlockStartEvent
|
|
542
544
|
# - RawContentBlockDeltaEvent
|
|
543
545
|
# - RawContentBlockStopEvent
|
|
544
|
-
return usage.
|
|
546
|
+
return usage.RequestUsage()
|
|
545
547
|
|
|
546
548
|
# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
|
|
547
549
|
# `response_tokens`
|
|
@@ -552,17 +554,16 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage:
|
|
|
552
554
|
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get`
|
|
553
555
|
# Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
|
|
554
556
|
# This approach maintains request_tokens as the count of all input tokens, with cached counts as details
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
details=details or None,
|
|
557
|
+
cache_write_tokens = details.get('cache_creation_input_tokens', 0)
|
|
558
|
+
cache_read_tokens = details.get('cache_read_input_tokens', 0)
|
|
559
|
+
request_tokens = details.get('input_tokens', 0) + cache_write_tokens + cache_read_tokens
|
|
560
|
+
|
|
561
|
+
return usage.RequestUsage(
|
|
562
|
+
input_tokens=request_tokens,
|
|
563
|
+
cache_read_tokens=cache_read_tokens,
|
|
564
|
+
cache_write_tokens=cache_write_tokens,
|
|
565
|
+
output_tokens=response_usage.output_tokens,
|
|
566
|
+
details=details,
|
|
566
567
|
)
|
|
567
568
|
|
|
568
569
|
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -190,17 +190,7 @@ class BedrockConverseModel(Model):
|
|
|
190
190
|
client: BedrockRuntimeClient
|
|
191
191
|
|
|
192
192
|
_model_name: BedrockModelName = field(repr=False)
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
@property
|
|
196
|
-
def model_name(self) -> str:
|
|
197
|
-
"""The model name."""
|
|
198
|
-
return self._model_name
|
|
199
|
-
|
|
200
|
-
@property
|
|
201
|
-
def system(self) -> str:
|
|
202
|
-
"""The system / model provider, ex: openai."""
|
|
203
|
-
return self._system
|
|
193
|
+
_provider: Provider[BaseClient] = field(repr=False)
|
|
204
194
|
|
|
205
195
|
def __init__(
|
|
206
196
|
self,
|
|
@@ -226,10 +216,25 @@ class BedrockConverseModel(Model):
|
|
|
226
216
|
|
|
227
217
|
if isinstance(provider, str):
|
|
228
218
|
provider = infer_provider(provider)
|
|
219
|
+
self._provider = provider
|
|
229
220
|
self.client = cast('BedrockRuntimeClient', provider.client)
|
|
230
221
|
|
|
231
222
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
232
223
|
|
|
224
|
+
@property
|
|
225
|
+
def base_url(self) -> str:
|
|
226
|
+
return str(self.client.meta.endpoint_url)
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def model_name(self) -> str:
|
|
230
|
+
"""The model name."""
|
|
231
|
+
return self._model_name
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def system(self) -> str:
|
|
235
|
+
"""The model provider."""
|
|
236
|
+
return self._provider.name
|
|
237
|
+
|
|
233
238
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
|
|
234
239
|
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
235
240
|
|
|
@@ -245,10 +250,6 @@ class BedrockConverseModel(Model):
|
|
|
245
250
|
|
|
246
251
|
return {'toolSpec': tool_spec}
|
|
247
252
|
|
|
248
|
-
@property
|
|
249
|
-
def base_url(self) -> str:
|
|
250
|
-
return str(self.client.meta.endpoint_url)
|
|
251
|
-
|
|
252
253
|
async def request(
|
|
253
254
|
self,
|
|
254
255
|
messages: list[ModelMessage],
|
|
@@ -258,7 +259,6 @@ class BedrockConverseModel(Model):
|
|
|
258
259
|
settings = cast(BedrockModelSettings, model_settings or {})
|
|
259
260
|
response = await self._messages_create(messages, False, settings, model_request_parameters)
|
|
260
261
|
model_response = await self._process_response(response)
|
|
261
|
-
model_response.usage.requests = 1
|
|
262
262
|
return model_response
|
|
263
263
|
|
|
264
264
|
@asynccontextmanager
|
|
@@ -299,13 +299,12 @@ class BedrockConverseModel(Model):
|
|
|
299
299
|
tool_call_id=tool_use['toolUseId'],
|
|
300
300
|
),
|
|
301
301
|
)
|
|
302
|
-
u = usage.
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
total_tokens=response['usage']['totalTokens'],
|
|
302
|
+
u = usage.RequestUsage(
|
|
303
|
+
input_tokens=response['usage']['inputTokens'],
|
|
304
|
+
output_tokens=response['usage']['outputTokens'],
|
|
306
305
|
)
|
|
307
306
|
vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None)
|
|
308
|
-
return ModelResponse(items, usage=u, model_name=self.model_name,
|
|
307
|
+
return ModelResponse(items, usage=u, model_name=self.model_name, provider_request_id=vendor_id)
|
|
309
308
|
|
|
310
309
|
@overload
|
|
311
310
|
async def _messages_create(
|
|
@@ -670,11 +669,10 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
670
669
|
"""Get the model name of the response."""
|
|
671
670
|
return self._model_name
|
|
672
671
|
|
|
673
|
-
def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.
|
|
674
|
-
return usage.
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
total_tokens=metadata['usage']['totalTokens'],
|
|
672
|
+
def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage:
|
|
673
|
+
return usage.RequestUsage(
|
|
674
|
+
input_tokens=metadata['usage']['inputTokens'],
|
|
675
|
+
output_tokens=metadata['usage']['outputTokens'],
|
|
678
676
|
)
|
|
679
677
|
|
|
680
678
|
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -30,11 +30,7 @@ from ..profiles import ModelProfileSpec
|
|
|
30
30
|
from ..providers import Provider, infer_provider
|
|
31
31
|
from ..settings import ModelSettings
|
|
32
32
|
from ..tools import ToolDefinition
|
|
33
|
-
from . import
|
|
34
|
-
Model,
|
|
35
|
-
ModelRequestParameters,
|
|
36
|
-
check_allow_model_requests,
|
|
37
|
-
)
|
|
33
|
+
from . import Model, ModelRequestParameters, check_allow_model_requests
|
|
38
34
|
|
|
39
35
|
try:
|
|
40
36
|
from cohere import (
|
|
@@ -106,7 +102,7 @@ class CohereModel(Model):
|
|
|
106
102
|
client: AsyncClientV2 = field(repr=False)
|
|
107
103
|
|
|
108
104
|
_model_name: CohereModelName = field(repr=False)
|
|
109
|
-
|
|
105
|
+
_provider: Provider[AsyncClientV2] = field(repr=False)
|
|
110
106
|
|
|
111
107
|
def __init__(
|
|
112
108
|
self,
|
|
@@ -131,6 +127,7 @@ class CohereModel(Model):
|
|
|
131
127
|
|
|
132
128
|
if isinstance(provider, str):
|
|
133
129
|
provider = infer_provider(provider)
|
|
130
|
+
self._provider = provider
|
|
134
131
|
self.client = provider.client
|
|
135
132
|
|
|
136
133
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -140,6 +137,16 @@ class CohereModel(Model):
|
|
|
140
137
|
client_wrapper = self.client._client_wrapper # type: ignore
|
|
141
138
|
return str(client_wrapper.get_base_url())
|
|
142
139
|
|
|
140
|
+
@property
|
|
141
|
+
def model_name(self) -> CohereModelName:
|
|
142
|
+
"""The model name."""
|
|
143
|
+
return self._model_name
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def system(self) -> str:
|
|
147
|
+
"""The model provider."""
|
|
148
|
+
return self._provider.name
|
|
149
|
+
|
|
143
150
|
async def request(
|
|
144
151
|
self,
|
|
145
152
|
messages: list[ModelMessage],
|
|
@@ -149,19 +156,8 @@ class CohereModel(Model):
|
|
|
149
156
|
check_allow_model_requests()
|
|
150
157
|
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
|
|
151
158
|
model_response = self._process_response(response)
|
|
152
|
-
model_response.usage.requests = 1
|
|
153
159
|
return model_response
|
|
154
160
|
|
|
155
|
-
@property
|
|
156
|
-
def model_name(self) -> CohereModelName:
|
|
157
|
-
"""The model name."""
|
|
158
|
-
return self._model_name
|
|
159
|
-
|
|
160
|
-
@property
|
|
161
|
-
def system(self) -> str:
|
|
162
|
-
"""The system / model provider."""
|
|
163
|
-
return self._system
|
|
164
|
-
|
|
165
161
|
async def _chat(
|
|
166
162
|
self,
|
|
167
163
|
messages: list[ModelMessage],
|
|
@@ -301,10 +297,10 @@ class CohereModel(Model):
|
|
|
301
297
|
assert_never(part)
|
|
302
298
|
|
|
303
299
|
|
|
304
|
-
def _map_usage(response: V2ChatResponse) -> usage.
|
|
300
|
+
def _map_usage(response: V2ChatResponse) -> usage.RequestUsage:
|
|
305
301
|
u = response.usage
|
|
306
302
|
if u is None:
|
|
307
|
-
return usage.
|
|
303
|
+
return usage.RequestUsage()
|
|
308
304
|
else:
|
|
309
305
|
details: dict[str, int] = {}
|
|
310
306
|
if u.billed_units is not None:
|
|
@@ -317,11 +313,10 @@ def _map_usage(response: V2ChatResponse) -> usage.Usage:
|
|
|
317
313
|
if u.billed_units.classifications: # pragma: no cover
|
|
318
314
|
details['classifications'] = int(u.billed_units.classifications)
|
|
319
315
|
|
|
320
|
-
request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else
|
|
321
|
-
response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else
|
|
322
|
-
return usage.
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
total_tokens=(request_tokens or 0) + (response_tokens or 0),
|
|
316
|
+
request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else 0
|
|
317
|
+
response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else 0
|
|
318
|
+
return usage.RequestUsage(
|
|
319
|
+
input_tokens=request_tokens,
|
|
320
|
+
output_tokens=response_tokens,
|
|
326
321
|
details=details,
|
|
327
322
|
)
|
pydantic_ai/models/fallback.py
CHANGED
|
@@ -33,8 +33,8 @@ class FallbackModel(Model):
|
|
|
33
33
|
|
|
34
34
|
def __init__(
|
|
35
35
|
self,
|
|
36
|
-
default_model: Model | KnownModelName,
|
|
37
|
-
*fallback_models: Model | KnownModelName,
|
|
36
|
+
default_model: Model | KnownModelName | str,
|
|
37
|
+
*fallback_models: Model | KnownModelName | str,
|
|
38
38
|
fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
|
|
39
39
|
):
|
|
40
40
|
"""Initialize a fallback model instance.
|
|
@@ -52,6 +52,19 @@ class FallbackModel(Model):
|
|
|
52
52
|
else:
|
|
53
53
|
self._fallback_on = fallback_on
|
|
54
54
|
|
|
55
|
+
@property
|
|
56
|
+
def model_name(self) -> str:
|
|
57
|
+
"""The model name."""
|
|
58
|
+
return f'fallback:{",".join(model.model_name for model in self.models)}'
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def system(self) -> str:
|
|
62
|
+
return f'fallback:{",".join(model.system for model in self.models)}'
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def base_url(self) -> str | None:
|
|
66
|
+
return self.models[0].base_url
|
|
67
|
+
|
|
55
68
|
async def request(
|
|
56
69
|
self,
|
|
57
70
|
messages: list[ModelMessage],
|
|
@@ -121,19 +134,6 @@ class FallbackModel(Model):
|
|
|
121
134
|
if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch
|
|
122
135
|
span.set_attributes(InstrumentedModel.model_attributes(model))
|
|
123
136
|
|
|
124
|
-
@property
|
|
125
|
-
def model_name(self) -> str:
|
|
126
|
-
"""The model name."""
|
|
127
|
-
return f'fallback:{",".join(model.model_name for model in self.models)}'
|
|
128
|
-
|
|
129
|
-
@property
|
|
130
|
-
def system(self) -> str:
|
|
131
|
-
return f'fallback:{",".join(model.system for model in self.models)}'
|
|
132
|
-
|
|
133
|
-
@property
|
|
134
|
-
def base_url(self) -> str | None:
|
|
135
|
-
return self.models[0].base_url
|
|
136
|
-
|
|
137
137
|
|
|
138
138
|
def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
|
|
139
139
|
"""Create a default fallback condition for the given exceptions."""
|
pydantic_ai/models/function.py
CHANGED
|
@@ -138,7 +138,6 @@ class FunctionModel(Model):
|
|
|
138
138
|
# Add usage data if not already present
|
|
139
139
|
if not response.usage.has_values(): # pragma: no branch
|
|
140
140
|
response.usage = _estimate_usage(chain(messages, [response]))
|
|
141
|
-
response.usage.requests = 1
|
|
142
141
|
return response
|
|
143
142
|
|
|
144
143
|
@asynccontextmanager
|
|
@@ -270,7 +269,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
270
269
|
async for item in self._iter:
|
|
271
270
|
if isinstance(item, str):
|
|
272
271
|
response_tokens = _estimate_string_tokens(item)
|
|
273
|
-
self._usage += usage.
|
|
272
|
+
self._usage += usage.RequestUsage(output_tokens=response_tokens)
|
|
274
273
|
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
|
|
275
274
|
if maybe_event is not None: # pragma: no branch
|
|
276
275
|
yield maybe_event
|
|
@@ -279,7 +278,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
279
278
|
if isinstance(delta, DeltaThinkingPart):
|
|
280
279
|
if delta.content: # pragma: no branch
|
|
281
280
|
response_tokens = _estimate_string_tokens(delta.content)
|
|
282
|
-
self._usage += usage.
|
|
281
|
+
self._usage += usage.RequestUsage(output_tokens=response_tokens)
|
|
283
282
|
yield self._parts_manager.handle_thinking_delta(
|
|
284
283
|
vendor_part_id=dtc_index,
|
|
285
284
|
content=delta.content,
|
|
@@ -288,7 +287,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
288
287
|
elif isinstance(delta, DeltaToolCall):
|
|
289
288
|
if delta.json_args:
|
|
290
289
|
response_tokens = _estimate_string_tokens(delta.json_args)
|
|
291
|
-
self._usage += usage.
|
|
290
|
+
self._usage += usage.RequestUsage(output_tokens=response_tokens)
|
|
292
291
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
293
292
|
vendor_part_id=dtc_index,
|
|
294
293
|
tool_name=delta.name,
|
|
@@ -311,7 +310,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
311
310
|
return self._timestamp
|
|
312
311
|
|
|
313
312
|
|
|
314
|
-
def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.
|
|
313
|
+
def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
|
|
315
314
|
"""Very rough guesstimate of the token usage associated with a series of messages.
|
|
316
315
|
|
|
317
316
|
This is designed to be used solely to give plausible numbers for testing!
|
|
@@ -349,10 +348,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
|
349
348
|
assert_never(part)
|
|
350
349
|
else:
|
|
351
350
|
assert_never(message)
|
|
352
|
-
return usage.
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
total_tokens=request_tokens + response_tokens,
|
|
351
|
+
return usage.RequestUsage(
|
|
352
|
+
input_tokens=request_tokens,
|
|
353
|
+
output_tokens=response_tokens,
|
|
356
354
|
)
|
|
357
355
|
|
|
358
356
|
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -40,14 +40,7 @@ from ..profiles import ModelProfileSpec
|
|
|
40
40
|
from ..providers import Provider, infer_provider
|
|
41
41
|
from ..settings import ModelSettings
|
|
42
42
|
from ..tools import ToolDefinition
|
|
43
|
-
from . import
|
|
44
|
-
Model,
|
|
45
|
-
ModelRequestParameters,
|
|
46
|
-
StreamedResponse,
|
|
47
|
-
check_allow_model_requests,
|
|
48
|
-
download_item,
|
|
49
|
-
get_user_agent,
|
|
50
|
-
)
|
|
43
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
|
|
51
44
|
|
|
52
45
|
LatestGeminiModelNames = Literal[
|
|
53
46
|
'gemini-2.0-flash',
|
|
@@ -108,10 +101,9 @@ class GeminiModel(Model):
|
|
|
108
101
|
client: httpx.AsyncClient = field(repr=False)
|
|
109
102
|
|
|
110
103
|
_model_name: GeminiModelName = field(repr=False)
|
|
111
|
-
_provider:
|
|
104
|
+
_provider: Provider[httpx.AsyncClient] = field(repr=False)
|
|
112
105
|
_auth: AuthProtocol | None = field(repr=False)
|
|
113
106
|
_url: str | None = field(repr=False)
|
|
114
|
-
_system: str = field(default='gemini', repr=False)
|
|
115
107
|
|
|
116
108
|
def __init__(
|
|
117
109
|
self,
|
|
@@ -132,11 +124,10 @@ class GeminiModel(Model):
|
|
|
132
124
|
settings: Default model settings for this model instance.
|
|
133
125
|
"""
|
|
134
126
|
self._model_name = model_name
|
|
135
|
-
self._provider = provider
|
|
136
127
|
|
|
137
128
|
if isinstance(provider, str):
|
|
138
129
|
provider = infer_provider(provider)
|
|
139
|
-
self.
|
|
130
|
+
self._provider = provider
|
|
140
131
|
self.client = provider.client
|
|
141
132
|
self._url = str(self.client.base_url)
|
|
142
133
|
|
|
@@ -147,6 +138,16 @@ class GeminiModel(Model):
|
|
|
147
138
|
assert self._url is not None, 'URL not initialized' # pragma: no cover
|
|
148
139
|
return self._url # pragma: no cover
|
|
149
140
|
|
|
141
|
+
@property
|
|
142
|
+
def model_name(self) -> GeminiModelName:
|
|
143
|
+
"""The model name."""
|
|
144
|
+
return self._model_name
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def system(self) -> str:
|
|
148
|
+
"""The model provider."""
|
|
149
|
+
return self._provider.name
|
|
150
|
+
|
|
150
151
|
async def request(
|
|
151
152
|
self,
|
|
152
153
|
messages: list[ModelMessage],
|
|
@@ -175,16 +176,6 @@ class GeminiModel(Model):
|
|
|
175
176
|
) as http_response:
|
|
176
177
|
yield await self._process_streamed_response(http_response, model_request_parameters)
|
|
177
178
|
|
|
178
|
-
@property
|
|
179
|
-
def model_name(self) -> GeminiModelName:
|
|
180
|
-
"""The model name."""
|
|
181
|
-
return self._model_name
|
|
182
|
-
|
|
183
|
-
@property
|
|
184
|
-
def system(self) -> str:
|
|
185
|
-
"""The system / model provider."""
|
|
186
|
-
return self._system
|
|
187
|
-
|
|
188
179
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
|
|
189
180
|
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
|
|
190
181
|
return _GeminiTools(function_declarations=tools) if tools else None
|
|
@@ -237,7 +228,7 @@ class GeminiModel(Model):
|
|
|
237
228
|
request_data['safetySettings'] = gemini_safety_settings
|
|
238
229
|
|
|
239
230
|
if gemini_labels := model_settings.get('gemini_labels'):
|
|
240
|
-
if self.
|
|
231
|
+
if self._provider.name == 'google-vertex':
|
|
241
232
|
request_data['labels'] = gemini_labels # pragma: lax no cover
|
|
242
233
|
|
|
243
234
|
headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
@@ -278,7 +269,6 @@ class GeminiModel(Model):
|
|
|
278
269
|
if finish_reason:
|
|
279
270
|
vendor_details = {'finish_reason': finish_reason}
|
|
280
271
|
usage = _metadata_as_usage(response)
|
|
281
|
-
usage.requests = 1
|
|
282
272
|
return _process_response_from_parts(
|
|
283
273
|
parts,
|
|
284
274
|
response.get('model_version', self._model_name),
|
|
@@ -673,7 +663,7 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
|
|
|
673
663
|
def _process_response_from_parts(
|
|
674
664
|
parts: Sequence[_GeminiPartUnion],
|
|
675
665
|
model_name: GeminiModelName,
|
|
676
|
-
usage: usage.
|
|
666
|
+
usage: usage.RequestUsage,
|
|
677
667
|
vendor_id: str | None,
|
|
678
668
|
vendor_details: dict[str, Any] | None = None,
|
|
679
669
|
) -> ModelResponse:
|
|
@@ -693,7 +683,7 @@ def _process_response_from_parts(
|
|
|
693
683
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
694
684
|
)
|
|
695
685
|
return ModelResponse(
|
|
696
|
-
parts=items, usage=usage, model_name=model_name,
|
|
686
|
+
parts=items, usage=usage, model_name=model_name, provider_request_id=vendor_id, provider_details=vendor_details
|
|
697
687
|
)
|
|
698
688
|
|
|
699
689
|
|
|
@@ -859,31 +849,45 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
859
849
|
]
|
|
860
850
|
|
|
861
851
|
|
|
862
|
-
def _metadata_as_usage(response: _GeminiResponse) -> usage.
|
|
852
|
+
def _metadata_as_usage(response: _GeminiResponse) -> usage.RequestUsage:
|
|
863
853
|
metadata = response.get('usage_metadata')
|
|
864
854
|
if metadata is None:
|
|
865
|
-
return usage.
|
|
855
|
+
return usage.RequestUsage()
|
|
866
856
|
details: dict[str, int] = {}
|
|
867
|
-
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
868
|
-
details['cached_content_tokens'] = cached_content_token_count
|
|
857
|
+
if cached_content_token_count := metadata.get('cached_content_token_count', 0):
|
|
858
|
+
details['cached_content_tokens'] = cached_content_token_count
|
|
869
859
|
|
|
870
|
-
if thoughts_token_count := metadata.get('thoughts_token_count'):
|
|
860
|
+
if thoughts_token_count := metadata.get('thoughts_token_count', 0):
|
|
871
861
|
details['thoughts_tokens'] = thoughts_token_count
|
|
872
862
|
|
|
873
|
-
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
|
|
874
|
-
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
|
|
863
|
+
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count', 0):
|
|
864
|
+
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
|
|
875
865
|
|
|
866
|
+
input_audio_tokens = 0
|
|
867
|
+
output_audio_tokens = 0
|
|
868
|
+
cache_audio_read_tokens = 0
|
|
876
869
|
for key, metadata_details in metadata.items():
|
|
877
870
|
if key.endswith('_details') and metadata_details:
|
|
878
871
|
metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
|
|
879
872
|
suffix = key.removesuffix('_details')
|
|
880
873
|
for detail in metadata_details:
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
874
|
+
modality = detail['modality']
|
|
875
|
+
details[f'{modality.lower()}_{suffix}'] = value = detail.get('token_count', 0)
|
|
876
|
+
if value and modality == 'AUDIO':
|
|
877
|
+
if key == 'prompt_tokens_details':
|
|
878
|
+
input_audio_tokens = value
|
|
879
|
+
elif key == 'candidates_tokens_details':
|
|
880
|
+
output_audio_tokens = value
|
|
881
|
+
elif key == 'cache_tokens_details': # pragma: no branch
|
|
882
|
+
cache_audio_read_tokens = value
|
|
883
|
+
|
|
884
|
+
return usage.RequestUsage(
|
|
885
|
+
input_tokens=metadata.get('prompt_token_count', 0),
|
|
886
|
+
output_tokens=metadata.get('candidates_token_count', 0),
|
|
887
|
+
cache_read_tokens=cached_content_token_count,
|
|
888
|
+
input_audio_tokens=input_audio_tokens,
|
|
889
|
+
output_audio_tokens=output_audio_tokens,
|
|
890
|
+
cache_audio_read_tokens=cache_audio_read_tokens,
|
|
887
891
|
details=details,
|
|
888
892
|
)
|
|
889
893
|
|