pydantic-ai-slim 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +66 -56
- pydantic_ai/_parts_manager.py +5 -4
- pydantic_ai/_tool_manager.py +50 -29
- pydantic_ai/agent/__init__.py +62 -75
- pydantic_ai/models/__init__.py +28 -0
- pydantic_ai/models/anthropic.py +20 -20
- pydantic_ai/models/bedrock.py +1 -1
- pydantic_ai/models/fallback.py +7 -2
- pydantic_ai/models/google.py +66 -6
- pydantic_ai/models/groq.py +1 -0
- pydantic_ai/models/huggingface.py +9 -2
- pydantic_ai/models/openai.py +31 -5
- pydantic_ai/profiles/__init__.py +10 -1
- pydantic_ai/profiles/deepseek.py +1 -1
- pydantic_ai/profiles/moonshotai.py +1 -1
- pydantic_ai/profiles/qwen.py +4 -1
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/huggingface.py +27 -0
- pydantic_ai/providers/ollama.py +105 -0
- pydantic_ai/providers/openrouter.py +2 -0
- pydantic_ai/result.py +1 -1
- pydantic_ai/tools.py +9 -9
- pydantic_ai/usage.py +17 -1
- {pydantic_ai_slim-0.7.0.dist-info → pydantic_ai_slim-0.7.2.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-0.7.0.dist-info → pydantic_ai_slim-0.7.2.dist-info}/RECORD +28 -27
- {pydantic_ai_slim-0.7.0.dist-info → pydantic_ai_slim-0.7.2.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.0.dist-info → pydantic_ai_slim-0.7.2.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.0.dist-info → pydantic_ai_slim-0.7.2.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -8,14 +8,6 @@ from dataclasses import dataclass, field
|
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
9
|
from typing import Any, Literal, Union, cast, overload
|
|
10
10
|
|
|
11
|
-
from anthropic.types.beta import (
|
|
12
|
-
BetaCitationsDelta,
|
|
13
|
-
BetaCodeExecutionToolResultBlock,
|
|
14
|
-
BetaCodeExecutionToolResultBlockParam,
|
|
15
|
-
BetaInputJSONDelta,
|
|
16
|
-
BetaServerToolUseBlockParam,
|
|
17
|
-
BetaWebSearchToolResultBlockParam,
|
|
18
|
-
)
|
|
19
11
|
from typing_extensions import assert_never
|
|
20
12
|
|
|
21
13
|
from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
|
|
@@ -47,24 +39,21 @@ from ..profiles import ModelProfileSpec
|
|
|
47
39
|
from ..providers import Provider, infer_provider
|
|
48
40
|
from ..settings import ModelSettings
|
|
49
41
|
from ..tools import ToolDefinition
|
|
50
|
-
from . import
|
|
51
|
-
Model,
|
|
52
|
-
ModelRequestParameters,
|
|
53
|
-
StreamedResponse,
|
|
54
|
-
check_allow_model_requests,
|
|
55
|
-
download_item,
|
|
56
|
-
get_user_agent,
|
|
57
|
-
)
|
|
42
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
|
|
58
43
|
|
|
59
44
|
try:
|
|
60
45
|
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
|
|
61
46
|
from anthropic.types.beta import (
|
|
62
47
|
BetaBase64PDFBlockParam,
|
|
63
48
|
BetaBase64PDFSourceParam,
|
|
49
|
+
BetaCitationsDelta,
|
|
64
50
|
BetaCodeExecutionTool20250522Param,
|
|
51
|
+
BetaCodeExecutionToolResultBlock,
|
|
52
|
+
BetaCodeExecutionToolResultBlockParam,
|
|
65
53
|
BetaContentBlock,
|
|
66
54
|
BetaContentBlockParam,
|
|
67
55
|
BetaImageBlockParam,
|
|
56
|
+
BetaInputJSONDelta,
|
|
68
57
|
BetaMessage,
|
|
69
58
|
BetaMessageParam,
|
|
70
59
|
BetaMetadataParam,
|
|
@@ -78,6 +67,7 @@ try:
|
|
|
78
67
|
BetaRawMessageStreamEvent,
|
|
79
68
|
BetaRedactedThinkingBlock,
|
|
80
69
|
BetaServerToolUseBlock,
|
|
70
|
+
BetaServerToolUseBlockParam,
|
|
81
71
|
BetaSignatureDelta,
|
|
82
72
|
BetaTextBlock,
|
|
83
73
|
BetaTextBlockParam,
|
|
@@ -94,6 +84,7 @@ try:
|
|
|
94
84
|
BetaToolUseBlockParam,
|
|
95
85
|
BetaWebSearchTool20250305Param,
|
|
96
86
|
BetaWebSearchToolResultBlock,
|
|
87
|
+
BetaWebSearchToolResultBlockParam,
|
|
97
88
|
)
|
|
98
89
|
from anthropic.types.beta.beta_web_search_tool_20250305_param import UserLocation
|
|
99
90
|
from anthropic.types.model_param import ModelParam
|
|
@@ -246,7 +237,9 @@ class AnthropicModel(Model):
|
|
|
246
237
|
) -> BetaMessage | AsyncStream[BetaRawMessageStreamEvent]:
|
|
247
238
|
# standalone function to make it easier to override
|
|
248
239
|
tools = self._get_tools(model_request_parameters)
|
|
249
|
-
|
|
240
|
+
builtin_tools, tool_headers = self._get_builtin_tools(model_request_parameters)
|
|
241
|
+
tools += builtin_tools
|
|
242
|
+
|
|
250
243
|
tool_choice: BetaToolChoiceParam | None
|
|
251
244
|
|
|
252
245
|
if not tools:
|
|
@@ -264,8 +257,10 @@ class AnthropicModel(Model):
|
|
|
264
257
|
|
|
265
258
|
try:
|
|
266
259
|
extra_headers = model_settings.get('extra_headers', {})
|
|
260
|
+
for k, v in tool_headers.items():
|
|
261
|
+
extra_headers.setdefault(k, v)
|
|
267
262
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
268
|
-
|
|
263
|
+
|
|
269
264
|
return await self.client.beta.messages.create(
|
|
270
265
|
max_tokens=model_settings.get('max_tokens', 4096),
|
|
271
266
|
system=system_prompt or NOT_GIVEN,
|
|
@@ -352,8 +347,11 @@ class AnthropicModel(Model):
|
|
|
352
347
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]:
|
|
353
348
|
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
354
349
|
|
|
355
|
-
def _get_builtin_tools(
|
|
350
|
+
def _get_builtin_tools(
|
|
351
|
+
self, model_request_parameters: ModelRequestParameters
|
|
352
|
+
) -> tuple[list[BetaToolUnionParam], dict[str, str]]:
|
|
356
353
|
tools: list[BetaToolUnionParam] = []
|
|
354
|
+
extra_headers: dict[str, str] = {}
|
|
357
355
|
for tool in model_request_parameters.builtin_tools:
|
|
358
356
|
if isinstance(tool, WebSearchTool):
|
|
359
357
|
user_location = UserLocation(type='approximate', **tool.user_location) if tool.user_location else None
|
|
@@ -361,18 +359,20 @@ class AnthropicModel(Model):
|
|
|
361
359
|
BetaWebSearchTool20250305Param(
|
|
362
360
|
name='web_search',
|
|
363
361
|
type='web_search_20250305',
|
|
362
|
+
max_uses=tool.max_uses,
|
|
364
363
|
allowed_domains=tool.allowed_domains,
|
|
365
364
|
blocked_domains=tool.blocked_domains,
|
|
366
365
|
user_location=user_location,
|
|
367
366
|
)
|
|
368
367
|
)
|
|
369
368
|
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
369
|
+
extra_headers['anthropic-beta'] = 'code-execution-2025-05-22'
|
|
370
370
|
tools.append(BetaCodeExecutionTool20250522Param(name='code_execution', type='code_execution_20250522'))
|
|
371
371
|
else: # pragma: no cover
|
|
372
372
|
raise UserError(
|
|
373
373
|
f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
|
|
374
374
|
)
|
|
375
|
-
return tools
|
|
375
|
+
return tools, extra_headers
|
|
376
376
|
|
|
377
377
|
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
|
|
378
378
|
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -648,7 +648,7 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
648
648
|
)
|
|
649
649
|
if 'text' in delta:
|
|
650
650
|
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
|
|
651
|
-
if maybe_event is not None:
|
|
651
|
+
if maybe_event is not None: # pragma: no branch
|
|
652
652
|
yield maybe_event
|
|
653
653
|
if 'toolUse' in delta:
|
|
654
654
|
tool_use = delta['toolUse']
|
pydantic_ai/models/fallback.py
CHANGED
|
@@ -11,6 +11,7 @@ from pydantic_ai._run_context import RunContext
|
|
|
11
11
|
from pydantic_ai.models.instrumented import InstrumentedModel
|
|
12
12
|
|
|
13
13
|
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
|
|
14
|
+
from ..settings import merge_model_settings
|
|
14
15
|
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
@@ -65,8 +66,9 @@ class FallbackModel(Model):
|
|
|
65
66
|
|
|
66
67
|
for model in self.models:
|
|
67
68
|
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
|
|
69
|
+
merged_settings = merge_model_settings(model.settings, model_settings)
|
|
68
70
|
try:
|
|
69
|
-
response = await model.request(messages,
|
|
71
|
+
response = await model.request(messages, merged_settings, customized_model_request_parameters)
|
|
70
72
|
except Exception as exc:
|
|
71
73
|
if self._fallback_on(exc):
|
|
72
74
|
exceptions.append(exc)
|
|
@@ -91,10 +93,13 @@ class FallbackModel(Model):
|
|
|
91
93
|
|
|
92
94
|
for model in self.models:
|
|
93
95
|
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
|
|
96
|
+
merged_settings = merge_model_settings(model.settings, model_settings)
|
|
94
97
|
async with AsyncExitStack() as stack:
|
|
95
98
|
try:
|
|
96
99
|
response = await stack.enter_async_context(
|
|
97
|
-
model.request_stream(
|
|
100
|
+
model.request_stream(
|
|
101
|
+
messages, merged_settings, customized_model_request_parameters, run_context
|
|
102
|
+
)
|
|
98
103
|
)
|
|
99
104
|
except Exception as exc:
|
|
100
105
|
if self._fallback_on(exc):
|
pydantic_ai/models/google.py
CHANGED
|
@@ -52,6 +52,7 @@ try:
|
|
|
52
52
|
from google.genai.types import (
|
|
53
53
|
ContentDict,
|
|
54
54
|
ContentUnionDict,
|
|
55
|
+
CountTokensConfigDict,
|
|
55
56
|
ExecutableCodeDict,
|
|
56
57
|
FunctionCallDict,
|
|
57
58
|
FunctionCallingConfigDict,
|
|
@@ -59,6 +60,7 @@ try:
|
|
|
59
60
|
FunctionDeclarationDict,
|
|
60
61
|
GenerateContentConfigDict,
|
|
61
62
|
GenerateContentResponse,
|
|
63
|
+
GenerationConfigDict,
|
|
62
64
|
GoogleSearchDict,
|
|
63
65
|
HttpOptionsDict,
|
|
64
66
|
MediaResolution,
|
|
@@ -188,6 +190,59 @@ class GoogleModel(Model):
|
|
|
188
190
|
response = await self._generate_content(messages, False, model_settings, model_request_parameters)
|
|
189
191
|
return self._process_response(response)
|
|
190
192
|
|
|
193
|
+
async def count_tokens(
|
|
194
|
+
self,
|
|
195
|
+
messages: list[ModelMessage],
|
|
196
|
+
model_settings: ModelSettings | None,
|
|
197
|
+
model_request_parameters: ModelRequestParameters,
|
|
198
|
+
) -> usage.Usage:
|
|
199
|
+
check_allow_model_requests()
|
|
200
|
+
model_settings = cast(GoogleModelSettings, model_settings or {})
|
|
201
|
+
contents, generation_config = await self._build_content_and_config(
|
|
202
|
+
messages, model_settings, model_request_parameters
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Annoyingly, the type of `GenerateContentConfigDict.get` is "partially `Unknown`" because `response_schema` includes `typing._UnionGenericAlias`,
|
|
206
|
+
# so without this we'd need `pyright: ignore[reportUnknownMemberType]` on every line and wouldn't get type checking anyway.
|
|
207
|
+
generation_config = cast(dict[str, Any], generation_config)
|
|
208
|
+
|
|
209
|
+
config = CountTokensConfigDict(
|
|
210
|
+
http_options=generation_config.get('http_options'),
|
|
211
|
+
)
|
|
212
|
+
if self.system != 'google-gla':
|
|
213
|
+
# The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
|
|
214
|
+
config.update(
|
|
215
|
+
system_instruction=generation_config.get('system_instruction'),
|
|
216
|
+
tools=cast(list[ToolDict], generation_config.get('tools')),
|
|
217
|
+
# Annoyingly, GenerationConfigDict has fewer fields than GenerateContentConfigDict, and no extra fields are allowed.
|
|
218
|
+
generation_config=GenerationConfigDict(
|
|
219
|
+
temperature=generation_config.get('temperature'),
|
|
220
|
+
top_p=generation_config.get('top_p'),
|
|
221
|
+
max_output_tokens=generation_config.get('max_output_tokens'),
|
|
222
|
+
stop_sequences=generation_config.get('stop_sequences'),
|
|
223
|
+
presence_penalty=generation_config.get('presence_penalty'),
|
|
224
|
+
frequency_penalty=generation_config.get('frequency_penalty'),
|
|
225
|
+
thinking_config=generation_config.get('thinking_config'),
|
|
226
|
+
media_resolution=generation_config.get('media_resolution'),
|
|
227
|
+
response_mime_type=generation_config.get('response_mime_type'),
|
|
228
|
+
response_schema=generation_config.get('response_schema'),
|
|
229
|
+
),
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
response = await self.client.aio.models.count_tokens(
|
|
233
|
+
model=self._model_name,
|
|
234
|
+
contents=contents,
|
|
235
|
+
config=config,
|
|
236
|
+
)
|
|
237
|
+
if response.total_tokens is None:
|
|
238
|
+
raise UnexpectedModelBehavior( # pragma: no cover
|
|
239
|
+
'Total tokens missing from Gemini response', str(response)
|
|
240
|
+
)
|
|
241
|
+
return usage.Usage(
|
|
242
|
+
request_tokens=response.total_tokens,
|
|
243
|
+
total_tokens=response.total_tokens,
|
|
244
|
+
)
|
|
245
|
+
|
|
191
246
|
@asynccontextmanager
|
|
192
247
|
async def request_stream(
|
|
193
248
|
self,
|
|
@@ -265,16 +320,23 @@ class GoogleModel(Model):
|
|
|
265
320
|
model_settings: GoogleModelSettings,
|
|
266
321
|
model_request_parameters: ModelRequestParameters,
|
|
267
322
|
) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]:
|
|
268
|
-
|
|
323
|
+
contents, config = await self._build_content_and_config(messages, model_settings, model_request_parameters)
|
|
324
|
+
func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
|
|
325
|
+
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
|
|
269
326
|
|
|
327
|
+
async def _build_content_and_config(
|
|
328
|
+
self,
|
|
329
|
+
messages: list[ModelMessage],
|
|
330
|
+
model_settings: GoogleModelSettings,
|
|
331
|
+
model_request_parameters: ModelRequestParameters,
|
|
332
|
+
) -> tuple[list[ContentUnionDict], GenerateContentConfigDict]:
|
|
333
|
+
tools = self._get_tools(model_request_parameters)
|
|
270
334
|
response_mime_type = None
|
|
271
335
|
response_schema = None
|
|
272
336
|
if model_request_parameters.output_mode == 'native':
|
|
273
337
|
if tools:
|
|
274
338
|
raise UserError('Gemini does not support structured output and tools at the same time.')
|
|
275
|
-
|
|
276
339
|
response_mime_type = 'application/json'
|
|
277
|
-
|
|
278
340
|
output_object = model_request_parameters.output_object
|
|
279
341
|
assert output_object is not None
|
|
280
342
|
response_schema = self._map_response_schema(output_object)
|
|
@@ -311,9 +373,7 @@ class GoogleModel(Model):
|
|
|
311
373
|
response_mime_type=response_mime_type,
|
|
312
374
|
response_schema=response_schema,
|
|
313
375
|
)
|
|
314
|
-
|
|
315
|
-
func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
|
|
316
|
-
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
|
|
376
|
+
return contents, config
|
|
317
377
|
|
|
318
378
|
def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
|
|
319
379
|
if not response.candidates or len(response.candidates) != 1:
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -457,6 +457,7 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
457
457
|
vendor_part_id='content',
|
|
458
458
|
content=content,
|
|
459
459
|
thinking_tags=self._model_profile.thinking_tags,
|
|
460
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
460
461
|
)
|
|
461
462
|
if maybe_event is not None: # pragma: no branch
|
|
462
463
|
yield maybe_event
|
|
@@ -35,7 +35,7 @@ from ..messages import (
|
|
|
35
35
|
UserPromptPart,
|
|
36
36
|
VideoUrl,
|
|
37
37
|
)
|
|
38
|
-
from ..profiles import ModelProfile
|
|
38
|
+
from ..profiles import ModelProfile, ModelProfileSpec
|
|
39
39
|
from ..providers import Provider, infer_provider
|
|
40
40
|
from ..settings import ModelSettings
|
|
41
41
|
from ..tools import ToolDefinition
|
|
@@ -121,6 +121,8 @@ class HuggingFaceModel(Model):
|
|
|
121
121
|
model_name: str,
|
|
122
122
|
*,
|
|
123
123
|
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
|
|
124
|
+
profile: ModelProfileSpec | None = None,
|
|
125
|
+
settings: ModelSettings | None = None,
|
|
124
126
|
):
|
|
125
127
|
"""Initialize a Hugging Face model.
|
|
126
128
|
|
|
@@ -128,6 +130,8 @@ class HuggingFaceModel(Model):
|
|
|
128
130
|
model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
129
131
|
provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an
|
|
130
132
|
instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used.
|
|
133
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
134
|
+
settings: Model-specific settings that will be used as defaults for this model.
|
|
131
135
|
"""
|
|
132
136
|
self._model_name = model_name
|
|
133
137
|
self._provider = provider
|
|
@@ -135,6 +139,8 @@ class HuggingFaceModel(Model):
|
|
|
135
139
|
provider = infer_provider(provider)
|
|
136
140
|
self.client = provider.client
|
|
137
141
|
|
|
142
|
+
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
143
|
+
|
|
138
144
|
async def request(
|
|
139
145
|
self,
|
|
140
146
|
messages: list[ModelMessage],
|
|
@@ -444,11 +450,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
444
450
|
|
|
445
451
|
# Handle the text part of the response
|
|
446
452
|
content = choice.delta.content
|
|
447
|
-
if content:
|
|
453
|
+
if content is not None:
|
|
448
454
|
maybe_event = self._parts_manager.handle_text_delta(
|
|
449
455
|
vendor_part_id='content',
|
|
450
456
|
content=content,
|
|
451
457
|
thinking_tags=self._model_profile.thinking_tags,
|
|
458
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
452
459
|
)
|
|
453
460
|
if maybe_event is not None: # pragma: no branch
|
|
454
461
|
yield maybe_event
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -59,6 +59,11 @@ try:
|
|
|
59
59
|
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
60
60
|
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
|
|
61
61
|
from openai.types.chat.chat_completion_content_part_param import File, FileFile
|
|
62
|
+
from openai.types.chat.chat_completion_message_custom_tool_call import ChatCompletionMessageCustomToolCall
|
|
63
|
+
from openai.types.chat.chat_completion_message_function_tool_call import ChatCompletionMessageFunctionToolCall
|
|
64
|
+
from openai.types.chat.chat_completion_message_function_tool_call_param import (
|
|
65
|
+
ChatCompletionMessageFunctionToolCallParam,
|
|
66
|
+
)
|
|
62
67
|
from openai.types.chat.chat_completion_prediction_content_param import ChatCompletionPredictionContentParam
|
|
63
68
|
from openai.types.chat.completion_create_params import (
|
|
64
69
|
WebSearchOptions,
|
|
@@ -172,6 +177,14 @@ class OpenAIResponsesModelSettings(OpenAIModelSettings, total=False):
|
|
|
172
177
|
middle of the conversation.
|
|
173
178
|
"""
|
|
174
179
|
|
|
180
|
+
openai_text_verbosity: Literal['low', 'medium', 'high']
|
|
181
|
+
"""Constrains the verbosity of the model's text response.
|
|
182
|
+
|
|
183
|
+
Lower values will result in more concise responses, while higher values will
|
|
184
|
+
result in more verbose responses. Currently supported values are `low`,
|
|
185
|
+
`medium`, and `high`.
|
|
186
|
+
"""
|
|
187
|
+
|
|
175
188
|
|
|
176
189
|
@dataclass(init=False)
|
|
177
190
|
class OpenAIModel(Model):
|
|
@@ -204,6 +217,7 @@ class OpenAIModel(Model):
|
|
|
204
217
|
'together',
|
|
205
218
|
'heroku',
|
|
206
219
|
'github',
|
|
220
|
+
'ollama',
|
|
207
221
|
]
|
|
208
222
|
| Provider[AsyncOpenAI] = 'openai',
|
|
209
223
|
profile: ModelProfileSpec | None = None,
|
|
@@ -416,7 +430,14 @@ class OpenAIModel(Model):
|
|
|
416
430
|
items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
|
|
417
431
|
if choice.message.tool_calls is not None:
|
|
418
432
|
for c in choice.message.tool_calls:
|
|
419
|
-
|
|
433
|
+
if isinstance(c, ChatCompletionMessageFunctionToolCall):
|
|
434
|
+
part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)
|
|
435
|
+
elif isinstance(c, ChatCompletionMessageCustomToolCall): # pragma: no cover
|
|
436
|
+
# NOTE: Custom tool calls are not supported.
|
|
437
|
+
# See <https://github.com/pydantic/pydantic-ai/issues/2513> for more details.
|
|
438
|
+
raise RuntimeError('Custom tool calls are not supported')
|
|
439
|
+
else:
|
|
440
|
+
assert_never(c)
|
|
420
441
|
part.tool_call_id = _guard_tool_call_id(part)
|
|
421
442
|
items.append(part)
|
|
422
443
|
return ModelResponse(
|
|
@@ -476,7 +497,7 @@ class OpenAIModel(Model):
|
|
|
476
497
|
openai_messages.append(item)
|
|
477
498
|
elif isinstance(message, ModelResponse):
|
|
478
499
|
texts: list[str] = []
|
|
479
|
-
tool_calls: list[
|
|
500
|
+
tool_calls: list[ChatCompletionMessageFunctionToolCallParam] = []
|
|
480
501
|
for item in message.parts:
|
|
481
502
|
if isinstance(item, TextPart):
|
|
482
503
|
texts.append(item.content)
|
|
@@ -507,8 +528,8 @@ class OpenAIModel(Model):
|
|
|
507
528
|
return openai_messages
|
|
508
529
|
|
|
509
530
|
@staticmethod
|
|
510
|
-
def _map_tool_call(t: ToolCallPart) ->
|
|
511
|
-
return
|
|
531
|
+
def _map_tool_call(t: ToolCallPart) -> ChatCompletionMessageFunctionToolCallParam:
|
|
532
|
+
return ChatCompletionMessageFunctionToolCallParam(
|
|
512
533
|
id=_guard_tool_call_id(t=t),
|
|
513
534
|
type='function',
|
|
514
535
|
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
@@ -807,6 +828,10 @@ class OpenAIResponsesModel(Model):
|
|
|
807
828
|
openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions))
|
|
808
829
|
instructions = NOT_GIVEN
|
|
809
830
|
|
|
831
|
+
if verbosity := model_settings.get('openai_text_verbosity'):
|
|
832
|
+
text = text or {}
|
|
833
|
+
text['verbosity'] = verbosity
|
|
834
|
+
|
|
810
835
|
sampling_settings = (
|
|
811
836
|
model_settings
|
|
812
837
|
if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
|
|
@@ -1070,11 +1095,12 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1070
1095
|
|
|
1071
1096
|
# Handle the text part of the response
|
|
1072
1097
|
content = choice.delta.content
|
|
1073
|
-
if content:
|
|
1098
|
+
if content is not None:
|
|
1074
1099
|
maybe_event = self._parts_manager.handle_text_delta(
|
|
1075
1100
|
vendor_part_id='content',
|
|
1076
1101
|
content=content,
|
|
1077
1102
|
thinking_tags=self._model_profile.thinking_tags,
|
|
1103
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
1078
1104
|
)
|
|
1079
1105
|
if maybe_event is not None: # pragma: no branch
|
|
1080
1106
|
yield maybe_event
|
pydantic_ai/profiles/__init__.py
CHANGED
|
@@ -20,7 +20,7 @@ __all__ = [
|
|
|
20
20
|
|
|
21
21
|
@dataclass
|
|
22
22
|
class ModelProfile:
|
|
23
|
-
"""Describes how requests to
|
|
23
|
+
"""Describes how requests to and responses from specific models or families of models need to be constructed and processed to get the best results, independent of the model and provider classes used."""
|
|
24
24
|
|
|
25
25
|
supports_tools: bool = True
|
|
26
26
|
"""Whether the model supports tools."""
|
|
@@ -46,6 +46,15 @@ class ModelProfile:
|
|
|
46
46
|
thinking_tags: tuple[str, str] = ('<think>', '</think>')
|
|
47
47
|
"""The tags used to indicate thinking parts in the model's output. Defaults to ('<think>', '</think>')."""
|
|
48
48
|
|
|
49
|
+
ignore_streamed_leading_whitespace: bool = False
|
|
50
|
+
"""Whether to ignore leading whitespace when streaming a response.
|
|
51
|
+
|
|
52
|
+
This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
|
|
53
|
+
which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
|
|
54
|
+
|
|
55
|
+
This is currently only used by `OpenAIModel`, `HuggingFaceModel`, and `GroqModel`.
|
|
56
|
+
"""
|
|
57
|
+
|
|
49
58
|
@classmethod
|
|
50
59
|
def from_profile(cls, profile: ModelProfile | None) -> Self:
|
|
51
60
|
"""Build a ModelProfile subclass instance from a ModelProfile instance."""
|
pydantic_ai/profiles/deepseek.py
CHANGED
pydantic_ai/profiles/qwen.py
CHANGED
|
@@ -5,4 +5,7 @@ from . import InlineDefsJsonSchemaTransformer, ModelProfile
|
|
|
5
5
|
|
|
6
6
|
def qwen_model_profile(model_name: str) -> ModelProfile | None:
|
|
7
7
|
"""Get the model profile for a Qwen model."""
|
|
8
|
-
return ModelProfile(
|
|
8
|
+
return ModelProfile(
|
|
9
|
+
json_schema_transformer=InlineDefsJsonSchemaTransformer,
|
|
10
|
+
ignore_streamed_leading_whitespace=True,
|
|
11
|
+
)
|
|
@@ -123,6 +123,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
123
123
|
from .huggingface import HuggingFaceProvider
|
|
124
124
|
|
|
125
125
|
return HuggingFaceProvider
|
|
126
|
+
elif provider == 'ollama':
|
|
127
|
+
from .ollama import OllamaProvider
|
|
128
|
+
|
|
129
|
+
return OllamaProvider
|
|
126
130
|
elif provider == 'github':
|
|
127
131
|
from .github import GitHubProvider
|
|
128
132
|
|
|
@@ -6,6 +6,13 @@ from typing import overload
|
|
|
6
6
|
from httpx import AsyncClient
|
|
7
7
|
|
|
8
8
|
from pydantic_ai.exceptions import UserError
|
|
9
|
+
from pydantic_ai.profiles import ModelProfile
|
|
10
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
11
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
12
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
13
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
14
|
+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
|
|
15
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
9
16
|
|
|
10
17
|
try:
|
|
11
18
|
from huggingface_hub import AsyncInferenceClient
|
|
@@ -33,6 +40,26 @@ class HuggingFaceProvider(Provider[AsyncInferenceClient]):
|
|
|
33
40
|
def client(self) -> AsyncInferenceClient:
|
|
34
41
|
return self._client
|
|
35
42
|
|
|
43
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
44
|
+
provider_to_profile = {
|
|
45
|
+
'deepseek-ai': deepseek_model_profile,
|
|
46
|
+
'google': google_model_profile,
|
|
47
|
+
'qwen': qwen_model_profile,
|
|
48
|
+
'meta-llama': meta_model_profile,
|
|
49
|
+
'mistralai': mistral_model_profile,
|
|
50
|
+
'moonshotai': moonshotai_model_profile,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
if '/' not in model_name:
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
model_name = model_name.lower()
|
|
57
|
+
provider, model_name = model_name.split('/', 1)
|
|
58
|
+
if provider in provider_to_profile:
|
|
59
|
+
return provider_to_profile[provider](model_name)
|
|
60
|
+
|
|
61
|
+
return None
|
|
62
|
+
|
|
36
63
|
@overload
|
|
37
64
|
def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ...
|
|
38
65
|
@overload
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from pydantic_ai.exceptions import UserError
|
|
9
|
+
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
from pydantic_ai.profiles import ModelProfile
|
|
11
|
+
from pydantic_ai.profiles.cohere import cohere_model_profile
|
|
12
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
13
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
14
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
15
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
16
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
|
|
17
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
18
|
+
from pydantic_ai.providers import Provider
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from openai import AsyncOpenAI
|
|
22
|
+
except ImportError as _import_error: # pragma: no cover
|
|
23
|
+
raise ImportError(
|
|
24
|
+
'Please install the `openai` package to use the Ollama provider, '
|
|
25
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
26
|
+
) from _import_error
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OllamaProvider(Provider[AsyncOpenAI]):
|
|
30
|
+
"""Provider for local or remote Ollama API."""
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def name(self) -> str:
|
|
34
|
+
return 'ollama'
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def base_url(self) -> str:
|
|
38
|
+
return str(self.client.base_url)
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def client(self) -> AsyncOpenAI:
|
|
42
|
+
return self._client
|
|
43
|
+
|
|
44
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
45
|
+
prefix_to_profile = {
|
|
46
|
+
'llama': meta_model_profile,
|
|
47
|
+
'gemma': google_model_profile,
|
|
48
|
+
'qwen': qwen_model_profile,
|
|
49
|
+
'qwq': qwen_model_profile,
|
|
50
|
+
'deepseek': deepseek_model_profile,
|
|
51
|
+
'mistral': mistral_model_profile,
|
|
52
|
+
'command': cohere_model_profile,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
profile = None
|
|
56
|
+
for prefix, profile_func in prefix_to_profile.items():
|
|
57
|
+
model_name = model_name.lower()
|
|
58
|
+
if model_name.startswith(prefix):
|
|
59
|
+
profile = profile_func(model_name)
|
|
60
|
+
|
|
61
|
+
# As OllamaProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
|
|
62
|
+
# we need to maintain that behavior unless json_schema_transformer is set explicitly
|
|
63
|
+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
base_url: str | None = None,
|
|
68
|
+
api_key: str | None = None,
|
|
69
|
+
openai_client: AsyncOpenAI | None = None,
|
|
70
|
+
http_client: httpx.AsyncClient | None = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Create a new Ollama provider.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
base_url: The base url for the Ollama requests. If not provided, the `OLLAMA_BASE_URL` environment variable
|
|
76
|
+
will be used if available.
|
|
77
|
+
api_key: The API key to use for authentication, if not provided, the `OLLAMA_API_KEY` environment variable
|
|
78
|
+
will be used if available.
|
|
79
|
+
openai_client: An existing
|
|
80
|
+
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
81
|
+
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
82
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
83
|
+
"""
|
|
84
|
+
if openai_client is not None:
|
|
85
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
86
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
87
|
+
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
88
|
+
self._client = openai_client
|
|
89
|
+
else:
|
|
90
|
+
base_url = base_url or os.getenv('OLLAMA_BASE_URL')
|
|
91
|
+
if not base_url:
|
|
92
|
+
raise UserError(
|
|
93
|
+
'Set the `OLLAMA_BASE_URL` environment variable or pass it via `OllamaProvider(base_url=...)`'
|
|
94
|
+
'to use the Ollama provider.'
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
98
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
99
|
+
api_key = api_key or os.getenv('OLLAMA_API_KEY') or 'api-key-not-set'
|
|
100
|
+
|
|
101
|
+
if http_client is not None:
|
|
102
|
+
self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
103
|
+
else:
|
|
104
|
+
http_client = cached_async_http_client(provider='ollama')
|
|
105
|
+
self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
@@ -17,6 +17,7 @@ from pydantic_ai.profiles.google import google_model_profile
|
|
|
17
17
|
from pydantic_ai.profiles.grok import grok_model_profile
|
|
18
18
|
from pydantic_ai.profiles.meta import meta_model_profile
|
|
19
19
|
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
20
|
+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
|
|
20
21
|
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
|
|
21
22
|
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
22
23
|
from pydantic_ai.providers import Provider
|
|
@@ -57,6 +58,7 @@ class OpenRouterProvider(Provider[AsyncOpenAI]):
|
|
|
57
58
|
'amazon': amazon_model_profile,
|
|
58
59
|
'deepseek': deepseek_model_profile,
|
|
59
60
|
'meta-llama': meta_model_profile,
|
|
61
|
+
'moonshotai': moonshotai_model_profile,
|
|
60
62
|
}
|
|
61
63
|
|
|
62
64
|
profile = None
|
pydantic_ai/result.py
CHANGED
|
@@ -196,7 +196,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
196
196
|
and isinstance(event.part, _messages.TextPart)
|
|
197
197
|
and event.part.content
|
|
198
198
|
):
|
|
199
|
-
yield event.part.content, event.index
|
|
199
|
+
yield event.part.content, event.index # pragma: no cover
|
|
200
200
|
elif ( # pragma: no branch
|
|
201
201
|
isinstance(event, _messages.PartDeltaEvent)
|
|
202
202
|
and isinstance(event.delta, _messages.TextPartDelta)
|