pydantic-ai-slim 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +23 -4
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +16 -3
- pydantic_ai/_system_prompt.py +1 -0
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +332 -124
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +224 -9
- pydantic_ai/models/__init__.py +59 -82
- pydantic_ai/models/anthropic.py +22 -22
- pydantic_ai/models/function.py +47 -79
- pydantic_ai/models/gemini.py +86 -125
- pydantic_ai/models/groq.py +53 -125
- pydantic_ai/models/mistral.py +75 -137
- pydantic_ai/models/ollama.py +1 -0
- pydantic_ai/models/openai.py +50 -125
- pydantic_ai/models/test.py +40 -73
- pydantic_ai/models/vertexai.py +1 -1
- pydantic_ai/result.py +91 -92
- pydantic_ai/tools.py +24 -5
- {pydantic_ai_slim-0.0.17.dist-info → pydantic_ai_slim-0.0.19.dist-info}/METADATA +3 -1
- pydantic_ai_slim-0.0.19.dist-info/RECORD +29 -0
- pydantic_ai_slim-0.0.17.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.17.dist-info → pydantic_ai_slim-0.0.19.dist-info}/WHEEL +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -4,13 +4,13 @@ import asyncio
|
|
|
4
4
|
import dataclasses
|
|
5
5
|
import inspect
|
|
6
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
|
-
from contextlib import asynccontextmanager, contextmanager
|
|
7
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from types import FrameType
|
|
10
10
|
from typing import Any, Callable, Generic, Literal, cast, final, overload
|
|
11
11
|
|
|
12
12
|
import logfire_api
|
|
13
|
-
from typing_extensions import assert_never, deprecated
|
|
13
|
+
from typing_extensions import TypeVar, assert_never, deprecated
|
|
14
14
|
|
|
15
15
|
from . import (
|
|
16
16
|
_result,
|
|
@@ -26,6 +26,7 @@ from .result import ResultData
|
|
|
26
26
|
from .settings import ModelSettings, merge_model_settings
|
|
27
27
|
from .tools import (
|
|
28
28
|
AgentDeps,
|
|
29
|
+
DocstringFormat,
|
|
29
30
|
RunContext,
|
|
30
31
|
Tool,
|
|
31
32
|
ToolDefinition,
|
|
@@ -57,6 +58,8 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
57
58
|
- `'early'`: Stop processing other tool calls once a final result is found
|
|
58
59
|
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
59
60
|
"""
|
|
61
|
+
RunResultData = TypeVar('RunResultData')
|
|
62
|
+
"""Type variable for the result data of a run where `result_type` was customized on the run call."""
|
|
60
63
|
|
|
61
64
|
|
|
62
65
|
@final
|
|
@@ -99,14 +102,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
99
102
|
Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
|
|
100
103
|
be merged with this value, with the runtime argument taking priority.
|
|
101
104
|
"""
|
|
102
|
-
|
|
105
|
+
_result_tool_name: str = dataclasses.field(repr=False)
|
|
106
|
+
_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
103
107
|
_result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
|
|
104
108
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
|
|
105
|
-
_allow_text_result: bool = dataclasses.field(repr=False)
|
|
106
109
|
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
|
|
107
110
|
_function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
|
|
108
111
|
_default_retries: int = dataclasses.field(repr=False)
|
|
109
112
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
|
|
113
|
+
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(
|
|
114
|
+
repr=False
|
|
115
|
+
)
|
|
110
116
|
_deps_type: type[AgentDeps] = dataclasses.field(repr=False)
|
|
111
117
|
_max_result_retries: int = dataclasses.field(repr=False)
|
|
112
118
|
_override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
|
|
@@ -166,11 +172,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
166
172
|
self.end_strategy = end_strategy
|
|
167
173
|
self.name = name
|
|
168
174
|
self.model_settings = model_settings
|
|
175
|
+
self._result_tool_name = result_tool_name
|
|
176
|
+
self._result_tool_description = result_tool_description
|
|
169
177
|
self._result_schema = _result.ResultSchema[result_type].build(
|
|
170
178
|
result_type, result_tool_name, result_tool_description
|
|
171
179
|
)
|
|
172
|
-
# if the result tool is None, or its schema allows `str`, we allow plain text results
|
|
173
|
-
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
|
|
174
180
|
|
|
175
181
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
176
182
|
self._function_tools = {}
|
|
@@ -182,13 +188,31 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
182
188
|
self._register_tool(Tool(tool))
|
|
183
189
|
self._deps_type = deps_type
|
|
184
190
|
self._system_prompt_functions = []
|
|
191
|
+
self._system_prompt_dynamic_functions = {}
|
|
185
192
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
186
193
|
self._result_validators = []
|
|
187
194
|
|
|
195
|
+
@overload
|
|
196
|
+
async def run(
|
|
197
|
+
self,
|
|
198
|
+
user_prompt: str,
|
|
199
|
+
*,
|
|
200
|
+
result_type: None = None,
|
|
201
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
202
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
203
|
+
deps: AgentDeps = None,
|
|
204
|
+
model_settings: ModelSettings | None = None,
|
|
205
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
206
|
+
usage: _usage.Usage | None = None,
|
|
207
|
+
infer_name: bool = True,
|
|
208
|
+
) -> result.RunResult[ResultData]: ...
|
|
209
|
+
|
|
210
|
+
@overload
|
|
188
211
|
async def run(
|
|
189
212
|
self,
|
|
190
213
|
user_prompt: str,
|
|
191
214
|
*,
|
|
215
|
+
result_type: type[RunResultData],
|
|
192
216
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
193
217
|
model: models.Model | models.KnownModelName | None = None,
|
|
194
218
|
deps: AgentDeps = None,
|
|
@@ -196,7 +220,21 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
196
220
|
usage_limits: _usage.UsageLimits | None = None,
|
|
197
221
|
usage: _usage.Usage | None = None,
|
|
198
222
|
infer_name: bool = True,
|
|
199
|
-
) -> result.RunResult[
|
|
223
|
+
) -> result.RunResult[RunResultData]: ...
|
|
224
|
+
|
|
225
|
+
async def run(
|
|
226
|
+
self,
|
|
227
|
+
user_prompt: str,
|
|
228
|
+
*,
|
|
229
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
230
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
231
|
+
deps: AgentDeps = None,
|
|
232
|
+
model_settings: ModelSettings | None = None,
|
|
233
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
234
|
+
usage: _usage.Usage | None = None,
|
|
235
|
+
result_type: type[RunResultData] | None = None,
|
|
236
|
+
infer_name: bool = True,
|
|
237
|
+
) -> result.RunResult[Any]:
|
|
200
238
|
"""Run the agent with a user prompt in async mode.
|
|
201
239
|
|
|
202
240
|
Example:
|
|
@@ -205,12 +243,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
205
243
|
|
|
206
244
|
agent = Agent('openai:gpt-4o')
|
|
207
245
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
246
|
+
async def main():
|
|
247
|
+
result = await agent.run('What is the capital of France?')
|
|
248
|
+
print(result.data)
|
|
249
|
+
#> Paris
|
|
211
250
|
```
|
|
212
251
|
|
|
213
252
|
Args:
|
|
253
|
+
result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
|
|
254
|
+
result validators since result validators would expect an argument that matches the agent's result type.
|
|
214
255
|
user_prompt: User input to start/continue the conversation.
|
|
215
256
|
message_history: History of the conversation so far.
|
|
216
257
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
@@ -229,6 +270,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
229
270
|
|
|
230
271
|
deps = self._get_deps(deps)
|
|
231
272
|
new_message_index = len(message_history) if message_history else 0
|
|
273
|
+
result_schema = self._prepare_result_schema(result_type)
|
|
232
274
|
|
|
233
275
|
with _logfire.span(
|
|
234
276
|
'{agent_name} run {prompt=}',
|
|
@@ -252,7 +294,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
252
294
|
|
|
253
295
|
run_context.run_step += 1
|
|
254
296
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
|
|
255
|
-
agent_model = await self._prepare_model(run_context)
|
|
297
|
+
agent_model = await self._prepare_model(run_context, result_schema)
|
|
256
298
|
|
|
257
299
|
with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
|
|
258
300
|
model_response, request_usage = await agent_model.request(messages, model_settings)
|
|
@@ -264,7 +306,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
264
306
|
usage_limits.check_tokens(run_context.usage)
|
|
265
307
|
|
|
266
308
|
with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
|
|
267
|
-
final_result, tool_responses = await self._handle_model_response(
|
|
309
|
+
final_result, tool_responses = await self._handle_model_response(
|
|
310
|
+
model_response, run_context, result_schema
|
|
311
|
+
)
|
|
268
312
|
|
|
269
313
|
if tool_responses:
|
|
270
314
|
# Add parts to the conversation as a new message
|
|
@@ -287,10 +331,26 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
287
331
|
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
288
332
|
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
289
333
|
|
|
334
|
+
@overload
|
|
335
|
+
def run_sync(
|
|
336
|
+
self,
|
|
337
|
+
user_prompt: str,
|
|
338
|
+
*,
|
|
339
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
340
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
341
|
+
deps: AgentDeps = None,
|
|
342
|
+
model_settings: ModelSettings | None = None,
|
|
343
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
344
|
+
usage: _usage.Usage | None = None,
|
|
345
|
+
infer_name: bool = True,
|
|
346
|
+
) -> result.RunResult[ResultData]: ...
|
|
347
|
+
|
|
348
|
+
@overload
|
|
290
349
|
def run_sync(
|
|
291
350
|
self,
|
|
292
351
|
user_prompt: str,
|
|
293
352
|
*,
|
|
353
|
+
result_type: type[RunResultData] | None,
|
|
294
354
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
295
355
|
model: models.Model | models.KnownModelName | None = None,
|
|
296
356
|
deps: AgentDeps = None,
|
|
@@ -298,7 +358,21 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
298
358
|
usage_limits: _usage.UsageLimits | None = None,
|
|
299
359
|
usage: _usage.Usage | None = None,
|
|
300
360
|
infer_name: bool = True,
|
|
301
|
-
) -> result.RunResult[
|
|
361
|
+
) -> result.RunResult[RunResultData]: ...
|
|
362
|
+
|
|
363
|
+
def run_sync(
|
|
364
|
+
self,
|
|
365
|
+
user_prompt: str,
|
|
366
|
+
*,
|
|
367
|
+
result_type: type[RunResultData] | None = None,
|
|
368
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
369
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
370
|
+
deps: AgentDeps = None,
|
|
371
|
+
model_settings: ModelSettings | None = None,
|
|
372
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
373
|
+
usage: _usage.Usage | None = None,
|
|
374
|
+
infer_name: bool = True,
|
|
375
|
+
) -> result.RunResult[Any]:
|
|
302
376
|
"""Run the agent with a user prompt synchronously.
|
|
303
377
|
|
|
304
378
|
This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
|
|
@@ -310,13 +384,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
310
384
|
|
|
311
385
|
agent = Agent('openai:gpt-4o')
|
|
312
386
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
#> Paris
|
|
387
|
+
result_sync = agent.run_sync('What is the capital of Italy?')
|
|
388
|
+
print(result_sync.data)
|
|
389
|
+
#> Rome
|
|
317
390
|
```
|
|
318
391
|
|
|
319
392
|
Args:
|
|
393
|
+
result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
|
|
394
|
+
result validators since result validators would expect an argument that matches the agent's result type.
|
|
320
395
|
user_prompt: User input to start/continue the conversation.
|
|
321
396
|
message_history: History of the conversation so far.
|
|
322
397
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
@@ -334,6 +409,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
334
409
|
return asyncio.get_event_loop().run_until_complete(
|
|
335
410
|
self.run(
|
|
336
411
|
user_prompt,
|
|
412
|
+
result_type=result_type,
|
|
337
413
|
message_history=message_history,
|
|
338
414
|
model=model,
|
|
339
415
|
deps=deps,
|
|
@@ -344,11 +420,42 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
344
420
|
)
|
|
345
421
|
)
|
|
346
422
|
|
|
423
|
+
@overload
|
|
424
|
+
def run_stream(
|
|
425
|
+
self,
|
|
426
|
+
user_prompt: str,
|
|
427
|
+
*,
|
|
428
|
+
result_type: None = None,
|
|
429
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
430
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
431
|
+
deps: AgentDeps = None,
|
|
432
|
+
model_settings: ModelSettings | None = None,
|
|
433
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
434
|
+
usage: _usage.Usage | None = None,
|
|
435
|
+
infer_name: bool = True,
|
|
436
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, ResultData]]: ...
|
|
437
|
+
|
|
438
|
+
@overload
|
|
439
|
+
def run_stream(
|
|
440
|
+
self,
|
|
441
|
+
user_prompt: str,
|
|
442
|
+
*,
|
|
443
|
+
result_type: type[RunResultData],
|
|
444
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
445
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
446
|
+
deps: AgentDeps = None,
|
|
447
|
+
model_settings: ModelSettings | None = None,
|
|
448
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
449
|
+
usage: _usage.Usage | None = None,
|
|
450
|
+
infer_name: bool = True,
|
|
451
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, RunResultData]]: ...
|
|
452
|
+
|
|
347
453
|
@asynccontextmanager
|
|
348
454
|
async def run_stream(
|
|
349
455
|
self,
|
|
350
456
|
user_prompt: str,
|
|
351
457
|
*,
|
|
458
|
+
result_type: type[RunResultData] | None = None,
|
|
352
459
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
353
460
|
model: models.Model | models.KnownModelName | None = None,
|
|
354
461
|
deps: AgentDeps = None,
|
|
@@ -356,7 +463,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
356
463
|
usage_limits: _usage.UsageLimits | None = None,
|
|
357
464
|
usage: _usage.Usage | None = None,
|
|
358
465
|
infer_name: bool = True,
|
|
359
|
-
) -> AsyncIterator[result.StreamedRunResult[AgentDeps,
|
|
466
|
+
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, Any]]:
|
|
360
467
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
361
468
|
|
|
362
469
|
Example:
|
|
@@ -372,6 +479,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
372
479
|
```
|
|
373
480
|
|
|
374
481
|
Args:
|
|
482
|
+
result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
|
|
483
|
+
result validators since result validators would expect an argument that matches the agent's result type.
|
|
375
484
|
user_prompt: User input to start/continue the conversation.
|
|
376
485
|
message_history: History of the conversation so far.
|
|
377
486
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
@@ -392,6 +501,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
392
501
|
|
|
393
502
|
deps = self._get_deps(deps)
|
|
394
503
|
new_message_index = len(message_history) if message_history else 0
|
|
504
|
+
result_schema = self._prepare_result_schema(result_type)
|
|
395
505
|
|
|
396
506
|
with _logfire.span(
|
|
397
507
|
'{agent_name} run stream {prompt=}',
|
|
@@ -415,7 +525,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
415
525
|
usage_limits.check_before_request(run_context.usage)
|
|
416
526
|
|
|
417
527
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
|
|
418
|
-
agent_model = await self._prepare_model(run_context)
|
|
528
|
+
agent_model = await self._prepare_model(run_context, result_schema)
|
|
419
529
|
|
|
420
530
|
with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
|
|
421
531
|
async with agent_model.request_stream(messages, model_settings) as model_response:
|
|
@@ -426,7 +536,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
426
536
|
model_req_span.__exit__(None, None, None)
|
|
427
537
|
|
|
428
538
|
with _logfire.span('handle model response') as handle_span:
|
|
429
|
-
maybe_final_result = await self.
|
|
539
|
+
maybe_final_result = await self._handle_streamed_response(
|
|
540
|
+
model_response, run_context, result_schema
|
|
541
|
+
)
|
|
430
542
|
|
|
431
543
|
# Check if we got a final result
|
|
432
544
|
if isinstance(maybe_final_result, _MarkFinalResult):
|
|
@@ -446,7 +558,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
446
558
|
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
|
|
447
559
|
]
|
|
448
560
|
parts = await self._process_function_tools(
|
|
449
|
-
tool_calls, result_tool_name, run_context
|
|
561
|
+
tool_calls, result_tool_name, run_context, result_schema
|
|
450
562
|
)
|
|
451
563
|
if parts:
|
|
452
564
|
messages.append(_messages.ModelRequest(parts))
|
|
@@ -457,7 +569,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
457
569
|
new_message_index,
|
|
458
570
|
usage_limits,
|
|
459
571
|
result_stream,
|
|
460
|
-
|
|
572
|
+
result_schema,
|
|
461
573
|
run_context,
|
|
462
574
|
self._result_validators,
|
|
463
575
|
result_tool_name,
|
|
@@ -535,17 +647,37 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
535
647
|
@overload
|
|
536
648
|
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
|
|
537
649
|
|
|
650
|
+
@overload
|
|
651
|
+
def system_prompt(
|
|
652
|
+
self, /, *, dynamic: bool = False
|
|
653
|
+
) -> Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]: ...
|
|
654
|
+
|
|
538
655
|
def system_prompt(
|
|
539
|
-
self,
|
|
540
|
-
|
|
656
|
+
self,
|
|
657
|
+
func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
|
|
658
|
+
/,
|
|
659
|
+
*,
|
|
660
|
+
dynamic: bool = False,
|
|
661
|
+
) -> (
|
|
662
|
+
Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]
|
|
663
|
+
| _system_prompt.SystemPromptFunc[AgentDeps]
|
|
664
|
+
):
|
|
541
665
|
"""Decorator to register a system prompt function.
|
|
542
666
|
|
|
543
667
|
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
|
|
544
668
|
Can decorate a sync or async functions.
|
|
545
669
|
|
|
670
|
+
The decorator can be used either bare (`agent.system_prompt`) or as a function call
|
|
671
|
+
(`agent.system_prompt(...)`), see the examples below.
|
|
672
|
+
|
|
546
673
|
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
547
674
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
548
675
|
|
|
676
|
+
Args:
|
|
677
|
+
func: The function to decorate
|
|
678
|
+
dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided,
|
|
679
|
+
see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref]
|
|
680
|
+
|
|
549
681
|
Example:
|
|
550
682
|
```python
|
|
551
683
|
from pydantic_ai import Agent, RunContext
|
|
@@ -556,17 +688,27 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
556
688
|
def simple_system_prompt() -> str:
|
|
557
689
|
return 'foobar'
|
|
558
690
|
|
|
559
|
-
@agent.system_prompt
|
|
691
|
+
@agent.system_prompt(dynamic=True)
|
|
560
692
|
async def async_system_prompt(ctx: RunContext[str]) -> str:
|
|
561
693
|
return f'{ctx.deps} is the best'
|
|
562
|
-
|
|
563
|
-
result = agent.run_sync('foobar', deps='spam')
|
|
564
|
-
print(result.data)
|
|
565
|
-
#> success (no tool calls)
|
|
566
694
|
```
|
|
567
695
|
"""
|
|
568
|
-
|
|
569
|
-
|
|
696
|
+
if func is None:
|
|
697
|
+
|
|
698
|
+
def decorator(
|
|
699
|
+
func_: _system_prompt.SystemPromptFunc[AgentDeps],
|
|
700
|
+
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
701
|
+
runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
|
|
702
|
+
self._system_prompt_functions.append(runner)
|
|
703
|
+
if dynamic:
|
|
704
|
+
self._system_prompt_dynamic_functions[func_.__qualname__] = runner
|
|
705
|
+
return func_
|
|
706
|
+
|
|
707
|
+
return decorator
|
|
708
|
+
else:
|
|
709
|
+
assert not dynamic, "dynamic can't be True in this case"
|
|
710
|
+
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
|
|
711
|
+
return func
|
|
570
712
|
|
|
571
713
|
@overload
|
|
572
714
|
def result_validator(
|
|
@@ -633,6 +775,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
633
775
|
*,
|
|
634
776
|
retries: int | None = None,
|
|
635
777
|
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
778
|
+
docstring_format: DocstringFormat = 'auto',
|
|
779
|
+
require_parameter_descriptions: bool = False,
|
|
636
780
|
) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
|
|
637
781
|
|
|
638
782
|
def tool(
|
|
@@ -642,6 +786,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
642
786
|
*,
|
|
643
787
|
retries: int | None = None,
|
|
644
788
|
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
789
|
+
docstring_format: DocstringFormat = 'auto',
|
|
790
|
+
require_parameter_descriptions: bool = False,
|
|
645
791
|
) -> Any:
|
|
646
792
|
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
647
793
|
|
|
@@ -679,6 +825,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
679
825
|
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
680
826
|
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
681
827
|
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
828
|
+
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
829
|
+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
830
|
+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
682
831
|
"""
|
|
683
832
|
if func is None:
|
|
684
833
|
|
|
@@ -686,13 +835,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
686
835
|
func_: ToolFuncContext[AgentDeps, ToolParams],
|
|
687
836
|
) -> ToolFuncContext[AgentDeps, ToolParams]:
|
|
688
837
|
# noinspection PyTypeChecker
|
|
689
|
-
self._register_function(func_, True, retries, prepare)
|
|
838
|
+
self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
|
|
690
839
|
return func_
|
|
691
840
|
|
|
692
841
|
return tool_decorator
|
|
693
842
|
else:
|
|
694
843
|
# noinspection PyTypeChecker
|
|
695
|
-
self._register_function(func, True, retries, prepare)
|
|
844
|
+
self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions)
|
|
696
845
|
return func
|
|
697
846
|
|
|
698
847
|
@overload
|
|
@@ -705,6 +854,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
705
854
|
*,
|
|
706
855
|
retries: int | None = None,
|
|
707
856
|
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
857
|
+
docstring_format: DocstringFormat = 'auto',
|
|
858
|
+
require_parameter_descriptions: bool = False,
|
|
708
859
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
709
860
|
|
|
710
861
|
def tool_plain(
|
|
@@ -714,6 +865,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
714
865
|
*,
|
|
715
866
|
retries: int | None = None,
|
|
716
867
|
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
868
|
+
docstring_format: DocstringFormat = 'auto',
|
|
869
|
+
require_parameter_descriptions: bool = False,
|
|
717
870
|
) -> Any:
|
|
718
871
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
719
872
|
|
|
@@ -751,17 +904,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
751
904
|
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
752
905
|
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
753
906
|
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
907
|
+
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
908
|
+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
909
|
+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
754
910
|
"""
|
|
755
911
|
if func is None:
|
|
756
912
|
|
|
757
913
|
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
758
914
|
# noinspection PyTypeChecker
|
|
759
|
-
self._register_function(
|
|
915
|
+
self._register_function(
|
|
916
|
+
func_, False, retries, prepare, docstring_format, require_parameter_descriptions
|
|
917
|
+
)
|
|
760
918
|
return func_
|
|
761
919
|
|
|
762
920
|
return tool_decorator
|
|
763
921
|
else:
|
|
764
|
-
self._register_function(func, False, retries, prepare)
|
|
922
|
+
self._register_function(func, False, retries, prepare, docstring_format, require_parameter_descriptions)
|
|
765
923
|
return func
|
|
766
924
|
|
|
767
925
|
def _register_function(
|
|
@@ -770,10 +928,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
770
928
|
takes_ctx: bool,
|
|
771
929
|
retries: int | None,
|
|
772
930
|
prepare: ToolPrepareFunc[AgentDeps] | None,
|
|
931
|
+
docstring_format: DocstringFormat,
|
|
932
|
+
require_parameter_descriptions: bool,
|
|
773
933
|
) -> None:
|
|
774
934
|
"""Private utility to register a function as a tool."""
|
|
775
935
|
retries_ = retries if retries is not None else self._default_retries
|
|
776
|
-
tool = Tool(
|
|
936
|
+
tool = Tool(
|
|
937
|
+
func,
|
|
938
|
+
takes_ctx=takes_ctx,
|
|
939
|
+
max_retries=retries_,
|
|
940
|
+
prepare=prepare,
|
|
941
|
+
docstring_format=docstring_format,
|
|
942
|
+
require_parameter_descriptions=require_parameter_descriptions,
|
|
943
|
+
)
|
|
777
944
|
self._register_tool(tool)
|
|
778
945
|
|
|
779
946
|
def _register_tool(self, tool: Tool[AgentDeps]) -> None:
|
|
@@ -818,7 +985,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
818
985
|
|
|
819
986
|
return model_
|
|
820
987
|
|
|
821
|
-
async def _prepare_model(
|
|
988
|
+
async def _prepare_model(
|
|
989
|
+
self, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
|
|
990
|
+
) -> models.AgentModel:
|
|
822
991
|
"""Build tools and create an agent model."""
|
|
823
992
|
function_tools: list[ToolDefinition] = []
|
|
824
993
|
|
|
@@ -831,10 +1000,39 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
831
1000
|
|
|
832
1001
|
return await run_context.model.agent_model(
|
|
833
1002
|
function_tools=function_tools,
|
|
834
|
-
allow_text_result=self._allow_text_result,
|
|
835
|
-
result_tools=
|
|
1003
|
+
allow_text_result=self._allow_text_result(result_schema),
|
|
1004
|
+
result_tools=result_schema.tool_defs() if result_schema is not None else [],
|
|
836
1005
|
)
|
|
837
1006
|
|
|
1007
|
+
async def _reevaluate_dynamic_prompts(
|
|
1008
|
+
self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps]
|
|
1009
|
+
) -> None:
|
|
1010
|
+
"""Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
|
|
1011
|
+
# Only proceed if there's at least one dynamic runner.
|
|
1012
|
+
if self._system_prompt_dynamic_functions:
|
|
1013
|
+
for msg in messages:
|
|
1014
|
+
if isinstance(msg, _messages.ModelRequest):
|
|
1015
|
+
for i, part in enumerate(msg.parts):
|
|
1016
|
+
if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
|
|
1017
|
+
# Look up the runner by its ref
|
|
1018
|
+
if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref):
|
|
1019
|
+
updated_part_content = await runner.run(run_context)
|
|
1020
|
+
msg.parts[i] = _messages.SystemPromptPart(
|
|
1021
|
+
updated_part_content, dynamic_ref=part.dynamic_ref
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
def _prepare_result_schema(
|
|
1025
|
+
self, result_type: type[RunResultData] | None
|
|
1026
|
+
) -> _result.ResultSchema[RunResultData] | None:
|
|
1027
|
+
if result_type is not None:
|
|
1028
|
+
if self._result_validators:
|
|
1029
|
+
raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
|
|
1030
|
+
return _result.ResultSchema[result_type].build(
|
|
1031
|
+
result_type, self._result_tool_name, self._result_tool_description
|
|
1032
|
+
)
|
|
1033
|
+
else:
|
|
1034
|
+
return self._result_schema # pyright: ignore[reportReturnType]
|
|
1035
|
+
|
|
838
1036
|
async def _prepare_messages(
|
|
839
1037
|
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
|
|
840
1038
|
) -> list[_messages.ModelMessage]:
|
|
@@ -850,8 +1048,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
850
1048
|
ctx_messages.used = True
|
|
851
1049
|
|
|
852
1050
|
if message_history:
|
|
853
|
-
#
|
|
1051
|
+
# Shallow copy messages
|
|
854
1052
|
messages.extend(message_history)
|
|
1053
|
+
# Reevaluate any dynamic system prompt parts
|
|
1054
|
+
await self._reevaluate_dynamic_prompts(messages, run_context)
|
|
855
1055
|
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
|
856
1056
|
else:
|
|
857
1057
|
parts = await self._sys_parts(run_context)
|
|
@@ -861,8 +1061,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
861
1061
|
return messages
|
|
862
1062
|
|
|
863
1063
|
async def _handle_model_response(
|
|
864
|
-
self,
|
|
865
|
-
|
|
1064
|
+
self,
|
|
1065
|
+
model_response: _messages.ModelResponse,
|
|
1066
|
+
run_context: RunContext[AgentDeps],
|
|
1067
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1068
|
+
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
866
1069
|
"""Process a non-streamed response from the model.
|
|
867
1070
|
|
|
868
1071
|
Returns:
|
|
@@ -883,19 +1086,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
883
1086
|
# This accounts for cases like anthropic returns that might contain a text response
|
|
884
1087
|
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
885
1088
|
if tool_calls:
|
|
886
|
-
return await self._handle_structured_response(tool_calls, run_context)
|
|
1089
|
+
return await self._handle_structured_response(tool_calls, run_context, result_schema)
|
|
887
1090
|
elif texts:
|
|
888
1091
|
text = '\n\n'.join(texts)
|
|
889
|
-
return await self._handle_text_response(text, run_context)
|
|
1092
|
+
return await self._handle_text_response(text, run_context, result_schema)
|
|
890
1093
|
else:
|
|
891
1094
|
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
892
1095
|
|
|
893
1096
|
async def _handle_text_response(
|
|
894
|
-
self, text: str, run_context: RunContext[AgentDeps]
|
|
895
|
-
) -> tuple[_MarkFinalResult[
|
|
1097
|
+
self, text: str, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
|
|
1098
|
+
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
896
1099
|
"""Handle a plain text response from the model for non-streaming responses."""
|
|
897
|
-
if self._allow_text_result:
|
|
898
|
-
result_data_input = cast(
|
|
1100
|
+
if self._allow_text_result(result_schema):
|
|
1101
|
+
result_data_input = cast(RunResultData, text)
|
|
899
1102
|
try:
|
|
900
1103
|
result_data = await self._validate_result(result_data_input, run_context, None)
|
|
901
1104
|
except _result.ToolRetryError as e:
|
|
@@ -911,16 +1114,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
911
1114
|
return None, [response]
|
|
912
1115
|
|
|
913
1116
|
async def _handle_structured_response(
|
|
914
|
-
self,
|
|
915
|
-
|
|
1117
|
+
self,
|
|
1118
|
+
tool_calls: list[_messages.ToolCallPart],
|
|
1119
|
+
run_context: RunContext[AgentDeps],
|
|
1120
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1121
|
+
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
916
1122
|
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
917
1123
|
assert tool_calls, 'Expected at least one tool call'
|
|
918
1124
|
|
|
919
1125
|
# first look for the result tool call
|
|
920
|
-
final_result: _MarkFinalResult[
|
|
1126
|
+
final_result: _MarkFinalResult[RunResultData] | None = None
|
|
921
1127
|
|
|
922
1128
|
parts: list[_messages.ModelRequestPart] = []
|
|
923
|
-
if result_schema
|
|
1129
|
+
if result_schema is not None:
|
|
924
1130
|
if match := result_schema.find_tool(tool_calls):
|
|
925
1131
|
call, result_tool = match
|
|
926
1132
|
try:
|
|
@@ -933,7 +1139,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
933
1139
|
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
934
1140
|
|
|
935
1141
|
# Then build the other request parts based on end strategy
|
|
936
|
-
parts += await self._process_function_tools(
|
|
1142
|
+
parts += await self._process_function_tools(
|
|
1143
|
+
tool_calls, final_result and final_result.tool_name, run_context, result_schema
|
|
1144
|
+
)
|
|
937
1145
|
|
|
938
1146
|
return final_result, parts
|
|
939
1147
|
|
|
@@ -942,6 +1150,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
942
1150
|
tool_calls: list[_messages.ToolCallPart],
|
|
943
1151
|
result_tool_name: str | None,
|
|
944
1152
|
run_context: RunContext[AgentDeps],
|
|
1153
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
945
1154
|
) -> list[_messages.ModelRequestPart]:
|
|
946
1155
|
"""Process function (non-result) tool calls in parallel.
|
|
947
1156
|
|
|
@@ -975,7 +1184,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
975
1184
|
)
|
|
976
1185
|
else:
|
|
977
1186
|
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
978
|
-
elif
|
|
1187
|
+
elif result_schema is not None and call.tool_name in result_schema.tools:
|
|
979
1188
|
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
980
1189
|
# validation, we don't add another part here
|
|
981
1190
|
if result_tool_name is not None:
|
|
@@ -987,7 +1196,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
987
1196
|
)
|
|
988
1197
|
)
|
|
989
1198
|
else:
|
|
990
|
-
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
1199
|
+
parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
|
|
991
1200
|
|
|
992
1201
|
# Run all tool tasks in parallel
|
|
993
1202
|
if tasks:
|
|
@@ -996,85 +1205,72 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
996
1205
|
parts.extend(task_results)
|
|
997
1206
|
return parts
|
|
998
1207
|
|
|
999
|
-
async def
|
|
1208
|
+
async def _handle_streamed_response(
|
|
1000
1209
|
self,
|
|
1001
|
-
|
|
1210
|
+
streamed_response: models.StreamedResponse,
|
|
1002
1211
|
run_context: RunContext[AgentDeps],
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
|
|
1006
|
-
):
|
|
1212
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1213
|
+
) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
|
|
1007
1214
|
"""Process a streamed response from the model.
|
|
1008
1215
|
|
|
1009
1216
|
Returns:
|
|
1010
1217
|
Either a final result or a tuple of the model response and the tool responses for the next request.
|
|
1011
1218
|
If a final result is returned, the conversation should end.
|
|
1012
1219
|
"""
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
elif isinstance(model_response, models.StreamStructuredResponse):
|
|
1029
|
-
if self._result_schema is not None:
|
|
1030
|
-
# if there's a result schema, iterate over the stream until we find at least one tool
|
|
1031
|
-
# NOTE: this means we ignore any other tools called here
|
|
1032
|
-
structured_msg = model_response.get()
|
|
1033
|
-
while not structured_msg.parts:
|
|
1034
|
-
try:
|
|
1035
|
-
await model_response.__anext__()
|
|
1036
|
-
except StopAsyncIteration:
|
|
1037
|
-
break
|
|
1038
|
-
structured_msg = model_response.get()
|
|
1039
|
-
|
|
1040
|
-
if match := self._result_schema.find_tool(structured_msg.parts):
|
|
1041
|
-
call, _ = match
|
|
1042
|
-
return _MarkFinalResult(model_response, call.tool_name)
|
|
1043
|
-
|
|
1044
|
-
# the model is calling a tool function, consume the response to get the next message
|
|
1045
|
-
async for _ in model_response:
|
|
1046
|
-
pass
|
|
1047
|
-
model_response_msg = model_response.get()
|
|
1048
|
-
if not model_response_msg.parts:
|
|
1049
|
-
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
1050
|
-
|
|
1051
|
-
# we now run all tool functions in parallel
|
|
1052
|
-
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1053
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
1054
|
-
for item in model_response_msg.parts:
|
|
1055
|
-
if isinstance(item, _messages.ToolCallPart):
|
|
1056
|
-
call = item
|
|
1057
|
-
if tool := self._function_tools.get(call.tool_name):
|
|
1058
|
-
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
1059
|
-
else:
|
|
1060
|
-
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
1220
|
+
received_text = False
|
|
1221
|
+
|
|
1222
|
+
async for maybe_part_event in streamed_response:
|
|
1223
|
+
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
1224
|
+
new_part = maybe_part_event.part
|
|
1225
|
+
if isinstance(new_part, _messages.TextPart):
|
|
1226
|
+
received_text = True
|
|
1227
|
+
if self._allow_text_result(result_schema):
|
|
1228
|
+
return _MarkFinalResult(streamed_response, None)
|
|
1229
|
+
elif isinstance(new_part, _messages.ToolCallPart):
|
|
1230
|
+
if result_schema is not None and (match := result_schema.find_tool([new_part])):
|
|
1231
|
+
call, _ = match
|
|
1232
|
+
return _MarkFinalResult(streamed_response, call.tool_name)
|
|
1233
|
+
else:
|
|
1234
|
+
assert_never(new_part)
|
|
1061
1235
|
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1236
|
+
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1237
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
1238
|
+
model_response = streamed_response.get()
|
|
1239
|
+
if not model_response.parts:
|
|
1240
|
+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
1241
|
+
for p in model_response.parts:
|
|
1242
|
+
if isinstance(p, _messages.ToolCallPart):
|
|
1243
|
+
if tool := self._function_tools.get(p.tool_name):
|
|
1244
|
+
tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
|
|
1245
|
+
else:
|
|
1246
|
+
parts.append(self._unknown_tool(p.tool_name, run_context, result_schema))
|
|
1247
|
+
|
|
1248
|
+
if received_text and not tasks and not parts:
|
|
1249
|
+
# Can only get here if self._allow_text_result returns `False` for the provided result_schema
|
|
1250
|
+
self._incr_result_retry(run_context)
|
|
1251
|
+
model_response = _messages.RetryPromptPart(
|
|
1252
|
+
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
1253
|
+
)
|
|
1254
|
+
return streamed_response.get(), [model_response]
|
|
1255
|
+
|
|
1256
|
+
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
1257
|
+
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
1258
|
+
parts.extend(task_results)
|
|
1259
|
+
return model_response, parts
|
|
1068
1260
|
|
|
1069
1261
|
async def _validate_result(
|
|
1070
1262
|
self,
|
|
1071
|
-
result_data:
|
|
1263
|
+
result_data: RunResultData,
|
|
1072
1264
|
run_context: RunContext[AgentDeps],
|
|
1073
1265
|
tool_call: _messages.ToolCallPart | None,
|
|
1074
|
-
) ->
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1266
|
+
) -> RunResultData:
|
|
1267
|
+
if self._result_validators:
|
|
1268
|
+
agent_result_data = cast(ResultData, result_data)
|
|
1269
|
+
for validator in self._result_validators:
|
|
1270
|
+
agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
|
|
1271
|
+
return cast(RunResultData, agent_result_data)
|
|
1272
|
+
else:
|
|
1273
|
+
return result_data
|
|
1078
1274
|
|
|
1079
1275
|
def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
|
|
1080
1276
|
run_context.retry += 1
|
|
@@ -1088,14 +1284,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1088
1284
|
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
1089
1285
|
for sys_prompt_runner in self._system_prompt_functions:
|
|
1090
1286
|
prompt = await sys_prompt_runner.run(run_context)
|
|
1091
|
-
|
|
1287
|
+
if sys_prompt_runner.dynamic:
|
|
1288
|
+
messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
|
|
1289
|
+
else:
|
|
1290
|
+
messages.append(_messages.SystemPromptPart(prompt))
|
|
1092
1291
|
return messages
|
|
1093
1292
|
|
|
1094
|
-
def _unknown_tool(
|
|
1293
|
+
def _unknown_tool(
|
|
1294
|
+
self,
|
|
1295
|
+
tool_name: str,
|
|
1296
|
+
run_context: RunContext[AgentDeps],
|
|
1297
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1298
|
+
) -> _messages.RetryPromptPart:
|
|
1095
1299
|
self._incr_result_retry(run_context)
|
|
1096
1300
|
names = list(self._function_tools.keys())
|
|
1097
|
-
if
|
|
1098
|
-
names.extend(
|
|
1301
|
+
if result_schema:
|
|
1302
|
+
names.extend(result_schema.tool_names())
|
|
1099
1303
|
if names:
|
|
1100
1304
|
msg = f'Available tools: {", ".join(names)}'
|
|
1101
1305
|
else:
|
|
@@ -1133,6 +1337,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1133
1337
|
self.name = name
|
|
1134
1338
|
return
|
|
1135
1339
|
|
|
1340
|
+
@staticmethod
|
|
1341
|
+
def _allow_text_result(result_schema: _result.ResultSchema[RunResultData] | None) -> bool:
|
|
1342
|
+
return result_schema is None or result_schema.allow_text_result
|
|
1343
|
+
|
|
1136
1344
|
@property
|
|
1137
1345
|
@deprecated(
|
|
1138
1346
|
'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
|