pydantic-ai-slim 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +770 -0
- pydantic_ai/_parts_manager.py +1 -1
- pydantic_ai/_result.py +3 -7
- pydantic_ai/_utils.py +1 -56
- pydantic_ai/agent.py +192 -560
- pydantic_ai/messages.py +21 -46
- pydantic_ai/models/__init__.py +104 -57
- pydantic_ai/models/anthropic.py +17 -10
- pydantic_ai/models/cohere.py +37 -25
- pydantic_ai/models/gemini.py +27 -7
- pydantic_ai/models/groq.py +19 -17
- pydantic_ai/models/mistral.py +22 -23
- pydantic_ai/models/openai.py +25 -12
- pydantic_ai/models/test.py +37 -22
- pydantic_ai/result.py +1 -1
- pydantic_ai/settings.py +46 -1
- pydantic_ai/tools.py +11 -8
- {pydantic_ai_slim-0.0.20.dist-info → pydantic_ai_slim-0.0.22.dist-info}/METADATA +2 -3
- pydantic_ai_slim-0.0.22.dist-info/RECORD +30 -0
- pydantic_ai/models/ollama.py +0 -123
- pydantic_ai_slim-0.0.20.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.20.dist-info → pydantic_ai_slim-0.0.22.dist-info}/WHEEL +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
7
|
from itertools import chain
|
|
8
|
-
from typing import Literal, overload
|
|
8
|
+
from typing import Literal, cast, overload
|
|
9
9
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -47,10 +47,7 @@ except ImportError as _import_error:
|
|
|
47
47
|
|
|
48
48
|
GroqModelName = Literal[
|
|
49
49
|
'llama-3.3-70b-versatile',
|
|
50
|
-
'llama-3.
|
|
51
|
-
'llama3-groq-70b-8192-tool-use-preview',
|
|
52
|
-
'llama3-groq-8b-8192-tool-use-preview',
|
|
53
|
-
'llama-3.1-70b-specdec',
|
|
50
|
+
'llama-3.3-70b-specdec',
|
|
54
51
|
'llama-3.1-8b-instant',
|
|
55
52
|
'llama-3.2-1b-preview',
|
|
56
53
|
'llama-3.2-3b-preview',
|
|
@@ -60,7 +57,6 @@ GroqModelName = Literal[
|
|
|
60
57
|
'llama3-8b-8192',
|
|
61
58
|
'mixtral-8x7b-32768',
|
|
62
59
|
'gemma2-9b-it',
|
|
63
|
-
'gemma-7b-it',
|
|
64
60
|
]
|
|
65
61
|
"""Named Groq models.
|
|
66
62
|
|
|
@@ -68,6 +64,12 @@ See [the Groq docs](https://console.groq.com/docs/models) for a full list.
|
|
|
68
64
|
"""
|
|
69
65
|
|
|
70
66
|
|
|
67
|
+
class GroqModelSettings(ModelSettings):
|
|
68
|
+
"""Settings used for a Groq model request."""
|
|
69
|
+
|
|
70
|
+
# This class is a placeholder for any future groq-specific settings
|
|
71
|
+
|
|
72
|
+
|
|
71
73
|
@dataclass(init=False)
|
|
72
74
|
class GroqModel(Model):
|
|
73
75
|
"""A model that uses the Groq API.
|
|
@@ -155,31 +157,31 @@ class GroqAgentModel(AgentModel):
|
|
|
155
157
|
async def request(
|
|
156
158
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
157
159
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
158
|
-
response = await self._completions_create(messages, False, model_settings)
|
|
160
|
+
response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
|
|
159
161
|
return self._process_response(response), _map_usage(response)
|
|
160
162
|
|
|
161
163
|
@asynccontextmanager
|
|
162
164
|
async def request_stream(
|
|
163
165
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
164
166
|
) -> AsyncIterator[StreamedResponse]:
|
|
165
|
-
response = await self._completions_create(messages, True, model_settings)
|
|
167
|
+
response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
|
|
166
168
|
async with response:
|
|
167
169
|
yield await self._process_streamed_response(response)
|
|
168
170
|
|
|
169
171
|
@overload
|
|
170
172
|
async def _completions_create(
|
|
171
|
-
self, messages: list[ModelMessage], stream: Literal[True], model_settings:
|
|
173
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
|
|
172
174
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
173
175
|
pass
|
|
174
176
|
|
|
175
177
|
@overload
|
|
176
178
|
async def _completions_create(
|
|
177
|
-
self, messages: list[ModelMessage], stream: Literal[False], model_settings:
|
|
179
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
|
|
178
180
|
) -> chat.ChatCompletion:
|
|
179
181
|
pass
|
|
180
182
|
|
|
181
183
|
async def _completions_create(
|
|
182
|
-
self, messages: list[ModelMessage], stream: bool, model_settings:
|
|
184
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
|
|
183
185
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
184
186
|
# standalone function to make it easier to override
|
|
185
187
|
if not self.tools:
|
|
@@ -191,13 +193,11 @@ class GroqAgentModel(AgentModel):
|
|
|
191
193
|
|
|
192
194
|
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
193
195
|
|
|
194
|
-
model_settings = model_settings or {}
|
|
195
|
-
|
|
196
196
|
return await self.client.chat.completions.create(
|
|
197
197
|
model=str(self.model_name),
|
|
198
198
|
messages=groq_messages,
|
|
199
199
|
n=1,
|
|
200
|
-
parallel_tool_calls=model_settings.get('parallel_tool_calls',
|
|
200
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
201
201
|
tools=self.tools or NOT_GIVEN,
|
|
202
202
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
203
203
|
stream=stream,
|
|
@@ -205,6 +205,10 @@ class GroqAgentModel(AgentModel):
|
|
|
205
205
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
206
206
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
207
207
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
208
|
+
seed=model_settings.get('seed', NOT_GIVEN),
|
|
209
|
+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
210
|
+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
211
|
+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
208
212
|
)
|
|
209
213
|
|
|
210
214
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
@@ -216,9 +220,7 @@ class GroqAgentModel(AgentModel):
|
|
|
216
220
|
items.append(TextPart(content=choice.message.content))
|
|
217
221
|
if choice.message.tool_calls is not None:
|
|
218
222
|
for c in choice.message.tool_calls:
|
|
219
|
-
items.append(
|
|
220
|
-
ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
|
|
221
|
-
)
|
|
223
|
+
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
222
224
|
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
|
|
223
225
|
|
|
224
226
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -6,7 +6,7 @@ from contextlib import asynccontextmanager
|
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
8
|
from itertools import chain
|
|
9
|
-
from typing import Any, Callable, Literal, Union
|
|
9
|
+
from typing import Any, Callable, Literal, Union, cast
|
|
10
10
|
|
|
11
11
|
import pydantic_core
|
|
12
12
|
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
@@ -15,7 +15,6 @@ from typing_extensions import assert_never
|
|
|
15
15
|
from .. import UnexpectedModelBehavior, _utils
|
|
16
16
|
from .._utils import now_utc as _now_utc
|
|
17
17
|
from ..messages import (
|
|
18
|
-
ArgsJson,
|
|
19
18
|
ModelMessage,
|
|
20
19
|
ModelRequest,
|
|
21
20
|
ModelResponse,
|
|
@@ -85,6 +84,12 @@ Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_o
|
|
|
85
84
|
"""
|
|
86
85
|
|
|
87
86
|
|
|
87
|
+
class MistralModelSettings(ModelSettings):
|
|
88
|
+
"""Settings used for a Mistral model request."""
|
|
89
|
+
|
|
90
|
+
# This class is a placeholder for any future mistral-specific settings
|
|
91
|
+
|
|
92
|
+
|
|
88
93
|
@dataclass(init=False)
|
|
89
94
|
class MistralModel(Model):
|
|
90
95
|
"""A model that uses Mistral.
|
|
@@ -159,7 +164,7 @@ class MistralAgentModel(AgentModel):
|
|
|
159
164
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
160
165
|
) -> tuple[ModelResponse, Usage]:
|
|
161
166
|
"""Make a non-streaming request to the model from Pydantic AI call."""
|
|
162
|
-
response = await self._completions_create(messages, model_settings)
|
|
167
|
+
response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {}))
|
|
163
168
|
return self._process_response(response), _map_usage(response)
|
|
164
169
|
|
|
165
170
|
@asynccontextmanager
|
|
@@ -167,15 +172,14 @@ class MistralAgentModel(AgentModel):
|
|
|
167
172
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
173
|
) -> AsyncIterator[StreamedResponse]:
|
|
169
174
|
"""Make a streaming request to the model from Pydantic AI call."""
|
|
170
|
-
response = await self._stream_completions_create(messages, model_settings)
|
|
175
|
+
response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {}))
|
|
171
176
|
async with response:
|
|
172
177
|
yield await self._process_streamed_response(self.result_tools, response)
|
|
173
178
|
|
|
174
179
|
async def _completions_create(
|
|
175
|
-
self, messages: list[ModelMessage], model_settings:
|
|
180
|
+
self, messages: list[ModelMessage], model_settings: MistralModelSettings
|
|
176
181
|
) -> MistralChatCompletionResponse:
|
|
177
182
|
"""Make a non-streaming request to the model."""
|
|
178
|
-
model_settings = model_settings or {}
|
|
179
183
|
response = await self.client.chat.complete_async(
|
|
180
184
|
model=str(self.model_name),
|
|
181
185
|
messages=list(chain(*(self._map_message(m) for m in messages))),
|
|
@@ -187,6 +191,7 @@ class MistralAgentModel(AgentModel):
|
|
|
187
191
|
temperature=model_settings.get('temperature', UNSET),
|
|
188
192
|
top_p=model_settings.get('top_p', 1),
|
|
189
193
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
194
|
+
random_seed=model_settings.get('seed', UNSET),
|
|
190
195
|
)
|
|
191
196
|
assert response, 'A unexpected empty response from Mistral.'
|
|
192
197
|
return response
|
|
@@ -194,12 +199,11 @@ class MistralAgentModel(AgentModel):
|
|
|
194
199
|
async def _stream_completions_create(
|
|
195
200
|
self,
|
|
196
201
|
messages: list[ModelMessage],
|
|
197
|
-
model_settings:
|
|
202
|
+
model_settings: MistralModelSettings,
|
|
198
203
|
) -> MistralEventStreamAsync[MistralCompletionEvent]:
|
|
199
204
|
"""Create a streaming completion request to the Mistral model."""
|
|
200
205
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
201
206
|
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
202
|
-
model_settings = model_settings or {}
|
|
203
207
|
|
|
204
208
|
if self.result_tools and self.function_tools or self.function_tools:
|
|
205
209
|
# Function Calling
|
|
@@ -213,6 +217,8 @@ class MistralAgentModel(AgentModel):
|
|
|
213
217
|
top_p=model_settings.get('top_p', 1),
|
|
214
218
|
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
215
219
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
220
|
+
presence_penalty=model_settings.get('presence_penalty'),
|
|
221
|
+
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
216
222
|
)
|
|
217
223
|
|
|
218
224
|
elif self.result_tools:
|
|
@@ -317,18 +323,11 @@ class MistralAgentModel(AgentModel):
|
|
|
317
323
|
@staticmethod
|
|
318
324
|
def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
319
325
|
"""Maps a pydantic-ai ToolCall to a MistralToolCall."""
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
)
|
|
326
|
-
else:
|
|
327
|
-
return MistralToolCall(
|
|
328
|
-
id=t.tool_call_id,
|
|
329
|
-
type='function',
|
|
330
|
-
function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_dict),
|
|
331
|
-
)
|
|
326
|
+
return MistralToolCall(
|
|
327
|
+
id=t.tool_call_id,
|
|
328
|
+
type='function',
|
|
329
|
+
function=MistralFunctionCall(name=t.tool_name, arguments=t.args),
|
|
330
|
+
)
|
|
332
331
|
|
|
333
332
|
def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
|
|
334
333
|
"""Get a message with an example of the expected output format."""
|
|
@@ -511,7 +510,7 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
511
510
|
continue
|
|
512
511
|
|
|
513
512
|
# The following part_id will be thrown away
|
|
514
|
-
return ToolCallPart
|
|
513
|
+
return ToolCallPart(tool_name=result_tool.name, args=output_json)
|
|
515
514
|
|
|
516
515
|
@staticmethod
|
|
517
516
|
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
@@ -569,7 +568,7 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa
|
|
|
569
568
|
tool_call_id = tool_call.id or None
|
|
570
569
|
func_call = tool_call.function
|
|
571
570
|
|
|
572
|
-
return ToolCallPart
|
|
571
|
+
return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
|
|
573
572
|
|
|
574
573
|
|
|
575
574
|
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
|
|
@@ -600,7 +599,7 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None
|
|
|
600
599
|
elif isinstance(content, str):
|
|
601
600
|
result = content
|
|
602
601
|
|
|
603
|
-
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and
|
|
602
|
+
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
|
|
604
603
|
if result and len(result) == 0:
|
|
605
604
|
result = None
|
|
606
605
|
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
7
|
from datetime import datetime, timezone
|
|
7
8
|
from itertools import chain
|
|
8
|
-
from typing import Literal, Union, overload
|
|
9
|
+
from typing import Literal, Union, cast, overload
|
|
9
10
|
|
|
10
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
12
|
from typing_extensions import assert_never
|
|
@@ -48,12 +49,18 @@ except ImportError as _import_error:
|
|
|
48
49
|
OpenAIModelName = Union[ChatModel, str]
|
|
49
50
|
"""
|
|
50
51
|
Using this more broad type for the model name instead of the ChatModel definition
|
|
51
|
-
allows this model to be used more easily with other model types (ie, Ollama)
|
|
52
|
+
allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
|
|
52
53
|
"""
|
|
53
54
|
|
|
54
55
|
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
55
56
|
|
|
56
57
|
|
|
58
|
+
class OpenAIModelSettings(ModelSettings):
|
|
59
|
+
"""Settings used for an OpenAI model request."""
|
|
60
|
+
|
|
61
|
+
# This class is a placeholder for any future openai-specific settings
|
|
62
|
+
|
|
63
|
+
|
|
57
64
|
@dataclass(init=False)
|
|
58
65
|
class OpenAIModel(Model):
|
|
59
66
|
"""A model that uses the OpenAI API.
|
|
@@ -95,7 +102,11 @@ class OpenAIModel(Model):
|
|
|
95
102
|
In the future, this may be inferred from the model name.
|
|
96
103
|
"""
|
|
97
104
|
self.model_name: OpenAIModelName = model_name
|
|
98
|
-
|
|
105
|
+
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
106
|
+
# openai compatible models do not always need an API key.
|
|
107
|
+
if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
|
|
108
|
+
api_key = ''
|
|
109
|
+
elif openai_client is not None:
|
|
99
110
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
100
111
|
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
101
112
|
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
@@ -153,31 +164,31 @@ class OpenAIAgentModel(AgentModel):
|
|
|
153
164
|
async def request(
|
|
154
165
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
155
166
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
156
|
-
response = await self._completions_create(messages, False, model_settings)
|
|
167
|
+
response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {}))
|
|
157
168
|
return self._process_response(response), _map_usage(response)
|
|
158
169
|
|
|
159
170
|
@asynccontextmanager
|
|
160
171
|
async def request_stream(
|
|
161
172
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
162
173
|
) -> AsyncIterator[StreamedResponse]:
|
|
163
|
-
response = await self._completions_create(messages, True, model_settings)
|
|
174
|
+
response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {}))
|
|
164
175
|
async with response:
|
|
165
176
|
yield await self._process_streamed_response(response)
|
|
166
177
|
|
|
167
178
|
@overload
|
|
168
179
|
async def _completions_create(
|
|
169
|
-
self, messages: list[ModelMessage], stream: Literal[True], model_settings:
|
|
180
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings
|
|
170
181
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
171
182
|
pass
|
|
172
183
|
|
|
173
184
|
@overload
|
|
174
185
|
async def _completions_create(
|
|
175
|
-
self, messages: list[ModelMessage], stream: Literal[False], model_settings:
|
|
186
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings
|
|
176
187
|
) -> chat.ChatCompletion:
|
|
177
188
|
pass
|
|
178
189
|
|
|
179
190
|
async def _completions_create(
|
|
180
|
-
self, messages: list[ModelMessage], stream: bool, model_settings:
|
|
191
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
|
|
181
192
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
182
193
|
# standalone function to make it easier to override
|
|
183
194
|
if not self.tools:
|
|
@@ -189,13 +200,11 @@ class OpenAIAgentModel(AgentModel):
|
|
|
189
200
|
|
|
190
201
|
openai_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
191
202
|
|
|
192
|
-
model_settings = model_settings or {}
|
|
193
|
-
|
|
194
203
|
return await self.client.chat.completions.create(
|
|
195
204
|
model=self.model_name,
|
|
196
205
|
messages=openai_messages,
|
|
197
206
|
n=1,
|
|
198
|
-
parallel_tool_calls=model_settings.get('parallel_tool_calls',
|
|
207
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
199
208
|
tools=self.tools or NOT_GIVEN,
|
|
200
209
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
201
210
|
stream=stream,
|
|
@@ -204,6 +213,10 @@ class OpenAIAgentModel(AgentModel):
|
|
|
204
213
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
205
214
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
206
215
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
216
|
+
seed=model_settings.get('seed', NOT_GIVEN),
|
|
217
|
+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
218
|
+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
219
|
+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
207
220
|
)
|
|
208
221
|
|
|
209
222
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
@@ -215,7 +228,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
215
228
|
items.append(TextPart(choice.message.content))
|
|
216
229
|
if choice.message.tool_calls is not None:
|
|
217
230
|
for c in choice.message.tool_calls:
|
|
218
|
-
items.append(ToolCallPart
|
|
231
|
+
items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
|
|
219
232
|
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
|
|
220
233
|
|
|
221
234
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
pydantic_ai/models/test.py
CHANGED
|
@@ -12,7 +12,6 @@ import pydantic_core
|
|
|
12
12
|
|
|
13
13
|
from .. import _utils
|
|
14
14
|
from ..messages import (
|
|
15
|
-
ArgsJson,
|
|
16
15
|
ModelMessage,
|
|
17
16
|
ModelRequest,
|
|
18
17
|
ModelResponse,
|
|
@@ -34,6 +33,20 @@ from . import (
|
|
|
34
33
|
from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
35
34
|
|
|
36
35
|
|
|
36
|
+
@dataclass
|
|
37
|
+
class _TextResult:
|
|
38
|
+
"""A private wrapper class to tag a result that came from the custom_result_text field."""
|
|
39
|
+
|
|
40
|
+
value: str | None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class _FunctionToolResult:
|
|
45
|
+
"""A wrapper class to tag a result that came from the custom_result_args field."""
|
|
46
|
+
|
|
47
|
+
value: Any | None
|
|
48
|
+
|
|
49
|
+
|
|
37
50
|
@dataclass
|
|
38
51
|
class TestModel(Model):
|
|
39
52
|
"""A model specifically for testing purposes.
|
|
@@ -53,7 +66,7 @@ class TestModel(Model):
|
|
|
53
66
|
call_tools: list[str] | Literal['all'] = 'all'
|
|
54
67
|
"""List of tools to call. If `'all'`, all tools will be called."""
|
|
55
68
|
custom_result_text: str | None = None
|
|
56
|
-
"""If set, this text is
|
|
69
|
+
"""If set, this text is returned as the final result."""
|
|
57
70
|
custom_result_args: Any | None = None
|
|
58
71
|
"""If set, these args will be passed to the result tool."""
|
|
59
72
|
seed: int = 0
|
|
@@ -95,21 +108,21 @@ class TestModel(Model):
|
|
|
95
108
|
if self.custom_result_text is not None:
|
|
96
109
|
assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
|
|
97
110
|
assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
|
|
98
|
-
result:
|
|
111
|
+
result: _TextResult | _FunctionToolResult = _TextResult(self.custom_result_text)
|
|
99
112
|
elif self.custom_result_args is not None:
|
|
100
113
|
assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.'
|
|
101
114
|
result_tool = result_tools[0]
|
|
102
115
|
|
|
103
116
|
if k := result_tool.outer_typed_dict_key:
|
|
104
|
-
result =
|
|
117
|
+
result = _FunctionToolResult({k: self.custom_result_args})
|
|
105
118
|
else:
|
|
106
|
-
result =
|
|
119
|
+
result = _FunctionToolResult(self.custom_result_args)
|
|
107
120
|
elif allow_text_result:
|
|
108
|
-
result =
|
|
121
|
+
result = _TextResult(None)
|
|
109
122
|
elif result_tools:
|
|
110
|
-
result =
|
|
123
|
+
result = _FunctionToolResult(None)
|
|
111
124
|
else:
|
|
112
|
-
result =
|
|
125
|
+
result = _TextResult(None)
|
|
113
126
|
|
|
114
127
|
return TestAgentModel(tool_calls, result, result_tools, self.seed)
|
|
115
128
|
|
|
@@ -126,7 +139,7 @@ class TestAgentModel(AgentModel):
|
|
|
126
139
|
|
|
127
140
|
tool_calls: list[tuple[str, ToolDefinition]]
|
|
128
141
|
# left means the text is plain text; right means it's a function call
|
|
129
|
-
result:
|
|
142
|
+
result: _TextResult | _FunctionToolResult
|
|
130
143
|
result_tools: list[ToolDefinition]
|
|
131
144
|
seed: int
|
|
132
145
|
model_name: str = 'test'
|
|
@@ -152,7 +165,7 @@ class TestAgentModel(AgentModel):
|
|
|
152
165
|
# if there are tools, the first thing we want to do is call all of them
|
|
153
166
|
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
154
167
|
return ModelResponse(
|
|
155
|
-
parts=[ToolCallPart
|
|
168
|
+
parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
|
|
156
169
|
model_name=self.model_name,
|
|
157
170
|
)
|
|
158
171
|
|
|
@@ -166,7 +179,7 @@ class TestAgentModel(AgentModel):
|
|
|
166
179
|
# Handle retries for both function tools and result tools
|
|
167
180
|
# Check function tools first
|
|
168
181
|
retry_parts: list[ModelResponsePart] = [
|
|
169
|
-
ToolCallPart
|
|
182
|
+
ToolCallPart(name, self.gen_tool_args(args))
|
|
170
183
|
for name, args in self.tool_calls
|
|
171
184
|
if name in new_retry_names
|
|
172
185
|
]
|
|
@@ -174,15 +187,20 @@ class TestAgentModel(AgentModel):
|
|
|
174
187
|
if self.result_tools:
|
|
175
188
|
retry_parts.extend(
|
|
176
189
|
[
|
|
177
|
-
ToolCallPart
|
|
190
|
+
ToolCallPart(
|
|
191
|
+
tool.name,
|
|
192
|
+
self.result.value
|
|
193
|
+
if isinstance(self.result, _FunctionToolResult) and self.result.value is not None
|
|
194
|
+
else self.gen_tool_args(tool),
|
|
195
|
+
)
|
|
178
196
|
for tool in self.result_tools
|
|
179
197
|
if tool.name in new_retry_names
|
|
180
198
|
]
|
|
181
199
|
)
|
|
182
200
|
return ModelResponse(parts=retry_parts, model_name=self.model_name)
|
|
183
201
|
|
|
184
|
-
if
|
|
185
|
-
if response_text.value is None:
|
|
202
|
+
if isinstance(self.result, _TextResult):
|
|
203
|
+
if (response_text := self.result.value) is None:
|
|
186
204
|
# build up details of tool responses
|
|
187
205
|
output: dict[str, Any] = {}
|
|
188
206
|
for message in messages:
|
|
@@ -197,20 +215,18 @@ class TestAgentModel(AgentModel):
|
|
|
197
215
|
else:
|
|
198
216
|
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
|
|
199
217
|
else:
|
|
200
|
-
return ModelResponse(parts=[TextPart(response_text
|
|
218
|
+
return ModelResponse(parts=[TextPart(response_text)], model_name=self.model_name)
|
|
201
219
|
else:
|
|
202
220
|
assert self.result_tools, 'No result tools provided'
|
|
203
|
-
custom_result_args = self.result.
|
|
221
|
+
custom_result_args = self.result.value
|
|
204
222
|
result_tool = self.result_tools[self.seed % len(self.result_tools)]
|
|
205
223
|
if custom_result_args is not None:
|
|
206
224
|
return ModelResponse(
|
|
207
|
-
parts=[ToolCallPart
|
|
225
|
+
parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.model_name
|
|
208
226
|
)
|
|
209
227
|
else:
|
|
210
228
|
response_args = self.gen_tool_args(result_tool)
|
|
211
|
-
return ModelResponse(
|
|
212
|
-
parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)], model_name=self.model_name
|
|
213
|
-
)
|
|
229
|
+
return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.model_name)
|
|
214
230
|
|
|
215
231
|
|
|
216
232
|
@dataclass
|
|
@@ -241,9 +257,8 @@ class TestStreamedResponse(StreamedResponse):
|
|
|
241
257
|
self._usage += _get_string_usage(word)
|
|
242
258
|
yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
|
|
243
259
|
else:
|
|
244
|
-
args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict
|
|
245
260
|
yield self._parts_manager.handle_tool_call_part(
|
|
246
|
-
vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id
|
|
261
|
+
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
|
|
247
262
|
)
|
|
248
263
|
|
|
249
264
|
def timestamp(self) -> datetime:
|
pydantic_ai/result.py
CHANGED
|
@@ -46,7 +46,7 @@ A function that always takes and returns the same type of data (which is the res
|
|
|
46
46
|
* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
|
|
47
47
|
* may or may not be async
|
|
48
48
|
|
|
49
|
-
Usage `ResultValidatorFunc[
|
|
49
|
+
Usage `ResultValidatorFunc[AgentDepsT, T]`.
|
|
50
50
|
"""
|
|
51
51
|
|
|
52
52
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
pydantic_ai/settings.py
CHANGED
|
@@ -80,11 +80,56 @@ class ModelSettings(TypedDict, total=False):
|
|
|
80
80
|
"""Whether to allow parallel tool calls.
|
|
81
81
|
|
|
82
82
|
Supported by:
|
|
83
|
-
|
|
83
|
+
|
|
84
|
+
* OpenAI (some models, not o1)
|
|
84
85
|
* Groq
|
|
85
86
|
* Anthropic
|
|
86
87
|
"""
|
|
87
88
|
|
|
89
|
+
seed: int
|
|
90
|
+
"""The random seed to use for the model, theoretically allowing for deterministic results.
|
|
91
|
+
|
|
92
|
+
Supported by:
|
|
93
|
+
|
|
94
|
+
* OpenAI
|
|
95
|
+
* Groq
|
|
96
|
+
* Cohere
|
|
97
|
+
* Mistral
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
presence_penalty: float
|
|
101
|
+
"""Penalize new tokens based on whether they have appeared in the text so far.
|
|
102
|
+
|
|
103
|
+
Supported by:
|
|
104
|
+
|
|
105
|
+
* OpenAI
|
|
106
|
+
* Groq
|
|
107
|
+
* Cohere
|
|
108
|
+
* Gemini
|
|
109
|
+
* Mistral
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
frequency_penalty: float
|
|
113
|
+
"""Penalize new tokens based on their existing frequency in the text so far.
|
|
114
|
+
|
|
115
|
+
Supported by:
|
|
116
|
+
|
|
117
|
+
* OpenAI
|
|
118
|
+
* Groq
|
|
119
|
+
* Cohere
|
|
120
|
+
* Gemini
|
|
121
|
+
* Mistral
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
logit_bias: dict[str, int]
|
|
125
|
+
"""Modify the likelihood of specified tokens appearing in the completion.
|
|
126
|
+
|
|
127
|
+
Supported by:
|
|
128
|
+
|
|
129
|
+
* OpenAI
|
|
130
|
+
* Groq
|
|
131
|
+
"""
|
|
132
|
+
|
|
88
133
|
|
|
89
134
|
def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
|
|
90
135
|
"""Merge two sets of model settings, preferring the overrides.
|
pydantic_ai/tools.py
CHANGED
|
@@ -79,13 +79,13 @@ SystemPromptFunc = Union[
|
|
|
79
79
|
]
|
|
80
80
|
"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
|
|
81
81
|
|
|
82
|
-
Usage `SystemPromptFunc[
|
|
82
|
+
Usage `SystemPromptFunc[AgentDepsT]`.
|
|
83
83
|
"""
|
|
84
84
|
|
|
85
85
|
ToolFuncContext = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any]
|
|
86
86
|
"""A tool function that takes `RunContext` as the first argument.
|
|
87
87
|
|
|
88
|
-
Usage `ToolContextFunc[
|
|
88
|
+
Usage `ToolContextFunc[AgentDepsT, ToolParams]`.
|
|
89
89
|
"""
|
|
90
90
|
ToolFuncPlain = Callable[ToolParams, Any]
|
|
91
91
|
"""A tool function that does not take `RunContext` as the first argument.
|
|
@@ -98,7 +98,7 @@ ToolFuncEither = Union[ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[To
|
|
|
98
98
|
This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
|
|
99
99
|
[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
|
|
100
100
|
|
|
101
|
-
Usage `ToolFuncEither[
|
|
101
|
+
Usage `ToolFuncEither[AgentDepsT, ToolParams]`.
|
|
102
102
|
"""
|
|
103
103
|
ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]'
|
|
104
104
|
"""Definition of a function that can prepare a tool definition at call time.
|
|
@@ -125,7 +125,7 @@ def hitchhiker(ctx: RunContext[int], answer: str) -> str:
|
|
|
125
125
|
hitchhiker = Tool(hitchhiker, prepare=only_if_42)
|
|
126
126
|
```
|
|
127
127
|
|
|
128
|
-
Usage `ToolPrepareFunc[
|
|
128
|
+
Usage `ToolPrepareFunc[AgentDepsT]`.
|
|
129
129
|
"""
|
|
130
130
|
|
|
131
131
|
DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
|
|
@@ -158,6 +158,9 @@ class Tool(Generic[AgentDepsT]):
|
|
|
158
158
|
_var_positional_field: str | None = field(init=False)
|
|
159
159
|
_validator: SchemaValidator = field(init=False, repr=False)
|
|
160
160
|
_parameters_json_schema: ObjectJsonSchema = field(init=False)
|
|
161
|
+
|
|
162
|
+
# TODO: Move this state off the Tool class, which is otherwise stateless.
|
|
163
|
+
# This should be tracked inside a specific agent run, not the tool.
|
|
161
164
|
current_retry: int = field(default=0, init=False)
|
|
162
165
|
|
|
163
166
|
def __init__(
|
|
@@ -261,13 +264,13 @@ class Tool(Generic[AgentDepsT]):
|
|
|
261
264
|
|
|
262
265
|
async def run(
|
|
263
266
|
self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
|
|
264
|
-
) -> _messages.
|
|
267
|
+
) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
|
|
265
268
|
"""Run the tool function asynchronously."""
|
|
266
269
|
try:
|
|
267
|
-
if isinstance(message.args,
|
|
268
|
-
args_dict = self._validator.validate_json(message.args
|
|
270
|
+
if isinstance(message.args, str):
|
|
271
|
+
args_dict = self._validator.validate_json(message.args)
|
|
269
272
|
else:
|
|
270
|
-
args_dict = self._validator.validate_python(message.args
|
|
273
|
+
args_dict = self._validator.validate_python(message.args)
|
|
271
274
|
except ValidationError as e:
|
|
272
275
|
return self._on_error(e, message)
|
|
273
276
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.22
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -28,13 +28,12 @@ Requires-Dist: eval-type-backport>=0.2.0
|
|
|
28
28
|
Requires-Dist: griffe>=1.3.2
|
|
29
29
|
Requires-Dist: httpx>=0.27
|
|
30
30
|
Requires-Dist: logfire-api>=1.2.0
|
|
31
|
+
Requires-Dist: pydantic-graph==0.0.22
|
|
31
32
|
Requires-Dist: pydantic>=2.10
|
|
32
33
|
Provides-Extra: anthropic
|
|
33
34
|
Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
|
|
34
35
|
Provides-Extra: cohere
|
|
35
36
|
Requires-Dist: cohere>=5.13.11; extra == 'cohere'
|
|
36
|
-
Provides-Extra: graph
|
|
37
|
-
Requires-Dist: pydantic-graph==0.0.20; extra == 'graph'
|
|
38
37
|
Provides-Extra: groq
|
|
39
38
|
Requires-Dist: groq>=0.12.0; extra == 'groq'
|
|
40
39
|
Provides-Extra: logfire
|