pydantic-ai-slim 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +60 -57
- pydantic_ai/_cli.py +18 -3
- pydantic_ai/_parts_manager.py +5 -4
- pydantic_ai/_run_context.py +2 -2
- pydantic_ai/_tool_manager.py +50 -29
- pydantic_ai/ag_ui.py +4 -4
- pydantic_ai/agent/__init__.py +69 -84
- pydantic_ai/agent/abstract.py +16 -18
- pydantic_ai/agent/wrapper.py +4 -6
- pydantic_ai/direct.py +4 -4
- pydantic_ai/durable_exec/temporal/_agent.py +13 -15
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/messages.py +16 -6
- pydantic_ai/models/__init__.py +5 -5
- pydantic_ai/models/anthropic.py +47 -46
- pydantic_ai/models/bedrock.py +25 -27
- pydantic_ai/models/cohere.py +20 -25
- pydantic_ai/models/fallback.py +15 -15
- pydantic_ai/models/function.py +7 -9
- pydantic_ai/models/gemini.py +43 -39
- pydantic_ai/models/google.py +59 -40
- pydantic_ai/models/groq.py +23 -19
- pydantic_ai/models/huggingface.py +27 -23
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mcp_sampling.py +1 -2
- pydantic_ai/models/mistral.py +24 -22
- pydantic_ai/models/openai.py +101 -45
- pydantic_ai/models/test.py +4 -5
- pydantic_ai/profiles/__init__.py +10 -1
- pydantic_ai/profiles/deepseek.py +1 -1
- pydantic_ai/profiles/moonshotai.py +1 -1
- pydantic_ai/profiles/openai.py +13 -3
- pydantic_ai/profiles/qwen.py +4 -1
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/huggingface.py +27 -0
- pydantic_ai/providers/ollama.py +105 -0
- pydantic_ai/providers/openai.py +1 -1
- pydantic_ai/providers/openrouter.py +2 -0
- pydantic_ai/result.py +6 -6
- pydantic_ai/run.py +4 -11
- pydantic_ai/tools.py +9 -9
- pydantic_ai/usage.py +229 -67
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/METADATA +10 -4
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/RECORD +47 -46
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/gemini.py
CHANGED
|
@@ -40,14 +40,7 @@ from ..profiles import ModelProfileSpec
|
|
|
40
40
|
from ..providers import Provider, infer_provider
|
|
41
41
|
from ..settings import ModelSettings
|
|
42
42
|
from ..tools import ToolDefinition
|
|
43
|
-
from . import
|
|
44
|
-
Model,
|
|
45
|
-
ModelRequestParameters,
|
|
46
|
-
StreamedResponse,
|
|
47
|
-
check_allow_model_requests,
|
|
48
|
-
download_item,
|
|
49
|
-
get_user_agent,
|
|
50
|
-
)
|
|
43
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
|
|
51
44
|
|
|
52
45
|
LatestGeminiModelNames = Literal[
|
|
53
46
|
'gemini-2.0-flash',
|
|
@@ -108,10 +101,9 @@ class GeminiModel(Model):
|
|
|
108
101
|
client: httpx.AsyncClient = field(repr=False)
|
|
109
102
|
|
|
110
103
|
_model_name: GeminiModelName = field(repr=False)
|
|
111
|
-
_provider:
|
|
104
|
+
_provider: Provider[httpx.AsyncClient] = field(repr=False)
|
|
112
105
|
_auth: AuthProtocol | None = field(repr=False)
|
|
113
106
|
_url: str | None = field(repr=False)
|
|
114
|
-
_system: str = field(default='gemini', repr=False)
|
|
115
107
|
|
|
116
108
|
def __init__(
|
|
117
109
|
self,
|
|
@@ -132,11 +124,10 @@ class GeminiModel(Model):
|
|
|
132
124
|
settings: Default model settings for this model instance.
|
|
133
125
|
"""
|
|
134
126
|
self._model_name = model_name
|
|
135
|
-
self._provider = provider
|
|
136
127
|
|
|
137
128
|
if isinstance(provider, str):
|
|
138
129
|
provider = infer_provider(provider)
|
|
139
|
-
self.
|
|
130
|
+
self._provider = provider
|
|
140
131
|
self.client = provider.client
|
|
141
132
|
self._url = str(self.client.base_url)
|
|
142
133
|
|
|
@@ -147,6 +138,16 @@ class GeminiModel(Model):
|
|
|
147
138
|
assert self._url is not None, 'URL not initialized' # pragma: no cover
|
|
148
139
|
return self._url # pragma: no cover
|
|
149
140
|
|
|
141
|
+
@property
|
|
142
|
+
def model_name(self) -> GeminiModelName:
|
|
143
|
+
"""The model name."""
|
|
144
|
+
return self._model_name
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def system(self) -> str:
|
|
148
|
+
"""The model provider."""
|
|
149
|
+
return self._provider.name
|
|
150
|
+
|
|
150
151
|
async def request(
|
|
151
152
|
self,
|
|
152
153
|
messages: list[ModelMessage],
|
|
@@ -175,16 +176,6 @@ class GeminiModel(Model):
|
|
|
175
176
|
) as http_response:
|
|
176
177
|
yield await self._process_streamed_response(http_response, model_request_parameters)
|
|
177
178
|
|
|
178
|
-
@property
|
|
179
|
-
def model_name(self) -> GeminiModelName:
|
|
180
|
-
"""The model name."""
|
|
181
|
-
return self._model_name
|
|
182
|
-
|
|
183
|
-
@property
|
|
184
|
-
def system(self) -> str:
|
|
185
|
-
"""The system / model provider."""
|
|
186
|
-
return self._system
|
|
187
|
-
|
|
188
179
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
|
|
189
180
|
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
|
|
190
181
|
return _GeminiTools(function_declarations=tools) if tools else None
|
|
@@ -237,7 +228,7 @@ class GeminiModel(Model):
|
|
|
237
228
|
request_data['safetySettings'] = gemini_safety_settings
|
|
238
229
|
|
|
239
230
|
if gemini_labels := model_settings.get('gemini_labels'):
|
|
240
|
-
if self.
|
|
231
|
+
if self._provider.name == 'google-vertex':
|
|
241
232
|
request_data['labels'] = gemini_labels # pragma: lax no cover
|
|
242
233
|
|
|
243
234
|
headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
@@ -278,7 +269,6 @@ class GeminiModel(Model):
|
|
|
278
269
|
if finish_reason:
|
|
279
270
|
vendor_details = {'finish_reason': finish_reason}
|
|
280
271
|
usage = _metadata_as_usage(response)
|
|
281
|
-
usage.requests = 1
|
|
282
272
|
return _process_response_from_parts(
|
|
283
273
|
parts,
|
|
284
274
|
response.get('model_version', self._model_name),
|
|
@@ -673,7 +663,7 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
|
|
|
673
663
|
def _process_response_from_parts(
|
|
674
664
|
parts: Sequence[_GeminiPartUnion],
|
|
675
665
|
model_name: GeminiModelName,
|
|
676
|
-
usage: usage.
|
|
666
|
+
usage: usage.RequestUsage,
|
|
677
667
|
vendor_id: str | None,
|
|
678
668
|
vendor_details: dict[str, Any] | None = None,
|
|
679
669
|
) -> ModelResponse:
|
|
@@ -693,7 +683,7 @@ def _process_response_from_parts(
|
|
|
693
683
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
694
684
|
)
|
|
695
685
|
return ModelResponse(
|
|
696
|
-
parts=items, usage=usage, model_name=model_name,
|
|
686
|
+
parts=items, usage=usage, model_name=model_name, provider_request_id=vendor_id, provider_details=vendor_details
|
|
697
687
|
)
|
|
698
688
|
|
|
699
689
|
|
|
@@ -859,31 +849,45 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
859
849
|
]
|
|
860
850
|
|
|
861
851
|
|
|
862
|
-
def _metadata_as_usage(response: _GeminiResponse) -> usage.
|
|
852
|
+
def _metadata_as_usage(response: _GeminiResponse) -> usage.RequestUsage:
|
|
863
853
|
metadata = response.get('usage_metadata')
|
|
864
854
|
if metadata is None:
|
|
865
|
-
return usage.
|
|
855
|
+
return usage.RequestUsage()
|
|
866
856
|
details: dict[str, int] = {}
|
|
867
|
-
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
868
|
-
details['cached_content_tokens'] = cached_content_token_count
|
|
857
|
+
if cached_content_token_count := metadata.get('cached_content_token_count', 0):
|
|
858
|
+
details['cached_content_tokens'] = cached_content_token_count
|
|
869
859
|
|
|
870
|
-
if thoughts_token_count := metadata.get('thoughts_token_count'):
|
|
860
|
+
if thoughts_token_count := metadata.get('thoughts_token_count', 0):
|
|
871
861
|
details['thoughts_tokens'] = thoughts_token_count
|
|
872
862
|
|
|
873
|
-
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
|
|
874
|
-
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
|
|
863
|
+
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count', 0):
|
|
864
|
+
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
|
|
875
865
|
|
|
866
|
+
input_audio_tokens = 0
|
|
867
|
+
output_audio_tokens = 0
|
|
868
|
+
cache_audio_read_tokens = 0
|
|
876
869
|
for key, metadata_details in metadata.items():
|
|
877
870
|
if key.endswith('_details') and metadata_details:
|
|
878
871
|
metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
|
|
879
872
|
suffix = key.removesuffix('_details')
|
|
880
873
|
for detail in metadata_details:
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
874
|
+
modality = detail['modality']
|
|
875
|
+
details[f'{modality.lower()}_{suffix}'] = value = detail.get('token_count', 0)
|
|
876
|
+
if value and modality == 'AUDIO':
|
|
877
|
+
if key == 'prompt_tokens_details':
|
|
878
|
+
input_audio_tokens = value
|
|
879
|
+
elif key == 'candidates_tokens_details':
|
|
880
|
+
output_audio_tokens = value
|
|
881
|
+
elif key == 'cache_tokens_details': # pragma: no branch
|
|
882
|
+
cache_audio_read_tokens = value
|
|
883
|
+
|
|
884
|
+
return usage.RequestUsage(
|
|
885
|
+
input_tokens=metadata.get('prompt_token_count', 0),
|
|
886
|
+
output_tokens=metadata.get('candidates_token_count', 0),
|
|
887
|
+
cache_read_tokens=cached_content_token_count,
|
|
888
|
+
input_audio_tokens=input_audio_tokens,
|
|
889
|
+
output_audio_tokens=output_audio_tokens,
|
|
890
|
+
cache_audio_read_tokens=cache_audio_read_tokens,
|
|
887
891
|
details=details,
|
|
888
892
|
)
|
|
889
893
|
|
pydantic_ai/models/google.py
CHANGED
|
@@ -144,7 +144,6 @@ class GoogleModel(Model):
|
|
|
144
144
|
_model_name: GoogleModelName = field(repr=False)
|
|
145
145
|
_provider: Provider[Client] = field(repr=False)
|
|
146
146
|
_url: str | None = field(repr=False)
|
|
147
|
-
_system: str = field(default='google', repr=False)
|
|
148
147
|
|
|
149
148
|
def __init__(
|
|
150
149
|
self,
|
|
@@ -168,9 +167,7 @@ class GoogleModel(Model):
|
|
|
168
167
|
|
|
169
168
|
if isinstance(provider, str):
|
|
170
169
|
provider = GoogleProvider(vertexai=provider == 'google-vertex')
|
|
171
|
-
|
|
172
170
|
self._provider = provider
|
|
173
|
-
self._system = provider.name
|
|
174
171
|
self.client = provider.client
|
|
175
172
|
|
|
176
173
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -179,6 +176,16 @@ class GoogleModel(Model):
|
|
|
179
176
|
def base_url(self) -> str:
|
|
180
177
|
return self._provider.base_url
|
|
181
178
|
|
|
179
|
+
@property
|
|
180
|
+
def model_name(self) -> GoogleModelName:
|
|
181
|
+
"""The model name."""
|
|
182
|
+
return self._model_name
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def system(self) -> str:
|
|
186
|
+
"""The model provider."""
|
|
187
|
+
return self._provider.name
|
|
188
|
+
|
|
182
189
|
async def request(
|
|
183
190
|
self,
|
|
184
191
|
messages: list[ModelMessage],
|
|
@@ -195,7 +202,7 @@ class GoogleModel(Model):
|
|
|
195
202
|
messages: list[ModelMessage],
|
|
196
203
|
model_settings: ModelSettings | None,
|
|
197
204
|
model_request_parameters: ModelRequestParameters,
|
|
198
|
-
) -> usage.
|
|
205
|
+
) -> usage.RequestUsage:
|
|
199
206
|
check_allow_model_requests()
|
|
200
207
|
model_settings = cast(GoogleModelSettings, model_settings or {})
|
|
201
208
|
contents, generation_config = await self._build_content_and_config(
|
|
@@ -209,7 +216,7 @@ class GoogleModel(Model):
|
|
|
209
216
|
config = CountTokensConfigDict(
|
|
210
217
|
http_options=generation_config.get('http_options'),
|
|
211
218
|
)
|
|
212
|
-
if self.
|
|
219
|
+
if self._provider.name != 'google-gla':
|
|
213
220
|
# The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
|
|
214
221
|
config.update(
|
|
215
222
|
system_instruction=generation_config.get('system_instruction'),
|
|
@@ -238,9 +245,8 @@ class GoogleModel(Model):
|
|
|
238
245
|
raise UnexpectedModelBehavior( # pragma: no cover
|
|
239
246
|
'Total tokens missing from Gemini response', str(response)
|
|
240
247
|
)
|
|
241
|
-
return usage.
|
|
242
|
-
|
|
243
|
-
total_tokens=response.total_tokens,
|
|
248
|
+
return usage.RequestUsage(
|
|
249
|
+
input_tokens=response.total_tokens,
|
|
244
250
|
)
|
|
245
251
|
|
|
246
252
|
@asynccontextmanager
|
|
@@ -256,16 +262,6 @@ class GoogleModel(Model):
|
|
|
256
262
|
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
|
|
257
263
|
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
|
|
258
264
|
|
|
259
|
-
@property
|
|
260
|
-
def model_name(self) -> GoogleModelName:
|
|
261
|
-
"""The model name."""
|
|
262
|
-
return self._model_name
|
|
263
|
-
|
|
264
|
-
@property
|
|
265
|
-
def system(self) -> str:
|
|
266
|
-
"""The system / model provider."""
|
|
267
|
-
return self._system
|
|
268
|
-
|
|
269
265
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
|
|
270
266
|
tools: list[ToolDict] = [
|
|
271
267
|
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
@@ -392,9 +388,12 @@ class GoogleModel(Model):
|
|
|
392
388
|
if finish_reason: # pragma: no branch
|
|
393
389
|
vendor_details = {'finish_reason': finish_reason.value}
|
|
394
390
|
usage = _metadata_as_usage(response)
|
|
395
|
-
usage.requests = 1
|
|
396
391
|
return _process_response_from_parts(
|
|
397
|
-
parts,
|
|
392
|
+
parts,
|
|
393
|
+
response.model_version or self._model_name,
|
|
394
|
+
usage,
|
|
395
|
+
vendor_id=vendor_id,
|
|
396
|
+
vendor_details=vendor_details,
|
|
398
397
|
)
|
|
399
398
|
|
|
400
399
|
async def _process_streamed_response(
|
|
@@ -590,7 +589,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
|
590
589
|
def _process_response_from_parts(
|
|
591
590
|
parts: list[Part],
|
|
592
591
|
model_name: GoogleModelName,
|
|
593
|
-
usage: usage.
|
|
592
|
+
usage: usage.RequestUsage,
|
|
594
593
|
vendor_id: str | None,
|
|
595
594
|
vendor_details: dict[str, Any] | None = None,
|
|
596
595
|
) -> ModelResponse:
|
|
@@ -627,7 +626,7 @@ def _process_response_from_parts(
|
|
|
627
626
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
628
627
|
)
|
|
629
628
|
return ModelResponse(
|
|
630
|
-
parts=items, model_name=model_name, usage=usage,
|
|
629
|
+
parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details
|
|
631
630
|
)
|
|
632
631
|
|
|
633
632
|
|
|
@@ -647,31 +646,51 @@ def _tool_config(function_names: list[str]) -> ToolConfigDict:
|
|
|
647
646
|
return ToolConfigDict(function_calling_config=function_calling_config)
|
|
648
647
|
|
|
649
648
|
|
|
650
|
-
def _metadata_as_usage(response: GenerateContentResponse) -> usage.
|
|
649
|
+
def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
|
|
651
650
|
metadata = response.usage_metadata
|
|
652
651
|
if metadata is None:
|
|
653
|
-
return usage.
|
|
654
|
-
metadata = metadata.model_dump(exclude_defaults=True)
|
|
655
|
-
|
|
652
|
+
return usage.RequestUsage()
|
|
656
653
|
details: dict[str, int] = {}
|
|
657
|
-
if cached_content_token_count := metadata.
|
|
658
|
-
details['cached_content_tokens'] = cached_content_token_count
|
|
654
|
+
if cached_content_token_count := metadata.cached_content_token_count:
|
|
655
|
+
details['cached_content_tokens'] = cached_content_token_count
|
|
659
656
|
|
|
660
|
-
if thoughts_token_count := metadata.
|
|
657
|
+
if thoughts_token_count := metadata.thoughts_token_count:
|
|
661
658
|
details['thoughts_tokens'] = thoughts_token_count
|
|
662
659
|
|
|
663
|
-
if tool_use_prompt_token_count := metadata.
|
|
660
|
+
if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count:
|
|
664
661
|
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
|
|
665
662
|
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
663
|
+
input_audio_tokens = 0
|
|
664
|
+
output_audio_tokens = 0
|
|
665
|
+
cache_audio_read_tokens = 0
|
|
666
|
+
for prefix, metadata_details in [
|
|
667
|
+
('prompt', metadata.prompt_tokens_details),
|
|
668
|
+
('cache', metadata.cache_tokens_details),
|
|
669
|
+
('candidates', metadata.candidates_tokens_details),
|
|
670
|
+
('tool_use_prompt', metadata.tool_use_prompt_tokens_details),
|
|
671
|
+
]:
|
|
672
|
+
assert getattr(metadata, f'{prefix}_tokens_details') is metadata_details
|
|
673
|
+
if not metadata_details:
|
|
674
|
+
continue
|
|
675
|
+
for detail in metadata_details:
|
|
676
|
+
if not detail.modality or not detail.token_count: # pragma: no cover
|
|
677
|
+
continue
|
|
678
|
+
details[f'{detail.modality.lower()}_{prefix}_tokens'] = detail.token_count
|
|
679
|
+
if detail.modality != 'AUDIO':
|
|
680
|
+
continue
|
|
681
|
+
if metadata_details is metadata.prompt_tokens_details:
|
|
682
|
+
input_audio_tokens = detail.token_count
|
|
683
|
+
elif metadata_details is metadata.candidates_tokens_details:
|
|
684
|
+
output_audio_tokens = detail.token_count
|
|
685
|
+
elif metadata_details is metadata.cache_tokens_details: # pragma: no branch
|
|
686
|
+
cache_audio_read_tokens = detail.token_count
|
|
687
|
+
|
|
688
|
+
return usage.RequestUsage(
|
|
689
|
+
input_tokens=metadata.prompt_token_count or 0,
|
|
690
|
+
output_tokens=metadata.candidates_token_count or 0,
|
|
691
|
+
cache_read_tokens=cached_content_token_count or 0,
|
|
692
|
+
input_audio_tokens=input_audio_tokens,
|
|
693
|
+
output_audio_tokens=output_audio_tokens,
|
|
694
|
+
cache_audio_read_tokens=cache_audio_read_tokens,
|
|
676
695
|
details=details,
|
|
677
696
|
)
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -118,7 +118,7 @@ class GroqModel(Model):
|
|
|
118
118
|
client: AsyncGroq = field(repr=False)
|
|
119
119
|
|
|
120
120
|
_model_name: GroqModelName = field(repr=False)
|
|
121
|
-
|
|
121
|
+
_provider: Provider[AsyncGroq] = field(repr=False)
|
|
122
122
|
|
|
123
123
|
def __init__(
|
|
124
124
|
self,
|
|
@@ -143,6 +143,7 @@ class GroqModel(Model):
|
|
|
143
143
|
|
|
144
144
|
if isinstance(provider, str):
|
|
145
145
|
provider = infer_provider(provider)
|
|
146
|
+
self._provider = provider
|
|
146
147
|
self.client = provider.client
|
|
147
148
|
|
|
148
149
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -151,6 +152,16 @@ class GroqModel(Model):
|
|
|
151
152
|
def base_url(self) -> str:
|
|
152
153
|
return str(self.client.base_url)
|
|
153
154
|
|
|
155
|
+
@property
|
|
156
|
+
def model_name(self) -> GroqModelName:
|
|
157
|
+
"""The model name."""
|
|
158
|
+
return self._model_name
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def system(self) -> str:
|
|
162
|
+
"""The model provider."""
|
|
163
|
+
return self._provider.name
|
|
164
|
+
|
|
154
165
|
async def request(
|
|
155
166
|
self,
|
|
156
167
|
messages: list[ModelMessage],
|
|
@@ -162,7 +173,6 @@ class GroqModel(Model):
|
|
|
162
173
|
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
163
174
|
)
|
|
164
175
|
model_response = self._process_response(response)
|
|
165
|
-
model_response.usage.requests = 1
|
|
166
176
|
return model_response
|
|
167
177
|
|
|
168
178
|
@asynccontextmanager
|
|
@@ -180,16 +190,6 @@ class GroqModel(Model):
|
|
|
180
190
|
async with response:
|
|
181
191
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
182
192
|
|
|
183
|
-
@property
|
|
184
|
-
def model_name(self) -> GroqModelName:
|
|
185
|
-
"""The model name."""
|
|
186
|
-
return self._model_name
|
|
187
|
-
|
|
188
|
-
@property
|
|
189
|
-
def system(self) -> str:
|
|
190
|
-
"""The system / model provider."""
|
|
191
|
-
return self._system
|
|
192
|
-
|
|
193
193
|
@overload
|
|
194
194
|
async def _completions_create(
|
|
195
195
|
self,
|
|
@@ -285,7 +285,11 @@ class GroqModel(Model):
|
|
|
285
285
|
for c in choice.message.tool_calls:
|
|
286
286
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
287
287
|
return ModelResponse(
|
|
288
|
-
items,
|
|
288
|
+
items,
|
|
289
|
+
usage=_map_usage(response),
|
|
290
|
+
model_name=response.model,
|
|
291
|
+
timestamp=timestamp,
|
|
292
|
+
provider_request_id=response.id,
|
|
289
293
|
)
|
|
290
294
|
|
|
291
295
|
async def _process_streamed_response(
|
|
@@ -457,6 +461,7 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
457
461
|
vendor_part_id='content',
|
|
458
462
|
content=content,
|
|
459
463
|
thinking_tags=self._model_profile.thinking_tags,
|
|
464
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
460
465
|
)
|
|
461
466
|
if maybe_event is not None: # pragma: no branch
|
|
462
467
|
yield maybe_event
|
|
@@ -483,7 +488,7 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
483
488
|
return self._timestamp
|
|
484
489
|
|
|
485
490
|
|
|
486
|
-
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.
|
|
491
|
+
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.RequestUsage:
|
|
487
492
|
response_usage = None
|
|
488
493
|
if isinstance(completion, chat.ChatCompletion):
|
|
489
494
|
response_usage = completion.usage
|
|
@@ -491,10 +496,9 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
|
|
491
496
|
response_usage = completion.x_groq.usage
|
|
492
497
|
|
|
493
498
|
if response_usage is None:
|
|
494
|
-
return usage.
|
|
499
|
+
return usage.RequestUsage()
|
|
495
500
|
|
|
496
|
-
return usage.
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
total_tokens=response_usage.total_tokens,
|
|
501
|
+
return usage.RequestUsage(
|
|
502
|
+
input_tokens=response_usage.prompt_tokens,
|
|
503
|
+
output_tokens=response_usage.completion_tokens,
|
|
500
504
|
)
|
|
@@ -35,7 +35,7 @@ from ..messages import (
|
|
|
35
35
|
UserPromptPart,
|
|
36
36
|
VideoUrl,
|
|
37
37
|
)
|
|
38
|
-
from ..profiles import ModelProfile
|
|
38
|
+
from ..profiles import ModelProfile, ModelProfileSpec
|
|
39
39
|
from ..providers import Provider, infer_provider
|
|
40
40
|
from ..settings import ModelSettings
|
|
41
41
|
from ..tools import ToolDefinition
|
|
@@ -114,13 +114,15 @@ class HuggingFaceModel(Model):
|
|
|
114
114
|
client: AsyncInferenceClient = field(repr=False)
|
|
115
115
|
|
|
116
116
|
_model_name: str = field(repr=False)
|
|
117
|
-
|
|
117
|
+
_provider: Provider[AsyncInferenceClient] = field(repr=False)
|
|
118
118
|
|
|
119
119
|
def __init__(
|
|
120
120
|
self,
|
|
121
121
|
model_name: str,
|
|
122
122
|
*,
|
|
123
123
|
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
|
|
124
|
+
profile: ModelProfileSpec | None = None,
|
|
125
|
+
settings: ModelSettings | None = None,
|
|
124
126
|
):
|
|
125
127
|
"""Initialize a Hugging Face model.
|
|
126
128
|
|
|
@@ -128,13 +130,27 @@ class HuggingFaceModel(Model):
|
|
|
128
130
|
model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
129
131
|
provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an
|
|
130
132
|
instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used.
|
|
133
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
134
|
+
settings: Model-specific settings that will be used as defaults for this model.
|
|
131
135
|
"""
|
|
132
136
|
self._model_name = model_name
|
|
133
|
-
self._provider = provider
|
|
134
137
|
if isinstance(provider, str):
|
|
135
138
|
provider = infer_provider(provider)
|
|
139
|
+
self._provider = provider
|
|
136
140
|
self.client = provider.client
|
|
137
141
|
|
|
142
|
+
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def model_name(self) -> HuggingFaceModelName:
|
|
146
|
+
"""The model name."""
|
|
147
|
+
return self._model_name
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def system(self) -> str:
|
|
151
|
+
"""The system / model provider."""
|
|
152
|
+
return self._provider.name
|
|
153
|
+
|
|
138
154
|
async def request(
|
|
139
155
|
self,
|
|
140
156
|
messages: list[ModelMessage],
|
|
@@ -146,7 +162,6 @@ class HuggingFaceModel(Model):
|
|
|
146
162
|
messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
147
163
|
)
|
|
148
164
|
model_response = self._process_response(response)
|
|
149
|
-
model_response.usage.requests = 1
|
|
150
165
|
return model_response
|
|
151
166
|
|
|
152
167
|
@asynccontextmanager
|
|
@@ -163,16 +178,6 @@ class HuggingFaceModel(Model):
|
|
|
163
178
|
)
|
|
164
179
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
165
180
|
|
|
166
|
-
@property
|
|
167
|
-
def model_name(self) -> HuggingFaceModelName:
|
|
168
|
-
"""The model name."""
|
|
169
|
-
return self._model_name
|
|
170
|
-
|
|
171
|
-
@property
|
|
172
|
-
def system(self) -> str:
|
|
173
|
-
"""The system / model provider."""
|
|
174
|
-
return self._system
|
|
175
|
-
|
|
176
181
|
@overload
|
|
177
182
|
async def _completions_create(
|
|
178
183
|
self,
|
|
@@ -266,7 +271,7 @@ class HuggingFaceModel(Model):
|
|
|
266
271
|
usage=_map_usage(response),
|
|
267
272
|
model_name=response.model,
|
|
268
273
|
timestamp=timestamp,
|
|
269
|
-
|
|
274
|
+
provider_request_id=response.id,
|
|
270
275
|
)
|
|
271
276
|
|
|
272
277
|
async def _process_streamed_response(
|
|
@@ -444,11 +449,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
444
449
|
|
|
445
450
|
# Handle the text part of the response
|
|
446
451
|
content = choice.delta.content
|
|
447
|
-
if content:
|
|
452
|
+
if content is not None:
|
|
448
453
|
maybe_event = self._parts_manager.handle_text_delta(
|
|
449
454
|
vendor_part_id='content',
|
|
450
455
|
content=content,
|
|
451
456
|
thinking_tags=self._model_profile.thinking_tags,
|
|
457
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
452
458
|
)
|
|
453
459
|
if maybe_event is not None: # pragma: no branch
|
|
454
460
|
yield maybe_event
|
|
@@ -474,14 +480,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
474
480
|
return self._timestamp
|
|
475
481
|
|
|
476
482
|
|
|
477
|
-
def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.
|
|
483
|
+
def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.RequestUsage:
|
|
478
484
|
response_usage = response.usage
|
|
479
485
|
if response_usage is None:
|
|
480
|
-
return usage.
|
|
486
|
+
return usage.RequestUsage()
|
|
481
487
|
|
|
482
|
-
return usage.
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
total_tokens=response_usage.total_tokens,
|
|
486
|
-
details=None,
|
|
488
|
+
return usage.RequestUsage(
|
|
489
|
+
input_tokens=response_usage.prompt_tokens,
|
|
490
|
+
output_tokens=response_usage.completion_tokens,
|
|
487
491
|
)
|
|
@@ -280,14 +280,14 @@ class InstrumentedModel(WrapperModel):
|
|
|
280
280
|
'gen_ai.request.model': request_model,
|
|
281
281
|
'gen_ai.response.model': response_model,
|
|
282
282
|
}
|
|
283
|
-
if response.usage.
|
|
283
|
+
if response.usage.input_tokens: # pragma: no branch
|
|
284
284
|
self.instrumentation_settings.tokens_histogram.record(
|
|
285
|
-
response.usage.
|
|
285
|
+
response.usage.input_tokens,
|
|
286
286
|
{**metric_attributes, 'gen_ai.token.type': 'input'},
|
|
287
287
|
)
|
|
288
|
-
if response.usage.
|
|
288
|
+
if response.usage.output_tokens: # pragma: no branch
|
|
289
289
|
self.instrumentation_settings.tokens_histogram.record(
|
|
290
|
-
response.usage.
|
|
290
|
+
response.usage.output_tokens,
|
|
291
291
|
{**metric_attributes, 'gen_ai.token.type': 'output'},
|
|
292
292
|
)
|
|
293
293
|
|
|
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
|
-
from .. import _mcp, exceptions
|
|
8
|
+
from .. import _mcp, exceptions
|
|
9
9
|
from .._run_context import RunContext
|
|
10
10
|
from ..messages import ModelMessage, ModelResponse
|
|
11
11
|
from ..settings import ModelSettings
|
|
@@ -63,7 +63,6 @@ class MCPSamplingModel(Model):
|
|
|
63
63
|
if result.role == 'assistant':
|
|
64
64
|
return ModelResponse(
|
|
65
65
|
parts=[_mcp.map_from_sampling_content(result.content)],
|
|
66
|
-
usage=usage.Usage(requests=1),
|
|
67
66
|
model_name=result.model,
|
|
68
67
|
)
|
|
69
68
|
else:
|