pydantic-ai-slim 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +60 -57
- pydantic_ai/_cli.py +18 -3
- pydantic_ai/_parts_manager.py +5 -4
- pydantic_ai/_run_context.py +2 -2
- pydantic_ai/_tool_manager.py +50 -29
- pydantic_ai/ag_ui.py +4 -4
- pydantic_ai/agent/__init__.py +69 -84
- pydantic_ai/agent/abstract.py +16 -18
- pydantic_ai/agent/wrapper.py +4 -6
- pydantic_ai/direct.py +4 -4
- pydantic_ai/durable_exec/temporal/_agent.py +13 -15
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/messages.py +16 -6
- pydantic_ai/models/__init__.py +5 -5
- pydantic_ai/models/anthropic.py +47 -46
- pydantic_ai/models/bedrock.py +25 -27
- pydantic_ai/models/cohere.py +20 -25
- pydantic_ai/models/fallback.py +15 -15
- pydantic_ai/models/function.py +7 -9
- pydantic_ai/models/gemini.py +43 -39
- pydantic_ai/models/google.py +59 -40
- pydantic_ai/models/groq.py +23 -19
- pydantic_ai/models/huggingface.py +27 -23
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mcp_sampling.py +1 -2
- pydantic_ai/models/mistral.py +24 -22
- pydantic_ai/models/openai.py +101 -45
- pydantic_ai/models/test.py +4 -5
- pydantic_ai/profiles/__init__.py +10 -1
- pydantic_ai/profiles/deepseek.py +1 -1
- pydantic_ai/profiles/moonshotai.py +1 -1
- pydantic_ai/profiles/openai.py +13 -3
- pydantic_ai/profiles/qwen.py +4 -1
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/huggingface.py +27 -0
- pydantic_ai/providers/ollama.py +105 -0
- pydantic_ai/providers/openai.py +1 -1
- pydantic_ai/providers/openrouter.py +2 -0
- pydantic_ai/result.py +6 -6
- pydantic_ai/run.py +4 -11
- pydantic_ai/tools.py +9 -9
- pydantic_ai/usage.py +229 -67
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/METADATA +10 -4
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/RECORD +47 -46
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/mistral.py
CHANGED
|
@@ -40,7 +40,7 @@ from ..profiles import ModelProfileSpec
|
|
|
40
40
|
from ..providers import Provider, infer_provider
|
|
41
41
|
from ..settings import ModelSettings
|
|
42
42
|
from ..tools import ToolDefinition
|
|
43
|
-
from ..usage import
|
|
43
|
+
from ..usage import RequestUsage
|
|
44
44
|
from . import (
|
|
45
45
|
Model,
|
|
46
46
|
ModelRequestParameters,
|
|
@@ -120,7 +120,7 @@ class MistralModel(Model):
|
|
|
120
120
|
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
|
|
121
121
|
|
|
122
122
|
_model_name: MistralModelName = field(repr=False)
|
|
123
|
-
|
|
123
|
+
_provider: Provider[Mistral] = field(repr=False)
|
|
124
124
|
|
|
125
125
|
def __init__(
|
|
126
126
|
self,
|
|
@@ -147,13 +147,24 @@ class MistralModel(Model):
|
|
|
147
147
|
|
|
148
148
|
if isinstance(provider, str):
|
|
149
149
|
provider = infer_provider(provider)
|
|
150
|
+
self._provider = provider
|
|
150
151
|
self.client = provider.client
|
|
151
152
|
|
|
152
153
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
153
154
|
|
|
154
155
|
@property
|
|
155
156
|
def base_url(self) -> str:
|
|
156
|
-
return self.
|
|
157
|
+
return self._provider.base_url
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def model_name(self) -> MistralModelName:
|
|
161
|
+
"""The model name."""
|
|
162
|
+
return self._model_name
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def system(self) -> str:
|
|
166
|
+
"""The model provider."""
|
|
167
|
+
return self._provider.name
|
|
157
168
|
|
|
158
169
|
async def request(
|
|
159
170
|
self,
|
|
@@ -167,7 +178,6 @@ class MistralModel(Model):
|
|
|
167
178
|
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
168
179
|
)
|
|
169
180
|
model_response = self._process_response(response)
|
|
170
|
-
model_response.usage.requests = 1
|
|
171
181
|
return model_response
|
|
172
182
|
|
|
173
183
|
@asynccontextmanager
|
|
@@ -186,16 +196,6 @@ class MistralModel(Model):
|
|
|
186
196
|
async with response:
|
|
187
197
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
188
198
|
|
|
189
|
-
@property
|
|
190
|
-
def model_name(self) -> MistralModelName:
|
|
191
|
-
"""The model name."""
|
|
192
|
-
return self._model_name
|
|
193
|
-
|
|
194
|
-
@property
|
|
195
|
-
def system(self) -> str:
|
|
196
|
-
"""The system / model provider."""
|
|
197
|
-
return self._system
|
|
198
|
-
|
|
199
199
|
async def _completions_create(
|
|
200
200
|
self,
|
|
201
201
|
messages: list[ModelMessage],
|
|
@@ -348,7 +348,11 @@ class MistralModel(Model):
|
|
|
348
348
|
parts.append(tool)
|
|
349
349
|
|
|
350
350
|
return ModelResponse(
|
|
351
|
-
parts,
|
|
351
|
+
parts,
|
|
352
|
+
usage=_map_usage(response),
|
|
353
|
+
model_name=response.model,
|
|
354
|
+
timestamp=timestamp,
|
|
355
|
+
provider_request_id=response.id,
|
|
352
356
|
)
|
|
353
357
|
|
|
354
358
|
async def _process_streamed_response(
|
|
@@ -699,17 +703,15 @@ SIMPLE_JSON_TYPE_MAPPING = {
|
|
|
699
703
|
}
|
|
700
704
|
|
|
701
705
|
|
|
702
|
-
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) ->
|
|
706
|
+
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> RequestUsage:
|
|
703
707
|
"""Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
|
|
704
708
|
if response.usage:
|
|
705
|
-
return
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
total_tokens=response.usage.total_tokens,
|
|
709
|
-
details=None,
|
|
709
|
+
return RequestUsage(
|
|
710
|
+
input_tokens=response.usage.prompt_tokens,
|
|
711
|
+
output_tokens=response.usage.completion_tokens,
|
|
710
712
|
)
|
|
711
713
|
else:
|
|
712
|
-
return
|
|
714
|
+
return RequestUsage() # pragma: no cover
|
|
713
715
|
|
|
714
716
|
|
|
715
717
|
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -9,7 +9,7 @@ from datetime import datetime
|
|
|
9
9
|
from typing import Any, Literal, Union, cast, overload
|
|
10
10
|
|
|
11
11
|
from pydantic import ValidationError
|
|
12
|
-
from typing_extensions import assert_never
|
|
12
|
+
from typing_extensions import assert_never, deprecated
|
|
13
13
|
|
|
14
14
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
15
|
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
@@ -40,7 +40,7 @@ from ..messages import (
|
|
|
40
40
|
VideoUrl,
|
|
41
41
|
)
|
|
42
42
|
from ..profiles import ModelProfile, ModelProfileSpec
|
|
43
|
-
from ..profiles.openai import OpenAIModelProfile
|
|
43
|
+
from ..profiles.openai import OpenAIModelProfile, OpenAISystemPromptRole
|
|
44
44
|
from ..providers import Provider, infer_provider
|
|
45
45
|
from ..settings import ModelSettings
|
|
46
46
|
from ..tools import ToolDefinition
|
|
@@ -100,8 +100,6 @@ Using this more broad type for the model name instead of the ChatModel definitio
|
|
|
100
100
|
allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
|
|
101
101
|
"""
|
|
102
102
|
|
|
103
|
-
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
104
|
-
|
|
105
103
|
|
|
106
104
|
class OpenAIModelSettings(ModelSettings, total=False):
|
|
107
105
|
"""Settings used for an OpenAI model request."""
|
|
@@ -196,10 +194,59 @@ class OpenAIModel(Model):
|
|
|
196
194
|
"""
|
|
197
195
|
|
|
198
196
|
client: AsyncOpenAI = field(repr=False)
|
|
199
|
-
system_prompt_role: OpenAISystemPromptRole | None = field(default=None, repr=False)
|
|
200
197
|
|
|
201
198
|
_model_name: OpenAIModelName = field(repr=False)
|
|
202
|
-
|
|
199
|
+
_provider: Provider[AsyncOpenAI] = field(repr=False)
|
|
200
|
+
|
|
201
|
+
@overload
|
|
202
|
+
def __init__(
|
|
203
|
+
self,
|
|
204
|
+
model_name: OpenAIModelName,
|
|
205
|
+
*,
|
|
206
|
+
provider: Literal[
|
|
207
|
+
'openai',
|
|
208
|
+
'deepseek',
|
|
209
|
+
'azure',
|
|
210
|
+
'openrouter',
|
|
211
|
+
'moonshotai',
|
|
212
|
+
'vercel',
|
|
213
|
+
'grok',
|
|
214
|
+
'fireworks',
|
|
215
|
+
'together',
|
|
216
|
+
'heroku',
|
|
217
|
+
'github',
|
|
218
|
+
'ollama',
|
|
219
|
+
]
|
|
220
|
+
| Provider[AsyncOpenAI] = 'openai',
|
|
221
|
+
profile: ModelProfileSpec | None = None,
|
|
222
|
+
settings: ModelSettings | None = None,
|
|
223
|
+
) -> None: ...
|
|
224
|
+
|
|
225
|
+
@deprecated('Set the `system_prompt_role` in the `OpenAIModelProfile` instead.')
|
|
226
|
+
@overload
|
|
227
|
+
def __init__(
|
|
228
|
+
self,
|
|
229
|
+
model_name: OpenAIModelName,
|
|
230
|
+
*,
|
|
231
|
+
provider: Literal[
|
|
232
|
+
'openai',
|
|
233
|
+
'deepseek',
|
|
234
|
+
'azure',
|
|
235
|
+
'openrouter',
|
|
236
|
+
'moonshotai',
|
|
237
|
+
'vercel',
|
|
238
|
+
'grok',
|
|
239
|
+
'fireworks',
|
|
240
|
+
'together',
|
|
241
|
+
'heroku',
|
|
242
|
+
'github',
|
|
243
|
+
'ollama',
|
|
244
|
+
]
|
|
245
|
+
| Provider[AsyncOpenAI] = 'openai',
|
|
246
|
+
profile: ModelProfileSpec | None = None,
|
|
247
|
+
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
248
|
+
settings: ModelSettings | None = None,
|
|
249
|
+
) -> None: ...
|
|
203
250
|
|
|
204
251
|
def __init__(
|
|
205
252
|
self,
|
|
@@ -217,6 +264,7 @@ class OpenAIModel(Model):
|
|
|
217
264
|
'together',
|
|
218
265
|
'heroku',
|
|
219
266
|
'github',
|
|
267
|
+
'ollama',
|
|
220
268
|
]
|
|
221
269
|
| Provider[AsyncOpenAI] = 'openai',
|
|
222
270
|
profile: ModelProfileSpec | None = None,
|
|
@@ -239,16 +287,33 @@ class OpenAIModel(Model):
|
|
|
239
287
|
|
|
240
288
|
if isinstance(provider, str):
|
|
241
289
|
provider = infer_provider(provider)
|
|
290
|
+
self._provider = provider
|
|
242
291
|
self.client = provider.client
|
|
243
292
|
|
|
244
|
-
self.system_prompt_role = system_prompt_role
|
|
245
|
-
|
|
246
293
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
247
294
|
|
|
295
|
+
if system_prompt_role is not None:
|
|
296
|
+
self.profile = OpenAIModelProfile(openai_system_prompt_role=system_prompt_role).update(self.profile)
|
|
297
|
+
|
|
248
298
|
@property
|
|
249
299
|
def base_url(self) -> str:
|
|
250
300
|
return str(self.client.base_url)
|
|
251
301
|
|
|
302
|
+
@property
|
|
303
|
+
def model_name(self) -> OpenAIModelName:
|
|
304
|
+
"""The model name."""
|
|
305
|
+
return self._model_name
|
|
306
|
+
|
|
307
|
+
@property
|
|
308
|
+
def system(self) -> str:
|
|
309
|
+
"""The model provider."""
|
|
310
|
+
return self._provider.name
|
|
311
|
+
|
|
312
|
+
@property
|
|
313
|
+
@deprecated('Set the `system_prompt_role` in the `OpenAIModelProfile` instead.')
|
|
314
|
+
def system_prompt_role(self) -> OpenAISystemPromptRole | None:
|
|
315
|
+
return OpenAIModelProfile.from_profile(self.profile).openai_system_prompt_role
|
|
316
|
+
|
|
252
317
|
async def request(
|
|
253
318
|
self,
|
|
254
319
|
messages: list[ModelMessage],
|
|
@@ -260,7 +325,6 @@ class OpenAIModel(Model):
|
|
|
260
325
|
messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
261
326
|
)
|
|
262
327
|
model_response = self._process_response(response)
|
|
263
|
-
model_response.usage.requests = 1
|
|
264
328
|
return model_response
|
|
265
329
|
|
|
266
330
|
@asynccontextmanager
|
|
@@ -278,16 +342,6 @@ class OpenAIModel(Model):
|
|
|
278
342
|
async with response:
|
|
279
343
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
280
344
|
|
|
281
|
-
@property
|
|
282
|
-
def model_name(self) -> OpenAIModelName:
|
|
283
|
-
"""The model name."""
|
|
284
|
-
return self._model_name
|
|
285
|
-
|
|
286
|
-
@property
|
|
287
|
-
def system(self) -> str:
|
|
288
|
-
"""The system / model provider."""
|
|
289
|
-
return self._system
|
|
290
|
-
|
|
291
345
|
@overload
|
|
292
346
|
async def _completions_create(
|
|
293
347
|
self,
|
|
@@ -444,8 +498,8 @@ class OpenAIModel(Model):
|
|
|
444
498
|
usage=_map_usage(response),
|
|
445
499
|
model_name=response.model,
|
|
446
500
|
timestamp=timestamp,
|
|
447
|
-
|
|
448
|
-
|
|
501
|
+
provider_details=vendor_details,
|
|
502
|
+
provider_request_id=response.id,
|
|
449
503
|
)
|
|
450
504
|
|
|
451
505
|
async def _process_streamed_response(
|
|
@@ -561,9 +615,10 @@ class OpenAIModel(Model):
|
|
|
561
615
|
async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
|
|
562
616
|
for part in message.parts:
|
|
563
617
|
if isinstance(part, SystemPromptPart):
|
|
564
|
-
|
|
618
|
+
system_prompt_role = OpenAIModelProfile.from_profile(self.profile).openai_system_prompt_role
|
|
619
|
+
if system_prompt_role == 'developer':
|
|
565
620
|
yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
|
|
566
|
-
elif
|
|
621
|
+
elif system_prompt_role == 'user':
|
|
567
622
|
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
568
623
|
else:
|
|
569
624
|
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
@@ -659,10 +714,9 @@ class OpenAIResponsesModel(Model):
|
|
|
659
714
|
"""
|
|
660
715
|
|
|
661
716
|
client: AsyncOpenAI = field(repr=False)
|
|
662
|
-
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
663
717
|
|
|
664
718
|
_model_name: OpenAIModelName = field(repr=False)
|
|
665
|
-
|
|
719
|
+
_provider: Provider[AsyncOpenAI] = field(repr=False)
|
|
666
720
|
|
|
667
721
|
def __init__(
|
|
668
722
|
self,
|
|
@@ -685,6 +739,7 @@ class OpenAIResponsesModel(Model):
|
|
|
685
739
|
|
|
686
740
|
if isinstance(provider, str):
|
|
687
741
|
provider = infer_provider(provider)
|
|
742
|
+
self._provider = provider
|
|
688
743
|
self.client = provider.client
|
|
689
744
|
|
|
690
745
|
super().__init__(settings=settings, profile=profile or provider.model_profile)
|
|
@@ -696,8 +751,8 @@ class OpenAIResponsesModel(Model):
|
|
|
696
751
|
|
|
697
752
|
@property
|
|
698
753
|
def system(self) -> str:
|
|
699
|
-
"""The
|
|
700
|
-
return self.
|
|
754
|
+
"""The model provider."""
|
|
755
|
+
return self._provider.name
|
|
701
756
|
|
|
702
757
|
async def request(
|
|
703
758
|
self,
|
|
@@ -746,7 +801,7 @@ class OpenAIResponsesModel(Model):
|
|
|
746
801
|
items,
|
|
747
802
|
usage=_map_usage(response),
|
|
748
803
|
model_name=response.model,
|
|
749
|
-
|
|
804
|
+
provider_request_id=response.id,
|
|
750
805
|
timestamp=timestamp,
|
|
751
806
|
)
|
|
752
807
|
|
|
@@ -1094,11 +1149,12 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1094
1149
|
|
|
1095
1150
|
# Handle the text part of the response
|
|
1096
1151
|
content = choice.delta.content
|
|
1097
|
-
if content:
|
|
1152
|
+
if content is not None:
|
|
1098
1153
|
maybe_event = self._parts_manager.handle_text_delta(
|
|
1099
1154
|
vendor_part_id='content',
|
|
1100
1155
|
content=content,
|
|
1101
1156
|
thinking_tags=self._model_profile.thinking_tags,
|
|
1157
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
1102
1158
|
)
|
|
1103
1159
|
if maybe_event is not None: # pragma: no branch
|
|
1104
1160
|
yield maybe_event
|
|
@@ -1263,10 +1319,10 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1263
1319
|
return self._timestamp
|
|
1264
1320
|
|
|
1265
1321
|
|
|
1266
|
-
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.
|
|
1322
|
+
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.RequestUsage:
|
|
1267
1323
|
response_usage = response.usage
|
|
1268
1324
|
if response_usage is None:
|
|
1269
|
-
return usage.
|
|
1325
|
+
return usage.RequestUsage()
|
|
1270
1326
|
elif isinstance(response_usage, responses.ResponseUsage):
|
|
1271
1327
|
details: dict[str, int] = {
|
|
1272
1328
|
key: value
|
|
@@ -1276,29 +1332,29 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
|
|
|
1276
1332
|
if isinstance(value, int)
|
|
1277
1333
|
}
|
|
1278
1334
|
details['reasoning_tokens'] = response_usage.output_tokens_details.reasoning_tokens
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
total_tokens=response_usage.total_tokens,
|
|
1335
|
+
return usage.RequestUsage(
|
|
1336
|
+
input_tokens=response_usage.input_tokens,
|
|
1337
|
+
output_tokens=response_usage.output_tokens,
|
|
1338
|
+
cache_read_tokens=response_usage.input_tokens_details.cached_tokens,
|
|
1284
1339
|
details=details,
|
|
1285
1340
|
)
|
|
1286
1341
|
else:
|
|
1287
1342
|
details = {
|
|
1288
1343
|
key: value
|
|
1289
1344
|
for key, value in response_usage.model_dump(
|
|
1290
|
-
exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'}
|
|
1345
|
+
exclude_none=True, exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'}
|
|
1291
1346
|
).items()
|
|
1292
1347
|
if isinstance(value, int)
|
|
1293
1348
|
}
|
|
1349
|
+
u = usage.RequestUsage(
|
|
1350
|
+
input_tokens=response_usage.prompt_tokens,
|
|
1351
|
+
output_tokens=response_usage.completion_tokens,
|
|
1352
|
+
details=details,
|
|
1353
|
+
)
|
|
1294
1354
|
if response_usage.completion_tokens_details is not None:
|
|
1295
1355
|
details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
|
|
1356
|
+
u.output_audio_tokens = response_usage.completion_tokens_details.audio_tokens or 0
|
|
1296
1357
|
if response_usage.prompt_tokens_details is not None:
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
request_tokens=response_usage.prompt_tokens,
|
|
1301
|
-
response_tokens=response_usage.completion_tokens,
|
|
1302
|
-
total_tokens=response_usage.total_tokens,
|
|
1303
|
-
details=details,
|
|
1304
|
-
)
|
|
1358
|
+
u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0
|
|
1359
|
+
u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0
|
|
1360
|
+
return u
|
pydantic_ai/models/test.py
CHANGED
|
@@ -31,7 +31,7 @@ from ..messages import (
|
|
|
31
31
|
from ..profiles import ModelProfileSpec
|
|
32
32
|
from ..settings import ModelSettings
|
|
33
33
|
from ..tools import ToolDefinition
|
|
34
|
-
from ..usage import
|
|
34
|
+
from ..usage import RequestUsage
|
|
35
35
|
from . import Model, ModelRequestParameters, StreamedResponse
|
|
36
36
|
from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
37
37
|
|
|
@@ -113,7 +113,6 @@ class TestModel(Model):
|
|
|
113
113
|
self.last_model_request_parameters = model_request_parameters
|
|
114
114
|
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
115
115
|
model_response.usage = _estimate_usage([*messages, model_response])
|
|
116
|
-
model_response.usage.requests = 1
|
|
117
116
|
return model_response
|
|
118
117
|
|
|
119
118
|
@asynccontextmanager
|
|
@@ -141,7 +140,7 @@ class TestModel(Model):
|
|
|
141
140
|
|
|
142
141
|
@property
|
|
143
142
|
def system(self) -> str:
|
|
144
|
-
"""The
|
|
143
|
+
"""The model provider."""
|
|
145
144
|
return self._system
|
|
146
145
|
|
|
147
146
|
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
@@ -468,6 +467,6 @@ class _JsonSchemaTestData:
|
|
|
468
467
|
return s
|
|
469
468
|
|
|
470
469
|
|
|
471
|
-
def _get_string_usage(text: str) ->
|
|
470
|
+
def _get_string_usage(text: str) -> RequestUsage:
|
|
472
471
|
response_tokens = _estimate_string_tokens(text)
|
|
473
|
-
return
|
|
472
|
+
return RequestUsage(output_tokens=response_tokens)
|
pydantic_ai/profiles/__init__.py
CHANGED
|
@@ -20,7 +20,7 @@ __all__ = [
|
|
|
20
20
|
|
|
21
21
|
@dataclass
|
|
22
22
|
class ModelProfile:
|
|
23
|
-
"""Describes how requests to
|
|
23
|
+
"""Describes how requests to and responses from specific models or families of models need to be constructed and processed to get the best results, independent of the model and provider classes used."""
|
|
24
24
|
|
|
25
25
|
supports_tools: bool = True
|
|
26
26
|
"""Whether the model supports tools."""
|
|
@@ -46,6 +46,15 @@ class ModelProfile:
|
|
|
46
46
|
thinking_tags: tuple[str, str] = ('<think>', '</think>')
|
|
47
47
|
"""The tags used to indicate thinking parts in the model's output. Defaults to ('<think>', '</think>')."""
|
|
48
48
|
|
|
49
|
+
ignore_streamed_leading_whitespace: bool = False
|
|
50
|
+
"""Whether to ignore leading whitespace when streaming a response.
|
|
51
|
+
|
|
52
|
+
This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
|
|
53
|
+
which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
|
|
54
|
+
|
|
55
|
+
This is currently only used by `OpenAIModel`, `HuggingFaceModel`, and `GroqModel`.
|
|
56
|
+
"""
|
|
57
|
+
|
|
49
58
|
@classmethod
|
|
50
59
|
def from_profile(cls, profile: ModelProfile | None) -> Self:
|
|
51
60
|
"""Build a ModelProfile subclass instance from a ModelProfile instance."""
|
pydantic_ai/profiles/deepseek.py
CHANGED
pydantic_ai/profiles/openai.py
CHANGED
|
@@ -2,11 +2,13 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
-
from typing import Any
|
|
5
|
+
from typing import Any, Literal
|
|
6
6
|
|
|
7
7
|
from . import ModelProfile
|
|
8
8
|
from ._json_schema import JsonSchema, JsonSchemaTransformer
|
|
9
9
|
|
|
10
|
+
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
@dataclass
|
|
12
14
|
class OpenAIModelProfile(ModelProfile):
|
|
@@ -26,8 +28,10 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
26
28
|
# safe to pass that value along. Default is `True` to preserve existing
|
|
27
29
|
# behaviour for OpenAI itself and most providers.
|
|
28
30
|
openai_supports_tool_choice_required: bool = True
|
|
29
|
-
"""Whether the provider accepts the value ``tool_choice='required'`` in the
|
|
30
|
-
|
|
31
|
+
"""Whether the provider accepts the value ``tool_choice='required'`` in the request payload."""
|
|
32
|
+
|
|
33
|
+
openai_system_prompt_role: OpenAISystemPromptRole | None = None
|
|
34
|
+
"""The role to use for the system prompt message. If not provided, defaults to `'system'`."""
|
|
31
35
|
|
|
32
36
|
|
|
33
37
|
def openai_model_profile(model_name: str) -> ModelProfile:
|
|
@@ -36,11 +40,17 @@ def openai_model_profile(model_name: str) -> ModelProfile:
|
|
|
36
40
|
# Structured Outputs (output mode 'native') is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later.
|
|
37
41
|
# We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `native` is only used
|
|
38
42
|
# when the user specifically uses the `NativeOutput` marker, so an error from the API is acceptable.
|
|
43
|
+
|
|
44
|
+
# The o1-mini model doesn't support the `system` role, so we default to `user`.
|
|
45
|
+
# See https://github.com/pydantic/pydantic-ai/issues/974 for more details.
|
|
46
|
+
openai_system_prompt_role = 'user' if model_name.startswith('o1-mini') else None
|
|
47
|
+
|
|
39
48
|
return OpenAIModelProfile(
|
|
40
49
|
json_schema_transformer=OpenAIJsonSchemaTransformer,
|
|
41
50
|
supports_json_schema_output=True,
|
|
42
51
|
supports_json_object_output=True,
|
|
43
52
|
openai_supports_sampling_settings=not is_reasoning_model,
|
|
53
|
+
openai_system_prompt_role=openai_system_prompt_role,
|
|
44
54
|
)
|
|
45
55
|
|
|
46
56
|
|
pydantic_ai/profiles/qwen.py
CHANGED
|
@@ -5,4 +5,7 @@ from . import InlineDefsJsonSchemaTransformer, ModelProfile
|
|
|
5
5
|
|
|
6
6
|
def qwen_model_profile(model_name: str) -> ModelProfile | None:
|
|
7
7
|
"""Get the model profile for a Qwen model."""
|
|
8
|
-
return ModelProfile(
|
|
8
|
+
return ModelProfile(
|
|
9
|
+
json_schema_transformer=InlineDefsJsonSchemaTransformer,
|
|
10
|
+
ignore_streamed_leading_whitespace=True,
|
|
11
|
+
)
|
|
@@ -123,6 +123,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
123
123
|
from .huggingface import HuggingFaceProvider
|
|
124
124
|
|
|
125
125
|
return HuggingFaceProvider
|
|
126
|
+
elif provider == 'ollama':
|
|
127
|
+
from .ollama import OllamaProvider
|
|
128
|
+
|
|
129
|
+
return OllamaProvider
|
|
126
130
|
elif provider == 'github':
|
|
127
131
|
from .github import GitHubProvider
|
|
128
132
|
|
|
@@ -6,6 +6,13 @@ from typing import overload
|
|
|
6
6
|
from httpx import AsyncClient
|
|
7
7
|
|
|
8
8
|
from pydantic_ai.exceptions import UserError
|
|
9
|
+
from pydantic_ai.profiles import ModelProfile
|
|
10
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
11
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
12
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
13
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
14
|
+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
|
|
15
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
9
16
|
|
|
10
17
|
try:
|
|
11
18
|
from huggingface_hub import AsyncInferenceClient
|
|
@@ -33,6 +40,26 @@ class HuggingFaceProvider(Provider[AsyncInferenceClient]):
|
|
|
33
40
|
def client(self) -> AsyncInferenceClient:
|
|
34
41
|
return self._client
|
|
35
42
|
|
|
43
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
44
|
+
provider_to_profile = {
|
|
45
|
+
'deepseek-ai': deepseek_model_profile,
|
|
46
|
+
'google': google_model_profile,
|
|
47
|
+
'qwen': qwen_model_profile,
|
|
48
|
+
'meta-llama': meta_model_profile,
|
|
49
|
+
'mistralai': mistral_model_profile,
|
|
50
|
+
'moonshotai': moonshotai_model_profile,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
if '/' not in model_name:
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
model_name = model_name.lower()
|
|
57
|
+
provider, model_name = model_name.split('/', 1)
|
|
58
|
+
if provider in provider_to_profile:
|
|
59
|
+
return provider_to_profile[provider](model_name)
|
|
60
|
+
|
|
61
|
+
return None
|
|
62
|
+
|
|
36
63
|
@overload
|
|
37
64
|
def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ...
|
|
38
65
|
@overload
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from pydantic_ai.exceptions import UserError
|
|
9
|
+
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
from pydantic_ai.profiles import ModelProfile
|
|
11
|
+
from pydantic_ai.profiles.cohere import cohere_model_profile
|
|
12
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
13
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
14
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
15
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
16
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
|
|
17
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
18
|
+
from pydantic_ai.providers import Provider
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from openai import AsyncOpenAI
|
|
22
|
+
except ImportError as _import_error: # pragma: no cover
|
|
23
|
+
raise ImportError(
|
|
24
|
+
'Please install the `openai` package to use the Ollama provider, '
|
|
25
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
26
|
+
) from _import_error
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OllamaProvider(Provider[AsyncOpenAI]):
|
|
30
|
+
"""Provider for local or remote Ollama API."""
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def name(self) -> str:
|
|
34
|
+
return 'ollama'
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def base_url(self) -> str:
|
|
38
|
+
return str(self.client.base_url)
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def client(self) -> AsyncOpenAI:
|
|
42
|
+
return self._client
|
|
43
|
+
|
|
44
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
45
|
+
prefix_to_profile = {
|
|
46
|
+
'llama': meta_model_profile,
|
|
47
|
+
'gemma': google_model_profile,
|
|
48
|
+
'qwen': qwen_model_profile,
|
|
49
|
+
'qwq': qwen_model_profile,
|
|
50
|
+
'deepseek': deepseek_model_profile,
|
|
51
|
+
'mistral': mistral_model_profile,
|
|
52
|
+
'command': cohere_model_profile,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
profile = None
|
|
56
|
+
for prefix, profile_func in prefix_to_profile.items():
|
|
57
|
+
model_name = model_name.lower()
|
|
58
|
+
if model_name.startswith(prefix):
|
|
59
|
+
profile = profile_func(model_name)
|
|
60
|
+
|
|
61
|
+
# As OllamaProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
|
|
62
|
+
# we need to maintain that behavior unless json_schema_transformer is set explicitly
|
|
63
|
+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
base_url: str | None = None,
|
|
68
|
+
api_key: str | None = None,
|
|
69
|
+
openai_client: AsyncOpenAI | None = None,
|
|
70
|
+
http_client: httpx.AsyncClient | None = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Create a new Ollama provider.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
base_url: The base url for the Ollama requests. If not provided, the `OLLAMA_BASE_URL` environment variable
|
|
76
|
+
will be used if available.
|
|
77
|
+
api_key: The API key to use for authentication, if not provided, the `OLLAMA_API_KEY` environment variable
|
|
78
|
+
will be used if available.
|
|
79
|
+
openai_client: An existing
|
|
80
|
+
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
81
|
+
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
82
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
83
|
+
"""
|
|
84
|
+
if openai_client is not None:
|
|
85
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
86
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
87
|
+
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
88
|
+
self._client = openai_client
|
|
89
|
+
else:
|
|
90
|
+
base_url = base_url or os.getenv('OLLAMA_BASE_URL')
|
|
91
|
+
if not base_url:
|
|
92
|
+
raise UserError(
|
|
93
|
+
'Set the `OLLAMA_BASE_URL` environment variable or pass it via `OllamaProvider(base_url=...)`'
|
|
94
|
+
'to use the Ollama provider.'
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
98
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
99
|
+
api_key = api_key or os.getenv('OLLAMA_API_KEY') or 'api-key-not-set'
|
|
100
|
+
|
|
101
|
+
if http_client is not None:
|
|
102
|
+
self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
103
|
+
else:
|
|
104
|
+
http_client = cached_async_http_client(provider='ollama')
|
|
105
|
+
self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|