pydantic-ai-slim 0.7.2__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +2 -2
- pydantic_ai/_cli.py +18 -3
- pydantic_ai/_run_context.py +2 -2
- pydantic_ai/ag_ui.py +4 -4
- pydantic_ai/agent/__init__.py +7 -9
- pydantic_ai/agent/abstract.py +16 -18
- pydantic_ai/agent/wrapper.py +4 -6
- pydantic_ai/direct.py +4 -4
- pydantic_ai/durable_exec/temporal/_agent.py +13 -15
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/messages.py +16 -6
- pydantic_ai/models/__init__.py +5 -5
- pydantic_ai/models/anthropic.py +27 -26
- pydantic_ai/models/bedrock.py +24 -26
- pydantic_ai/models/cohere.py +20 -25
- pydantic_ai/models/fallback.py +15 -15
- pydantic_ai/models/function.py +7 -9
- pydantic_ai/models/gemini.py +43 -39
- pydantic_ai/models/google.py +59 -40
- pydantic_ai/models/groq.py +22 -19
- pydantic_ai/models/huggingface.py +18 -21
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mcp_sampling.py +1 -2
- pydantic_ai/models/mistral.py +24 -22
- pydantic_ai/models/openai.py +98 -44
- pydantic_ai/models/test.py +4 -5
- pydantic_ai/profiles/openai.py +13 -3
- pydantic_ai/providers/openai.py +1 -1
- pydantic_ai/result.py +5 -5
- pydantic_ai/run.py +4 -11
- pydantic_ai/usage.py +229 -67
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.3.dist-info}/METADATA +10 -4
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.3.dist-info}/RECORD +36 -36
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.3.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/google.py
CHANGED
|
@@ -144,7 +144,6 @@ class GoogleModel(Model):
|
|
|
144
144
|
_model_name: GoogleModelName = field(repr=False)
|
|
145
145
|
_provider: Provider[Client] = field(repr=False)
|
|
146
146
|
_url: str | None = field(repr=False)
|
|
147
|
-
_system: str = field(default='google', repr=False)
|
|
148
147
|
|
|
149
148
|
def __init__(
|
|
150
149
|
self,
|
|
@@ -168,9 +167,7 @@ class GoogleModel(Model):
|
|
|
168
167
|
|
|
169
168
|
if isinstance(provider, str):
|
|
170
169
|
provider = GoogleProvider(vertexai=provider == 'google-vertex')
|
|
171
|
-
|
|
172
170
|
self._provider = provider
|
|
173
|
-
self._system = provider.name
|
|
174
171
|
self.client = provider.client
|
|
175
172
|
|
|
176
173
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -179,6 +176,16 @@ class GoogleModel(Model):
|
|
|
179
176
|
def base_url(self) -> str:
|
|
180
177
|
return self._provider.base_url
|
|
181
178
|
|
|
179
|
+
@property
|
|
180
|
+
def model_name(self) -> GoogleModelName:
|
|
181
|
+
"""The model name."""
|
|
182
|
+
return self._model_name
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def system(self) -> str:
|
|
186
|
+
"""The model provider."""
|
|
187
|
+
return self._provider.name
|
|
188
|
+
|
|
182
189
|
async def request(
|
|
183
190
|
self,
|
|
184
191
|
messages: list[ModelMessage],
|
|
@@ -195,7 +202,7 @@ class GoogleModel(Model):
|
|
|
195
202
|
messages: list[ModelMessage],
|
|
196
203
|
model_settings: ModelSettings | None,
|
|
197
204
|
model_request_parameters: ModelRequestParameters,
|
|
198
|
-
) -> usage.
|
|
205
|
+
) -> usage.RequestUsage:
|
|
199
206
|
check_allow_model_requests()
|
|
200
207
|
model_settings = cast(GoogleModelSettings, model_settings or {})
|
|
201
208
|
contents, generation_config = await self._build_content_and_config(
|
|
@@ -209,7 +216,7 @@ class GoogleModel(Model):
|
|
|
209
216
|
config = CountTokensConfigDict(
|
|
210
217
|
http_options=generation_config.get('http_options'),
|
|
211
218
|
)
|
|
212
|
-
if self.
|
|
219
|
+
if self._provider.name != 'google-gla':
|
|
213
220
|
# The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
|
|
214
221
|
config.update(
|
|
215
222
|
system_instruction=generation_config.get('system_instruction'),
|
|
@@ -238,9 +245,8 @@ class GoogleModel(Model):
|
|
|
238
245
|
raise UnexpectedModelBehavior( # pragma: no cover
|
|
239
246
|
'Total tokens missing from Gemini response', str(response)
|
|
240
247
|
)
|
|
241
|
-
return usage.
|
|
242
|
-
|
|
243
|
-
total_tokens=response.total_tokens,
|
|
248
|
+
return usage.RequestUsage(
|
|
249
|
+
input_tokens=response.total_tokens,
|
|
244
250
|
)
|
|
245
251
|
|
|
246
252
|
@asynccontextmanager
|
|
@@ -256,16 +262,6 @@ class GoogleModel(Model):
|
|
|
256
262
|
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
|
|
257
263
|
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
|
|
258
264
|
|
|
259
|
-
@property
|
|
260
|
-
def model_name(self) -> GoogleModelName:
|
|
261
|
-
"""The model name."""
|
|
262
|
-
return self._model_name
|
|
263
|
-
|
|
264
|
-
@property
|
|
265
|
-
def system(self) -> str:
|
|
266
|
-
"""The system / model provider."""
|
|
267
|
-
return self._system
|
|
268
|
-
|
|
269
265
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
|
|
270
266
|
tools: list[ToolDict] = [
|
|
271
267
|
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
@@ -392,9 +388,12 @@ class GoogleModel(Model):
|
|
|
392
388
|
if finish_reason: # pragma: no branch
|
|
393
389
|
vendor_details = {'finish_reason': finish_reason.value}
|
|
394
390
|
usage = _metadata_as_usage(response)
|
|
395
|
-
usage.requests = 1
|
|
396
391
|
return _process_response_from_parts(
|
|
397
|
-
parts,
|
|
392
|
+
parts,
|
|
393
|
+
response.model_version or self._model_name,
|
|
394
|
+
usage,
|
|
395
|
+
vendor_id=vendor_id,
|
|
396
|
+
vendor_details=vendor_details,
|
|
398
397
|
)
|
|
399
398
|
|
|
400
399
|
async def _process_streamed_response(
|
|
@@ -590,7 +589,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
|
590
589
|
def _process_response_from_parts(
|
|
591
590
|
parts: list[Part],
|
|
592
591
|
model_name: GoogleModelName,
|
|
593
|
-
usage: usage.
|
|
592
|
+
usage: usage.RequestUsage,
|
|
594
593
|
vendor_id: str | None,
|
|
595
594
|
vendor_details: dict[str, Any] | None = None,
|
|
596
595
|
) -> ModelResponse:
|
|
@@ -627,7 +626,7 @@ def _process_response_from_parts(
|
|
|
627
626
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
628
627
|
)
|
|
629
628
|
return ModelResponse(
|
|
630
|
-
parts=items, model_name=model_name, usage=usage,
|
|
629
|
+
parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details
|
|
631
630
|
)
|
|
632
631
|
|
|
633
632
|
|
|
@@ -647,31 +646,51 @@ def _tool_config(function_names: list[str]) -> ToolConfigDict:
|
|
|
647
646
|
return ToolConfigDict(function_calling_config=function_calling_config)
|
|
648
647
|
|
|
649
648
|
|
|
650
|
-
def _metadata_as_usage(response: GenerateContentResponse) -> usage.
|
|
649
|
+
def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
|
|
651
650
|
metadata = response.usage_metadata
|
|
652
651
|
if metadata is None:
|
|
653
|
-
return usage.
|
|
654
|
-
metadata = metadata.model_dump(exclude_defaults=True)
|
|
655
|
-
|
|
652
|
+
return usage.RequestUsage()
|
|
656
653
|
details: dict[str, int] = {}
|
|
657
|
-
if cached_content_token_count := metadata.
|
|
658
|
-
details['cached_content_tokens'] = cached_content_token_count
|
|
654
|
+
if cached_content_token_count := metadata.cached_content_token_count:
|
|
655
|
+
details['cached_content_tokens'] = cached_content_token_count
|
|
659
656
|
|
|
660
|
-
if thoughts_token_count := metadata.
|
|
657
|
+
if thoughts_token_count := metadata.thoughts_token_count:
|
|
661
658
|
details['thoughts_tokens'] = thoughts_token_count
|
|
662
659
|
|
|
663
|
-
if tool_use_prompt_token_count := metadata.
|
|
660
|
+
if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count:
|
|
664
661
|
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
|
|
665
662
|
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
663
|
+
input_audio_tokens = 0
|
|
664
|
+
output_audio_tokens = 0
|
|
665
|
+
cache_audio_read_tokens = 0
|
|
666
|
+
for prefix, metadata_details in [
|
|
667
|
+
('prompt', metadata.prompt_tokens_details),
|
|
668
|
+
('cache', metadata.cache_tokens_details),
|
|
669
|
+
('candidates', metadata.candidates_tokens_details),
|
|
670
|
+
('tool_use_prompt', metadata.tool_use_prompt_tokens_details),
|
|
671
|
+
]:
|
|
672
|
+
assert getattr(metadata, f'{prefix}_tokens_details') is metadata_details
|
|
673
|
+
if not metadata_details:
|
|
674
|
+
continue
|
|
675
|
+
for detail in metadata_details:
|
|
676
|
+
if not detail.modality or not detail.token_count: # pragma: no cover
|
|
677
|
+
continue
|
|
678
|
+
details[f'{detail.modality.lower()}_{prefix}_tokens'] = detail.token_count
|
|
679
|
+
if detail.modality != 'AUDIO':
|
|
680
|
+
continue
|
|
681
|
+
if metadata_details is metadata.prompt_tokens_details:
|
|
682
|
+
input_audio_tokens = detail.token_count
|
|
683
|
+
elif metadata_details is metadata.candidates_tokens_details:
|
|
684
|
+
output_audio_tokens = detail.token_count
|
|
685
|
+
elif metadata_details is metadata.cache_tokens_details: # pragma: no branch
|
|
686
|
+
cache_audio_read_tokens = detail.token_count
|
|
687
|
+
|
|
688
|
+
return usage.RequestUsage(
|
|
689
|
+
input_tokens=metadata.prompt_token_count or 0,
|
|
690
|
+
output_tokens=metadata.candidates_token_count or 0,
|
|
691
|
+
cache_read_tokens=cached_content_token_count or 0,
|
|
692
|
+
input_audio_tokens=input_audio_tokens,
|
|
693
|
+
output_audio_tokens=output_audio_tokens,
|
|
694
|
+
cache_audio_read_tokens=cache_audio_read_tokens,
|
|
676
695
|
details=details,
|
|
677
696
|
)
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -118,7 +118,7 @@ class GroqModel(Model):
|
|
|
118
118
|
client: AsyncGroq = field(repr=False)
|
|
119
119
|
|
|
120
120
|
_model_name: GroqModelName = field(repr=False)
|
|
121
|
-
|
|
121
|
+
_provider: Provider[AsyncGroq] = field(repr=False)
|
|
122
122
|
|
|
123
123
|
def __init__(
|
|
124
124
|
self,
|
|
@@ -143,6 +143,7 @@ class GroqModel(Model):
|
|
|
143
143
|
|
|
144
144
|
if isinstance(provider, str):
|
|
145
145
|
provider = infer_provider(provider)
|
|
146
|
+
self._provider = provider
|
|
146
147
|
self.client = provider.client
|
|
147
148
|
|
|
148
149
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -151,6 +152,16 @@ class GroqModel(Model):
|
|
|
151
152
|
def base_url(self) -> str:
|
|
152
153
|
return str(self.client.base_url)
|
|
153
154
|
|
|
155
|
+
@property
|
|
156
|
+
def model_name(self) -> GroqModelName:
|
|
157
|
+
"""The model name."""
|
|
158
|
+
return self._model_name
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def system(self) -> str:
|
|
162
|
+
"""The model provider."""
|
|
163
|
+
return self._provider.name
|
|
164
|
+
|
|
154
165
|
async def request(
|
|
155
166
|
self,
|
|
156
167
|
messages: list[ModelMessage],
|
|
@@ -162,7 +173,6 @@ class GroqModel(Model):
|
|
|
162
173
|
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
163
174
|
)
|
|
164
175
|
model_response = self._process_response(response)
|
|
165
|
-
model_response.usage.requests = 1
|
|
166
176
|
return model_response
|
|
167
177
|
|
|
168
178
|
@asynccontextmanager
|
|
@@ -180,16 +190,6 @@ class GroqModel(Model):
|
|
|
180
190
|
async with response:
|
|
181
191
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
182
192
|
|
|
183
|
-
@property
|
|
184
|
-
def model_name(self) -> GroqModelName:
|
|
185
|
-
"""The model name."""
|
|
186
|
-
return self._model_name
|
|
187
|
-
|
|
188
|
-
@property
|
|
189
|
-
def system(self) -> str:
|
|
190
|
-
"""The system / model provider."""
|
|
191
|
-
return self._system
|
|
192
|
-
|
|
193
193
|
@overload
|
|
194
194
|
async def _completions_create(
|
|
195
195
|
self,
|
|
@@ -285,7 +285,11 @@ class GroqModel(Model):
|
|
|
285
285
|
for c in choice.message.tool_calls:
|
|
286
286
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
287
287
|
return ModelResponse(
|
|
288
|
-
items,
|
|
288
|
+
items,
|
|
289
|
+
usage=_map_usage(response),
|
|
290
|
+
model_name=response.model,
|
|
291
|
+
timestamp=timestamp,
|
|
292
|
+
provider_request_id=response.id,
|
|
289
293
|
)
|
|
290
294
|
|
|
291
295
|
async def _process_streamed_response(
|
|
@@ -484,7 +488,7 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
484
488
|
return self._timestamp
|
|
485
489
|
|
|
486
490
|
|
|
487
|
-
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.
|
|
491
|
+
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.RequestUsage:
|
|
488
492
|
response_usage = None
|
|
489
493
|
if isinstance(completion, chat.ChatCompletion):
|
|
490
494
|
response_usage = completion.usage
|
|
@@ -492,10 +496,9 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
|
|
492
496
|
response_usage = completion.x_groq.usage
|
|
493
497
|
|
|
494
498
|
if response_usage is None:
|
|
495
|
-
return usage.
|
|
499
|
+
return usage.RequestUsage()
|
|
496
500
|
|
|
497
|
-
return usage.
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
total_tokens=response_usage.total_tokens,
|
|
501
|
+
return usage.RequestUsage(
|
|
502
|
+
input_tokens=response_usage.prompt_tokens,
|
|
503
|
+
output_tokens=response_usage.completion_tokens,
|
|
501
504
|
)
|
|
@@ -114,7 +114,7 @@ class HuggingFaceModel(Model):
|
|
|
114
114
|
client: AsyncInferenceClient = field(repr=False)
|
|
115
115
|
|
|
116
116
|
_model_name: str = field(repr=False)
|
|
117
|
-
|
|
117
|
+
_provider: Provider[AsyncInferenceClient] = field(repr=False)
|
|
118
118
|
|
|
119
119
|
def __init__(
|
|
120
120
|
self,
|
|
@@ -134,13 +134,23 @@ class HuggingFaceModel(Model):
|
|
|
134
134
|
settings: Model-specific settings that will be used as defaults for this model.
|
|
135
135
|
"""
|
|
136
136
|
self._model_name = model_name
|
|
137
|
-
self._provider = provider
|
|
138
137
|
if isinstance(provider, str):
|
|
139
138
|
provider = infer_provider(provider)
|
|
139
|
+
self._provider = provider
|
|
140
140
|
self.client = provider.client
|
|
141
141
|
|
|
142
142
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
143
143
|
|
|
144
|
+
@property
|
|
145
|
+
def model_name(self) -> HuggingFaceModelName:
|
|
146
|
+
"""The model name."""
|
|
147
|
+
return self._model_name
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def system(self) -> str:
|
|
151
|
+
"""The system / model provider."""
|
|
152
|
+
return self._provider.name
|
|
153
|
+
|
|
144
154
|
async def request(
|
|
145
155
|
self,
|
|
146
156
|
messages: list[ModelMessage],
|
|
@@ -152,7 +162,6 @@ class HuggingFaceModel(Model):
|
|
|
152
162
|
messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
153
163
|
)
|
|
154
164
|
model_response = self._process_response(response)
|
|
155
|
-
model_response.usage.requests = 1
|
|
156
165
|
return model_response
|
|
157
166
|
|
|
158
167
|
@asynccontextmanager
|
|
@@ -169,16 +178,6 @@ class HuggingFaceModel(Model):
|
|
|
169
178
|
)
|
|
170
179
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
171
180
|
|
|
172
|
-
@property
|
|
173
|
-
def model_name(self) -> HuggingFaceModelName:
|
|
174
|
-
"""The model name."""
|
|
175
|
-
return self._model_name
|
|
176
|
-
|
|
177
|
-
@property
|
|
178
|
-
def system(self) -> str:
|
|
179
|
-
"""The system / model provider."""
|
|
180
|
-
return self._system
|
|
181
|
-
|
|
182
181
|
@overload
|
|
183
182
|
async def _completions_create(
|
|
184
183
|
self,
|
|
@@ -272,7 +271,7 @@ class HuggingFaceModel(Model):
|
|
|
272
271
|
usage=_map_usage(response),
|
|
273
272
|
model_name=response.model,
|
|
274
273
|
timestamp=timestamp,
|
|
275
|
-
|
|
274
|
+
provider_request_id=response.id,
|
|
276
275
|
)
|
|
277
276
|
|
|
278
277
|
async def _process_streamed_response(
|
|
@@ -481,14 +480,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
481
480
|
return self._timestamp
|
|
482
481
|
|
|
483
482
|
|
|
484
|
-
def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.
|
|
483
|
+
def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.RequestUsage:
|
|
485
484
|
response_usage = response.usage
|
|
486
485
|
if response_usage is None:
|
|
487
|
-
return usage.
|
|
486
|
+
return usage.RequestUsage()
|
|
488
487
|
|
|
489
|
-
return usage.
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
total_tokens=response_usage.total_tokens,
|
|
493
|
-
details=None,
|
|
488
|
+
return usage.RequestUsage(
|
|
489
|
+
input_tokens=response_usage.prompt_tokens,
|
|
490
|
+
output_tokens=response_usage.completion_tokens,
|
|
494
491
|
)
|
|
@@ -280,14 +280,14 @@ class InstrumentedModel(WrapperModel):
|
|
|
280
280
|
'gen_ai.request.model': request_model,
|
|
281
281
|
'gen_ai.response.model': response_model,
|
|
282
282
|
}
|
|
283
|
-
if response.usage.
|
|
283
|
+
if response.usage.input_tokens: # pragma: no branch
|
|
284
284
|
self.instrumentation_settings.tokens_histogram.record(
|
|
285
|
-
response.usage.
|
|
285
|
+
response.usage.input_tokens,
|
|
286
286
|
{**metric_attributes, 'gen_ai.token.type': 'input'},
|
|
287
287
|
)
|
|
288
|
-
if response.usage.
|
|
288
|
+
if response.usage.output_tokens: # pragma: no branch
|
|
289
289
|
self.instrumentation_settings.tokens_histogram.record(
|
|
290
|
-
response.usage.
|
|
290
|
+
response.usage.output_tokens,
|
|
291
291
|
{**metric_attributes, 'gen_ai.token.type': 'output'},
|
|
292
292
|
)
|
|
293
293
|
|
|
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
|
-
from .. import _mcp, exceptions
|
|
8
|
+
from .. import _mcp, exceptions
|
|
9
9
|
from .._run_context import RunContext
|
|
10
10
|
from ..messages import ModelMessage, ModelResponse
|
|
11
11
|
from ..settings import ModelSettings
|
|
@@ -63,7 +63,6 @@ class MCPSamplingModel(Model):
|
|
|
63
63
|
if result.role == 'assistant':
|
|
64
64
|
return ModelResponse(
|
|
65
65
|
parts=[_mcp.map_from_sampling_content(result.content)],
|
|
66
|
-
usage=usage.Usage(requests=1),
|
|
67
66
|
model_name=result.model,
|
|
68
67
|
)
|
|
69
68
|
else:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -40,7 +40,7 @@ from ..profiles import ModelProfileSpec
|
|
|
40
40
|
from ..providers import Provider, infer_provider
|
|
41
41
|
from ..settings import ModelSettings
|
|
42
42
|
from ..tools import ToolDefinition
|
|
43
|
-
from ..usage import
|
|
43
|
+
from ..usage import RequestUsage
|
|
44
44
|
from . import (
|
|
45
45
|
Model,
|
|
46
46
|
ModelRequestParameters,
|
|
@@ -120,7 +120,7 @@ class MistralModel(Model):
|
|
|
120
120
|
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
|
|
121
121
|
|
|
122
122
|
_model_name: MistralModelName = field(repr=False)
|
|
123
|
-
|
|
123
|
+
_provider: Provider[Mistral] = field(repr=False)
|
|
124
124
|
|
|
125
125
|
def __init__(
|
|
126
126
|
self,
|
|
@@ -147,13 +147,24 @@ class MistralModel(Model):
|
|
|
147
147
|
|
|
148
148
|
if isinstance(provider, str):
|
|
149
149
|
provider = infer_provider(provider)
|
|
150
|
+
self._provider = provider
|
|
150
151
|
self.client = provider.client
|
|
151
152
|
|
|
152
153
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
153
154
|
|
|
154
155
|
@property
|
|
155
156
|
def base_url(self) -> str:
|
|
156
|
-
return self.
|
|
157
|
+
return self._provider.base_url
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def model_name(self) -> MistralModelName:
|
|
161
|
+
"""The model name."""
|
|
162
|
+
return self._model_name
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def system(self) -> str:
|
|
166
|
+
"""The model provider."""
|
|
167
|
+
return self._provider.name
|
|
157
168
|
|
|
158
169
|
async def request(
|
|
159
170
|
self,
|
|
@@ -167,7 +178,6 @@ class MistralModel(Model):
|
|
|
167
178
|
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
168
179
|
)
|
|
169
180
|
model_response = self._process_response(response)
|
|
170
|
-
model_response.usage.requests = 1
|
|
171
181
|
return model_response
|
|
172
182
|
|
|
173
183
|
@asynccontextmanager
|
|
@@ -186,16 +196,6 @@ class MistralModel(Model):
|
|
|
186
196
|
async with response:
|
|
187
197
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
188
198
|
|
|
189
|
-
@property
|
|
190
|
-
def model_name(self) -> MistralModelName:
|
|
191
|
-
"""The model name."""
|
|
192
|
-
return self._model_name
|
|
193
|
-
|
|
194
|
-
@property
|
|
195
|
-
def system(self) -> str:
|
|
196
|
-
"""The system / model provider."""
|
|
197
|
-
return self._system
|
|
198
|
-
|
|
199
199
|
async def _completions_create(
|
|
200
200
|
self,
|
|
201
201
|
messages: list[ModelMessage],
|
|
@@ -348,7 +348,11 @@ class MistralModel(Model):
|
|
|
348
348
|
parts.append(tool)
|
|
349
349
|
|
|
350
350
|
return ModelResponse(
|
|
351
|
-
parts,
|
|
351
|
+
parts,
|
|
352
|
+
usage=_map_usage(response),
|
|
353
|
+
model_name=response.model,
|
|
354
|
+
timestamp=timestamp,
|
|
355
|
+
provider_request_id=response.id,
|
|
352
356
|
)
|
|
353
357
|
|
|
354
358
|
async def _process_streamed_response(
|
|
@@ -699,17 +703,15 @@ SIMPLE_JSON_TYPE_MAPPING = {
|
|
|
699
703
|
}
|
|
700
704
|
|
|
701
705
|
|
|
702
|
-
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) ->
|
|
706
|
+
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> RequestUsage:
|
|
703
707
|
"""Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
|
|
704
708
|
if response.usage:
|
|
705
|
-
return
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
total_tokens=response.usage.total_tokens,
|
|
709
|
-
details=None,
|
|
709
|
+
return RequestUsage(
|
|
710
|
+
input_tokens=response.usage.prompt_tokens,
|
|
711
|
+
output_tokens=response.usage.completion_tokens,
|
|
710
712
|
)
|
|
711
713
|
else:
|
|
712
|
-
return
|
|
714
|
+
return RequestUsage() # pragma: no cover
|
|
713
715
|
|
|
714
716
|
|
|
715
717
|
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|