pydantic-ai-slim 0.7.2__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -1
- pydantic_ai/_agent_graph.py +2 -2
- pydantic_ai/_cli.py +18 -3
- pydantic_ai/_run_context.py +2 -2
- pydantic_ai/ag_ui.py +4 -4
- pydantic_ai/agent/__init__.py +7 -9
- pydantic_ai/agent/abstract.py +16 -18
- pydantic_ai/agent/wrapper.py +4 -6
- pydantic_ai/builtin_tools.py +9 -1
- pydantic_ai/direct.py +4 -4
- pydantic_ai/durable_exec/temporal/_agent.py +13 -15
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/messages.py +16 -6
- pydantic_ai/models/__init__.py +5 -5
- pydantic_ai/models/anthropic.py +27 -26
- pydantic_ai/models/bedrock.py +24 -26
- pydantic_ai/models/cohere.py +20 -25
- pydantic_ai/models/fallback.py +15 -15
- pydantic_ai/models/function.py +7 -9
- pydantic_ai/models/gemini.py +43 -39
- pydantic_ai/models/google.py +76 -50
- pydantic_ai/models/groq.py +22 -19
- pydantic_ai/models/huggingface.py +18 -21
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mcp_sampling.py +1 -2
- pydantic_ai/models/mistral.py +24 -22
- pydantic_ai/models/openai.py +98 -44
- pydantic_ai/models/test.py +4 -5
- pydantic_ai/profiles/openai.py +13 -3
- pydantic_ai/providers/openai.py +1 -1
- pydantic_ai/result.py +5 -5
- pydantic_ai/run.py +4 -11
- pydantic_ai/tools.py +5 -2
- pydantic_ai/usage.py +230 -68
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/METADATA +10 -4
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/RECORD +39 -39
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/google.py
CHANGED
|
@@ -13,7 +13,7 @@ from typing_extensions import assert_never
|
|
|
13
13
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._output import OutputObjectDefinition
|
|
15
15
|
from .._run_context import RunContext
|
|
16
|
-
from ..builtin_tools import CodeExecutionTool, WebSearchTool
|
|
16
|
+
from ..builtin_tools import CodeExecutionTool, UrlContextTool, WebSearchTool
|
|
17
17
|
from ..exceptions import UserError
|
|
18
18
|
from ..messages import (
|
|
19
19
|
BinaryContent,
|
|
@@ -72,6 +72,7 @@ try:
|
|
|
72
72
|
ToolConfigDict,
|
|
73
73
|
ToolDict,
|
|
74
74
|
ToolListUnionDict,
|
|
75
|
+
UrlContextDict,
|
|
75
76
|
)
|
|
76
77
|
|
|
77
78
|
from ..providers.google import GoogleProvider
|
|
@@ -144,7 +145,6 @@ class GoogleModel(Model):
|
|
|
144
145
|
_model_name: GoogleModelName = field(repr=False)
|
|
145
146
|
_provider: Provider[Client] = field(repr=False)
|
|
146
147
|
_url: str | None = field(repr=False)
|
|
147
|
-
_system: str = field(default='google', repr=False)
|
|
148
148
|
|
|
149
149
|
def __init__(
|
|
150
150
|
self,
|
|
@@ -168,9 +168,7 @@ class GoogleModel(Model):
|
|
|
168
168
|
|
|
169
169
|
if isinstance(provider, str):
|
|
170
170
|
provider = GoogleProvider(vertexai=provider == 'google-vertex')
|
|
171
|
-
|
|
172
171
|
self._provider = provider
|
|
173
|
-
self._system = provider.name
|
|
174
172
|
self.client = provider.client
|
|
175
173
|
|
|
176
174
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -179,6 +177,16 @@ class GoogleModel(Model):
|
|
|
179
177
|
def base_url(self) -> str:
|
|
180
178
|
return self._provider.base_url
|
|
181
179
|
|
|
180
|
+
@property
|
|
181
|
+
def model_name(self) -> GoogleModelName:
|
|
182
|
+
"""The model name."""
|
|
183
|
+
return self._model_name
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def system(self) -> str:
|
|
187
|
+
"""The model provider."""
|
|
188
|
+
return self._provider.name
|
|
189
|
+
|
|
182
190
|
async def request(
|
|
183
191
|
self,
|
|
184
192
|
messages: list[ModelMessage],
|
|
@@ -195,7 +203,7 @@ class GoogleModel(Model):
|
|
|
195
203
|
messages: list[ModelMessage],
|
|
196
204
|
model_settings: ModelSettings | None,
|
|
197
205
|
model_request_parameters: ModelRequestParameters,
|
|
198
|
-
) -> usage.
|
|
206
|
+
) -> usage.RequestUsage:
|
|
199
207
|
check_allow_model_requests()
|
|
200
208
|
model_settings = cast(GoogleModelSettings, model_settings or {})
|
|
201
209
|
contents, generation_config = await self._build_content_and_config(
|
|
@@ -209,9 +217,9 @@ class GoogleModel(Model):
|
|
|
209
217
|
config = CountTokensConfigDict(
|
|
210
218
|
http_options=generation_config.get('http_options'),
|
|
211
219
|
)
|
|
212
|
-
if self.
|
|
220
|
+
if self._provider.name != 'google-gla':
|
|
213
221
|
# The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
|
|
214
|
-
config.update(
|
|
222
|
+
config.update( # pragma: lax no cover
|
|
215
223
|
system_instruction=generation_config.get('system_instruction'),
|
|
216
224
|
tools=cast(list[ToolDict], generation_config.get('tools')),
|
|
217
225
|
# Annoyingly, GenerationConfigDict has fewer fields than GenerateContentConfigDict, and no extra fields are allowed.
|
|
@@ -238,9 +246,8 @@ class GoogleModel(Model):
|
|
|
238
246
|
raise UnexpectedModelBehavior( # pragma: no cover
|
|
239
247
|
'Total tokens missing from Gemini response', str(response)
|
|
240
248
|
)
|
|
241
|
-
return usage.
|
|
242
|
-
|
|
243
|
-
total_tokens=response.total_tokens,
|
|
249
|
+
return usage.RequestUsage(
|
|
250
|
+
input_tokens=response.total_tokens,
|
|
244
251
|
)
|
|
245
252
|
|
|
246
253
|
@asynccontextmanager
|
|
@@ -256,16 +263,6 @@ class GoogleModel(Model):
|
|
|
256
263
|
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
|
|
257
264
|
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
|
|
258
265
|
|
|
259
|
-
@property
|
|
260
|
-
def model_name(self) -> GoogleModelName:
|
|
261
|
-
"""The model name."""
|
|
262
|
-
return self._model_name
|
|
263
|
-
|
|
264
|
-
@property
|
|
265
|
-
def system(self) -> str:
|
|
266
|
-
"""The system / model provider."""
|
|
267
|
-
return self._system
|
|
268
|
-
|
|
269
266
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
|
|
270
267
|
tools: list[ToolDict] = [
|
|
271
268
|
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
@@ -274,6 +271,8 @@ class GoogleModel(Model):
|
|
|
274
271
|
for tool in model_request_parameters.builtin_tools:
|
|
275
272
|
if isinstance(tool, WebSearchTool):
|
|
276
273
|
tools.append(ToolDict(google_search=GoogleSearchDict()))
|
|
274
|
+
elif isinstance(tool, UrlContextTool):
|
|
275
|
+
tools.append(ToolDict(url_context=UrlContextDict()))
|
|
277
276
|
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
278
277
|
tools.append(ToolDict(code_execution=ToolCodeExecutionDict()))
|
|
279
278
|
else: # pragma: no cover
|
|
@@ -378,23 +377,27 @@ class GoogleModel(Model):
|
|
|
378
377
|
def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
|
|
379
378
|
if not response.candidates or len(response.candidates) != 1:
|
|
380
379
|
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
|
|
381
|
-
|
|
382
|
-
|
|
380
|
+
candidate = response.candidates[0]
|
|
381
|
+
if candidate.content is None or candidate.content.parts is None:
|
|
382
|
+
if candidate.finish_reason == 'SAFETY':
|
|
383
383
|
raise UnexpectedModelBehavior('Safety settings triggered', str(response))
|
|
384
384
|
else:
|
|
385
385
|
raise UnexpectedModelBehavior(
|
|
386
386
|
'Content field missing from Gemini response', str(response)
|
|
387
387
|
) # pragma: no cover
|
|
388
|
-
parts =
|
|
388
|
+
parts = candidate.content.parts or []
|
|
389
389
|
vendor_id = response.response_id or None
|
|
390
390
|
vendor_details: dict[str, Any] | None = None
|
|
391
|
-
finish_reason =
|
|
391
|
+
finish_reason = candidate.finish_reason
|
|
392
392
|
if finish_reason: # pragma: no branch
|
|
393
393
|
vendor_details = {'finish_reason': finish_reason.value}
|
|
394
394
|
usage = _metadata_as_usage(response)
|
|
395
|
-
usage.requests = 1
|
|
396
395
|
return _process_response_from_parts(
|
|
397
|
-
parts,
|
|
396
|
+
parts,
|
|
397
|
+
response.model_version or self._model_name,
|
|
398
|
+
usage,
|
|
399
|
+
vendor_id=vendor_id,
|
|
400
|
+
vendor_details=vendor_details,
|
|
398
401
|
)
|
|
399
402
|
|
|
400
403
|
async def _process_streamed_response(
|
|
@@ -527,10 +530,13 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
527
530
|
|
|
528
531
|
assert chunk.candidates is not None
|
|
529
532
|
candidate = chunk.candidates[0]
|
|
530
|
-
if candidate.content is None:
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
533
|
+
if candidate.content is None or candidate.content.parts is None:
|
|
534
|
+
if candidate.finish_reason == 'SAFETY': # pragma: no cover
|
|
535
|
+
raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
|
|
536
|
+
else: # pragma: no cover
|
|
537
|
+
raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
|
|
538
|
+
parts = candidate.content.parts or []
|
|
539
|
+
for part in parts:
|
|
534
540
|
if part.text is not None:
|
|
535
541
|
if part.thought:
|
|
536
542
|
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
|
|
@@ -590,7 +596,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
|
590
596
|
def _process_response_from_parts(
|
|
591
597
|
parts: list[Part],
|
|
592
598
|
model_name: GoogleModelName,
|
|
593
|
-
usage: usage.
|
|
599
|
+
usage: usage.RequestUsage,
|
|
594
600
|
vendor_id: str | None,
|
|
595
601
|
vendor_details: dict[str, Any] | None = None,
|
|
596
602
|
) -> ModelResponse:
|
|
@@ -627,7 +633,7 @@ def _process_response_from_parts(
|
|
|
627
633
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
628
634
|
)
|
|
629
635
|
return ModelResponse(
|
|
630
|
-
parts=items, model_name=model_name, usage=usage,
|
|
636
|
+
parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details
|
|
631
637
|
)
|
|
632
638
|
|
|
633
639
|
|
|
@@ -647,31 +653,51 @@ def _tool_config(function_names: list[str]) -> ToolConfigDict:
|
|
|
647
653
|
return ToolConfigDict(function_calling_config=function_calling_config)
|
|
648
654
|
|
|
649
655
|
|
|
650
|
-
def _metadata_as_usage(response: GenerateContentResponse) -> usage.
|
|
656
|
+
def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
|
|
651
657
|
metadata = response.usage_metadata
|
|
652
658
|
if metadata is None:
|
|
653
|
-
return usage.
|
|
654
|
-
metadata = metadata.model_dump(exclude_defaults=True)
|
|
655
|
-
|
|
659
|
+
return usage.RequestUsage()
|
|
656
660
|
details: dict[str, int] = {}
|
|
657
|
-
if cached_content_token_count := metadata.
|
|
658
|
-
details['cached_content_tokens'] = cached_content_token_count
|
|
661
|
+
if cached_content_token_count := metadata.cached_content_token_count:
|
|
662
|
+
details['cached_content_tokens'] = cached_content_token_count
|
|
659
663
|
|
|
660
|
-
if thoughts_token_count := metadata.
|
|
664
|
+
if thoughts_token_count := metadata.thoughts_token_count:
|
|
661
665
|
details['thoughts_tokens'] = thoughts_token_count
|
|
662
666
|
|
|
663
|
-
if tool_use_prompt_token_count := metadata.
|
|
667
|
+
if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count:
|
|
664
668
|
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
|
|
665
669
|
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
670
|
+
input_audio_tokens = 0
|
|
671
|
+
output_audio_tokens = 0
|
|
672
|
+
cache_audio_read_tokens = 0
|
|
673
|
+
for prefix, metadata_details in [
|
|
674
|
+
('prompt', metadata.prompt_tokens_details),
|
|
675
|
+
('cache', metadata.cache_tokens_details),
|
|
676
|
+
('candidates', metadata.candidates_tokens_details),
|
|
677
|
+
('tool_use_prompt', metadata.tool_use_prompt_tokens_details),
|
|
678
|
+
]:
|
|
679
|
+
assert getattr(metadata, f'{prefix}_tokens_details') is metadata_details
|
|
680
|
+
if not metadata_details:
|
|
681
|
+
continue
|
|
682
|
+
for detail in metadata_details:
|
|
683
|
+
if not detail.modality or not detail.token_count: # pragma: no cover
|
|
684
|
+
continue
|
|
685
|
+
details[f'{detail.modality.lower()}_{prefix}_tokens'] = detail.token_count
|
|
686
|
+
if detail.modality != 'AUDIO':
|
|
687
|
+
continue
|
|
688
|
+
if metadata_details is metadata.prompt_tokens_details:
|
|
689
|
+
input_audio_tokens = detail.token_count
|
|
690
|
+
elif metadata_details is metadata.candidates_tokens_details:
|
|
691
|
+
output_audio_tokens = detail.token_count
|
|
692
|
+
elif metadata_details is metadata.cache_tokens_details: # pragma: no branch
|
|
693
|
+
cache_audio_read_tokens = detail.token_count
|
|
694
|
+
|
|
695
|
+
return usage.RequestUsage(
|
|
696
|
+
input_tokens=metadata.prompt_token_count or 0,
|
|
697
|
+
output_tokens=metadata.candidates_token_count or 0,
|
|
698
|
+
cache_read_tokens=cached_content_token_count or 0,
|
|
699
|
+
input_audio_tokens=input_audio_tokens,
|
|
700
|
+
output_audio_tokens=output_audio_tokens,
|
|
701
|
+
cache_audio_read_tokens=cache_audio_read_tokens,
|
|
676
702
|
details=details,
|
|
677
703
|
)
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -118,7 +118,7 @@ class GroqModel(Model):
|
|
|
118
118
|
client: AsyncGroq = field(repr=False)
|
|
119
119
|
|
|
120
120
|
_model_name: GroqModelName = field(repr=False)
|
|
121
|
-
|
|
121
|
+
_provider: Provider[AsyncGroq] = field(repr=False)
|
|
122
122
|
|
|
123
123
|
def __init__(
|
|
124
124
|
self,
|
|
@@ -143,6 +143,7 @@ class GroqModel(Model):
|
|
|
143
143
|
|
|
144
144
|
if isinstance(provider, str):
|
|
145
145
|
provider = infer_provider(provider)
|
|
146
|
+
self._provider = provider
|
|
146
147
|
self.client = provider.client
|
|
147
148
|
|
|
148
149
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -151,6 +152,16 @@ class GroqModel(Model):
|
|
|
151
152
|
def base_url(self) -> str:
|
|
152
153
|
return str(self.client.base_url)
|
|
153
154
|
|
|
155
|
+
@property
|
|
156
|
+
def model_name(self) -> GroqModelName:
|
|
157
|
+
"""The model name."""
|
|
158
|
+
return self._model_name
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def system(self) -> str:
|
|
162
|
+
"""The model provider."""
|
|
163
|
+
return self._provider.name
|
|
164
|
+
|
|
154
165
|
async def request(
|
|
155
166
|
self,
|
|
156
167
|
messages: list[ModelMessage],
|
|
@@ -162,7 +173,6 @@ class GroqModel(Model):
|
|
|
162
173
|
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
163
174
|
)
|
|
164
175
|
model_response = self._process_response(response)
|
|
165
|
-
model_response.usage.requests = 1
|
|
166
176
|
return model_response
|
|
167
177
|
|
|
168
178
|
@asynccontextmanager
|
|
@@ -180,16 +190,6 @@ class GroqModel(Model):
|
|
|
180
190
|
async with response:
|
|
181
191
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
182
192
|
|
|
183
|
-
@property
|
|
184
|
-
def model_name(self) -> GroqModelName:
|
|
185
|
-
"""The model name."""
|
|
186
|
-
return self._model_name
|
|
187
|
-
|
|
188
|
-
@property
|
|
189
|
-
def system(self) -> str:
|
|
190
|
-
"""The system / model provider."""
|
|
191
|
-
return self._system
|
|
192
|
-
|
|
193
193
|
@overload
|
|
194
194
|
async def _completions_create(
|
|
195
195
|
self,
|
|
@@ -285,7 +285,11 @@ class GroqModel(Model):
|
|
|
285
285
|
for c in choice.message.tool_calls:
|
|
286
286
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
287
287
|
return ModelResponse(
|
|
288
|
-
items,
|
|
288
|
+
items,
|
|
289
|
+
usage=_map_usage(response),
|
|
290
|
+
model_name=response.model,
|
|
291
|
+
timestamp=timestamp,
|
|
292
|
+
provider_request_id=response.id,
|
|
289
293
|
)
|
|
290
294
|
|
|
291
295
|
async def _process_streamed_response(
|
|
@@ -484,7 +488,7 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
484
488
|
return self._timestamp
|
|
485
489
|
|
|
486
490
|
|
|
487
|
-
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.
|
|
491
|
+
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.RequestUsage:
|
|
488
492
|
response_usage = None
|
|
489
493
|
if isinstance(completion, chat.ChatCompletion):
|
|
490
494
|
response_usage = completion.usage
|
|
@@ -492,10 +496,9 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
|
|
492
496
|
response_usage = completion.x_groq.usage
|
|
493
497
|
|
|
494
498
|
if response_usage is None:
|
|
495
|
-
return usage.
|
|
499
|
+
return usage.RequestUsage()
|
|
496
500
|
|
|
497
|
-
return usage.
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
total_tokens=response_usage.total_tokens,
|
|
501
|
+
return usage.RequestUsage(
|
|
502
|
+
input_tokens=response_usage.prompt_tokens,
|
|
503
|
+
output_tokens=response_usage.completion_tokens,
|
|
501
504
|
)
|
|
@@ -114,7 +114,7 @@ class HuggingFaceModel(Model):
|
|
|
114
114
|
client: AsyncInferenceClient = field(repr=False)
|
|
115
115
|
|
|
116
116
|
_model_name: str = field(repr=False)
|
|
117
|
-
|
|
117
|
+
_provider: Provider[AsyncInferenceClient] = field(repr=False)
|
|
118
118
|
|
|
119
119
|
def __init__(
|
|
120
120
|
self,
|
|
@@ -134,13 +134,23 @@ class HuggingFaceModel(Model):
|
|
|
134
134
|
settings: Model-specific settings that will be used as defaults for this model.
|
|
135
135
|
"""
|
|
136
136
|
self._model_name = model_name
|
|
137
|
-
self._provider = provider
|
|
138
137
|
if isinstance(provider, str):
|
|
139
138
|
provider = infer_provider(provider)
|
|
139
|
+
self._provider = provider
|
|
140
140
|
self.client = provider.client
|
|
141
141
|
|
|
142
142
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
143
143
|
|
|
144
|
+
@property
|
|
145
|
+
def model_name(self) -> HuggingFaceModelName:
|
|
146
|
+
"""The model name."""
|
|
147
|
+
return self._model_name
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def system(self) -> str:
|
|
151
|
+
"""The system / model provider."""
|
|
152
|
+
return self._provider.name
|
|
153
|
+
|
|
144
154
|
async def request(
|
|
145
155
|
self,
|
|
146
156
|
messages: list[ModelMessage],
|
|
@@ -152,7 +162,6 @@ class HuggingFaceModel(Model):
|
|
|
152
162
|
messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
153
163
|
)
|
|
154
164
|
model_response = self._process_response(response)
|
|
155
|
-
model_response.usage.requests = 1
|
|
156
165
|
return model_response
|
|
157
166
|
|
|
158
167
|
@asynccontextmanager
|
|
@@ -169,16 +178,6 @@ class HuggingFaceModel(Model):
|
|
|
169
178
|
)
|
|
170
179
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
171
180
|
|
|
172
|
-
@property
|
|
173
|
-
def model_name(self) -> HuggingFaceModelName:
|
|
174
|
-
"""The model name."""
|
|
175
|
-
return self._model_name
|
|
176
|
-
|
|
177
|
-
@property
|
|
178
|
-
def system(self) -> str:
|
|
179
|
-
"""The system / model provider."""
|
|
180
|
-
return self._system
|
|
181
|
-
|
|
182
181
|
@overload
|
|
183
182
|
async def _completions_create(
|
|
184
183
|
self,
|
|
@@ -272,7 +271,7 @@ class HuggingFaceModel(Model):
|
|
|
272
271
|
usage=_map_usage(response),
|
|
273
272
|
model_name=response.model,
|
|
274
273
|
timestamp=timestamp,
|
|
275
|
-
|
|
274
|
+
provider_request_id=response.id,
|
|
276
275
|
)
|
|
277
276
|
|
|
278
277
|
async def _process_streamed_response(
|
|
@@ -481,14 +480,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
481
480
|
return self._timestamp
|
|
482
481
|
|
|
483
482
|
|
|
484
|
-
def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.
|
|
483
|
+
def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.RequestUsage:
|
|
485
484
|
response_usage = response.usage
|
|
486
485
|
if response_usage is None:
|
|
487
|
-
return usage.
|
|
486
|
+
return usage.RequestUsage()
|
|
488
487
|
|
|
489
|
-
return usage.
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
total_tokens=response_usage.total_tokens,
|
|
493
|
-
details=None,
|
|
488
|
+
return usage.RequestUsage(
|
|
489
|
+
input_tokens=response_usage.prompt_tokens,
|
|
490
|
+
output_tokens=response_usage.completion_tokens,
|
|
494
491
|
)
|
|
@@ -280,14 +280,14 @@ class InstrumentedModel(WrapperModel):
|
|
|
280
280
|
'gen_ai.request.model': request_model,
|
|
281
281
|
'gen_ai.response.model': response_model,
|
|
282
282
|
}
|
|
283
|
-
if response.usage.
|
|
283
|
+
if response.usage.input_tokens: # pragma: no branch
|
|
284
284
|
self.instrumentation_settings.tokens_histogram.record(
|
|
285
|
-
response.usage.
|
|
285
|
+
response.usage.input_tokens,
|
|
286
286
|
{**metric_attributes, 'gen_ai.token.type': 'input'},
|
|
287
287
|
)
|
|
288
|
-
if response.usage.
|
|
288
|
+
if response.usage.output_tokens: # pragma: no branch
|
|
289
289
|
self.instrumentation_settings.tokens_histogram.record(
|
|
290
|
-
response.usage.
|
|
290
|
+
response.usage.output_tokens,
|
|
291
291
|
{**metric_attributes, 'gen_ai.token.type': 'output'},
|
|
292
292
|
)
|
|
293
293
|
|
|
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
|
-
from .. import _mcp, exceptions
|
|
8
|
+
from .. import _mcp, exceptions
|
|
9
9
|
from .._run_context import RunContext
|
|
10
10
|
from ..messages import ModelMessage, ModelResponse
|
|
11
11
|
from ..settings import ModelSettings
|
|
@@ -63,7 +63,6 @@ class MCPSamplingModel(Model):
|
|
|
63
63
|
if result.role == 'assistant':
|
|
64
64
|
return ModelResponse(
|
|
65
65
|
parts=[_mcp.map_from_sampling_content(result.content)],
|
|
66
|
-
usage=usage.Usage(requests=1),
|
|
67
66
|
model_name=result.model,
|
|
68
67
|
)
|
|
69
68
|
else:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -40,7 +40,7 @@ from ..profiles import ModelProfileSpec
|
|
|
40
40
|
from ..providers import Provider, infer_provider
|
|
41
41
|
from ..settings import ModelSettings
|
|
42
42
|
from ..tools import ToolDefinition
|
|
43
|
-
from ..usage import
|
|
43
|
+
from ..usage import RequestUsage
|
|
44
44
|
from . import (
|
|
45
45
|
Model,
|
|
46
46
|
ModelRequestParameters,
|
|
@@ -120,7 +120,7 @@ class MistralModel(Model):
|
|
|
120
120
|
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
|
|
121
121
|
|
|
122
122
|
_model_name: MistralModelName = field(repr=False)
|
|
123
|
-
|
|
123
|
+
_provider: Provider[Mistral] = field(repr=False)
|
|
124
124
|
|
|
125
125
|
def __init__(
|
|
126
126
|
self,
|
|
@@ -147,13 +147,24 @@ class MistralModel(Model):
|
|
|
147
147
|
|
|
148
148
|
if isinstance(provider, str):
|
|
149
149
|
provider = infer_provider(provider)
|
|
150
|
+
self._provider = provider
|
|
150
151
|
self.client = provider.client
|
|
151
152
|
|
|
152
153
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
153
154
|
|
|
154
155
|
@property
|
|
155
156
|
def base_url(self) -> str:
|
|
156
|
-
return self.
|
|
157
|
+
return self._provider.base_url
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def model_name(self) -> MistralModelName:
|
|
161
|
+
"""The model name."""
|
|
162
|
+
return self._model_name
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def system(self) -> str:
|
|
166
|
+
"""The model provider."""
|
|
167
|
+
return self._provider.name
|
|
157
168
|
|
|
158
169
|
async def request(
|
|
159
170
|
self,
|
|
@@ -167,7 +178,6 @@ class MistralModel(Model):
|
|
|
167
178
|
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
168
179
|
)
|
|
169
180
|
model_response = self._process_response(response)
|
|
170
|
-
model_response.usage.requests = 1
|
|
171
181
|
return model_response
|
|
172
182
|
|
|
173
183
|
@asynccontextmanager
|
|
@@ -186,16 +196,6 @@ class MistralModel(Model):
|
|
|
186
196
|
async with response:
|
|
187
197
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
188
198
|
|
|
189
|
-
@property
|
|
190
|
-
def model_name(self) -> MistralModelName:
|
|
191
|
-
"""The model name."""
|
|
192
|
-
return self._model_name
|
|
193
|
-
|
|
194
|
-
@property
|
|
195
|
-
def system(self) -> str:
|
|
196
|
-
"""The system / model provider."""
|
|
197
|
-
return self._system
|
|
198
|
-
|
|
199
199
|
async def _completions_create(
|
|
200
200
|
self,
|
|
201
201
|
messages: list[ModelMessage],
|
|
@@ -348,7 +348,11 @@ class MistralModel(Model):
|
|
|
348
348
|
parts.append(tool)
|
|
349
349
|
|
|
350
350
|
return ModelResponse(
|
|
351
|
-
parts,
|
|
351
|
+
parts,
|
|
352
|
+
usage=_map_usage(response),
|
|
353
|
+
model_name=response.model,
|
|
354
|
+
timestamp=timestamp,
|
|
355
|
+
provider_request_id=response.id,
|
|
352
356
|
)
|
|
353
357
|
|
|
354
358
|
async def _process_streamed_response(
|
|
@@ -699,17 +703,15 @@ SIMPLE_JSON_TYPE_MAPPING = {
|
|
|
699
703
|
}
|
|
700
704
|
|
|
701
705
|
|
|
702
|
-
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) ->
|
|
706
|
+
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> RequestUsage:
|
|
703
707
|
"""Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
|
|
704
708
|
if response.usage:
|
|
705
|
-
return
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
total_tokens=response.usage.total_tokens,
|
|
709
|
-
details=None,
|
|
709
|
+
return RequestUsage(
|
|
710
|
+
input_tokens=response.usage.prompt_tokens,
|
|
711
|
+
output_tokens=response.usage.completion_tokens,
|
|
710
712
|
)
|
|
711
713
|
else:
|
|
712
|
-
return
|
|
714
|
+
return RequestUsage() # pragma: no cover
|
|
713
715
|
|
|
714
716
|
|
|
715
717
|
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|