pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_result.py +4 -7
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/agent.py +85 -75
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +29 -7
- pydantic_ai/models/__init__.py +10 -9
- pydantic_ai/models/anthropic.py +12 -12
- pydantic_ai/models/function.py +16 -22
- pydantic_ai/models/gemini.py +16 -18
- pydantic_ai/models/groq.py +21 -23
- pydantic_ai/models/mistral.py +24 -36
- pydantic_ai/models/openai.py +21 -23
- pydantic_ai/models/test.py +23 -17
- pydantic_ai/result.py +63 -33
- pydantic_ai/settings.py +65 -0
- pydantic_ai/tools.py +24 -14
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- pydantic_ai_slim-0.0.13.dist-info/RECORD +0 -26
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +0 -0
pydantic_ai/__init__.py
CHANGED
|
@@ -1,8 +1,18 @@
|
|
|
1
1
|
from importlib.metadata import version
|
|
2
2
|
|
|
3
3
|
from .agent import Agent
|
|
4
|
-
from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
|
|
4
|
+
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
|
|
5
5
|
from .tools import RunContext, Tool
|
|
6
6
|
|
|
7
|
-
__all__ =
|
|
7
|
+
__all__ = (
|
|
8
|
+
'Agent',
|
|
9
|
+
'RunContext',
|
|
10
|
+
'Tool',
|
|
11
|
+
'AgentRunError',
|
|
12
|
+
'ModelRetry',
|
|
13
|
+
'UnexpectedModelBehavior',
|
|
14
|
+
'UsageLimitExceeded',
|
|
15
|
+
'UserError',
|
|
16
|
+
'__version__',
|
|
17
|
+
)
|
|
8
18
|
__version__ = version('pydantic_ai_slim')
|
pydantic_ai/_result.py
CHANGED
|
@@ -29,25 +29,22 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
29
29
|
async def validate(
|
|
30
30
|
self,
|
|
31
31
|
result: ResultData,
|
|
32
|
-
deps: AgentDeps,
|
|
33
|
-
retry: int,
|
|
34
32
|
tool_call: _messages.ToolCallPart | None,
|
|
35
|
-
|
|
33
|
+
run_context: RunContext[AgentDeps],
|
|
36
34
|
) -> ResultData:
|
|
37
35
|
"""Validate a result but calling the function.
|
|
38
36
|
|
|
39
37
|
Args:
|
|
40
38
|
result: The result data after Pydantic validation the message content.
|
|
41
|
-
deps: The agent dependencies.
|
|
42
|
-
retry: The current retry number.
|
|
43
39
|
tool_call: The original tool call message, `None` if there was no tool call.
|
|
44
|
-
|
|
40
|
+
run_context: The current run context.
|
|
45
41
|
|
|
46
42
|
Returns:
|
|
47
43
|
Result of either the validated result data (ok) or a retry message (Err).
|
|
48
44
|
"""
|
|
49
45
|
if self._takes_ctx:
|
|
50
|
-
|
|
46
|
+
ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None)
|
|
47
|
+
args = ctx, result
|
|
51
48
|
else:
|
|
52
49
|
args = (result,)
|
|
53
50
|
|
pydantic_ai/_system_prompt.py
CHANGED
|
@@ -19,9 +19,9 @@ class SystemPromptRunner(Generic[AgentDeps]):
|
|
|
19
19
|
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
|
|
20
20
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
21
21
|
|
|
22
|
-
async def run(self,
|
|
22
|
+
async def run(self, run_context: RunContext[AgentDeps]) -> str:
|
|
23
23
|
if self._takes_ctx:
|
|
24
|
-
args = (
|
|
24
|
+
args = (run_context,)
|
|
25
25
|
else:
|
|
26
26
|
args = ()
|
|
27
27
|
|
pydantic_ai/agent.py
CHANGED
|
@@ -22,7 +22,7 @@ from . import (
|
|
|
22
22
|
result,
|
|
23
23
|
)
|
|
24
24
|
from .result import ResultData
|
|
25
|
-
from .settings import ModelSettings, merge_model_settings
|
|
25
|
+
from .settings import ModelSettings, UsageLimits, merge_model_settings
|
|
26
26
|
from .tools import (
|
|
27
27
|
AgentDeps,
|
|
28
28
|
RunContext,
|
|
@@ -104,7 +104,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
104
104
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
|
|
105
105
|
_deps_type: type[AgentDeps] = field(repr=False)
|
|
106
106
|
_max_result_retries: int = field(repr=False)
|
|
107
|
-
_current_result_retry: int = field(repr=False)
|
|
108
107
|
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
|
|
109
108
|
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
|
|
110
109
|
|
|
@@ -180,7 +179,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
180
179
|
self._deps_type = deps_type
|
|
181
180
|
self._system_prompt_functions = []
|
|
182
181
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
183
|
-
self._current_result_retry = 0
|
|
184
182
|
self._result_validators = []
|
|
185
183
|
|
|
186
184
|
async def run(
|
|
@@ -191,6 +189,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
191
189
|
model: models.Model | models.KnownModelName | None = None,
|
|
192
190
|
deps: AgentDeps = None,
|
|
193
191
|
model_settings: ModelSettings | None = None,
|
|
192
|
+
usage_limits: UsageLimits | None = None,
|
|
194
193
|
infer_name: bool = True,
|
|
195
194
|
) -> result.RunResult[ResultData]:
|
|
196
195
|
"""Run the agent with a user prompt in async mode.
|
|
@@ -211,8 +210,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
211
210
|
message_history: History of the conversation so far.
|
|
212
211
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
213
212
|
deps: Optional dependencies to use for this run.
|
|
214
|
-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
215
213
|
model_settings: Optional settings to use for this model's request.
|
|
214
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
215
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
216
216
|
|
|
217
217
|
Returns:
|
|
218
218
|
The result of the run.
|
|
@@ -232,31 +232,37 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
232
232
|
model_name=model_used.name(),
|
|
233
233
|
agent_name=self.name or 'agent',
|
|
234
234
|
) as run_span:
|
|
235
|
-
|
|
235
|
+
run_context = RunContext(deps, 0, [], None, model_used)
|
|
236
|
+
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
237
|
+
self.last_run_messages = run_context.messages = messages
|
|
236
238
|
|
|
237
239
|
for tool in self._function_tools.values():
|
|
238
240
|
tool.current_retry = 0
|
|
239
241
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
+
usage = result.Usage(requests=0)
|
|
242
243
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
244
|
+
usage_limits = usage_limits or UsageLimits()
|
|
243
245
|
|
|
244
246
|
run_step = 0
|
|
245
247
|
while True:
|
|
248
|
+
usage_limits.check_before_request(usage)
|
|
249
|
+
|
|
246
250
|
run_step += 1
|
|
247
251
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
248
|
-
agent_model = await self._prepare_model(
|
|
252
|
+
agent_model = await self._prepare_model(run_context)
|
|
249
253
|
|
|
250
254
|
with _logfire.span('model request', run_step=run_step) as model_req_span:
|
|
251
|
-
model_response,
|
|
255
|
+
model_response, request_usage = await agent_model.request(messages, model_settings)
|
|
252
256
|
model_req_span.set_attribute('response', model_response)
|
|
253
|
-
model_req_span.set_attribute('
|
|
257
|
+
model_req_span.set_attribute('usage', request_usage)
|
|
254
258
|
|
|
255
259
|
messages.append(model_response)
|
|
256
|
-
|
|
260
|
+
usage += request_usage
|
|
261
|
+
usage.requests += 1
|
|
262
|
+
usage_limits.check_tokens(request_usage)
|
|
257
263
|
|
|
258
264
|
with _logfire.span('handle model response', run_step=run_step) as handle_span:
|
|
259
|
-
final_result, tool_responses = await self._handle_model_response(model_response,
|
|
265
|
+
final_result, tool_responses = await self._handle_model_response(model_response, run_context)
|
|
260
266
|
|
|
261
267
|
if tool_responses:
|
|
262
268
|
# Add parts to the conversation as a new message
|
|
@@ -266,10 +272,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
266
272
|
if final_result is not None:
|
|
267
273
|
result_data = final_result.data
|
|
268
274
|
run_span.set_attribute('all_messages', messages)
|
|
269
|
-
run_span.set_attribute('
|
|
275
|
+
run_span.set_attribute('usage', usage)
|
|
270
276
|
handle_span.set_attribute('result', result_data)
|
|
271
277
|
handle_span.message = 'handle model response -> final result'
|
|
272
|
-
return result.RunResult(messages, new_message_index, result_data,
|
|
278
|
+
return result.RunResult(messages, new_message_index, result_data, usage)
|
|
273
279
|
else:
|
|
274
280
|
# continue the conversation
|
|
275
281
|
handle_span.set_attribute('tool_responses', tool_responses)
|
|
@@ -284,6 +290,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
284
290
|
model: models.Model | models.KnownModelName | None = None,
|
|
285
291
|
deps: AgentDeps = None,
|
|
286
292
|
model_settings: ModelSettings | None = None,
|
|
293
|
+
usage_limits: UsageLimits | None = None,
|
|
287
294
|
infer_name: bool = True,
|
|
288
295
|
) -> result.RunResult[ResultData]:
|
|
289
296
|
"""Run the agent with a user prompt synchronously.
|
|
@@ -308,8 +315,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
308
315
|
message_history: History of the conversation so far.
|
|
309
316
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
310
317
|
deps: Optional dependencies to use for this run.
|
|
311
|
-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
312
318
|
model_settings: Optional settings to use for this model's request.
|
|
319
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
320
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
313
321
|
|
|
314
322
|
Returns:
|
|
315
323
|
The result of the run.
|
|
@@ -322,8 +330,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
322
330
|
message_history=message_history,
|
|
323
331
|
model=model,
|
|
324
332
|
deps=deps,
|
|
325
|
-
infer_name=False,
|
|
326
333
|
model_settings=model_settings,
|
|
334
|
+
usage_limits=usage_limits,
|
|
335
|
+
infer_name=False,
|
|
327
336
|
)
|
|
328
337
|
)
|
|
329
338
|
|
|
@@ -336,6 +345,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
336
345
|
model: models.Model | models.KnownModelName | None = None,
|
|
337
346
|
deps: AgentDeps = None,
|
|
338
347
|
model_settings: ModelSettings | None = None,
|
|
348
|
+
usage_limits: UsageLimits | None = None,
|
|
339
349
|
infer_name: bool = True,
|
|
340
350
|
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
|
|
341
351
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
@@ -357,8 +367,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
357
367
|
message_history: History of the conversation so far.
|
|
358
368
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
359
369
|
deps: Optional dependencies to use for this run.
|
|
360
|
-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
361
370
|
model_settings: Optional settings to use for this model's request.
|
|
371
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
372
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
362
373
|
|
|
363
374
|
Returns:
|
|
364
375
|
The result of the run.
|
|
@@ -380,32 +391,35 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
380
391
|
model_name=model_used.name(),
|
|
381
392
|
agent_name=self.name or 'agent',
|
|
382
393
|
) as run_span:
|
|
383
|
-
|
|
394
|
+
run_context = RunContext(deps, 0, [], None, model_used)
|
|
395
|
+
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
396
|
+
self.last_run_messages = run_context.messages = messages
|
|
384
397
|
|
|
385
398
|
for tool in self._function_tools.values():
|
|
386
399
|
tool.current_retry = 0
|
|
387
400
|
|
|
388
|
-
|
|
401
|
+
usage = result.Usage()
|
|
389
402
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
403
|
+
usage_limits = usage_limits or UsageLimits()
|
|
390
404
|
|
|
391
405
|
run_step = 0
|
|
392
406
|
while True:
|
|
393
407
|
run_step += 1
|
|
408
|
+
usage_limits.check_before_request(usage)
|
|
394
409
|
|
|
395
410
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
396
|
-
agent_model = await self._prepare_model(
|
|
411
|
+
agent_model = await self._prepare_model(run_context)
|
|
397
412
|
|
|
398
413
|
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
399
414
|
async with agent_model.request_stream(messages, model_settings) as model_response:
|
|
415
|
+
usage.requests += 1
|
|
400
416
|
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
|
|
401
417
|
# We want to end the "model request" span here, but we can't exit the context manager
|
|
402
418
|
# in the traditional way
|
|
403
419
|
model_req_span.__exit__(None, None, None)
|
|
404
420
|
|
|
405
421
|
with _logfire.span('handle model response') as handle_span:
|
|
406
|
-
maybe_final_result = await self._handle_streamed_model_response(
|
|
407
|
-
model_response, deps, messages
|
|
408
|
-
)
|
|
422
|
+
maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)
|
|
409
423
|
|
|
410
424
|
# Check if we got a final result
|
|
411
425
|
if isinstance(maybe_final_result, _MarkFinalResult):
|
|
@@ -425,7 +439,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
425
439
|
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
|
|
426
440
|
]
|
|
427
441
|
parts = await self._process_function_tools(
|
|
428
|
-
tool_calls, result_tool_name,
|
|
442
|
+
tool_calls, result_tool_name, run_context
|
|
429
443
|
)
|
|
430
444
|
if parts:
|
|
431
445
|
messages.append(_messages.ModelRequest(parts))
|
|
@@ -434,10 +448,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
434
448
|
yield result.StreamedRunResult(
|
|
435
449
|
messages,
|
|
436
450
|
new_message_index,
|
|
437
|
-
|
|
451
|
+
usage,
|
|
452
|
+
usage_limits,
|
|
438
453
|
result_stream,
|
|
439
454
|
self._result_schema,
|
|
440
|
-
|
|
455
|
+
run_context,
|
|
441
456
|
self._result_validators,
|
|
442
457
|
result_tool_name,
|
|
443
458
|
on_complete,
|
|
@@ -455,8 +470,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
455
470
|
handle_span.set_attribute('tool_responses', tool_responses)
|
|
456
471
|
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
457
472
|
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
458
|
-
# the model_response should have been fully streamed by now, we can add
|
|
459
|
-
|
|
473
|
+
# the model_response should have been fully streamed by now, we can add its usage
|
|
474
|
+
model_response_usage = model_response.usage()
|
|
475
|
+
usage += model_response_usage
|
|
476
|
+
usage_limits.check_tokens(usage)
|
|
460
477
|
|
|
461
478
|
@contextmanager
|
|
462
479
|
def override(
|
|
@@ -798,41 +815,39 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
798
815
|
|
|
799
816
|
return model_, mode_selection
|
|
800
817
|
|
|
801
|
-
async def _prepare_model(
|
|
802
|
-
|
|
803
|
-
) -> models.AgentModel:
|
|
804
|
-
"""Create building tools and create an agent model."""
|
|
818
|
+
async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
|
|
819
|
+
"""Build tools and create an agent model."""
|
|
805
820
|
function_tools: list[ToolDefinition] = []
|
|
806
821
|
|
|
807
822
|
async def add_tool(tool: Tool[AgentDeps]) -> None:
|
|
808
|
-
ctx =
|
|
823
|
+
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
809
824
|
if tool_def := await tool.prepare_tool_def(ctx):
|
|
810
825
|
function_tools.append(tool_def)
|
|
811
826
|
|
|
812
827
|
await asyncio.gather(*map(add_tool, self._function_tools.values()))
|
|
813
828
|
|
|
814
|
-
return await model.agent_model(
|
|
829
|
+
return await run_context.model.agent_model(
|
|
815
830
|
function_tools=function_tools,
|
|
816
831
|
allow_text_result=self._allow_text_result,
|
|
817
832
|
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
|
|
818
833
|
)
|
|
819
834
|
|
|
820
835
|
async def _prepare_messages(
|
|
821
|
-
self,
|
|
836
|
+
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
|
|
822
837
|
) -> list[_messages.ModelMessage]:
|
|
823
838
|
if message_history:
|
|
824
839
|
# shallow copy messages
|
|
825
840
|
messages = message_history.copy()
|
|
826
841
|
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
|
827
842
|
else:
|
|
828
|
-
parts = await self._sys_parts(
|
|
843
|
+
parts = await self._sys_parts(run_context)
|
|
829
844
|
parts.append(_messages.UserPromptPart(user_prompt))
|
|
830
845
|
messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
|
|
831
846
|
|
|
832
847
|
return messages
|
|
833
848
|
|
|
834
849
|
async def _handle_model_response(
|
|
835
|
-
self, model_response: _messages.ModelResponse,
|
|
850
|
+
self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps]
|
|
836
851
|
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
837
852
|
"""Process a non-streamed response from the model.
|
|
838
853
|
|
|
@@ -841,42 +856,44 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
841
856
|
"""
|
|
842
857
|
texts: list[str] = []
|
|
843
858
|
tool_calls: list[_messages.ToolCallPart] = []
|
|
844
|
-
for
|
|
845
|
-
if isinstance(
|
|
846
|
-
|
|
859
|
+
for part in model_response.parts:
|
|
860
|
+
if isinstance(part, _messages.TextPart):
|
|
861
|
+
# ignore empty content for text parts, see #437
|
|
862
|
+
if part.content:
|
|
863
|
+
texts.append(part.content)
|
|
847
864
|
else:
|
|
848
|
-
tool_calls.append(
|
|
865
|
+
tool_calls.append(part)
|
|
849
866
|
|
|
850
867
|
if texts:
|
|
851
868
|
text = '\n\n'.join(texts)
|
|
852
|
-
return await self._handle_text_response(text,
|
|
869
|
+
return await self._handle_text_response(text, run_context)
|
|
853
870
|
elif tool_calls:
|
|
854
|
-
return await self._handle_structured_response(tool_calls,
|
|
871
|
+
return await self._handle_structured_response(tool_calls, run_context)
|
|
855
872
|
else:
|
|
856
873
|
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
857
874
|
|
|
858
875
|
async def _handle_text_response(
|
|
859
|
-
self, text: str,
|
|
876
|
+
self, text: str, run_context: RunContext[AgentDeps]
|
|
860
877
|
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
861
878
|
"""Handle a plain text response from the model for non-streaming responses."""
|
|
862
879
|
if self._allow_text_result:
|
|
863
880
|
result_data_input = cast(ResultData, text)
|
|
864
881
|
try:
|
|
865
|
-
result_data = await self._validate_result(result_data_input,
|
|
882
|
+
result_data = await self._validate_result(result_data_input, run_context, None)
|
|
866
883
|
except _result.ToolRetryError as e:
|
|
867
|
-
self._incr_result_retry()
|
|
884
|
+
self._incr_result_retry(run_context)
|
|
868
885
|
return None, [e.tool_retry]
|
|
869
886
|
else:
|
|
870
887
|
return _MarkFinalResult(result_data, None), []
|
|
871
888
|
else:
|
|
872
|
-
self._incr_result_retry()
|
|
889
|
+
self._incr_result_retry(run_context)
|
|
873
890
|
response = _messages.RetryPromptPart(
|
|
874
891
|
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
875
892
|
)
|
|
876
893
|
return None, [response]
|
|
877
894
|
|
|
878
895
|
async def _handle_structured_response(
|
|
879
|
-
self, tool_calls: list[_messages.ToolCallPart],
|
|
896
|
+
self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps]
|
|
880
897
|
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
881
898
|
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
882
899
|
assert tool_calls, 'Expected at least one tool call'
|
|
@@ -890,17 +907,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
890
907
|
call, result_tool = match
|
|
891
908
|
try:
|
|
892
909
|
result_data = result_tool.validate(call)
|
|
893
|
-
result_data = await self._validate_result(result_data,
|
|
910
|
+
result_data = await self._validate_result(result_data, run_context, call)
|
|
894
911
|
except _result.ToolRetryError as e:
|
|
895
|
-
self._incr_result_retry()
|
|
912
|
+
self._incr_result_retry(run_context)
|
|
896
913
|
parts.append(e.tool_retry)
|
|
897
914
|
else:
|
|
898
915
|
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
899
916
|
|
|
900
917
|
# Then build the other request parts based on end strategy
|
|
901
|
-
parts += await self._process_function_tools(
|
|
902
|
-
tool_calls, final_result and final_result.tool_name, deps, conv_messages
|
|
903
|
-
)
|
|
918
|
+
parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context)
|
|
904
919
|
|
|
905
920
|
return final_result, parts
|
|
906
921
|
|
|
@@ -908,8 +923,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
908
923
|
self,
|
|
909
924
|
tool_calls: list[_messages.ToolCallPart],
|
|
910
925
|
result_tool_name: str | None,
|
|
911
|
-
|
|
912
|
-
conv_messages: list[_messages.ModelMessage],
|
|
926
|
+
run_context: RunContext[AgentDeps],
|
|
913
927
|
) -> list[_messages.ModelRequestPart]:
|
|
914
928
|
"""Process function (non-result) tool calls in parallel.
|
|
915
929
|
|
|
@@ -942,7 +956,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
942
956
|
)
|
|
943
957
|
)
|
|
944
958
|
else:
|
|
945
|
-
tasks.append(asyncio.create_task(tool.run(
|
|
959
|
+
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
946
960
|
elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
|
|
947
961
|
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
948
962
|
# validation, we don't add another part here
|
|
@@ -955,7 +969,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
955
969
|
)
|
|
956
970
|
)
|
|
957
971
|
else:
|
|
958
|
-
parts.append(self._unknown_tool(call.tool_name))
|
|
972
|
+
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
959
973
|
|
|
960
974
|
# Run all tool tasks in parallel
|
|
961
975
|
if tasks:
|
|
@@ -967,8 +981,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
967
981
|
async def _handle_streamed_model_response(
|
|
968
982
|
self,
|
|
969
983
|
model_response: models.EitherStreamedResponse,
|
|
970
|
-
|
|
971
|
-
conv_messages: list[_messages.ModelMessage],
|
|
984
|
+
run_context: RunContext[AgentDeps],
|
|
972
985
|
) -> (
|
|
973
986
|
_MarkFinalResult[models.EitherStreamedResponse]
|
|
974
987
|
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
|
|
@@ -984,11 +997,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
984
997
|
if self._allow_text_result:
|
|
985
998
|
return _MarkFinalResult(model_response, None)
|
|
986
999
|
else:
|
|
987
|
-
self._incr_result_retry()
|
|
1000
|
+
self._incr_result_retry(run_context)
|
|
988
1001
|
response = _messages.RetryPromptPart(
|
|
989
1002
|
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
990
1003
|
)
|
|
991
|
-
# stream the response, so
|
|
1004
|
+
# stream the response, so usage is correct
|
|
992
1005
|
async for _ in model_response:
|
|
993
1006
|
pass
|
|
994
1007
|
|
|
@@ -1024,9 +1037,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1024
1037
|
if isinstance(item, _messages.ToolCallPart):
|
|
1025
1038
|
call = item
|
|
1026
1039
|
if tool := self._function_tools.get(call.tool_name):
|
|
1027
|
-
tasks.append(asyncio.create_task(tool.run(
|
|
1040
|
+
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
1028
1041
|
else:
|
|
1029
|
-
parts.append(self._unknown_tool(call.tool_name))
|
|
1042
|
+
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
1030
1043
|
|
|
1031
1044
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
1032
1045
|
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
@@ -1038,33 +1051,30 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1038
1051
|
async def _validate_result(
|
|
1039
1052
|
self,
|
|
1040
1053
|
result_data: ResultData,
|
|
1041
|
-
|
|
1054
|
+
run_context: RunContext[AgentDeps],
|
|
1042
1055
|
tool_call: _messages.ToolCallPart | None,
|
|
1043
|
-
conv_messages: list[_messages.ModelMessage],
|
|
1044
1056
|
) -> ResultData:
|
|
1045
1057
|
for validator in self._result_validators:
|
|
1046
|
-
result_data = await validator.validate(
|
|
1047
|
-
result_data, deps, self._current_result_retry, tool_call, conv_messages
|
|
1048
|
-
)
|
|
1058
|
+
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
1049
1059
|
return result_data
|
|
1050
1060
|
|
|
1051
|
-
def _incr_result_retry(self) -> None:
|
|
1052
|
-
|
|
1053
|
-
if
|
|
1061
|
+
def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
|
|
1062
|
+
run_context.retry += 1
|
|
1063
|
+
if run_context.retry > self._max_result_retries:
|
|
1054
1064
|
raise exceptions.UnexpectedModelBehavior(
|
|
1055
1065
|
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
|
|
1056
1066
|
)
|
|
1057
1067
|
|
|
1058
|
-
async def _sys_parts(self,
|
|
1068
|
+
async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages.ModelRequestPart]:
|
|
1059
1069
|
"""Build the initial messages for the conversation."""
|
|
1060
1070
|
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
1061
1071
|
for sys_prompt_runner in self._system_prompt_functions:
|
|
1062
|
-
prompt = await sys_prompt_runner.run(
|
|
1072
|
+
prompt = await sys_prompt_runner.run(run_context)
|
|
1063
1073
|
messages.append(_messages.SystemPromptPart(prompt))
|
|
1064
1074
|
return messages
|
|
1065
1075
|
|
|
1066
|
-
def _unknown_tool(self, tool_name: str) -> _messages.RetryPromptPart:
|
|
1067
|
-
self._incr_result_retry()
|
|
1076
|
+
def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
|
|
1077
|
+
self._incr_result_retry(run_context)
|
|
1068
1078
|
names = list(self._function_tools.keys())
|
|
1069
1079
|
if self._result_schema:
|
|
1070
1080
|
names.extend(self._result_schema.tool_names())
|
pydantic_ai/exceptions.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
|
|
5
|
-
__all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
|
|
5
|
+
__all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class ModelRetry(Exception):
|
|
@@ -30,7 +30,25 @@ class UserError(RuntimeError):
|
|
|
30
30
|
super().__init__(message)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class
|
|
33
|
+
class AgentRunError(RuntimeError):
|
|
34
|
+
"""Base class for errors occurring during an agent run."""
|
|
35
|
+
|
|
36
|
+
message: str
|
|
37
|
+
"""The error message."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, message: str):
|
|
40
|
+
self.message = message
|
|
41
|
+
super().__init__(message)
|
|
42
|
+
|
|
43
|
+
def __str__(self) -> str:
|
|
44
|
+
return self.message
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class UsageLimitExceeded(AgentRunError):
|
|
48
|
+
"""Error raised when a Model's usage exceeds the specified limits."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class UnexpectedModelBehavior(AgentRunError):
|
|
34
52
|
"""Error caused by unexpected Model behavior, e.g. an unexpected response code."""
|
|
35
53
|
|
|
36
54
|
message: str
|
pydantic_ai/messages.py
CHANGED
|
@@ -2,11 +2,11 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import Annotated, Any, Literal, Union
|
|
5
|
+
from typing import Annotated, Any, Literal, Union, cast
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
import pydantic_core
|
|
9
|
-
from typing_extensions import Self
|
|
9
|
+
from typing_extensions import Self, assert_never
|
|
10
10
|
|
|
11
11
|
from ._utils import now_utc as _now_utc
|
|
12
12
|
|
|
@@ -190,12 +190,34 @@ class ToolCallPart:
|
|
|
190
190
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
191
191
|
|
|
192
192
|
@classmethod
|
|
193
|
-
def
|
|
194
|
-
|
|
193
|
+
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
|
|
194
|
+
"""Create a `ToolCallPart` from raw arguments."""
|
|
195
|
+
if isinstance(args, str):
|
|
196
|
+
return cls(tool_name, ArgsJson(args), tool_call_id)
|
|
197
|
+
elif isinstance(args, dict):
|
|
198
|
+
return cls(tool_name, ArgsDict(args), tool_call_id)
|
|
199
|
+
else:
|
|
200
|
+
assert_never(args)
|
|
195
201
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
202
|
+
def args_as_dict(self) -> dict[str, Any]:
|
|
203
|
+
"""Return the arguments as a Python dictionary.
|
|
204
|
+
|
|
205
|
+
This is just for convenience with models that require dicts as input.
|
|
206
|
+
"""
|
|
207
|
+
if isinstance(self.args, ArgsDict):
|
|
208
|
+
return self.args.args_dict
|
|
209
|
+
args = pydantic_core.from_json(self.args.args_json)
|
|
210
|
+
assert isinstance(args, dict), 'args should be a dict'
|
|
211
|
+
return cast(dict[str, Any], args)
|
|
212
|
+
|
|
213
|
+
def args_as_json_str(self) -> str:
|
|
214
|
+
"""Return the arguments as a JSON string.
|
|
215
|
+
|
|
216
|
+
This is just for convenience with models that require JSON strings as input.
|
|
217
|
+
"""
|
|
218
|
+
if isinstance(self.args, ArgsJson):
|
|
219
|
+
return self.args.args_json
|
|
220
|
+
return pydantic_core.to_json(self.args.args_dict).decode()
|
|
199
221
|
|
|
200
222
|
def has_content(self) -> bool:
|
|
201
223
|
if isinstance(self.args, ArgsDict):
|