pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +13 -29
- pydantic_ai/_result.py +52 -38
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +20 -8
- pydantic_ai/agent.py +431 -167
- pydantic_ai/messages.py +90 -48
- pydantic_ai/models/__init__.py +59 -42
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +66 -44
- pydantic_ai/models/gemini.py +160 -117
- pydantic_ai/models/groq.py +125 -108
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +145 -114
- pydantic_ai/models/test.py +109 -77
- pydantic_ai/models/vertexai.py +14 -9
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +140 -45
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.11.dist-info/RECORD +0 -22
pydantic_ai/models/test.py
CHANGED
|
@@ -2,27 +2,29 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
4
|
import string
|
|
5
|
-
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import date, datetime, timedelta
|
|
9
9
|
from typing import Any, Literal
|
|
10
10
|
|
|
11
11
|
import pydantic_core
|
|
12
|
+
from typing_extensions import assert_never
|
|
12
13
|
|
|
13
14
|
from .. import _utils
|
|
14
15
|
from ..messages import (
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
16
|
+
ModelMessage,
|
|
17
|
+
ModelRequest,
|
|
18
|
+
ModelResponse,
|
|
19
|
+
RetryPromptPart,
|
|
20
|
+
TextPart,
|
|
21
|
+
ToolCallPart,
|
|
22
|
+
ToolReturnPart,
|
|
22
23
|
)
|
|
23
24
|
from ..result import Cost
|
|
25
|
+
from ..settings import ModelSettings
|
|
26
|
+
from ..tools import ToolDefinition
|
|
24
27
|
from . import (
|
|
25
|
-
AbstractToolDefinition,
|
|
26
28
|
AgentModel,
|
|
27
29
|
EitherStreamedResponse,
|
|
28
30
|
Model,
|
|
@@ -55,25 +57,38 @@ class TestModel(Model):
|
|
|
55
57
|
"""If set, these args will be passed to the result tool."""
|
|
56
58
|
seed: int = 0
|
|
57
59
|
"""Seed for generating random data."""
|
|
58
|
-
|
|
59
|
-
|
|
60
|
+
agent_model_function_tools: list[ToolDefinition] | None = field(default=None, init=False)
|
|
61
|
+
"""Definition of function tools passed to the model.
|
|
62
|
+
|
|
63
|
+
This is set when the model is called, so will reflect the function tools from the last step of the last run.
|
|
64
|
+
"""
|
|
60
65
|
agent_model_allow_text_result: bool | None = field(default=None, init=False)
|
|
61
|
-
|
|
66
|
+
"""Whether plain text responses from the model are allowed.
|
|
67
|
+
|
|
68
|
+
This is set when the model is called, so will reflect the value from the last step of the last run.
|
|
69
|
+
"""
|
|
70
|
+
agent_model_result_tools: list[ToolDefinition] | None = field(default=None, init=False)
|
|
71
|
+
"""Definition of result tools passed to the model.
|
|
72
|
+
|
|
73
|
+
This is set when the model is called, so will reflect the result tools from the last step of the last run.
|
|
74
|
+
"""
|
|
62
75
|
|
|
63
76
|
async def agent_model(
|
|
64
77
|
self,
|
|
65
|
-
|
|
78
|
+
*,
|
|
79
|
+
function_tools: list[ToolDefinition],
|
|
66
80
|
allow_text_result: bool,
|
|
67
|
-
result_tools:
|
|
81
|
+
result_tools: list[ToolDefinition],
|
|
68
82
|
) -> AgentModel:
|
|
69
|
-
self.
|
|
83
|
+
self.agent_model_function_tools = function_tools
|
|
70
84
|
self.agent_model_allow_text_result = allow_text_result
|
|
71
|
-
self.agent_model_result_tools =
|
|
85
|
+
self.agent_model_result_tools = result_tools
|
|
72
86
|
|
|
73
87
|
if self.call_tools == 'all':
|
|
74
|
-
tool_calls = [(r.name, r) for r in function_tools
|
|
88
|
+
tool_calls = [(r.name, r) for r in function_tools]
|
|
75
89
|
else:
|
|
76
|
-
|
|
90
|
+
function_tools_lookup = {t.name: t for t in function_tools}
|
|
91
|
+
tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
|
|
77
92
|
tool_calls = [(r.name, r) for r in tools_to_call]
|
|
78
93
|
|
|
79
94
|
if self.custom_result_text is not None:
|
|
@@ -90,11 +105,12 @@ class TestModel(Model):
|
|
|
90
105
|
result = _utils.Either(right=self.custom_result_args)
|
|
91
106
|
elif allow_text_result:
|
|
92
107
|
result = _utils.Either(left=None)
|
|
93
|
-
elif result_tools
|
|
108
|
+
elif result_tools:
|
|
94
109
|
result = _utils.Either(right=None)
|
|
95
110
|
else:
|
|
96
111
|
result = _utils.Either(left=None)
|
|
97
|
-
|
|
112
|
+
|
|
113
|
+
return TestAgentModel(tool_calls, result, result_tools, self.seed)
|
|
98
114
|
|
|
99
115
|
def name(self) -> str:
|
|
100
116
|
return 'test-model'
|
|
@@ -107,73 +123,89 @@ class TestAgentModel(AgentModel):
|
|
|
107
123
|
# NOTE: Avoid test discovery by pytest.
|
|
108
124
|
__test__ = False
|
|
109
125
|
|
|
110
|
-
tool_calls: list[tuple[str,
|
|
126
|
+
tool_calls: list[tuple[str, ToolDefinition]]
|
|
111
127
|
# left means the text is plain text; right means it's a function call
|
|
112
128
|
result: _utils.Either[str | None, Any | None]
|
|
113
|
-
result_tools: list[
|
|
129
|
+
result_tools: list[ToolDefinition]
|
|
114
130
|
seed: int
|
|
115
|
-
step: int = 0
|
|
116
|
-
last_message_count: int = 0
|
|
117
131
|
|
|
118
|
-
async def request(
|
|
119
|
-
|
|
132
|
+
async def request(
|
|
133
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
134
|
+
) -> tuple[ModelResponse, Cost]:
|
|
135
|
+
return self._request(messages, model_settings), Cost()
|
|
120
136
|
|
|
121
137
|
@asynccontextmanager
|
|
122
|
-
async def request_stream(
|
|
123
|
-
|
|
138
|
+
async def request_stream(
|
|
139
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
140
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
141
|
+
msg = self._request(messages, model_settings)
|
|
124
142
|
cost = Cost()
|
|
125
|
-
|
|
126
|
-
|
|
143
|
+
|
|
144
|
+
# TODO: Rework this once we make StreamTextResponse more general
|
|
145
|
+
texts: list[str] = []
|
|
146
|
+
tool_calls: list[ToolCallPart] = []
|
|
147
|
+
for item in msg.parts:
|
|
148
|
+
if isinstance(item, TextPart):
|
|
149
|
+
texts.append(item.content)
|
|
150
|
+
elif isinstance(item, ToolCallPart):
|
|
151
|
+
tool_calls.append(item)
|
|
152
|
+
else:
|
|
153
|
+
assert_never(item)
|
|
154
|
+
|
|
155
|
+
if texts:
|
|
156
|
+
yield TestStreamTextResponse('\n\n'.join(texts), cost)
|
|
127
157
|
else:
|
|
128
158
|
yield TestStreamStructuredResponse(msg, cost)
|
|
129
159
|
|
|
130
|
-
def gen_tool_args(self, tool_def:
|
|
131
|
-
return _JsonSchemaTestData(tool_def.
|
|
132
|
-
|
|
133
|
-
def _request(self, messages: list[
|
|
134
|
-
if
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
if
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
160
|
+
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
161
|
+
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
162
|
+
|
|
163
|
+
def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
|
|
164
|
+
# if there are tools, the first thing we want to do is call all of them
|
|
165
|
+
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
166
|
+
return ModelResponse(
|
|
167
|
+
parts=[ToolCallPart.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if messages:
|
|
171
|
+
last_message = messages[-1]
|
|
172
|
+
assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
|
|
173
|
+
|
|
174
|
+
# check if there are any retry prompts, if so retry them
|
|
175
|
+
new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
|
|
176
|
+
if new_retry_names:
|
|
177
|
+
return ModelResponse(
|
|
178
|
+
parts=[
|
|
179
|
+
ToolCallPart.from_dict(name, self.gen_tool_args(args))
|
|
180
|
+
for name, args in self.tool_calls
|
|
181
|
+
if name in new_retry_names
|
|
182
|
+
]
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if response_text := self.result.left:
|
|
186
|
+
if response_text.value is None:
|
|
187
|
+
# build up details of tool responses
|
|
188
|
+
output: dict[str, Any] = {}
|
|
189
|
+
for message in messages:
|
|
190
|
+
if isinstance(message, ModelRequest):
|
|
191
|
+
for part in message.parts:
|
|
192
|
+
if isinstance(part, ToolReturnPart):
|
|
193
|
+
output[part.tool_name] = part.content
|
|
194
|
+
if output:
|
|
195
|
+
return ModelResponse.from_text(pydantic_core.to_json(output).decode())
|
|
164
196
|
else:
|
|
165
|
-
return
|
|
197
|
+
return ModelResponse.from_text('success (no tool calls)')
|
|
166
198
|
else:
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
199
|
+
return ModelResponse.from_text(response_text.value)
|
|
200
|
+
else:
|
|
201
|
+
assert self.result_tools, 'No result tools provided'
|
|
202
|
+
custom_result_args = self.result.right
|
|
203
|
+
result_tool = self.result_tools[self.seed % len(self.result_tools)]
|
|
204
|
+
if custom_result_args is not None:
|
|
205
|
+
return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, custom_result_args)])
|
|
206
|
+
else:
|
|
207
|
+
response_args = self.gen_tool_args(result_tool)
|
|
208
|
+
return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, response_args)])
|
|
177
209
|
|
|
178
210
|
|
|
179
211
|
@dataclass
|
|
@@ -213,7 +245,7 @@ class TestStreamTextResponse(StreamTextResponse):
|
|
|
213
245
|
class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
214
246
|
"""A structured response that streams test data."""
|
|
215
247
|
|
|
216
|
-
_structured_response:
|
|
248
|
+
_structured_response: ModelResponse
|
|
217
249
|
_cost: Cost
|
|
218
250
|
_iter: Iterator[None] = field(default_factory=lambda: iter([None]))
|
|
219
251
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
@@ -221,7 +253,7 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
|
221
253
|
async def __anext__(self) -> None:
|
|
222
254
|
return _utils.sync_anext(self._iter)
|
|
223
255
|
|
|
224
|
-
def get(self, *, final: bool = False) ->
|
|
256
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
225
257
|
return self._structured_response
|
|
226
258
|
|
|
227
259
|
def cost(self) -> Cost:
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import Mapping, Sequence
|
|
4
3
|
from dataclasses import dataclass, field
|
|
5
4
|
from datetime import datetime, timedelta
|
|
6
5
|
from pathlib import Path
|
|
@@ -10,7 +9,8 @@ from httpx import AsyncClient as AsyncHTTPClient
|
|
|
10
9
|
|
|
11
10
|
from .._utils import run_in_executor
|
|
12
11
|
from ..exceptions import UserError
|
|
13
|
-
from
|
|
12
|
+
from ..tools import ToolDefinition
|
|
13
|
+
from . import Model, cached_async_http_client
|
|
14
14
|
from .gemini import GeminiAgentModel, GeminiModelName
|
|
15
15
|
|
|
16
16
|
try:
|
|
@@ -18,11 +18,11 @@ try:
|
|
|
18
18
|
from google.auth.credentials import Credentials as BaseCredentials
|
|
19
19
|
from google.auth.transport.requests import Request
|
|
20
20
|
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
|
21
|
-
except ImportError as
|
|
21
|
+
except ImportError as _import_error:
|
|
22
22
|
raise ImportError(
|
|
23
23
|
'Please install `google-auth` to use the VertexAI model, '
|
|
24
|
-
"you can use the `vertexai` optional group — `pip install 'pydantic-ai[vertexai]'`"
|
|
25
|
-
) from
|
|
24
|
+
"you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
|
|
25
|
+
) from _import_error
|
|
26
26
|
|
|
27
27
|
VERTEX_AI_URL_TEMPLATE = (
|
|
28
28
|
'https://{region}-aiplatform.googleapis.com/v1'
|
|
@@ -109,11 +109,12 @@ class VertexAIModel(Model):
|
|
|
109
109
|
|
|
110
110
|
async def agent_model(
|
|
111
111
|
self,
|
|
112
|
-
|
|
112
|
+
*,
|
|
113
|
+
function_tools: list[ToolDefinition],
|
|
113
114
|
allow_text_result: bool,
|
|
114
|
-
result_tools:
|
|
115
|
+
result_tools: list[ToolDefinition],
|
|
115
116
|
) -> GeminiAgentModel:
|
|
116
|
-
url, auth = await self.
|
|
117
|
+
url, auth = await self.ainit()
|
|
117
118
|
return GeminiAgentModel(
|
|
118
119
|
http_client=self.http_client,
|
|
119
120
|
model_name=self.model_name,
|
|
@@ -124,7 +125,11 @@ class VertexAIModel(Model):
|
|
|
124
125
|
result_tools=result_tools,
|
|
125
126
|
)
|
|
126
127
|
|
|
127
|
-
async def
|
|
128
|
+
async def ainit(self) -> tuple[str, BearerTokenAuth]:
|
|
129
|
+
"""Initialize the model, setting the URL and auth.
|
|
130
|
+
|
|
131
|
+
This will raise an error if authentication fails.
|
|
132
|
+
"""
|
|
128
133
|
if self.url is not None and self.auth is not None:
|
|
129
134
|
return self.url, self.auth
|
|
130
135
|
|
pydantic_ai/result.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import AsyncIterator, Callable
|
|
4
|
+
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import Generic, TypeVar, cast
|
|
8
8
|
|
|
9
9
|
import logfire_api
|
|
10
10
|
|
|
11
|
-
from . import _result, _utils, exceptions, messages, models
|
|
11
|
+
from . import _result, _utils, exceptions, messages as _messages, models
|
|
12
12
|
from .tools import AgentDeps
|
|
13
13
|
|
|
14
14
|
__all__ = (
|
|
@@ -71,19 +71,19 @@ class _BaseRunResult(ABC, Generic[ResultData]):
|
|
|
71
71
|
You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`.
|
|
72
72
|
"""
|
|
73
73
|
|
|
74
|
-
_all_messages: list[
|
|
74
|
+
_all_messages: list[_messages.ModelMessage]
|
|
75
75
|
_new_message_index: int
|
|
76
76
|
|
|
77
|
-
def all_messages(self) -> list[
|
|
78
|
-
"""Return the history of
|
|
77
|
+
def all_messages(self) -> list[_messages.ModelMessage]:
|
|
78
|
+
"""Return the history of _messages."""
|
|
79
79
|
# this is a method to be consistent with the other methods
|
|
80
80
|
return self._all_messages
|
|
81
81
|
|
|
82
82
|
def all_messages_json(self) -> bytes:
|
|
83
|
-
"""Return all messages from [`all_messages`][
|
|
84
|
-
return
|
|
83
|
+
"""Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes."""
|
|
84
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages())
|
|
85
85
|
|
|
86
|
-
def new_messages(self) -> list[
|
|
86
|
+
def new_messages(self) -> list[_messages.ModelMessage]:
|
|
87
87
|
"""Return new messages associated with this run.
|
|
88
88
|
|
|
89
89
|
System prompts and any messages from older runs are excluded.
|
|
@@ -91,8 +91,8 @@ class _BaseRunResult(ABC, Generic[ResultData]):
|
|
|
91
91
|
return self.all_messages()[self._new_message_index :]
|
|
92
92
|
|
|
93
93
|
def new_messages_json(self) -> bytes:
|
|
94
|
-
"""Return new messages from [`new_messages`][
|
|
95
|
-
return
|
|
94
|
+
"""Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes."""
|
|
95
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
|
|
96
96
|
|
|
97
97
|
@abstractmethod
|
|
98
98
|
def cost(self) -> Cost:
|
|
@@ -122,7 +122,8 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
122
122
|
_result_schema: _result.ResultSchema[ResultData] | None
|
|
123
123
|
_deps: AgentDeps
|
|
124
124
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
|
|
125
|
-
|
|
125
|
+
_result_tool_name: str | None
|
|
126
|
+
_on_complete: Callable[[], Awaitable[None]]
|
|
126
127
|
is_complete: bool = field(default=False, init=False)
|
|
127
128
|
"""Whether the stream has all been received.
|
|
128
129
|
|
|
@@ -205,11 +206,11 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
205
206
|
combined = await self._validate_text_result(''.join(chunks))
|
|
206
207
|
yield combined
|
|
207
208
|
lf_span.set_attribute('combined_text', combined)
|
|
208
|
-
self._marked_completed(
|
|
209
|
+
await self._marked_completed(_messages.ModelResponse.from_text(combined))
|
|
209
210
|
|
|
210
211
|
async def stream_structured(
|
|
211
212
|
self, *, debounce_by: float | None = 0.1
|
|
212
|
-
) -> AsyncIterator[tuple[
|
|
213
|
+
) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
|
|
213
214
|
"""Stream the response as an async iterable of Structured LLM Messages.
|
|
214
215
|
|
|
215
216
|
!!! note
|
|
@@ -230,17 +231,21 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
230
231
|
else:
|
|
231
232
|
# we should already have a message at this point, yield that first if it has any content
|
|
232
233
|
msg = self._stream_response.get()
|
|
233
|
-
|
|
234
|
-
|
|
234
|
+
for item in msg.parts:
|
|
235
|
+
if isinstance(item, _messages.ToolCallPart) and item.has_content():
|
|
236
|
+
yield msg, False
|
|
237
|
+
break
|
|
235
238
|
async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
|
|
236
239
|
async for _ in group_iter:
|
|
237
240
|
msg = self._stream_response.get()
|
|
238
|
-
|
|
239
|
-
|
|
241
|
+
for item in msg.parts:
|
|
242
|
+
if isinstance(item, _messages.ToolCallPart) and item.has_content():
|
|
243
|
+
yield msg, False
|
|
244
|
+
break
|
|
240
245
|
msg = self._stream_response.get(final=True)
|
|
241
246
|
yield msg, True
|
|
242
247
|
lf_span.set_attribute('structured_response', msg)
|
|
243
|
-
self._marked_completed(
|
|
248
|
+
await self._marked_completed(msg)
|
|
244
249
|
|
|
245
250
|
async def get_data(self) -> ResultData:
|
|
246
251
|
"""Stream the whole response, validate and return it."""
|
|
@@ -249,12 +254,12 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
249
254
|
if isinstance(self._stream_response, models.StreamTextResponse):
|
|
250
255
|
text = ''.join(self._stream_response.get(final=True))
|
|
251
256
|
text = await self._validate_text_result(text)
|
|
252
|
-
self._marked_completed(text
|
|
257
|
+
await self._marked_completed(_messages.ModelResponse.from_text(text))
|
|
253
258
|
return cast(ResultData, text)
|
|
254
259
|
else:
|
|
255
|
-
|
|
256
|
-
self._marked_completed(
|
|
257
|
-
return await self.validate_structured_result(
|
|
260
|
+
message = self._stream_response.get(final=True)
|
|
261
|
+
await self._marked_completed(message)
|
|
262
|
+
return await self.validate_structured_result(message)
|
|
258
263
|
|
|
259
264
|
@property
|
|
260
265
|
def is_structured(self) -> bool:
|
|
@@ -274,11 +279,12 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
274
279
|
return self._stream_response.timestamp()
|
|
275
280
|
|
|
276
281
|
async def validate_structured_result(
|
|
277
|
-
self, message:
|
|
282
|
+
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
278
283
|
) -> ResultData:
|
|
279
284
|
"""Validate a structured result message."""
|
|
280
285
|
assert self._result_schema is not None, 'Expected _result_schema to not be None'
|
|
281
|
-
|
|
286
|
+
assert self._result_tool_name is not None, 'Expected _result_tool_name to not be None'
|
|
287
|
+
match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
|
|
282
288
|
if match is None:
|
|
283
289
|
raise exceptions.UnexpectedModelBehavior(
|
|
284
290
|
f'Invalid message, unable to find tool: {self._result_schema.tool_names()}'
|
|
@@ -288,7 +294,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
288
294
|
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
|
|
289
295
|
|
|
290
296
|
for validator in self._result_validators:
|
|
291
|
-
result_data = await validator.validate(result_data, self._deps, 0, call)
|
|
297
|
+
result_data = await validator.validate(result_data, self._deps, 0, call, self._all_messages)
|
|
292
298
|
return result_data
|
|
293
299
|
|
|
294
300
|
async def _validate_text_result(self, text: str) -> str:
|
|
@@ -298,19 +304,11 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
298
304
|
self._deps,
|
|
299
305
|
0,
|
|
300
306
|
None,
|
|
307
|
+
self._all_messages,
|
|
301
308
|
)
|
|
302
309
|
return text
|
|
303
310
|
|
|
304
|
-
def _marked_completed(
|
|
305
|
-
self, *, text: str | None = None, structured_message: messages.ModelStructuredResponse | None = None
|
|
306
|
-
) -> None:
|
|
311
|
+
async def _marked_completed(self, message: _messages.ModelResponse) -> None:
|
|
307
312
|
self.is_complete = True
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
self._all_messages.append(
|
|
311
|
-
messages.ModelTextResponse(content=text, timestamp=self._stream_response.timestamp())
|
|
312
|
-
)
|
|
313
|
-
else:
|
|
314
|
-
assert structured_message is not None, 'Either text or structured_message should provided, not both'
|
|
315
|
-
self._all_messages.append(structured_message)
|
|
316
|
-
self._on_complete(self._all_messages)
|
|
313
|
+
self._all_messages.append(message)
|
|
314
|
+
await self._on_complete()
|
pydantic_ai/settings.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from httpx import Timeout
|
|
4
|
+
from typing_extensions import TypedDict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ModelSettings(TypedDict, total=False):
|
|
8
|
+
"""Settings to configure an LLM.
|
|
9
|
+
|
|
10
|
+
Here we include only settings which apply to multiple models / model providers.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
max_tokens: int
|
|
14
|
+
"""The maximum number of tokens to generate before stopping.
|
|
15
|
+
|
|
16
|
+
Supported by:
|
|
17
|
+
* Gemini
|
|
18
|
+
* Anthropic
|
|
19
|
+
* OpenAI
|
|
20
|
+
* Groq
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
temperature: float
|
|
24
|
+
"""Amount of randomness injected into the response.
|
|
25
|
+
|
|
26
|
+
Use `temperature` closer to `0.0` for analytical / multiple choice, and closer to a model's
|
|
27
|
+
maximum `temperature` for creative and generative tasks.
|
|
28
|
+
|
|
29
|
+
Note that even with `temperature` of `0.0`, the results will not be fully deterministic.
|
|
30
|
+
|
|
31
|
+
Supported by:
|
|
32
|
+
* Gemini
|
|
33
|
+
* Anthropic
|
|
34
|
+
* OpenAI
|
|
35
|
+
* Groq
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
top_p: float
|
|
39
|
+
"""An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
|
|
40
|
+
|
|
41
|
+
So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
|
42
|
+
|
|
43
|
+
You should either alter `temperature` or `top_p`, but not both.
|
|
44
|
+
|
|
45
|
+
Supported by:
|
|
46
|
+
* Gemini
|
|
47
|
+
* Anthropic
|
|
48
|
+
* OpenAI
|
|
49
|
+
* Groq
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
timeout: float | Timeout
|
|
53
|
+
"""Override the client-level default timeout for a request, in seconds.
|
|
54
|
+
|
|
55
|
+
Supported by:
|
|
56
|
+
* Gemini
|
|
57
|
+
* Anthropic
|
|
58
|
+
* OpenAI
|
|
59
|
+
* Groq
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
|
|
64
|
+
"""Merge two sets of model settings, preferring the overrides.
|
|
65
|
+
|
|
66
|
+
A common use case is: merge_model_settings(<agent settings>, <run settings>)
|
|
67
|
+
"""
|
|
68
|
+
# Note: we may want merge recursively if/when we add non-primitive values
|
|
69
|
+
if base and overrides:
|
|
70
|
+
return base | overrides
|
|
71
|
+
else:
|
|
72
|
+
return base or overrides
|