pydantic-ai-slim 0.7.2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (39) hide show
  1. pydantic_ai/__init__.py +2 -1
  2. pydantic_ai/_agent_graph.py +2 -2
  3. pydantic_ai/_cli.py +18 -3
  4. pydantic_ai/_run_context.py +2 -2
  5. pydantic_ai/ag_ui.py +4 -4
  6. pydantic_ai/agent/__init__.py +7 -9
  7. pydantic_ai/agent/abstract.py +16 -18
  8. pydantic_ai/agent/wrapper.py +4 -6
  9. pydantic_ai/builtin_tools.py +9 -1
  10. pydantic_ai/direct.py +4 -4
  11. pydantic_ai/durable_exec/temporal/_agent.py +13 -15
  12. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  13. pydantic_ai/messages.py +16 -6
  14. pydantic_ai/models/__init__.py +5 -5
  15. pydantic_ai/models/anthropic.py +27 -26
  16. pydantic_ai/models/bedrock.py +24 -26
  17. pydantic_ai/models/cohere.py +20 -25
  18. pydantic_ai/models/fallback.py +15 -15
  19. pydantic_ai/models/function.py +7 -9
  20. pydantic_ai/models/gemini.py +43 -39
  21. pydantic_ai/models/google.py +76 -50
  22. pydantic_ai/models/groq.py +22 -19
  23. pydantic_ai/models/huggingface.py +18 -21
  24. pydantic_ai/models/instrumented.py +4 -4
  25. pydantic_ai/models/mcp_sampling.py +1 -2
  26. pydantic_ai/models/mistral.py +24 -22
  27. pydantic_ai/models/openai.py +98 -44
  28. pydantic_ai/models/test.py +4 -5
  29. pydantic_ai/profiles/openai.py +13 -3
  30. pydantic_ai/providers/openai.py +1 -1
  31. pydantic_ai/result.py +5 -5
  32. pydantic_ai/run.py +4 -11
  33. pydantic_ai/tools.py +5 -2
  34. pydantic_ai/usage.py +230 -68
  35. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/METADATA +10 -4
  36. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/RECORD +39 -39
  37. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/WHEEL +0 -0
  38. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/entry_points.txt +0 -0
  39. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,7 @@ from typing_extensions import assert_never
13
13
  from .. import UnexpectedModelBehavior, _utils, usage
14
14
  from .._output import OutputObjectDefinition
15
15
  from .._run_context import RunContext
16
- from ..builtin_tools import CodeExecutionTool, WebSearchTool
16
+ from ..builtin_tools import CodeExecutionTool, UrlContextTool, WebSearchTool
17
17
  from ..exceptions import UserError
18
18
  from ..messages import (
19
19
  BinaryContent,
@@ -72,6 +72,7 @@ try:
72
72
  ToolConfigDict,
73
73
  ToolDict,
74
74
  ToolListUnionDict,
75
+ UrlContextDict,
75
76
  )
76
77
 
77
78
  from ..providers.google import GoogleProvider
@@ -144,7 +145,6 @@ class GoogleModel(Model):
144
145
  _model_name: GoogleModelName = field(repr=False)
145
146
  _provider: Provider[Client] = field(repr=False)
146
147
  _url: str | None = field(repr=False)
147
- _system: str = field(default='google', repr=False)
148
148
 
149
149
  def __init__(
150
150
  self,
@@ -168,9 +168,7 @@ class GoogleModel(Model):
168
168
 
169
169
  if isinstance(provider, str):
170
170
  provider = GoogleProvider(vertexai=provider == 'google-vertex')
171
-
172
171
  self._provider = provider
173
- self._system = provider.name
174
172
  self.client = provider.client
175
173
 
176
174
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -179,6 +177,16 @@ class GoogleModel(Model):
179
177
  def base_url(self) -> str:
180
178
  return self._provider.base_url
181
179
 
180
+ @property
181
+ def model_name(self) -> GoogleModelName:
182
+ """The model name."""
183
+ return self._model_name
184
+
185
+ @property
186
+ def system(self) -> str:
187
+ """The model provider."""
188
+ return self._provider.name
189
+
182
190
  async def request(
183
191
  self,
184
192
  messages: list[ModelMessage],
@@ -195,7 +203,7 @@ class GoogleModel(Model):
195
203
  messages: list[ModelMessage],
196
204
  model_settings: ModelSettings | None,
197
205
  model_request_parameters: ModelRequestParameters,
198
- ) -> usage.Usage:
206
+ ) -> usage.RequestUsage:
199
207
  check_allow_model_requests()
200
208
  model_settings = cast(GoogleModelSettings, model_settings or {})
201
209
  contents, generation_config = await self._build_content_and_config(
@@ -209,9 +217,9 @@ class GoogleModel(Model):
209
217
  config = CountTokensConfigDict(
210
218
  http_options=generation_config.get('http_options'),
211
219
  )
212
- if self.system != 'google-gla':
220
+ if self._provider.name != 'google-gla':
213
221
  # The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
214
- config.update(
222
+ config.update( # pragma: lax no cover
215
223
  system_instruction=generation_config.get('system_instruction'),
216
224
  tools=cast(list[ToolDict], generation_config.get('tools')),
217
225
  # Annoyingly, GenerationConfigDict has fewer fields than GenerateContentConfigDict, and no extra fields are allowed.
@@ -238,9 +246,8 @@ class GoogleModel(Model):
238
246
  raise UnexpectedModelBehavior( # pragma: no cover
239
247
  'Total tokens missing from Gemini response', str(response)
240
248
  )
241
- return usage.Usage(
242
- request_tokens=response.total_tokens,
243
- total_tokens=response.total_tokens,
249
+ return usage.RequestUsage(
250
+ input_tokens=response.total_tokens,
244
251
  )
245
252
 
246
253
  @asynccontextmanager
@@ -256,16 +263,6 @@ class GoogleModel(Model):
256
263
  response = await self._generate_content(messages, True, model_settings, model_request_parameters)
257
264
  yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
258
265
 
259
- @property
260
- def model_name(self) -> GoogleModelName:
261
- """The model name."""
262
- return self._model_name
263
-
264
- @property
265
- def system(self) -> str:
266
- """The system / model provider."""
267
- return self._system
268
-
269
266
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
270
267
  tools: list[ToolDict] = [
271
268
  ToolDict(function_declarations=[_function_declaration_from_tool(t)])
@@ -274,6 +271,8 @@ class GoogleModel(Model):
274
271
  for tool in model_request_parameters.builtin_tools:
275
272
  if isinstance(tool, WebSearchTool):
276
273
  tools.append(ToolDict(google_search=GoogleSearchDict()))
274
+ elif isinstance(tool, UrlContextTool):
275
+ tools.append(ToolDict(url_context=UrlContextDict()))
277
276
  elif isinstance(tool, CodeExecutionTool): # pragma: no branch
278
277
  tools.append(ToolDict(code_execution=ToolCodeExecutionDict()))
279
278
  else: # pragma: no cover
@@ -378,23 +377,27 @@ class GoogleModel(Model):
378
377
  def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
379
378
  if not response.candidates or len(response.candidates) != 1:
380
379
  raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
381
- if response.candidates[0].content is None or response.candidates[0].content.parts is None:
382
- if response.candidates[0].finish_reason == 'SAFETY':
380
+ candidate = response.candidates[0]
381
+ if candidate.content is None or candidate.content.parts is None:
382
+ if candidate.finish_reason == 'SAFETY':
383
383
  raise UnexpectedModelBehavior('Safety settings triggered', str(response))
384
384
  else:
385
385
  raise UnexpectedModelBehavior(
386
386
  'Content field missing from Gemini response', str(response)
387
387
  ) # pragma: no cover
388
- parts = response.candidates[0].content.parts or []
388
+ parts = candidate.content.parts or []
389
389
  vendor_id = response.response_id or None
390
390
  vendor_details: dict[str, Any] | None = None
391
- finish_reason = response.candidates[0].finish_reason
391
+ finish_reason = candidate.finish_reason
392
392
  if finish_reason: # pragma: no branch
393
393
  vendor_details = {'finish_reason': finish_reason.value}
394
394
  usage = _metadata_as_usage(response)
395
- usage.requests = 1
396
395
  return _process_response_from_parts(
397
- parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
396
+ parts,
397
+ response.model_version or self._model_name,
398
+ usage,
399
+ vendor_id=vendor_id,
400
+ vendor_details=vendor_details,
398
401
  )
399
402
 
400
403
  async def _process_streamed_response(
@@ -527,10 +530,13 @@ class GeminiStreamedResponse(StreamedResponse):
527
530
 
528
531
  assert chunk.candidates is not None
529
532
  candidate = chunk.candidates[0]
530
- if candidate.content is None:
531
- raise UnexpectedModelBehavior('Streamed response has no content field') # pragma: no cover
532
- assert candidate.content.parts is not None
533
- for part in candidate.content.parts:
533
+ if candidate.content is None or candidate.content.parts is None:
534
+ if candidate.finish_reason == 'SAFETY': # pragma: no cover
535
+ raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
536
+ else: # pragma: no cover
537
+ raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
538
+ parts = candidate.content.parts or []
539
+ for part in parts:
534
540
  if part.text is not None:
535
541
  if part.thought:
536
542
  yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
@@ -590,7 +596,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
590
596
  def _process_response_from_parts(
591
597
  parts: list[Part],
592
598
  model_name: GoogleModelName,
593
- usage: usage.Usage,
599
+ usage: usage.RequestUsage,
594
600
  vendor_id: str | None,
595
601
  vendor_details: dict[str, Any] | None = None,
596
602
  ) -> ModelResponse:
@@ -627,7 +633,7 @@ def _process_response_from_parts(
627
633
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
628
634
  )
629
635
  return ModelResponse(
630
- parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details
636
+ parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details
631
637
  )
632
638
 
633
639
 
@@ -647,31 +653,51 @@ def _tool_config(function_names: list[str]) -> ToolConfigDict:
647
653
  return ToolConfigDict(function_calling_config=function_calling_config)
648
654
 
649
655
 
650
- def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
656
+ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
651
657
  metadata = response.usage_metadata
652
658
  if metadata is None:
653
- return usage.Usage() # pragma: no cover
654
- metadata = metadata.model_dump(exclude_defaults=True)
655
-
659
+ return usage.RequestUsage()
656
660
  details: dict[str, int] = {}
657
- if cached_content_token_count := metadata.get('cached_content_token_count'):
658
- details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
661
+ if cached_content_token_count := metadata.cached_content_token_count:
662
+ details['cached_content_tokens'] = cached_content_token_count
659
663
 
660
- if thoughts_token_count := metadata.get('thoughts_token_count'):
664
+ if thoughts_token_count := metadata.thoughts_token_count:
661
665
  details['thoughts_tokens'] = thoughts_token_count
662
666
 
663
- if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
667
+ if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count:
664
668
  details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
665
669
 
666
- for key, metadata_details in metadata.items():
667
- if key.endswith('_details') and metadata_details:
668
- suffix = key.removesuffix('_details')
669
- for detail in metadata_details:
670
- details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
671
-
672
- return usage.Usage(
673
- request_tokens=metadata.get('prompt_token_count', 0),
674
- response_tokens=metadata.get('candidates_token_count', 0),
675
- total_tokens=metadata.get('total_token_count', 0),
670
+ input_audio_tokens = 0
671
+ output_audio_tokens = 0
672
+ cache_audio_read_tokens = 0
673
+ for prefix, metadata_details in [
674
+ ('prompt', metadata.prompt_tokens_details),
675
+ ('cache', metadata.cache_tokens_details),
676
+ ('candidates', metadata.candidates_tokens_details),
677
+ ('tool_use_prompt', metadata.tool_use_prompt_tokens_details),
678
+ ]:
679
+ assert getattr(metadata, f'{prefix}_tokens_details') is metadata_details
680
+ if not metadata_details:
681
+ continue
682
+ for detail in metadata_details:
683
+ if not detail.modality or not detail.token_count: # pragma: no cover
684
+ continue
685
+ details[f'{detail.modality.lower()}_{prefix}_tokens'] = detail.token_count
686
+ if detail.modality != 'AUDIO':
687
+ continue
688
+ if metadata_details is metadata.prompt_tokens_details:
689
+ input_audio_tokens = detail.token_count
690
+ elif metadata_details is metadata.candidates_tokens_details:
691
+ output_audio_tokens = detail.token_count
692
+ elif metadata_details is metadata.cache_tokens_details: # pragma: no branch
693
+ cache_audio_read_tokens = detail.token_count
694
+
695
+ return usage.RequestUsage(
696
+ input_tokens=metadata.prompt_token_count or 0,
697
+ output_tokens=metadata.candidates_token_count or 0,
698
+ cache_read_tokens=cached_content_token_count or 0,
699
+ input_audio_tokens=input_audio_tokens,
700
+ output_audio_tokens=output_audio_tokens,
701
+ cache_audio_read_tokens=cache_audio_read_tokens,
676
702
  details=details,
677
703
  )
@@ -118,7 +118,7 @@ class GroqModel(Model):
118
118
  client: AsyncGroq = field(repr=False)
119
119
 
120
120
  _model_name: GroqModelName = field(repr=False)
121
- _system: str = field(default='groq', repr=False)
121
+ _provider: Provider[AsyncGroq] = field(repr=False)
122
122
 
123
123
  def __init__(
124
124
  self,
@@ -143,6 +143,7 @@ class GroqModel(Model):
143
143
 
144
144
  if isinstance(provider, str):
145
145
  provider = infer_provider(provider)
146
+ self._provider = provider
146
147
  self.client = provider.client
147
148
 
148
149
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -151,6 +152,16 @@ class GroqModel(Model):
151
152
  def base_url(self) -> str:
152
153
  return str(self.client.base_url)
153
154
 
155
+ @property
156
+ def model_name(self) -> GroqModelName:
157
+ """The model name."""
158
+ return self._model_name
159
+
160
+ @property
161
+ def system(self) -> str:
162
+ """The model provider."""
163
+ return self._provider.name
164
+
154
165
  async def request(
155
166
  self,
156
167
  messages: list[ModelMessage],
@@ -162,7 +173,6 @@ class GroqModel(Model):
162
173
  messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
163
174
  )
164
175
  model_response = self._process_response(response)
165
- model_response.usage.requests = 1
166
176
  return model_response
167
177
 
168
178
  @asynccontextmanager
@@ -180,16 +190,6 @@ class GroqModel(Model):
180
190
  async with response:
181
191
  yield await self._process_streamed_response(response, model_request_parameters)
182
192
 
183
- @property
184
- def model_name(self) -> GroqModelName:
185
- """The model name."""
186
- return self._model_name
187
-
188
- @property
189
- def system(self) -> str:
190
- """The system / model provider."""
191
- return self._system
192
-
193
193
  @overload
194
194
  async def _completions_create(
195
195
  self,
@@ -285,7 +285,11 @@ class GroqModel(Model):
285
285
  for c in choice.message.tool_calls:
286
286
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
287
287
  return ModelResponse(
288
- items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
288
+ items,
289
+ usage=_map_usage(response),
290
+ model_name=response.model,
291
+ timestamp=timestamp,
292
+ provider_request_id=response.id,
289
293
  )
290
294
 
291
295
  async def _process_streamed_response(
@@ -484,7 +488,7 @@ class GroqStreamedResponse(StreamedResponse):
484
488
  return self._timestamp
485
489
 
486
490
 
487
- def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage:
491
+ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.RequestUsage:
488
492
  response_usage = None
489
493
  if isinstance(completion, chat.ChatCompletion):
490
494
  response_usage = completion.usage
@@ -492,10 +496,9 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
492
496
  response_usage = completion.x_groq.usage
493
497
 
494
498
  if response_usage is None:
495
- return usage.Usage()
499
+ return usage.RequestUsage()
496
500
 
497
- return usage.Usage(
498
- request_tokens=response_usage.prompt_tokens,
499
- response_tokens=response_usage.completion_tokens,
500
- total_tokens=response_usage.total_tokens,
501
+ return usage.RequestUsage(
502
+ input_tokens=response_usage.prompt_tokens,
503
+ output_tokens=response_usage.completion_tokens,
501
504
  )
@@ -114,7 +114,7 @@ class HuggingFaceModel(Model):
114
114
  client: AsyncInferenceClient = field(repr=False)
115
115
 
116
116
  _model_name: str = field(repr=False)
117
- _system: str = field(default='huggingface', repr=False)
117
+ _provider: Provider[AsyncInferenceClient] = field(repr=False)
118
118
 
119
119
  def __init__(
120
120
  self,
@@ -134,13 +134,23 @@ class HuggingFaceModel(Model):
134
134
  settings: Model-specific settings that will be used as defaults for this model.
135
135
  """
136
136
  self._model_name = model_name
137
- self._provider = provider
138
137
  if isinstance(provider, str):
139
138
  provider = infer_provider(provider)
139
+ self._provider = provider
140
140
  self.client = provider.client
141
141
 
142
142
  super().__init__(settings=settings, profile=profile or provider.model_profile)
143
143
 
144
+ @property
145
+ def model_name(self) -> HuggingFaceModelName:
146
+ """The model name."""
147
+ return self._model_name
148
+
149
+ @property
150
+ def system(self) -> str:
151
+ """The system / model provider."""
152
+ return self._provider.name
153
+
144
154
  async def request(
145
155
  self,
146
156
  messages: list[ModelMessage],
@@ -152,7 +162,6 @@ class HuggingFaceModel(Model):
152
162
  messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
153
163
  )
154
164
  model_response = self._process_response(response)
155
- model_response.usage.requests = 1
156
165
  return model_response
157
166
 
158
167
  @asynccontextmanager
@@ -169,16 +178,6 @@ class HuggingFaceModel(Model):
169
178
  )
170
179
  yield await self._process_streamed_response(response, model_request_parameters)
171
180
 
172
- @property
173
- def model_name(self) -> HuggingFaceModelName:
174
- """The model name."""
175
- return self._model_name
176
-
177
- @property
178
- def system(self) -> str:
179
- """The system / model provider."""
180
- return self._system
181
-
182
181
  @overload
183
182
  async def _completions_create(
184
183
  self,
@@ -272,7 +271,7 @@ class HuggingFaceModel(Model):
272
271
  usage=_map_usage(response),
273
272
  model_name=response.model,
274
273
  timestamp=timestamp,
275
- vendor_id=response.id,
274
+ provider_request_id=response.id,
276
275
  )
277
276
 
278
277
  async def _process_streamed_response(
@@ -481,14 +480,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
481
480
  return self._timestamp
482
481
 
483
482
 
484
- def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage:
483
+ def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.RequestUsage:
485
484
  response_usage = response.usage
486
485
  if response_usage is None:
487
- return usage.Usage()
486
+ return usage.RequestUsage()
488
487
 
489
- return usage.Usage(
490
- request_tokens=response_usage.prompt_tokens,
491
- response_tokens=response_usage.completion_tokens,
492
- total_tokens=response_usage.total_tokens,
493
- details=None,
488
+ return usage.RequestUsage(
489
+ input_tokens=response_usage.prompt_tokens,
490
+ output_tokens=response_usage.completion_tokens,
494
491
  )
@@ -280,14 +280,14 @@ class InstrumentedModel(WrapperModel):
280
280
  'gen_ai.request.model': request_model,
281
281
  'gen_ai.response.model': response_model,
282
282
  }
283
- if response.usage.request_tokens: # pragma: no branch
283
+ if response.usage.input_tokens: # pragma: no branch
284
284
  self.instrumentation_settings.tokens_histogram.record(
285
- response.usage.request_tokens,
285
+ response.usage.input_tokens,
286
286
  {**metric_attributes, 'gen_ai.token.type': 'input'},
287
287
  )
288
- if response.usage.response_tokens: # pragma: no branch
288
+ if response.usage.output_tokens: # pragma: no branch
289
289
  self.instrumentation_settings.tokens_histogram.record(
290
- response.usage.response_tokens,
290
+ response.usage.output_tokens,
291
291
  {**metric_attributes, 'gen_ai.token.type': 'output'},
292
292
  )
293
293
 
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass
6
6
  from typing import TYPE_CHECKING, Any, cast
7
7
 
8
- from .. import _mcp, exceptions, usage
8
+ from .. import _mcp, exceptions
9
9
  from .._run_context import RunContext
10
10
  from ..messages import ModelMessage, ModelResponse
11
11
  from ..settings import ModelSettings
@@ -63,7 +63,6 @@ class MCPSamplingModel(Model):
63
63
  if result.role == 'assistant':
64
64
  return ModelResponse(
65
65
  parts=[_mcp.map_from_sampling_content(result.content)],
66
- usage=usage.Usage(requests=1),
67
66
  model_name=result.model,
68
67
  )
69
68
  else:
@@ -40,7 +40,7 @@ from ..profiles import ModelProfileSpec
40
40
  from ..providers import Provider, infer_provider
41
41
  from ..settings import ModelSettings
42
42
  from ..tools import ToolDefinition
43
- from ..usage import Usage
43
+ from ..usage import RequestUsage
44
44
  from . import (
45
45
  Model,
46
46
  ModelRequestParameters,
@@ -120,7 +120,7 @@ class MistralModel(Model):
120
120
  json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
121
121
 
122
122
  _model_name: MistralModelName = field(repr=False)
123
- _system: str = field(default='mistral_ai', repr=False)
123
+ _provider: Provider[Mistral] = field(repr=False)
124
124
 
125
125
  def __init__(
126
126
  self,
@@ -147,13 +147,24 @@ class MistralModel(Model):
147
147
 
148
148
  if isinstance(provider, str):
149
149
  provider = infer_provider(provider)
150
+ self._provider = provider
150
151
  self.client = provider.client
151
152
 
152
153
  super().__init__(settings=settings, profile=profile or provider.model_profile)
153
154
 
154
155
  @property
155
156
  def base_url(self) -> str:
156
- return self.client.sdk_configuration.get_server_details()[0]
157
+ return self._provider.base_url
158
+
159
+ @property
160
+ def model_name(self) -> MistralModelName:
161
+ """The model name."""
162
+ return self._model_name
163
+
164
+ @property
165
+ def system(self) -> str:
166
+ """The model provider."""
167
+ return self._provider.name
157
168
 
158
169
  async def request(
159
170
  self,
@@ -167,7 +178,6 @@ class MistralModel(Model):
167
178
  messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
168
179
  )
169
180
  model_response = self._process_response(response)
170
- model_response.usage.requests = 1
171
181
  return model_response
172
182
 
173
183
  @asynccontextmanager
@@ -186,16 +196,6 @@ class MistralModel(Model):
186
196
  async with response:
187
197
  yield await self._process_streamed_response(response, model_request_parameters)
188
198
 
189
- @property
190
- def model_name(self) -> MistralModelName:
191
- """The model name."""
192
- return self._model_name
193
-
194
- @property
195
- def system(self) -> str:
196
- """The system / model provider."""
197
- return self._system
198
-
199
199
  async def _completions_create(
200
200
  self,
201
201
  messages: list[ModelMessage],
@@ -348,7 +348,11 @@ class MistralModel(Model):
348
348
  parts.append(tool)
349
349
 
350
350
  return ModelResponse(
351
- parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
351
+ parts,
352
+ usage=_map_usage(response),
353
+ model_name=response.model,
354
+ timestamp=timestamp,
355
+ provider_request_id=response.id,
352
356
  )
353
357
 
354
358
  async def _process_streamed_response(
@@ -699,17 +703,15 @@ SIMPLE_JSON_TYPE_MAPPING = {
699
703
  }
700
704
 
701
705
 
702
- def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
706
+ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> RequestUsage:
703
707
  """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
704
708
  if response.usage:
705
- return Usage(
706
- request_tokens=response.usage.prompt_tokens,
707
- response_tokens=response.usage.completion_tokens,
708
- total_tokens=response.usage.total_tokens,
709
- details=None,
709
+ return RequestUsage(
710
+ input_tokens=response.usage.prompt_tokens,
711
+ output_tokens=response.usage.completion_tokens,
710
712
  )
711
713
  else:
712
- return Usage() # pragma: no cover
714
+ return RequestUsage() # pragma: no cover
713
715
 
714
716
 
715
717
  def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None: