pydantic-ai-slim 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. pydantic_ai/_agent_graph.py +60 -57
  2. pydantic_ai/_cli.py +18 -3
  3. pydantic_ai/_parts_manager.py +5 -4
  4. pydantic_ai/_run_context.py +2 -2
  5. pydantic_ai/_tool_manager.py +50 -29
  6. pydantic_ai/ag_ui.py +4 -4
  7. pydantic_ai/agent/__init__.py +69 -84
  8. pydantic_ai/agent/abstract.py +16 -18
  9. pydantic_ai/agent/wrapper.py +4 -6
  10. pydantic_ai/direct.py +4 -4
  11. pydantic_ai/durable_exec/temporal/_agent.py +13 -15
  12. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  13. pydantic_ai/messages.py +16 -6
  14. pydantic_ai/models/__init__.py +5 -5
  15. pydantic_ai/models/anthropic.py +47 -46
  16. pydantic_ai/models/bedrock.py +25 -27
  17. pydantic_ai/models/cohere.py +20 -25
  18. pydantic_ai/models/fallback.py +15 -15
  19. pydantic_ai/models/function.py +7 -9
  20. pydantic_ai/models/gemini.py +43 -39
  21. pydantic_ai/models/google.py +59 -40
  22. pydantic_ai/models/groq.py +23 -19
  23. pydantic_ai/models/huggingface.py +27 -23
  24. pydantic_ai/models/instrumented.py +4 -4
  25. pydantic_ai/models/mcp_sampling.py +1 -2
  26. pydantic_ai/models/mistral.py +24 -22
  27. pydantic_ai/models/openai.py +101 -45
  28. pydantic_ai/models/test.py +4 -5
  29. pydantic_ai/profiles/__init__.py +10 -1
  30. pydantic_ai/profiles/deepseek.py +1 -1
  31. pydantic_ai/profiles/moonshotai.py +1 -1
  32. pydantic_ai/profiles/openai.py +13 -3
  33. pydantic_ai/profiles/qwen.py +4 -1
  34. pydantic_ai/providers/__init__.py +4 -0
  35. pydantic_ai/providers/huggingface.py +27 -0
  36. pydantic_ai/providers/ollama.py +105 -0
  37. pydantic_ai/providers/openai.py +1 -1
  38. pydantic_ai/providers/openrouter.py +2 -0
  39. pydantic_ai/result.py +6 -6
  40. pydantic_ai/run.py +4 -11
  41. pydantic_ai/tools.py +9 -9
  42. pydantic_ai/usage.py +229 -67
  43. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/METADATA +10 -4
  44. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/RECORD +47 -46
  45. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/WHEEL +0 -0
  46. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/entry_points.txt +0 -0
  47. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -40,14 +40,7 @@ from ..profiles import ModelProfileSpec
40
40
  from ..providers import Provider, infer_provider
41
41
  from ..settings import ModelSettings
42
42
  from ..tools import ToolDefinition
43
- from . import (
44
- Model,
45
- ModelRequestParameters,
46
- StreamedResponse,
47
- check_allow_model_requests,
48
- download_item,
49
- get_user_agent,
50
- )
43
+ from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
51
44
 
52
45
  LatestGeminiModelNames = Literal[
53
46
  'gemini-2.0-flash',
@@ -108,10 +101,9 @@ class GeminiModel(Model):
108
101
  client: httpx.AsyncClient = field(repr=False)
109
102
 
110
103
  _model_name: GeminiModelName = field(repr=False)
111
- _provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] | None = field(repr=False)
104
+ _provider: Provider[httpx.AsyncClient] = field(repr=False)
112
105
  _auth: AuthProtocol | None = field(repr=False)
113
106
  _url: str | None = field(repr=False)
114
- _system: str = field(default='gemini', repr=False)
115
107
 
