pydantic-ai-slim 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -1
- pydantic_ai/_a2a.py +3 -4
- pydantic_ai/_agent_graph.py +5 -2
- pydantic_ai/_output.py +130 -20
- pydantic_ai/_utils.py +6 -1
- pydantic_ai/agent.py +13 -10
- pydantic_ai/common_tools/duckduckgo.py +5 -2
- pydantic_ai/exceptions.py +2 -2
- pydantic_ai/messages.py +6 -4
- pydantic_ai/models/__init__.py +34 -1
- pydantic_ai/models/anthropic.py +5 -2
- pydantic_ai/models/bedrock.py +5 -2
- pydantic_ai/models/cohere.py +5 -2
- pydantic_ai/models/fallback.py +1 -0
- pydantic_ai/models/function.py +13 -2
- pydantic_ai/models/gemini.py +13 -10
- pydantic_ai/models/google.py +5 -2
- pydantic_ai/models/groq.py +5 -2
- pydantic_ai/models/huggingface.py +463 -0
- pydantic_ai/models/instrumented.py +12 -12
- pydantic_ai/models/mistral.py +6 -3
- pydantic_ai/models/openai.py +16 -4
- pydantic_ai/models/test.py +22 -1
- pydantic_ai/models/wrapper.py +6 -0
- pydantic_ai/output.py +65 -1
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +2 -2
- pydantic_ai/providers/google_vertex.py +10 -5
- pydantic_ai/providers/huggingface.py +88 -0
- pydantic_ai/result.py +16 -5
- {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/METADATA +7 -5
- {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/RECORD +35 -33
- {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
from collections.abc import AsyncIterable, AsyncIterator
|
|
5
|
+
from contextlib import asynccontextmanager
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from typing import Literal, Union, cast, overload
|
|
9
|
+
|
|
10
|
+
from typing_extensions import assert_never
|
|
11
|
+
|
|
12
|
+
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
13
|
+
from pydantic_ai.providers import Provider, infer_provider
|
|
14
|
+
|
|
15
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
16
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
|
|
17
|
+
from ..messages import (
|
|
18
|
+
AudioUrl,
|
|
19
|
+
BinaryContent,
|
|
20
|
+
DocumentUrl,
|
|
21
|
+
ImageUrl,
|
|
22
|
+
ModelMessage,
|
|
23
|
+
ModelRequest,
|
|
24
|
+
ModelResponse,
|
|
25
|
+
ModelResponsePart,
|
|
26
|
+
ModelResponseStreamEvent,
|
|
27
|
+
RetryPromptPart,
|
|
28
|
+
SystemPromptPart,
|
|
29
|
+
TextPart,
|
|
30
|
+
ThinkingPart,
|
|
31
|
+
ToolCallPart,
|
|
32
|
+
ToolReturnPart,
|
|
33
|
+
UserPromptPart,
|
|
34
|
+
VideoUrl,
|
|
35
|
+
)
|
|
36
|
+
from ..settings import ModelSettings
|
|
37
|
+
from ..tools import ToolDefinition
|
|
38
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
import aiohttp
|
|
42
|
+
from huggingface_hub import (
|
|
43
|
+
AsyncInferenceClient,
|
|
44
|
+
ChatCompletionInputMessage,
|
|
45
|
+
ChatCompletionInputMessageChunk,
|
|
46
|
+
ChatCompletionInputTool,
|
|
47
|
+
ChatCompletionInputToolCall,
|
|
48
|
+
ChatCompletionInputURL,
|
|
49
|
+
ChatCompletionOutput,
|
|
50
|
+
ChatCompletionOutputMessage,
|
|
51
|
+
ChatCompletionStreamOutput,
|
|
52
|
+
)
|
|
53
|
+
from huggingface_hub.errors import HfHubHTTPError
|
|
54
|
+
|
|
55
|
+
except ImportError as _import_error:
|
|
56
|
+
raise ImportError(
|
|
57
|
+
'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
|
|
58
|
+
'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
|
|
59
|
+
) from _import_error
|
|
60
|
+
|
|
61
|
+
__all__ = (
|
|
62
|
+
'HuggingFaceModel',
|
|
63
|
+
'HuggingFaceModelSettings',
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
HFSystemPromptRole = Literal['system', 'user']
|
|
68
|
+
|
|
69
|
+
LatestHuggingFaceModelNames = Literal[
|
|
70
|
+
'deepseek-ai/DeepSeek-R1',
|
|
71
|
+
'meta-llama/Llama-3.3-70B-Instruct',
|
|
72
|
+
'meta-llama/Llama-4-Maverick-17B-128E-Instruct',
|
|
73
|
+
'meta-llama/Llama-4-Scout-17B-16E-Instruct',
|
|
74
|
+
'Qwen/QwQ-32B',
|
|
75
|
+
'Qwen/Qwen2.5-72B-Instruct',
|
|
76
|
+
'Qwen/Qwen3-235B-A22B',
|
|
77
|
+
'Qwen/Qwen3-32B',
|
|
78
|
+
]
|
|
79
|
+
"""Latest Hugging Face models."""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames]
|
|
83
|
+
"""Possible Hugging Face model names.
|
|
84
|
+
|
|
85
|
+
You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class HuggingFaceModelSettings(ModelSettings, total=False):
|
|
90
|
+
"""Settings used for a Hugging Face model request."""
|
|
91
|
+
|
|
92
|
+
# ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
93
|
+
# This class is a placeholder for any future huggingface-specific settings
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass(init=False)
|
|
97
|
+
class HuggingFaceModel(Model):
|
|
98
|
+
"""A model that uses Hugging Face Inference Providers.
|
|
99
|
+
|
|
100
|
+
Internally, this uses the [HF Python client](https://github.com/huggingface/huggingface_hub) to interact with the API.
|
|
101
|
+
|
|
102
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
client: AsyncInferenceClient = field(repr=False)
|
|
106
|
+
|
|
107
|
+
_model_name: str = field(repr=False)
|
|
108
|
+
_system: str = field(default='huggingface', repr=False)
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
model_name: str,
|
|
113
|
+
*,
|
|
114
|
+
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
|
|
115
|
+
):
|
|
116
|
+
"""Initialize a Hugging Face model.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
120
|
+
provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an
|
|
121
|
+
instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used.
|
|
122
|
+
"""
|
|
123
|
+
self._model_name = model_name
|
|
124
|
+
self._provider = provider
|
|
125
|
+
if isinstance(provider, str):
|
|
126
|
+
provider = infer_provider(provider)
|
|
127
|
+
self.client = provider.client
|
|
128
|
+
|
|
129
|
+
async def request(
|
|
130
|
+
self,
|
|
131
|
+
messages: list[ModelMessage],
|
|
132
|
+
model_settings: ModelSettings | None,
|
|
133
|
+
model_request_parameters: ModelRequestParameters,
|
|
134
|
+
) -> ModelResponse:
|
|
135
|
+
check_allow_model_requests()
|
|
136
|
+
response = await self._completions_create(
|
|
137
|
+
messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
138
|
+
)
|
|
139
|
+
model_response = self._process_response(response)
|
|
140
|
+
model_response.usage.requests = 1
|
|
141
|
+
return model_response
|
|
142
|
+
|
|
143
|
+
@asynccontextmanager
|
|
144
|
+
async def request_stream(
|
|
145
|
+
self,
|
|
146
|
+
messages: list[ModelMessage],
|
|
147
|
+
model_settings: ModelSettings | None,
|
|
148
|
+
model_request_parameters: ModelRequestParameters,
|
|
149
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
150
|
+
check_allow_model_requests()
|
|
151
|
+
response = await self._completions_create(
|
|
152
|
+
messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
153
|
+
)
|
|
154
|
+
yield await self._process_streamed_response(response)
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def model_name(self) -> HuggingFaceModelName:
|
|
158
|
+
"""The model name."""
|
|
159
|
+
return self._model_name
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def system(self) -> str:
|
|
163
|
+
"""The system / model provider."""
|
|
164
|
+
return self._system
|
|
165
|
+
|
|
166
|
+
@overload
|
|
167
|
+
async def _completions_create(
|
|
168
|
+
self,
|
|
169
|
+
messages: list[ModelMessage],
|
|
170
|
+
stream: Literal[True],
|
|
171
|
+
model_settings: HuggingFaceModelSettings,
|
|
172
|
+
model_request_parameters: ModelRequestParameters,
|
|
173
|
+
) -> AsyncIterable[ChatCompletionStreamOutput]: ...
|
|
174
|
+
|
|
175
|
+
@overload
|
|
176
|
+
async def _completions_create(
|
|
177
|
+
self,
|
|
178
|
+
messages: list[ModelMessage],
|
|
179
|
+
stream: Literal[False],
|
|
180
|
+
model_settings: HuggingFaceModelSettings,
|
|
181
|
+
model_request_parameters: ModelRequestParameters,
|
|
182
|
+
) -> ChatCompletionOutput: ...
|
|
183
|
+
|
|
184
|
+
async def _completions_create(
|
|
185
|
+
self,
|
|
186
|
+
messages: list[ModelMessage],
|
|
187
|
+
stream: bool,
|
|
188
|
+
model_settings: HuggingFaceModelSettings,
|
|
189
|
+
model_request_parameters: ModelRequestParameters,
|
|
190
|
+
) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]:
|
|
191
|
+
tools = self._get_tools(model_request_parameters)
|
|
192
|
+
|
|
193
|
+
if not tools:
|
|
194
|
+
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
195
|
+
elif not model_request_parameters.allow_text_output:
|
|
196
|
+
tool_choice = 'required'
|
|
197
|
+
else:
|
|
198
|
+
tool_choice = 'auto'
|
|
199
|
+
|
|
200
|
+
hf_messages = await self._map_messages(messages)
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
return await self.client.chat.completions.create( # type: ignore
|
|
204
|
+
model=self._model_name,
|
|
205
|
+
messages=hf_messages, # type: ignore
|
|
206
|
+
tools=tools,
|
|
207
|
+
tool_choice=tool_choice or None,
|
|
208
|
+
stream=stream,
|
|
209
|
+
stop=model_settings.get('stop_sequences', None),
|
|
210
|
+
temperature=model_settings.get('temperature', None),
|
|
211
|
+
top_p=model_settings.get('top_p', None),
|
|
212
|
+
seed=model_settings.get('seed', None),
|
|
213
|
+
presence_penalty=model_settings.get('presence_penalty', None),
|
|
214
|
+
frequency_penalty=model_settings.get('frequency_penalty', None),
|
|
215
|
+
logit_bias=model_settings.get('logit_bias', None), # type: ignore
|
|
216
|
+
logprobs=model_settings.get('logprobs', None),
|
|
217
|
+
top_logprobs=model_settings.get('top_logprobs', None),
|
|
218
|
+
extra_body=model_settings.get('extra_body'), # type: ignore
|
|
219
|
+
)
|
|
220
|
+
except aiohttp.ClientResponseError as e:
|
|
221
|
+
raise ModelHTTPError(
|
|
222
|
+
status_code=e.status,
|
|
223
|
+
model_name=self.model_name,
|
|
224
|
+
body=e.response_error_payload, # type: ignore
|
|
225
|
+
) from e
|
|
226
|
+
except HfHubHTTPError as e:
|
|
227
|
+
raise ModelHTTPError(
|
|
228
|
+
status_code=e.response.status_code,
|
|
229
|
+
model_name=self.model_name,
|
|
230
|
+
body=e.response.content,
|
|
231
|
+
) from e
|
|
232
|
+
|
|
233
|
+
def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
|
|
234
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
235
|
+
if response.created:
|
|
236
|
+
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
237
|
+
else:
|
|
238
|
+
timestamp = _now_utc()
|
|
239
|
+
|
|
240
|
+
choice = response.choices[0]
|
|
241
|
+
content = choice.message.content
|
|
242
|
+
tool_calls = choice.message.tool_calls
|
|
243
|
+
|
|
244
|
+
items: list[ModelResponsePart] = []
|
|
245
|
+
|
|
246
|
+
if content is not None:
|
|
247
|
+
items.extend(split_content_into_text_and_thinking(content))
|
|
248
|
+
if tool_calls is not None:
|
|
249
|
+
for c in tool_calls:
|
|
250
|
+
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
251
|
+
return ModelResponse(
|
|
252
|
+
items,
|
|
253
|
+
usage=_map_usage(response),
|
|
254
|
+
model_name=response.model,
|
|
255
|
+
timestamp=timestamp,
|
|
256
|
+
vendor_id=response.id,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse:
|
|
260
|
+
"""Process a streamed response, and prepare a streaming response to return."""
|
|
261
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
262
|
+
first_chunk = await peekable_response.peek()
|
|
263
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
264
|
+
raise UnexpectedModelBehavior( # pragma: no cover
|
|
265
|
+
'Streamed response ended without content or tool calls'
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
return HuggingFaceStreamedResponse(
|
|
269
|
+
_model_name=self._model_name,
|
|
270
|
+
_response=peekable_response,
|
|
271
|
+
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
|
|
275
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
276
|
+
if model_request_parameters.output_tools:
|
|
277
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
278
|
+
return tools
|
|
279
|
+
|
|
280
|
+
async def _map_messages(
|
|
281
|
+
self, messages: list[ModelMessage]
|
|
282
|
+
) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
|
|
283
|
+
"""Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`."""
|
|
284
|
+
hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = []
|
|
285
|
+
for message in messages:
|
|
286
|
+
if isinstance(message, ModelRequest):
|
|
287
|
+
async for item in self._map_user_message(message):
|
|
288
|
+
hf_messages.append(item)
|
|
289
|
+
elif isinstance(message, ModelResponse):
|
|
290
|
+
texts: list[str] = []
|
|
291
|
+
tool_calls: list[ChatCompletionInputToolCall] = []
|
|
292
|
+
for item in message.parts:
|
|
293
|
+
if isinstance(item, TextPart):
|
|
294
|
+
texts.append(item.content)
|
|
295
|
+
elif isinstance(item, ToolCallPart):
|
|
296
|
+
tool_calls.append(self._map_tool_call(item))
|
|
297
|
+
elif isinstance(item, ThinkingPart):
|
|
298
|
+
# NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
|
|
299
|
+
# please open an issue. The below code is the code to send thinking to the provider.
|
|
300
|
+
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
301
|
+
pass
|
|
302
|
+
else:
|
|
303
|
+
assert_never(item)
|
|
304
|
+
message_param = ChatCompletionInputMessage(role='assistant') # type: ignore
|
|
305
|
+
if texts:
|
|
306
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
307
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
308
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
309
|
+
if tool_calls:
|
|
310
|
+
message_param['tool_calls'] = tool_calls
|
|
311
|
+
hf_messages.append(message_param)
|
|
312
|
+
else:
|
|
313
|
+
assert_never(message)
|
|
314
|
+
if instructions := self._get_instructions(messages):
|
|
315
|
+
hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore
|
|
316
|
+
return hf_messages
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall:
|
|
320
|
+
return ChatCompletionInputToolCall.parse_obj_as_instance( # type: ignore
|
|
321
|
+
{
|
|
322
|
+
'id': _guard_tool_call_id(t=t),
|
|
323
|
+
'type': 'function',
|
|
324
|
+
'function': {
|
|
325
|
+
'name': t.tool_name,
|
|
326
|
+
'arguments': t.args_as_json_str(),
|
|
327
|
+
},
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
@staticmethod
|
|
332
|
+
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
|
|
333
|
+
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
|
|
334
|
+
{
|
|
335
|
+
'type': 'function',
|
|
336
|
+
'function': {
|
|
337
|
+
'name': f.name,
|
|
338
|
+
'description': f.description,
|
|
339
|
+
'parameters': f.parameters_json_schema,
|
|
340
|
+
},
|
|
341
|
+
}
|
|
342
|
+
)
|
|
343
|
+
if f.strict is not None:
|
|
344
|
+
tool_param['function']['strict'] = f.strict
|
|
345
|
+
return tool_param
|
|
346
|
+
|
|
347
|
+
async def _map_user_message(
|
|
348
|
+
self, message: ModelRequest
|
|
349
|
+
) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
|
|
350
|
+
for part in message.parts:
|
|
351
|
+
if isinstance(part, SystemPromptPart):
|
|
352
|
+
yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore
|
|
353
|
+
elif isinstance(part, UserPromptPart):
|
|
354
|
+
yield await self._map_user_prompt(part)
|
|
355
|
+
elif isinstance(part, ToolReturnPart):
|
|
356
|
+
yield ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore
|
|
357
|
+
{
|
|
358
|
+
'role': 'tool',
|
|
359
|
+
'tool_call_id': _guard_tool_call_id(t=part),
|
|
360
|
+
'content': part.model_response_str(),
|
|
361
|
+
}
|
|
362
|
+
)
|
|
363
|
+
elif isinstance(part, RetryPromptPart):
|
|
364
|
+
if part.tool_name is None:
|
|
365
|
+
yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore
|
|
366
|
+
{'role': 'user', 'content': part.model_response()}
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore
|
|
370
|
+
{
|
|
371
|
+
'role': 'tool',
|
|
372
|
+
'tool_call_id': _guard_tool_call_id(t=part),
|
|
373
|
+
'content': part.model_response(),
|
|
374
|
+
}
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
assert_never(part)
|
|
378
|
+
|
|
379
|
+
@staticmethod
|
|
380
|
+
async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage:
|
|
381
|
+
content: str | list[ChatCompletionInputMessage]
|
|
382
|
+
if isinstance(part.content, str):
|
|
383
|
+
content = part.content
|
|
384
|
+
else:
|
|
385
|
+
content = []
|
|
386
|
+
for item in part.content:
|
|
387
|
+
if isinstance(item, str):
|
|
388
|
+
content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore
|
|
389
|
+
elif isinstance(item, ImageUrl):
|
|
390
|
+
url = ChatCompletionInputURL(url=item.url) # type: ignore
|
|
391
|
+
content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
|
|
392
|
+
elif isinstance(item, BinaryContent):
|
|
393
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
394
|
+
if item.is_image:
|
|
395
|
+
url = ChatCompletionInputURL(url=f'data:{item.media_type};base64,{base64_encoded}') # type: ignore
|
|
396
|
+
content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
|
|
397
|
+
else: # pragma: no cover
|
|
398
|
+
raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
|
|
399
|
+
elif isinstance(item, AudioUrl):
|
|
400
|
+
raise NotImplementedError('AudioUrl is not supported for Hugging Face')
|
|
401
|
+
elif isinstance(item, DocumentUrl):
|
|
402
|
+
raise NotImplementedError('DocumentUrl is not supported for Hugging Face')
|
|
403
|
+
elif isinstance(item, VideoUrl):
|
|
404
|
+
raise NotImplementedError('VideoUrl is not supported for Hugging Face')
|
|
405
|
+
else:
|
|
406
|
+
assert_never(item)
|
|
407
|
+
return ChatCompletionInputMessage(role='user', content=content) # type: ignore
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
@dataclass
|
|
411
|
+
class HuggingFaceStreamedResponse(StreamedResponse):
|
|
412
|
+
"""Implementation of `StreamedResponse` for Hugging Face models."""
|
|
413
|
+
|
|
414
|
+
_model_name: str
|
|
415
|
+
_response: AsyncIterable[ChatCompletionStreamOutput]
|
|
416
|
+
_timestamp: datetime
|
|
417
|
+
|
|
418
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
419
|
+
async for chunk in self._response:
|
|
420
|
+
self._usage += _map_usage(chunk)
|
|
421
|
+
|
|
422
|
+
try:
|
|
423
|
+
choice = chunk.choices[0]
|
|
424
|
+
except IndexError:
|
|
425
|
+
continue
|
|
426
|
+
|
|
427
|
+
# Handle the text part of the response
|
|
428
|
+
content = choice.delta.content
|
|
429
|
+
if content is not None:
|
|
430
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
|
|
431
|
+
|
|
432
|
+
for dtc in choice.delta.tool_calls or []:
|
|
433
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
434
|
+
vendor_part_id=dtc.index,
|
|
435
|
+
tool_name=dtc.function and dtc.function.name, # type: ignore
|
|
436
|
+
args=dtc.function and dtc.function.arguments,
|
|
437
|
+
tool_call_id=dtc.id,
|
|
438
|
+
)
|
|
439
|
+
if maybe_event is not None:
|
|
440
|
+
yield maybe_event
|
|
441
|
+
|
|
442
|
+
@property
|
|
443
|
+
def model_name(self) -> str:
|
|
444
|
+
"""Get the model name of the response."""
|
|
445
|
+
return self._model_name
|
|
446
|
+
|
|
447
|
+
@property
|
|
448
|
+
def timestamp(self) -> datetime:
|
|
449
|
+
"""Get the timestamp of the response."""
|
|
450
|
+
return self._timestamp
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage:
|
|
454
|
+
response_usage = response.usage
|
|
455
|
+
if response_usage is None:
|
|
456
|
+
return usage.Usage()
|
|
457
|
+
|
|
458
|
+
return usage.Usage(
|
|
459
|
+
request_tokens=response_usage.prompt_tokens,
|
|
460
|
+
response_tokens=response_usage.completion_tokens,
|
|
461
|
+
total_tokens=response_usage.total_tokens,
|
|
462
|
+
details=None,
|
|
463
|
+
)
|
|
@@ -138,7 +138,7 @@ class InstrumentationSettings:
|
|
|
138
138
|
**tokens_histogram_kwargs,
|
|
139
139
|
explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES,
|
|
140
140
|
)
|
|
141
|
-
except TypeError:
|
|
141
|
+
except TypeError:
|
|
142
142
|
# Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory
|
|
143
143
|
self.tokens_histogram = self.meter.create_histogram(
|
|
144
144
|
**tokens_histogram_kwargs, # pyright: ignore
|
|
@@ -182,15 +182,15 @@ GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system'
|
|
|
182
182
|
GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model'
|
|
183
183
|
|
|
184
184
|
|
|
185
|
-
@dataclass
|
|
185
|
+
@dataclass(init=False)
|
|
186
186
|
class InstrumentedModel(WrapperModel):
|
|
187
187
|
"""Model which wraps another model so that requests are instrumented with OpenTelemetry.
|
|
188
188
|
|
|
189
189
|
See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
|
|
190
190
|
"""
|
|
191
191
|
|
|
192
|
-
|
|
193
|
-
"""
|
|
192
|
+
instrumentation_settings: InstrumentationSettings
|
|
193
|
+
"""Instrumentation settings for this model."""
|
|
194
194
|
|
|
195
195
|
def __init__(
|
|
196
196
|
self,
|
|
@@ -198,7 +198,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
198
198
|
options: InstrumentationSettings | None = None,
|
|
199
199
|
) -> None:
|
|
200
200
|
super().__init__(wrapped)
|
|
201
|
-
self.
|
|
201
|
+
self.instrumentation_settings = options or InstrumentationSettings()
|
|
202
202
|
|
|
203
203
|
async def request(
|
|
204
204
|
self,
|
|
@@ -260,7 +260,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
260
260
|
|
|
261
261
|
record_metrics: Callable[[], None] | None = None
|
|
262
262
|
try:
|
|
263
|
-
with self.
|
|
263
|
+
with self.instrumentation_settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
264
264
|
|
|
265
265
|
def finish(response: ModelResponse):
|
|
266
266
|
# FallbackModel updates these span attributes.
|
|
@@ -278,12 +278,12 @@ class InstrumentedModel(WrapperModel):
|
|
|
278
278
|
'gen_ai.response.model': response_model,
|
|
279
279
|
}
|
|
280
280
|
if response.usage.request_tokens: # pragma: no branch
|
|
281
|
-
self.
|
|
281
|
+
self.instrumentation_settings.tokens_histogram.record(
|
|
282
282
|
response.usage.request_tokens,
|
|
283
283
|
{**metric_attributes, 'gen_ai.token.type': 'input'},
|
|
284
284
|
)
|
|
285
285
|
if response.usage.response_tokens: # pragma: no branch
|
|
286
|
-
self.
|
|
286
|
+
self.instrumentation_settings.tokens_histogram.record(
|
|
287
287
|
response.usage.response_tokens,
|
|
288
288
|
{**metric_attributes, 'gen_ai.token.type': 'output'},
|
|
289
289
|
)
|
|
@@ -294,8 +294,8 @@ class InstrumentedModel(WrapperModel):
|
|
|
294
294
|
if not span.is_recording():
|
|
295
295
|
return
|
|
296
296
|
|
|
297
|
-
events = self.
|
|
298
|
-
for event in self.
|
|
297
|
+
events = self.instrumentation_settings.messages_to_otel_events(messages)
|
|
298
|
+
for event in self.instrumentation_settings.messages_to_otel_events([response]):
|
|
299
299
|
events.append(
|
|
300
300
|
Event(
|
|
301
301
|
'gen_ai.choice',
|
|
@@ -328,9 +328,9 @@ class InstrumentedModel(WrapperModel):
|
|
|
328
328
|
record_metrics()
|
|
329
329
|
|
|
330
330
|
def _emit_events(self, span: Span, events: list[Event]) -> None:
|
|
331
|
-
if self.
|
|
331
|
+
if self.instrumentation_settings.event_mode == 'logs':
|
|
332
332
|
for event in events:
|
|
333
|
-
self.
|
|
333
|
+
self.instrumentation_settings.event_logger.emit(event)
|
|
334
334
|
else:
|
|
335
335
|
attr_name = 'events'
|
|
336
336
|
span.set_attributes(
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -75,7 +75,7 @@ try:
|
|
|
75
75
|
from mistralai.models.usermessage import UserMessage as MistralUserMessage
|
|
76
76
|
from mistralai.types.basemodel import Unset as MistralUnset
|
|
77
77
|
from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
|
|
78
|
-
except ImportError as e: # pragma:
|
|
78
|
+
except ImportError as e: # pragma: no cover
|
|
79
79
|
raise ImportError(
|
|
80
80
|
'Please install `mistral` to use the Mistral model, '
|
|
81
81
|
'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
|
|
@@ -125,6 +125,7 @@ class MistralModel(Model):
|
|
|
125
125
|
provider: Literal['mistral'] | Provider[Mistral] = 'mistral',
|
|
126
126
|
profile: ModelProfileSpec | None = None,
|
|
127
127
|
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
128
|
+
settings: ModelSettings | None = None,
|
|
128
129
|
):
|
|
129
130
|
"""Initialize a Mistral model.
|
|
130
131
|
|
|
@@ -135,6 +136,7 @@ class MistralModel(Model):
|
|
|
135
136
|
created using the other parameters.
|
|
136
137
|
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
137
138
|
json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
|
|
139
|
+
settings: Model-specific settings that will be used as defaults for this model.
|
|
138
140
|
"""
|
|
139
141
|
self._model_name = model_name
|
|
140
142
|
self.json_mode_schema_prompt = json_mode_schema_prompt
|
|
@@ -142,7 +144,8 @@ class MistralModel(Model):
|
|
|
142
144
|
if isinstance(provider, str):
|
|
143
145
|
provider = infer_provider(provider)
|
|
144
146
|
self.client = provider.client
|
|
145
|
-
|
|
147
|
+
|
|
148
|
+
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
146
149
|
|
|
147
150
|
@property
|
|
148
151
|
def base_url(self) -> str:
|
|
@@ -214,7 +217,7 @@ class MistralModel(Model):
|
|
|
214
217
|
except SDKError as e:
|
|
215
218
|
if (status_code := e.status_code) >= 400:
|
|
216
219
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
217
|
-
raise # pragma:
|
|
220
|
+
raise # pragma: no cover
|
|
218
221
|
|
|
219
222
|
assert response, 'A unexpected empty response from Mistral.'
|
|
220
223
|
return response
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -195,6 +195,7 @@ class OpenAIModel(Model):
|
|
|
195
195
|
| Provider[AsyncOpenAI] = 'openai',
|
|
196
196
|
profile: ModelProfileSpec | None = None,
|
|
197
197
|
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
198
|
+
settings: ModelSettings | None = None,
|
|
198
199
|
):
|
|
199
200
|
"""Initialize an OpenAI model.
|
|
200
201
|
|
|
@@ -206,16 +207,18 @@ class OpenAIModel(Model):
|
|
|
206
207
|
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
207
208
|
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
208
209
|
In the future, this may be inferred from the model name.
|
|
210
|
+
settings: Default model settings for this model instance.
|
|
209
211
|
"""
|
|
210
212
|
self._model_name = model_name
|
|
211
213
|
|
|
212
214
|
if isinstance(provider, str):
|
|
213
215
|
provider = infer_provider(provider)
|
|
214
216
|
self.client = provider.client
|
|
215
|
-
self._profile = profile or provider.model_profile
|
|
216
217
|
|
|
217
218
|
self.system_prompt_role = system_prompt_role
|
|
218
219
|
|
|
220
|
+
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
221
|
+
|
|
219
222
|
@property
|
|
220
223
|
def base_url(self) -> str:
|
|
221
224
|
return str(self.client.base_url)
|
|
@@ -342,7 +345,7 @@ class OpenAIModel(Model):
|
|
|
342
345
|
except APIStatusError as e:
|
|
343
346
|
if (status_code := e.status_code) >= 400:
|
|
344
347
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
345
|
-
raise # pragma:
|
|
348
|
+
raise # pragma: no cover
|
|
346
349
|
|
|
347
350
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
348
351
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -598,6 +601,7 @@ class OpenAIResponsesModel(Model):
|
|
|
598
601
|
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
|
|
599
602
|
| Provider[AsyncOpenAI] = 'openai',
|
|
600
603
|
profile: ModelProfileSpec | None = None,
|
|
604
|
+
settings: ModelSettings | None = None,
|
|
601
605
|
):
|
|
602
606
|
"""Initialize an OpenAI Responses model.
|
|
603
607
|
|
|
@@ -605,13 +609,15 @@ class OpenAIResponsesModel(Model):
|
|
|
605
609
|
model_name: The name of the OpenAI model to use.
|
|
606
610
|
provider: The provider to use. Defaults to `'openai'`.
|
|
607
611
|
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
612
|
+
settings: Default model settings for this model instance.
|
|
608
613
|
"""
|
|
609
614
|
self._model_name = model_name
|
|
610
615
|
|
|
611
616
|
if isinstance(provider, str):
|
|
612
617
|
provider = infer_provider(provider)
|
|
613
618
|
self.client = provider.client
|
|
614
|
-
|
|
619
|
+
|
|
620
|
+
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
615
621
|
|
|
616
622
|
@property
|
|
617
623
|
def model_name(self) -> OpenAIModelName:
|
|
@@ -775,7 +781,7 @@ class OpenAIResponsesModel(Model):
|
|
|
775
781
|
except APIStatusError as e:
|
|
776
782
|
if (status_code := e.status_code) >= 400:
|
|
777
783
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
778
|
-
raise # pragma:
|
|
784
|
+
raise # pragma: no cover
|
|
779
785
|
|
|
780
786
|
def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven:
|
|
781
787
|
reasoning_effort = model_settings.get('openai_reasoning_effort', None)
|
|
@@ -988,6 +994,12 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
988
994
|
if content is not None:
|
|
989
995
|
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
|
|
990
996
|
|
|
997
|
+
# Handle reasoning part of the response, present in DeepSeek models
|
|
998
|
+
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
|
|
999
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
1000
|
+
vendor_part_id='reasoning_content', content=reasoning_content
|
|
1001
|
+
)
|
|
1002
|
+
|
|
991
1003
|
for dtc in choice.delta.tool_calls or []:
|
|
992
1004
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
993
1005
|
vendor_part_id=dtc.index,
|