pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +14 -3
- pydantic_ai/_result.py +6 -9
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/agent.py +154 -90
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +29 -7
- pydantic_ai/models/__init__.py +10 -9
- pydantic_ai/models/anthropic.py +12 -12
- pydantic_ai/models/function.py +16 -22
- pydantic_ai/models/gemini.py +16 -18
- pydantic_ai/models/groq.py +21 -23
- pydantic_ai/models/mistral.py +34 -51
- pydantic_ai/models/openai.py +21 -23
- pydantic_ai/models/test.py +23 -17
- pydantic_ai/result.py +82 -35
- pydantic_ai/settings.py +69 -0
- pydantic_ai/tools.py +22 -28
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.15.dist-info}/METADATA +1 -2
- pydantic_ai_slim-0.0.15.dist-info/RECORD +26 -0
- pydantic_ai_slim-0.0.13.dist-info/RECORD +0 -26
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.15.dist-info}/WHEEL +0 -0
pydantic_ai/__init__.py
CHANGED
|
@@ -1,8 +1,19 @@
|
|
|
1
1
|
from importlib.metadata import version
|
|
2
2
|
|
|
3
|
-
from .agent import Agent
|
|
4
|
-
from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
|
|
3
|
+
from .agent import Agent, capture_run_messages
|
|
4
|
+
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
|
|
5
5
|
from .tools import RunContext, Tool
|
|
6
6
|
|
|
7
|
-
__all__ =
|
|
7
|
+
__all__ = (
|
|
8
|
+
'Agent',
|
|
9
|
+
'capture_run_messages',
|
|
10
|
+
'RunContext',
|
|
11
|
+
'Tool',
|
|
12
|
+
'AgentRunError',
|
|
13
|
+
'ModelRetry',
|
|
14
|
+
'UnexpectedModelBehavior',
|
|
15
|
+
'UsageLimitExceeded',
|
|
16
|
+
'UserError',
|
|
17
|
+
'__version__',
|
|
18
|
+
)
|
|
8
19
|
__version__ = version('pydantic_ai_slim')
|
pydantic_ai/_result.py
CHANGED
|
@@ -12,8 +12,8 @@ from typing_extensions import Self, TypeAliasType, TypedDict
|
|
|
12
12
|
|
|
13
13
|
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
15
|
-
from .result import ResultData
|
|
16
|
-
from .tools import AgentDeps,
|
|
15
|
+
from .result import ResultData, ResultValidatorFunc
|
|
16
|
+
from .tools import AgentDeps, RunContext, ToolDefinition
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
@dataclass
|
|
@@ -29,25 +29,22 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
29
29
|
async def validate(
|
|
30
30
|
self,
|
|
31
31
|
result: ResultData,
|
|
32
|
-
deps: AgentDeps,
|
|
33
|
-
retry: int,
|
|
34
32
|
tool_call: _messages.ToolCallPart | None,
|
|
35
|
-
|
|
33
|
+
run_context: RunContext[AgentDeps],
|
|
36
34
|
) -> ResultData:
|
|
37
35
|
"""Validate a result but calling the function.
|
|
38
36
|
|
|
39
37
|
Args:
|
|
40
38
|
result: The result data after Pydantic validation the message content.
|
|
41
|
-
deps: The agent dependencies.
|
|
42
|
-
retry: The current retry number.
|
|
43
39
|
tool_call: The original tool call message, `None` if there was no tool call.
|
|
44
|
-
|
|
40
|
+
run_context: The current run context.
|
|
45
41
|
|
|
46
42
|
Returns:
|
|
47
43
|
Result of either the validated result data (ok) or a retry message (Err).
|
|
48
44
|
"""
|
|
49
45
|
if self._takes_ctx:
|
|
50
|
-
|
|
46
|
+
ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None)
|
|
47
|
+
args = ctx, result
|
|
51
48
|
else:
|
|
52
49
|
args = (result,)
|
|
53
50
|
|
pydantic_ai/_system_prompt.py
CHANGED
|
@@ -19,9 +19,9 @@ class SystemPromptRunner(Generic[AgentDeps]):
|
|
|
19
19
|
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
|
|
20
20
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
21
21
|
|
|
22
|
-
async def run(self,
|
|
22
|
+
async def run(self, run_context: RunContext[AgentDeps]) -> str:
|
|
23
23
|
if self._takes_ctx:
|
|
24
|
-
args = (
|
|
24
|
+
args = (run_context,)
|
|
25
25
|
else:
|
|
26
26
|
args = ()
|
|
27
27
|
|
pydantic_ai/agent.py
CHANGED
|
@@ -5,12 +5,13 @@ import dataclasses
|
|
|
5
5
|
import inspect
|
|
6
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
|
+
from contextvars import ContextVar
|
|
8
9
|
from dataclasses import dataclass, field
|
|
9
10
|
from types import FrameType
|
|
10
11
|
from typing import Any, Callable, Generic, Literal, cast, final, overload
|
|
11
12
|
|
|
12
13
|
import logfire_api
|
|
13
|
-
from typing_extensions import assert_never
|
|
14
|
+
from typing_extensions import assert_never, deprecated
|
|
14
15
|
|
|
15
16
|
from . import (
|
|
16
17
|
_result,
|
|
@@ -22,7 +23,7 @@ from . import (
|
|
|
22
23
|
result,
|
|
23
24
|
)
|
|
24
25
|
from .result import ResultData
|
|
25
|
-
from .settings import ModelSettings, merge_model_settings
|
|
26
|
+
from .settings import ModelSettings, UsageLimits, merge_model_settings
|
|
26
27
|
from .tools import (
|
|
27
28
|
AgentDeps,
|
|
28
29
|
RunContext,
|
|
@@ -35,7 +36,7 @@ from .tools import (
|
|
|
35
36
|
ToolPrepareFunc,
|
|
36
37
|
)
|
|
37
38
|
|
|
38
|
-
__all__ =
|
|
39
|
+
__all__ = 'Agent', 'capture_run_messages', 'EndStrategy'
|
|
39
40
|
|
|
40
41
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
41
42
|
|
|
@@ -89,12 +90,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
89
90
|
be merged with this value, with the runtime argument taking priority.
|
|
90
91
|
"""
|
|
91
92
|
|
|
92
|
-
last_run_messages: list[_messages.ModelMessage] | None
|
|
93
|
-
"""The messages from the last run, useful when a run raised an exception.
|
|
94
|
-
|
|
95
|
-
Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
|
|
96
|
-
"""
|
|
97
|
-
|
|
98
93
|
_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
|
|
99
94
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
|
100
95
|
_allow_text_result: bool = field(repr=False)
|
|
@@ -104,7 +99,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
104
99
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
|
|
105
100
|
_deps_type: type[AgentDeps] = field(repr=False)
|
|
106
101
|
_max_result_retries: int = field(repr=False)
|
|
107
|
-
_current_result_retry: int = field(repr=False)
|
|
108
102
|
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
|
|
109
103
|
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
|
|
110
104
|
|
|
@@ -162,7 +156,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
162
156
|
self.end_strategy = end_strategy
|
|
163
157
|
self.name = name
|
|
164
158
|
self.model_settings = model_settings
|
|
165
|
-
self.last_run_messages = None
|
|
166
159
|
self._result_schema = _result.ResultSchema[result_type].build(
|
|
167
160
|
result_type, result_tool_name, result_tool_description
|
|
168
161
|
)
|
|
@@ -180,7 +173,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
180
173
|
self._deps_type = deps_type
|
|
181
174
|
self._system_prompt_functions = []
|
|
182
175
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
183
|
-
self._current_result_retry = 0
|
|
184
176
|
self._result_validators = []
|
|
185
177
|
|
|
186
178
|
async def run(
|
|
@@ -191,6 +183,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
191
183
|
model: models.Model | models.KnownModelName | None = None,
|
|
192
184
|
deps: AgentDeps = None,
|
|
193
185
|
model_settings: ModelSettings | None = None,
|
|
186
|
+
usage_limits: UsageLimits | None = None,
|
|
194
187
|
infer_name: bool = True,
|
|
195
188
|
) -> result.RunResult[ResultData]:
|
|
196
189
|
"""Run the agent with a user prompt in async mode.
|
|
@@ -211,8 +204,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
211
204
|
message_history: History of the conversation so far.
|
|
212
205
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
213
206
|
deps: Optional dependencies to use for this run.
|
|
214
|
-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
215
207
|
model_settings: Optional settings to use for this model's request.
|
|
208
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
209
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
216
210
|
|
|
217
211
|
Returns:
|
|
218
212
|
The result of the run.
|
|
@@ -232,31 +226,37 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
232
226
|
model_name=model_used.name(),
|
|
233
227
|
agent_name=self.name or 'agent',
|
|
234
228
|
) as run_span:
|
|
235
|
-
|
|
229
|
+
run_context = RunContext(deps, 0, [], None, model_used)
|
|
230
|
+
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
231
|
+
run_context.messages = messages
|
|
236
232
|
|
|
237
233
|
for tool in self._function_tools.values():
|
|
238
234
|
tool.current_retry = 0
|
|
239
235
|
|
|
240
|
-
|
|
241
|
-
|
|
236
|
+
usage = result.Usage(requests=0)
|
|
242
237
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
238
|
+
usage_limits = usage_limits or UsageLimits()
|
|
243
239
|
|
|
244
240
|
run_step = 0
|
|
245
241
|
while True:
|
|
242
|
+
usage_limits.check_before_request(usage)
|
|
243
|
+
|
|
246
244
|
run_step += 1
|
|
247
245
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
248
|
-
agent_model = await self._prepare_model(
|
|
246
|
+
agent_model = await self._prepare_model(run_context)
|
|
249
247
|
|
|
250
248
|
with _logfire.span('model request', run_step=run_step) as model_req_span:
|
|
251
|
-
model_response,
|
|
249
|
+
model_response, request_usage = await agent_model.request(messages, model_settings)
|
|
252
250
|
model_req_span.set_attribute('response', model_response)
|
|
253
|
-
model_req_span.set_attribute('
|
|
251
|
+
model_req_span.set_attribute('usage', request_usage)
|
|
254
252
|
|
|
255
253
|
messages.append(model_response)
|
|
256
|
-
|
|
254
|
+
usage += request_usage
|
|
255
|
+
usage.requests += 1
|
|
256
|
+
usage_limits.check_tokens(request_usage)
|
|
257
257
|
|
|
258
258
|
with _logfire.span('handle model response', run_step=run_step) as handle_span:
|
|
259
|
-
final_result, tool_responses = await self._handle_model_response(model_response,
|
|
259
|
+
final_result, tool_responses = await self._handle_model_response(model_response, run_context)
|
|
260
260
|
|
|
261
261
|
if tool_responses:
|
|
262
262
|
# Add parts to the conversation as a new message
|
|
@@ -266,10 +266,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
266
266
|
if final_result is not None:
|
|
267
267
|
result_data = final_result.data
|
|
268
268
|
run_span.set_attribute('all_messages', messages)
|
|
269
|
-
run_span.set_attribute('
|
|
269
|
+
run_span.set_attribute('usage', usage)
|
|
270
270
|
handle_span.set_attribute('result', result_data)
|
|
271
271
|
handle_span.message = 'handle model response -> final result'
|
|
272
|
-
return result.RunResult(messages, new_message_index, result_data,
|
|
272
|
+
return result.RunResult(messages, new_message_index, result_data, usage)
|
|
273
273
|
else:
|
|
274
274
|
# continue the conversation
|
|
275
275
|
handle_span.set_attribute('tool_responses', tool_responses)
|
|
@@ -284,6 +284,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
284
284
|
model: models.Model | models.KnownModelName | None = None,
|
|
285
285
|
deps: AgentDeps = None,
|
|
286
286
|
model_settings: ModelSettings | None = None,
|
|
287
|
+
usage_limits: UsageLimits | None = None,
|
|
287
288
|
infer_name: bool = True,
|
|
288
289
|
) -> result.RunResult[ResultData]:
|
|
289
290
|
"""Run the agent with a user prompt synchronously.
|
|
@@ -308,8 +309,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
308
309
|
message_history: History of the conversation so far.
|
|
309
310
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
310
311
|
deps: Optional dependencies to use for this run.
|
|
311
|
-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
312
312
|
model_settings: Optional settings to use for this model's request.
|
|
313
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
314
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
313
315
|
|
|
314
316
|
Returns:
|
|
315
317
|
The result of the run.
|
|
@@ -322,8 +324,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
322
324
|
message_history=message_history,
|
|
323
325
|
model=model,
|
|
324
326
|
deps=deps,
|
|
325
|
-
infer_name=False,
|
|
326
327
|
model_settings=model_settings,
|
|
328
|
+
usage_limits=usage_limits,
|
|
329
|
+
infer_name=False,
|
|
327
330
|
)
|
|
328
331
|
)
|
|
329
332
|
|
|
@@ -336,6 +339,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
336
339
|
model: models.Model | models.KnownModelName | None = None,
|
|
337
340
|
deps: AgentDeps = None,
|
|
338
341
|
model_settings: ModelSettings | None = None,
|
|
342
|
+
usage_limits: UsageLimits | None = None,
|
|
339
343
|
infer_name: bool = True,
|
|
340
344
|
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
|
|
341
345
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
@@ -357,8 +361,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
357
361
|
message_history: History of the conversation so far.
|
|
358
362
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
359
363
|
deps: Optional dependencies to use for this run.
|
|
360
|
-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
361
364
|
model_settings: Optional settings to use for this model's request.
|
|
365
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
366
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
362
367
|
|
|
363
368
|
Returns:
|
|
364
369
|
The result of the run.
|
|
@@ -380,32 +385,35 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
380
385
|
model_name=model_used.name(),
|
|
381
386
|
agent_name=self.name or 'agent',
|
|
382
387
|
) as run_span:
|
|
383
|
-
|
|
388
|
+
run_context = RunContext(deps, 0, [], None, model_used)
|
|
389
|
+
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
390
|
+
run_context.messages = messages
|
|
384
391
|
|
|
385
392
|
for tool in self._function_tools.values():
|
|
386
393
|
tool.current_retry = 0
|
|
387
394
|
|
|
388
|
-
|
|
395
|
+
usage = result.Usage()
|
|
389
396
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
397
|
+
usage_limits = usage_limits or UsageLimits()
|
|
390
398
|
|
|
391
399
|
run_step = 0
|
|
392
400
|
while True:
|
|
393
401
|
run_step += 1
|
|
402
|
+
usage_limits.check_before_request(usage)
|
|
394
403
|
|
|
395
404
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
396
|
-
agent_model = await self._prepare_model(
|
|
405
|
+
agent_model = await self._prepare_model(run_context)
|
|
397
406
|
|
|
398
407
|
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
399
408
|
async with agent_model.request_stream(messages, model_settings) as model_response:
|
|
409
|
+
usage.requests += 1
|
|
400
410
|
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
|
|
401
411
|
# We want to end the "model request" span here, but we can't exit the context manager
|
|
402
412
|
# in the traditional way
|
|
403
413
|
model_req_span.__exit__(None, None, None)
|
|
404
414
|
|
|
405
415
|
with _logfire.span('handle model response') as handle_span:
|
|
406
|
-
maybe_final_result = await self._handle_streamed_model_response(
|
|
407
|
-
model_response, deps, messages
|
|
408
|
-
)
|
|
416
|
+
maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)
|
|
409
417
|
|
|
410
418
|
# Check if we got a final result
|
|
411
419
|
if isinstance(maybe_final_result, _MarkFinalResult):
|
|
@@ -425,7 +433,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
425
433
|
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
|
|
426
434
|
]
|
|
427
435
|
parts = await self._process_function_tools(
|
|
428
|
-
tool_calls, result_tool_name,
|
|
436
|
+
tool_calls, result_tool_name, run_context
|
|
429
437
|
)
|
|
430
438
|
if parts:
|
|
431
439
|
messages.append(_messages.ModelRequest(parts))
|
|
@@ -434,10 +442,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
434
442
|
yield result.StreamedRunResult(
|
|
435
443
|
messages,
|
|
436
444
|
new_message_index,
|
|
437
|
-
|
|
445
|
+
usage,
|
|
446
|
+
usage_limits,
|
|
438
447
|
result_stream,
|
|
439
448
|
self._result_schema,
|
|
440
|
-
|
|
449
|
+
run_context,
|
|
441
450
|
self._result_validators,
|
|
442
451
|
result_tool_name,
|
|
443
452
|
on_complete,
|
|
@@ -455,8 +464,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
455
464
|
handle_span.set_attribute('tool_responses', tool_responses)
|
|
456
465
|
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
457
466
|
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
458
|
-
# the model_response should have been fully streamed by now, we can add
|
|
459
|
-
|
|
467
|
+
# the model_response should have been fully streamed by now, we can add its usage
|
|
468
|
+
model_response_usage = model_response.usage()
|
|
469
|
+
usage += model_response_usage
|
|
470
|
+
usage_limits.check_tokens(usage)
|
|
460
471
|
|
|
461
472
|
@contextmanager
|
|
462
473
|
def override(
|
|
@@ -597,7 +608,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
597
608
|
#> success (no tool calls)
|
|
598
609
|
```
|
|
599
610
|
"""
|
|
600
|
-
self._result_validators.append(_result.ResultValidator(func))
|
|
611
|
+
self._result_validators.append(_result.ResultValidator[AgentDeps, Any](func))
|
|
601
612
|
return func
|
|
602
613
|
|
|
603
614
|
@overload
|
|
@@ -798,41 +809,50 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
798
809
|
|
|
799
810
|
return model_, mode_selection
|
|
800
811
|
|
|
801
|
-
async def _prepare_model(
|
|
802
|
-
|
|
803
|
-
) -> models.AgentModel:
|
|
804
|
-
"""Create building tools and create an agent model."""
|
|
812
|
+
async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
|
|
813
|
+
"""Build tools and create an agent model."""
|
|
805
814
|
function_tools: list[ToolDefinition] = []
|
|
806
815
|
|
|
807
816
|
async def add_tool(tool: Tool[AgentDeps]) -> None:
|
|
808
|
-
ctx =
|
|
817
|
+
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
809
818
|
if tool_def := await tool.prepare_tool_def(ctx):
|
|
810
819
|
function_tools.append(tool_def)
|
|
811
820
|
|
|
812
821
|
await asyncio.gather(*map(add_tool, self._function_tools.values()))
|
|
813
822
|
|
|
814
|
-
return await model.agent_model(
|
|
823
|
+
return await run_context.model.agent_model(
|
|
815
824
|
function_tools=function_tools,
|
|
816
825
|
allow_text_result=self._allow_text_result,
|
|
817
826
|
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
|
|
818
827
|
)
|
|
819
828
|
|
|
820
829
|
async def _prepare_messages(
|
|
821
|
-
self,
|
|
830
|
+
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
|
|
822
831
|
) -> list[_messages.ModelMessage]:
|
|
832
|
+
try:
|
|
833
|
+
messages = _messages_ctx_var.get()
|
|
834
|
+
except LookupError:
|
|
835
|
+
messages = []
|
|
836
|
+
else:
|
|
837
|
+
if messages:
|
|
838
|
+
raise exceptions.UserError(
|
|
839
|
+
'The capture_run_messages() context manager may only be used to wrap '
|
|
840
|
+
'one call to run(), run_sync(), or run_stream().'
|
|
841
|
+
)
|
|
842
|
+
|
|
823
843
|
if message_history:
|
|
824
844
|
# shallow copy messages
|
|
825
|
-
messages
|
|
845
|
+
messages.extend(message_history)
|
|
826
846
|
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
|
827
847
|
else:
|
|
828
|
-
parts = await self._sys_parts(
|
|
848
|
+
parts = await self._sys_parts(run_context)
|
|
829
849
|
parts.append(_messages.UserPromptPart(user_prompt))
|
|
830
|
-
messages
|
|
850
|
+
messages.append(_messages.ModelRequest(parts))
|
|
831
851
|
|
|
832
852
|
return messages
|
|
833
853
|
|
|
834
854
|
async def _handle_model_response(
|
|
835
|
-
self, model_response: _messages.ModelResponse,
|
|
855
|
+
self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps]
|
|
836
856
|
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
837
857
|
"""Process a non-streamed response from the model.
|
|
838
858
|
|
|
@@ -841,42 +861,48 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
841
861
|
"""
|
|
842
862
|
texts: list[str] = []
|
|
843
863
|
tool_calls: list[_messages.ToolCallPart] = []
|
|
844
|
-
for
|
|
845
|
-
if isinstance(
|
|
846
|
-
|
|
864
|
+
for part in model_response.parts:
|
|
865
|
+
if isinstance(part, _messages.TextPart):
|
|
866
|
+
# ignore empty content for text parts, see #437
|
|
867
|
+
if part.content:
|
|
868
|
+
texts.append(part.content)
|
|
847
869
|
else:
|
|
848
|
-
tool_calls.append(
|
|
849
|
-
|
|
850
|
-
if
|
|
870
|
+
tool_calls.append(part)
|
|
871
|
+
|
|
872
|
+
# At the moment, we prioritize at least executing tool calls if they are present.
|
|
873
|
+
# In the future, we'd consider making this configurable at the agent or run level.
|
|
874
|
+
# This accounts for cases like anthropic returns that might contain a text response
|
|
875
|
+
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
876
|
+
if tool_calls:
|
|
877
|
+
return await self._handle_structured_response(tool_calls, run_context)
|
|
878
|
+
elif texts:
|
|
851
879
|
text = '\n\n'.join(texts)
|
|
852
|
-
return await self._handle_text_response(text,
|
|
853
|
-
elif tool_calls:
|
|
854
|
-
return await self._handle_structured_response(tool_calls, deps, conv_messages)
|
|
880
|
+
return await self._handle_text_response(text, run_context)
|
|
855
881
|
else:
|
|
856
882
|
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
857
883
|
|
|
858
884
|
async def _handle_text_response(
|
|
859
|
-
self, text: str,
|
|
885
|
+
self, text: str, run_context: RunContext[AgentDeps]
|
|
860
886
|
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
861
887
|
"""Handle a plain text response from the model for non-streaming responses."""
|
|
862
888
|
if self._allow_text_result:
|
|
863
889
|
result_data_input = cast(ResultData, text)
|
|
864
890
|
try:
|
|
865
|
-
result_data = await self._validate_result(result_data_input,
|
|
891
|
+
result_data = await self._validate_result(result_data_input, run_context, None)
|
|
866
892
|
except _result.ToolRetryError as e:
|
|
867
|
-
self._incr_result_retry()
|
|
893
|
+
self._incr_result_retry(run_context)
|
|
868
894
|
return None, [e.tool_retry]
|
|
869
895
|
else:
|
|
870
896
|
return _MarkFinalResult(result_data, None), []
|
|
871
897
|
else:
|
|
872
|
-
self._incr_result_retry()
|
|
898
|
+
self._incr_result_retry(run_context)
|
|
873
899
|
response = _messages.RetryPromptPart(
|
|
874
900
|
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
875
901
|
)
|
|
876
902
|
return None, [response]
|
|
877
903
|
|
|
878
904
|
async def _handle_structured_response(
|
|
879
|
-
self, tool_calls: list[_messages.ToolCallPart],
|
|
905
|
+
self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps]
|
|
880
906
|
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
881
907
|
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
882
908
|
assert tool_calls, 'Expected at least one tool call'
|
|
@@ -890,17 +916,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
890
916
|
call, result_tool = match
|
|
891
917
|
try:
|
|
892
918
|
result_data = result_tool.validate(call)
|
|
893
|
-
result_data = await self._validate_result(result_data,
|
|
919
|
+
result_data = await self._validate_result(result_data, run_context, call)
|
|
894
920
|
except _result.ToolRetryError as e:
|
|
895
|
-
self._incr_result_retry()
|
|
921
|
+
self._incr_result_retry(run_context)
|
|
896
922
|
parts.append(e.tool_retry)
|
|
897
923
|
else:
|
|
898
924
|
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
899
925
|
|
|
900
926
|
# Then build the other request parts based on end strategy
|
|
901
|
-
parts += await self._process_function_tools(
|
|
902
|
-
tool_calls, final_result and final_result.tool_name, deps, conv_messages
|
|
903
|
-
)
|
|
927
|
+
parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context)
|
|
904
928
|
|
|
905
929
|
return final_result, parts
|
|
906
930
|
|
|
@@ -908,8 +932,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
908
932
|
self,
|
|
909
933
|
tool_calls: list[_messages.ToolCallPart],
|
|
910
934
|
result_tool_name: str | None,
|
|
911
|
-
|
|
912
|
-
conv_messages: list[_messages.ModelMessage],
|
|
935
|
+
run_context: RunContext[AgentDeps],
|
|
913
936
|
) -> list[_messages.ModelRequestPart]:
|
|
914
937
|
"""Process function (non-result) tool calls in parallel.
|
|
915
938
|
|
|
@@ -942,7 +965,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
942
965
|
)
|
|
943
966
|
)
|
|
944
967
|
else:
|
|
945
|
-
tasks.append(asyncio.create_task(tool.run(
|
|
968
|
+
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
946
969
|
elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
|
|
947
970
|
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
948
971
|
# validation, we don't add another part here
|
|
@@ -955,7 +978,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
955
978
|
)
|
|
956
979
|
)
|
|
957
980
|
else:
|
|
958
|
-
parts.append(self._unknown_tool(call.tool_name))
|
|
981
|
+
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
959
982
|
|
|
960
983
|
# Run all tool tasks in parallel
|
|
961
984
|
if tasks:
|
|
@@ -967,8 +990,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
967
990
|
async def _handle_streamed_model_response(
|
|
968
991
|
self,
|
|
969
992
|
model_response: models.EitherStreamedResponse,
|
|
970
|
-
|
|
971
|
-
conv_messages: list[_messages.ModelMessage],
|
|
993
|
+
run_context: RunContext[AgentDeps],
|
|
972
994
|
) -> (
|
|
973
995
|
_MarkFinalResult[models.EitherStreamedResponse]
|
|
974
996
|
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
|
|
@@ -984,11 +1006,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
984
1006
|
if self._allow_text_result:
|
|
985
1007
|
return _MarkFinalResult(model_response, None)
|
|
986
1008
|
else:
|
|
987
|
-
self._incr_result_retry()
|
|
1009
|
+
self._incr_result_retry(run_context)
|
|
988
1010
|
response = _messages.RetryPromptPart(
|
|
989
1011
|
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
990
1012
|
)
|
|
991
|
-
# stream the response, so
|
|
1013
|
+
# stream the response, so usage is correct
|
|
992
1014
|
async for _ in model_response:
|
|
993
1015
|
pass
|
|
994
1016
|
|
|
@@ -1024,9 +1046,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1024
1046
|
if isinstance(item, _messages.ToolCallPart):
|
|
1025
1047
|
call = item
|
|
1026
1048
|
if tool := self._function_tools.get(call.tool_name):
|
|
1027
|
-
tasks.append(asyncio.create_task(tool.run(
|
|
1049
|
+
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
1028
1050
|
else:
|
|
1029
|
-
parts.append(self._unknown_tool(call.tool_name))
|
|
1051
|
+
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
1030
1052
|
|
|
1031
1053
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
1032
1054
|
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
@@ -1038,33 +1060,30 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1038
1060
|
async def _validate_result(
|
|
1039
1061
|
self,
|
|
1040
1062
|
result_data: ResultData,
|
|
1041
|
-
|
|
1063
|
+
run_context: RunContext[AgentDeps],
|
|
1042
1064
|
tool_call: _messages.ToolCallPart | None,
|
|
1043
|
-
conv_messages: list[_messages.ModelMessage],
|
|
1044
1065
|
) -> ResultData:
|
|
1045
1066
|
for validator in self._result_validators:
|
|
1046
|
-
result_data = await validator.validate(
|
|
1047
|
-
result_data, deps, self._current_result_retry, tool_call, conv_messages
|
|
1048
|
-
)
|
|
1067
|
+
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
1049
1068
|
return result_data
|
|
1050
1069
|
|
|
1051
|
-
def _incr_result_retry(self) -> None:
|
|
1052
|
-
|
|
1053
|
-
if
|
|
1070
|
+
def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
|
|
1071
|
+
run_context.retry += 1
|
|
1072
|
+
if run_context.retry > self._max_result_retries:
|
|
1054
1073
|
raise exceptions.UnexpectedModelBehavior(
|
|
1055
1074
|
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
|
|
1056
1075
|
)
|
|
1057
1076
|
|
|
1058
|
-
async def _sys_parts(self,
|
|
1077
|
+
async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages.ModelRequestPart]:
|
|
1059
1078
|
"""Build the initial messages for the conversation."""
|
|
1060
1079
|
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
1061
1080
|
for sys_prompt_runner in self._system_prompt_functions:
|
|
1062
|
-
prompt = await sys_prompt_runner.run(
|
|
1081
|
+
prompt = await sys_prompt_runner.run(run_context)
|
|
1063
1082
|
messages.append(_messages.SystemPromptPart(prompt))
|
|
1064
1083
|
return messages
|
|
1065
1084
|
|
|
1066
|
-
def _unknown_tool(self, tool_name: str) -> _messages.RetryPromptPart:
|
|
1067
|
-
self._incr_result_retry()
|
|
1085
|
+
def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
|
|
1086
|
+
self._incr_result_retry(run_context)
|
|
1068
1087
|
names = list(self._function_tools.keys())
|
|
1069
1088
|
if self._result_schema:
|
|
1070
1089
|
names.extend(self._result_schema.tool_names())
|
|
@@ -1105,6 +1124,51 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1105
1124
|
self.name = name
|
|
1106
1125
|
return
|
|
1107
1126
|
|
|
1127
|
+
@property
|
|
1128
|
+
@deprecated(
|
|
1129
|
+
'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
|
|
1130
|
+
)
|
|
1131
|
+
def last_run_messages(self) -> list[_messages.ModelMessage]:
|
|
1132
|
+
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
|
|
1133
|
+
|
|
1134
|
+
|
|
1135
|
+
_messages_ctx_var: ContextVar[list[_messages.ModelMessage]] = ContextVar('var')
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
@contextmanager
|
|
1139
|
+
def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
|
|
1140
|
+
"""Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
|
|
1141
|
+
|
|
1142
|
+
Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
|
|
1143
|
+
|
|
1144
|
+
Examples:
|
|
1145
|
+
```python
|
|
1146
|
+
from pydantic_ai import Agent, capture_run_messages
|
|
1147
|
+
|
|
1148
|
+
agent = Agent('test')
|
|
1149
|
+
|
|
1150
|
+
with capture_run_messages() as messages:
|
|
1151
|
+
try:
|
|
1152
|
+
result = agent.run_sync('foobar')
|
|
1153
|
+
except Exception:
|
|
1154
|
+
print(messages)
|
|
1155
|
+
raise
|
|
1156
|
+
```
|
|
1157
|
+
|
|
1158
|
+
!!! note
|
|
1159
|
+
You may not call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context.
|
|
1160
|
+
If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised.
|
|
1161
|
+
"""
|
|
1162
|
+
try:
|
|
1163
|
+
yield _messages_ctx_var.get()
|
|
1164
|
+
except LookupError:
|
|
1165
|
+
messages: list[_messages.ModelMessage] = []
|
|
1166
|
+
token = _messages_ctx_var.set(messages)
|
|
1167
|
+
try:
|
|
1168
|
+
yield messages
|
|
1169
|
+
finally:
|
|
1170
|
+
_messages_ctx_var.reset(token)
|
|
1171
|
+
|
|
1108
1172
|
|
|
1109
1173
|
@dataclass
|
|
1110
1174
|
class _MarkFinalResult(Generic[ResultData]):
|