116
108
  def __init__(
117
109
  self,
@@ -132,11 +124,10 @@ class GeminiModel(Model):
132
124
  settings: Default model settings for this model instance.
133
125
  """
134
126
  self._model_name = model_name
135
- self._provider = provider
136
127
 
137
128
  if isinstance(provider, str):
138
129
  provider = infer_provider(provider)
139
- self._system = provider.name
130
+ self._provider = provider
140
131
  self.client = provider.client
141
132
  self._url = str(self.client.base_url)
142
133
 
@@ -147,6 +138,16 @@ class GeminiModel(Model):
147
138
  assert self._url is not None, 'URL not initialized' # pragma: no cover
148
139
  return self._url # pragma: no cover
149
140
 
141
+ @property
142
+ def model_name(self) -> GeminiModelName:
143
+ """The model name."""
144
+ return self._model_name
145
+
146
+ @property
147
+ def system(self) -> str:
148
+ """The model provider."""
149
+ return self._provider.name
150
+
150
151
  async def request(
151
152
  self,
152
153
  messages: list[ModelMessage],
@@ -175,16 +176,6 @@ class GeminiModel(Model):
175
176
  ) as http_response:
176
177
  yield await self._process_streamed_response(http_response, model_request_parameters)
177
178
 
178
- @property
179
- def model_name(self) -> GeminiModelName:
180
- """The model name."""
181
- return self._model_name
182
-
183
- @property
184
- def system(self) -> str:
185
- """The system / model provider."""
186
- return self._system
187
-
188
179
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
189
180
  tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
190
181
  return _GeminiTools(function_declarations=tools) if tools else None
@@ -237,7 +228,7 @@ class GeminiModel(Model):
237
228
  request_data['safetySettings'] = gemini_safety_settings
238
229
 
239
230
  if gemini_labels := model_settings.get('gemini_labels'):
240
- if self._system == 'google-vertex':
231
+ if self._provider.name == 'google-vertex':
241
232
  request_data['labels'] = gemini_labels # pragma: lax no cover
242
233
 
243
234
  headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
@@ -278,7 +269,6 @@ class GeminiModel(Model):
278
269
  if finish_reason:
279
270
  vendor_details = {'finish_reason': finish_reason}
280
271
  usage = _metadata_as_usage(response)
281
- usage.requests = 1
282
272
  return _process_response_from_parts(
283
273
  parts,
284
274
  response.get('model_version', self._model_name),
@@ -673,7 +663,7 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
673
663
  def _process_response_from_parts(
674
664
  parts: Sequence[_GeminiPartUnion],
675
665
  model_name: GeminiModelName,
676
- usage: usage.Usage,
666
+ usage: usage.RequestUsage,
677
667
  vendor_id: str | None,
678
668
  vendor_details: dict[str, Any] | None = None,
679
669
  ) -> ModelResponse:
@@ -693,7 +683,7 @@ def _process_response_from_parts(
693
683
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
694
684
  )
695
685
  return ModelResponse(
696
- parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details
686
+ parts=items, usage=usage, model_name=model_name, provider_request_id=vendor_id, provider_details=vendor_details
697
687
  )
698
688
 
699
689
 
@@ -859,31 +849,45 @@ class _GeminiUsageMetaData(TypedDict, total=False):
859
849
  ]
860
850
 
861
851
 
862
- def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
852
+ def _metadata_as_usage(response: _GeminiResponse) -> usage.RequestUsage:
863
853
  metadata = response.get('usage_metadata')
864
854
  if metadata is None:
865
- return usage.Usage() # pragma: no cover
855
+ return usage.RequestUsage()
866
856
  details: dict[str, int] = {}
867
- if cached_content_token_count := metadata.get('cached_content_token_count'):
868
- details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
857
+ if cached_content_token_count := metadata.get('cached_content_token_count', 0):
858
+ details['cached_content_tokens'] = cached_content_token_count
869
859
 
870
- if thoughts_token_count := metadata.get('thoughts_token_count'):
860
+ if thoughts_token_count := metadata.get('thoughts_token_count', 0):
871
861
  details['thoughts_tokens'] = thoughts_token_count
872
862
 
873
- if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
874
- details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
863
+ if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count', 0):
864
+ details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
875
865
 
866
+ input_audio_tokens = 0
867
+ output_audio_tokens = 0
868
+ cache_audio_read_tokens = 0
876
869
  for key, metadata_details in metadata.items():
877
870
  if key.endswith('_details') and metadata_details:
878
871
  metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
879
872
  suffix = key.removesuffix('_details')
880
873
  for detail in metadata_details:
881
- details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
882
-
883
- return usage.Usage(
884
- request_tokens=metadata.get('prompt_token_count', 0),
885
- response_tokens=metadata.get('candidates_token_count', 0),
886
- total_tokens=metadata.get('total_token_count', 0),
874
+ modality = detail['modality']
875
+ details[f'{modality.lower()}_{suffix}'] = value = detail.get('token_count', 0)
876
+ if value and modality == 'AUDIO':
877
+ if key == 'prompt_tokens_details':
878
+ input_audio_tokens = value
879
+ elif key == 'candidates_tokens_details':
880
+ output_audio_tokens = value
881
+ elif key == 'cache_tokens_details': # pragma: no branch
882
+ cache_audio_read_tokens = value
883
+
884
+ return usage.RequestUsage(
885
+ input_tokens=metadata.get('prompt_token_count', 0),
886
+ output_tokens=metadata.get('candidates_token_count', 0),
887
+ cache_read_tokens=cached_content_token_count,
888
+ input_audio_tokens=input_audio_tokens,
889
+ output_audio_tokens=output_audio_tokens,
890
+ cache_audio_read_tokens=cache_audio_read_tokens,
887
891
  details=details,
888
892
  )
889
893
 
@@ -144,7 +144,6 @@ class GoogleModel(Model):
144
144
  _model_name: GoogleModelName = field(repr=False)
145
145
  _provider: Provider[Client] = field(repr=False)
146
146
  _url: str | None = field(repr=False)
147
- _system: str = field(default='google', repr=False)
148
147
 
149
148
  def __init__(
150
149
  self,
@@ -168,9 +167,7 @@ class GoogleModel(Model):
168
167
 
169
168
  if isinstance(provider, str):
170
169
  provider = GoogleProvider(vertexai=provider == 'google-vertex')
171
-
172
170
  self._provider = provider
173
- self._system = provider.name
174
171
  self.client = provider.client
175
172
 
176
173
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -179,6 +176,16 @@ class GoogleModel(Model):
179
176
  def base_url(self) -> str:
180
177
  return self._provider.base_url
181
178
 
179
+ @property
180
+ def model_name(self) -> GoogleModelName:
181
+ """The model name."""
182
+ return self._model_name
183
+
184
+ @property
185
+ def system(self) -> str:
186
+ """The model provider."""
187
+ return self._provider.name
188
+
182
189
  async def request(
183
190
  self,
184
191
  messages: list[ModelMessage],
@@ -195,7 +202,7 @@ class GoogleModel(Model):
195
202
  messages: list[ModelMessage],
196
203
  model_settings: ModelSettings | None,
197
204
  model_request_parameters: ModelRequestParameters,
198
- ) -> usage.Usage:
205
+ ) -> usage.RequestUsage:
199
206
  check_allow_model_requests()
200
207
  model_settings = cast(GoogleModelSettings, model_settings or {})
201
208
  contents, generation_config = await self._build_content_and_config(
@@ -209,7 +216,7 @@ class GoogleModel(Model):
209
216
  config = CountTokensConfigDict(
210
217
  http_options=generation_config.get('http_options'),
211
218
  )
212
- if self.system != 'google-gla':
219
+ if self._provider.name != 'google-gla':
213
220
  # The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
214
221
  config.update(
215
222
  system_instruction=generation_config.get('system_instruction'),
@@ -238,9 +245,8 @@ class GoogleModel(Model):
238
245
  raise UnexpectedModelBehavior( # pragma: no cover
239
246
  'Total tokens missing from Gemini response', str(response)
240
247
  )
241
- return usage.Usage(
242
- request_tokens=response.total_tokens,
243
- total_tokens=response.total_tokens,
248
+ return usage.RequestUsage(
249
+ input_tokens=response.total_tokens,
244
250
  )
245
251
 
246
252
  @asynccontextmanager
@@ -256,16 +262,6 @@ class GoogleModel(Model):
256
262
  response = await self._generate_content(messages, True, model_settings, model_request_parameters)
257
263
  yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
258
264
 
259
- @property
260
- def model_name(self) -> GoogleModelName:
261
- """The model name."""
262
- return self._model_name
263
-
264
- @property
265
- def system(self) -> str:
266
- """The system / model provider."""
267
- return self._system
268
-
269
265
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
270
266
  tools: list[ToolDict] = [
271
267
  ToolDict(function_declarations=[_function_declaration_from_tool(t)])
@@ -392,9 +388,12 @@ class GoogleModel(Model):
392
388
  if finish_reason: # pragma: no branch
393
389
  vendor_details = {'finish_reason': finish_reason.value}
394
390
  usage = _metadata_as_usage(response)
395
- usage.requests = 1
396
391
  return _process_response_from_parts(
397
- parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
392
+ parts,
393
+ response.model_version or self._model_name,
394
+ usage,
395
+ vendor_id=vendor_id,
396
+ vendor_details=vendor_details,
398
397
  )
399
398
 
400
399
  async def _process_streamed_response(
@@ -590,7 +589,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
590
589
  def _process_response_from_parts(
591
590
  parts: list[Part],
592
591
  model_name: GoogleModelName,
593
- usage: usage.Usage,
592
+ usage: usage.RequestUsage,
594
593
  vendor_id: str | None,
595
594
  vendor_details: dict[str, Any] | None = None,
596
595
  ) -> ModelResponse:
@@ -627,7 +626,7 @@ def _process_response_from_parts(
627
626
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
628
627
  )
629
628
  return ModelResponse(
630
- parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details
629
+ parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details
631
630
  )
632
631
 
633
632
 
@@ -647,31 +646,51 @@ def _tool_config(function_names: list[str]) -> ToolConfigDict:
647
646
  return ToolConfigDict(function_calling_config=function_calling_config)
648
647
 
649
648
 
650
- def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
649
+ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
651
650
  metadata = response.usage_metadata
652
651
  if metadata is None:
653
- return usage.Usage() # pragma: no cover
654
- metadata = metadata.model_dump(exclude_defaults=True)
655
-
652
+ return usage.RequestUsage()
656
653
  details: dict[str, int] = {}
657
- if cached_content_token_count := metadata.get('cached_content_token_count'):
658
- details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
654
+ if cached_content_token_count := metadata.cached_content_token_count:
655
+ details['cached_content_tokens'] = cached_content_token_count
659
656
 
660
- if thoughts_token_count := metadata.get('thoughts_token_count'):
657
+ if thoughts_token_count := metadata.thoughts_token_count:
661
658
  details['thoughts_tokens'] = thoughts_token_count
662
659
 
663
- if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
660
+ if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count:
664
661
  details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
665
662
 
666
- for key, metadata_details in metadata.items():
667
- if key.endswith('_details') and metadata_details:
668
- suffix = key.removesuffix('_details')
669
- for detail in metadata_details:
670
- details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
671
-
672
- return usage.Usage(
673
- request_tokens=metadata.get('prompt_token_count', 0),
674
- response_tokens=metadata.get('candidates_token_count', 0),
675
- total_tokens=metadata.get('total_token_count', 0),
663
+ input_audio_tokens = 0
664
+ output_audio_tokens = 0
665
+ cache_audio_read_tokens = 0
666
+ for prefix, metadata_details in [
667
+ ('prompt', metadata.prompt_tokens_details),
668
+ ('cache', metadata.cache_tokens_details),
669
+ ('candidates', metadata.candidates_tokens_details),
670
+ ('tool_use_prompt', metadata.tool_use_prompt_tokens_details),
671
+ ]:
672
+ assert getattr(metadata, f'{prefix}_tokens_details') is metadata_details
673
+ if not metadata_details:
674
+ continue
675
+ for detail in metadata_details:
676
+ if not detail.modality or not detail.token_count: # pragma: no cover
677
+ continue
678
+ details[f'{detail.modality.lower()}_{prefix}_tokens'] = detail.token_count
679
+ if detail.modality != 'AUDIO':
680
+ continue
681
+ if metadata_details is metadata.prompt_tokens_details:
682
+ input_audio_tokens = detail.token_count
683
+ elif metadata_details is metadata.candidates_tokens_details:
684
+ output_audio_tokens = detail.token_count
685
+ elif metadata_details is metadata.cache_tokens_details: # pragma: no branch
686
+ cache_audio_read_tokens = detail.token_count
687
+
688
+ return usage.RequestUsage(
689
+ input_tokens=metadata.prompt_token_count or 0,
690
+ output_tokens=metadata.candidates_token_count or 0,
691
+ cache_read_tokens=cached_content_token_count or 0,
692
+ input_audio_tokens=input_audio_tokens,
693
+ output_audio_tokens=output_audio_tokens,
694
+ cache_audio_read_tokens=cache_audio_read_tokens,
676
695
  details=details,
677
696
  )
@@ -118,7 +118,7 @@ class GroqModel(Model):
118
118
  client: AsyncGroq = field(repr=False)
119
119
 
120
120
  _model_name: GroqModelName = field(repr=False)
121
- _system: str = field(default='groq', repr=False)
121
+ _provider: Provider[AsyncGroq] = field(repr=False)
122
122
 
123
123
  def __init__(
124
124
  self,
@@ -143,6 +143,7 @@ class GroqModel(Model):
143
143
 
144
144
  if isinstance(provider, str):
145
145
  provider = infer_provider(provider)
146
+ self._provider = provider
146
147
  self.client = provider.client
147
148
 
148
149
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -151,6 +152,16 @@ class GroqModel(Model):
151
152
  def base_url(self) -> str:
152
153
  return str(self.client.base_url)
153
154
 
155
+ @property
156
+ def model_name(self) -> GroqModelName:
157
+ """The model name."""
158
+ return self._model_name
159
+
160
+ @property
161
+ def system(self) -> str:
162
+ """The model provider."""
163
+ return self._provider.name
164
+
154
165
  async def request(
155
166
  self,
156
167
  messages: list[ModelMessage],
@@ -162,7 +173,6 @@ class GroqModel(Model):
162
173
  messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
163
174
  )
164
175
  model_response = self._process_response(response)
165
- model_response.usage.requests = 1
166
176
  return model_response
167
177
 
168
178
  @asynccontextmanager
@@ -180,16 +190,6 @@ class GroqModel(Model):
180
190
  async with response:
181
191
  yield await self._process_streamed_response(response, model_request_parameters)
182
192
 
183
- @property
184
- def model_name(self) -> GroqModelName:
185
- """The model name."""
186
- return self._model_name
187
-
188
- @property
189
- def system(self) -> str:
190
- """The system / model provider."""
191
- return self._system
192
-
193
193
  @overload
194
194
  async def _completions_create(
195
195
  self,
@@ -285,7 +285,11 @@ class GroqModel(Model):
285
285
  for c in choice.message.tool_calls:
286
286
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
287
287
  return ModelResponse(
288
- items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
288
+ items,
289
+ usage=_map_usage(response),
290
+ model_name=response.model,
291
+ timestamp=timestamp,
292
+ provider_request_id=response.id,
289
293
  )
290
294
 
291
295
  async def _process_streamed_response(
@@ -457,6 +461,7 @@ class GroqStreamedResponse(StreamedResponse):
457
461
  vendor_part_id='content',
458
462
  content=content,
459
463
  thinking_tags=self._model_profile.thinking_tags,
464
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
460
465
  )
461
466
  if maybe_event is not None: # pragma: no branch
462
467
  yield maybe_event
@@ -483,7 +488,7 @@ class GroqStreamedResponse(StreamedResponse):
483
488
  return self._timestamp
484
489
 
485
490
 
486
- def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage:
491
+ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.RequestUsage:
487
492
  response_usage = None
488
493
  if isinstance(completion, chat.ChatCompletion):
489
494
  response_usage = completion.usage
@@ -491,10 +496,9 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
491
496
  response_usage = completion.x_groq.usage
492
497
 
493
498
  if response_usage is None:
494
- return usage.Usage()
499
+ return usage.RequestUsage()
495
500
 
496
- return usage.Usage(
497
- request_tokens=response_usage.prompt_tokens,
498
- response_tokens=response_usage.completion_tokens,
499
- total_tokens=response_usage.total_tokens,
501
+ return usage.RequestUsage(
502
+ input_tokens=response_usage.prompt_tokens,
503
+ output_tokens=response_usage.completion_tokens,
500
504
  )
@@ -35,7 +35,7 @@ from ..messages import (
35
35
  UserPromptPart,
36
36
  VideoUrl,
37
37
  )
38
- from ..profiles import ModelProfile
38
+ from ..profiles import ModelProfile, ModelProfileSpec
39
39
  from ..providers import Provider, infer_provider
40
40
  from ..settings import ModelSettings
41
41
  from ..tools import ToolDefinition
@@ -114,13 +114,15 @@ class HuggingFaceModel(Model):
114
114
  client: AsyncInferenceClient = field(repr=False)
115
115
 
116
116
  _model_name: str = field(repr=False)
117
- _system: str = field(default='huggingface', repr=False)
117
+ _provider: Provider[AsyncInferenceClient] = field(repr=False)
118
118
 
119
119
  def __init__(
120
120
  self,
121
121
  model_name: str,
122
122
  *,
123
123
  provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
124
+ profile: ModelProfileSpec | None = None,
125
+ settings: ModelSettings | None = None,
124
126
  ):
125
127
  """Initialize a Hugging Face model.
126
128
 
@@ -128,13 +130,27 @@ class HuggingFaceModel(Model):
128
130
  model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
129
131
  provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an
130
132
  instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used.
133
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
134
+ settings: Model-specific settings that will be used as defaults for this model.
131
135
  """
132
136
  self._model_name = model_name
133
- self._provider = provider
134
137
  if isinstance(provider, str):
135
138
  provider = infer_provider(provider)
139
+ self._provider = provider
136
140
  self.client = provider.client
137
141
 
142
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
143
+
144
+ @property
145
+ def model_name(self) -> HuggingFaceModelName:
146
+ """The model name."""
147
+ return self._model_name
148
+
149
+ @property
150
+ def system(self) -> str:
151
+ """The system / model provider."""
152
+ return self._provider.name
153
+
138
154
  async def request(
139
155
  self,
140
156
  messages: list[ModelMessage],
@@ -146,7 +162,6 @@ class HuggingFaceModel(Model):
146
162
  messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
147
163
  )
148
164
  model_response = self._process_response(response)
149
- model_response.usage.requests = 1
150
165
  return model_response
151
166
 
152
167
  @asynccontextmanager
@@ -163,16 +178,6 @@ class HuggingFaceModel(Model):
163
178
  )
164
179
  yield await self._process_streamed_response(response, model_request_parameters)
165
180
 
166
- @property
167
- def model_name(self) -> HuggingFaceModelName:
168
- """The model name."""
169
- return self._model_name
170
-
171
- @property
172
- def system(self) -> str:
173
- """The system / model provider."""
174
- return self._system
175
-
176
181
  @overload
177
182
  async def _completions_create(
178
183
  self,
@@ -266,7 +271,7 @@ class HuggingFaceModel(Model):
266
271
  usage=_map_usage(response),
267
272
  model_name=response.model,
268
273
  timestamp=timestamp,
269
- vendor_id=response.id,
274
+ provider_request_id=response.id,
270
275
  )
271
276
 
272
277
  async def _process_streamed_response(
@@ -444,11 +449,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
444
449
 
445
450
  # Handle the text part of the response
446
451
  content = choice.delta.content
447
- if content:
452
+ if content is not None:
448
453
  maybe_event = self._parts_manager.handle_text_delta(
449
454
  vendor_part_id='content',
450
455
  content=content,
451
456
  thinking_tags=self._model_profile.thinking_tags,
457
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
452
458
  )
453
459
  if maybe_event is not None: # pragma: no branch
454
460
  yield maybe_event
@@ -474,14 +480,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
474
480
  return self._timestamp
475
481
 
476
482
 
477
- def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage:
483
+ def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.RequestUsage:
478
484
  response_usage = response.usage
479
485
  if response_usage is None:
480
- return usage.Usage()
486
+ return usage.RequestUsage()
481
487
 
482
- return usage.Usage(
483
- request_tokens=response_usage.prompt_tokens,
484
- response_tokens=response_usage.completion_tokens,
485
- total_tokens=response_usage.total_tokens,
486
- details=None,
488
+ return usage.RequestUsage(
489
+ input_tokens=response_usage.prompt_tokens,
490
+ output_tokens=response_usage.completion_tokens,
487
491
  )
@@ -280,14 +280,14 @@ class InstrumentedModel(WrapperModel):
280
280
  'gen_ai.request.model': request_model,
281
281
  'gen_ai.response.model': response_model,
282
282
  }
283
- if response.usage.request_tokens: # pragma: no branch
283
+ if response.usage.input_tokens: # pragma: no branch
284
284
  self.instrumentation_settings.tokens_histogram.record(
285
- response.usage.request_tokens,
285
+ response.usage.input_tokens,
286
286
  {**metric_attributes, 'gen_ai.token.type': 'input'},
287
287
  )
288
- if response.usage.response_tokens: # pragma: no branch
288
+ if response.usage.output_tokens: # pragma: no branch
289
289
  self.instrumentation_settings.tokens_histogram.record(
290
- response.usage.response_tokens,
290
+ response.usage.output_tokens,
291
291
  {**metric_attributes, 'gen_ai.token.type': 'output'},
292
292
  )
293
293
 
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass
6
6
  from typing import TYPE_CHECKING, Any, cast
7
7
 
8
- from .. import _mcp, exceptions, usage
8
+ from .. import _mcp, exceptions
9
9
  from .._run_context import RunContext
10
10
  from ..messages import ModelMessage, ModelResponse
11
11
  from ..settings import ModelSettings
@@ -63,7 +63,6 @@ class MCPSamplingModel(Model):
63
63
  if result.role == 'assistant':
64
64
  return ModelResponse(
65
65
  parts=[_mcp.map_from_sampling_content(result.content)],
66
- usage=usage.Usage(requests=1),
67
66
  model_name=result.model,
68
67
  )
69
68
  else: