pydantic-ai-slim 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +6 -4
- pydantic_ai/_result.py +18 -22
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +11 -6
- pydantic_ai/agent.py +156 -69
- pydantic_ai/messages.py +5 -2
- pydantic_ai/models/__init__.py +30 -37
- pydantic_ai/models/function.py +8 -14
- pydantic_ai/models/gemini.py +11 -10
- pydantic_ai/models/groq.py +31 -34
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +43 -38
- pydantic_ai/models/test.py +70 -49
- pydantic_ai/models/vertexai.py +7 -6
- pydantic_ai/tools.py +119 -34
- {pydantic_ai_slim-0.0.10.dist-info → pydantic_ai_slim-0.0.12.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.10.dist-info/RECORD +0 -22
- {pydantic_ai_slim-0.0.10.dist-info → pydantic_ai_slim-0.0.12.dist-info}/WHEEL +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -22,7 +22,17 @@ from . import (
|
|
|
22
22
|
result,
|
|
23
23
|
)
|
|
24
24
|
from .result import ResultData
|
|
25
|
-
from .tools import
|
|
25
|
+
from .tools import (
|
|
26
|
+
AgentDeps,
|
|
27
|
+
RunContext,
|
|
28
|
+
Tool,
|
|
29
|
+
ToolDefinition,
|
|
30
|
+
ToolFuncContext,
|
|
31
|
+
ToolFuncEither,
|
|
32
|
+
ToolFuncPlain,
|
|
33
|
+
ToolParams,
|
|
34
|
+
ToolPrepareFunc,
|
|
35
|
+
)
|
|
26
36
|
|
|
27
37
|
__all__ = ('Agent',)
|
|
28
38
|
|
|
@@ -136,7 +146,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
136
146
|
self._function_tools = {}
|
|
137
147
|
self._default_retries = retries
|
|
138
148
|
for tool in tools:
|
|
139
|
-
|
|
149
|
+
if isinstance(tool, Tool):
|
|
150
|
+
self._register_tool(tool)
|
|
151
|
+
else:
|
|
152
|
+
self._register_tool(Tool(tool))
|
|
140
153
|
self._deps_type = deps_type
|
|
141
154
|
self._system_prompt_functions = []
|
|
142
155
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
@@ -166,29 +179,33 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
166
179
|
"""
|
|
167
180
|
if infer_name and self.name is None:
|
|
168
181
|
self._infer_name(inspect.currentframe())
|
|
169
|
-
model_used,
|
|
182
|
+
model_used, mode_selection = await self._get_model(model)
|
|
170
183
|
|
|
171
184
|
deps = self._get_deps(deps)
|
|
172
185
|
|
|
173
186
|
with _logfire.span(
|
|
174
|
-
'{
|
|
187
|
+
'{agent_name} run {prompt=}',
|
|
175
188
|
prompt=user_prompt,
|
|
176
189
|
agent=self,
|
|
177
|
-
|
|
190
|
+
mode_selection=mode_selection,
|
|
178
191
|
model_name=model_used.name(),
|
|
192
|
+
agent_name=self.name or 'agent',
|
|
179
193
|
) as run_span:
|
|
180
194
|
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
181
195
|
self.last_run_messages = messages
|
|
182
196
|
|
|
183
197
|
for tool in self._function_tools.values():
|
|
184
|
-
tool.
|
|
198
|
+
tool.current_retry = 0
|
|
185
199
|
|
|
186
200
|
cost = result.Cost()
|
|
187
201
|
|
|
188
202
|
run_step = 0
|
|
189
203
|
while True:
|
|
190
204
|
run_step += 1
|
|
191
|
-
with _logfire.span('model
|
|
205
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
206
|
+
agent_model = await self._prepare_model(model_used, deps)
|
|
207
|
+
|
|
208
|
+
with _logfire.span('model request', run_step=run_step) as model_req_span:
|
|
192
209
|
model_response, request_cost = await agent_model.request(messages)
|
|
193
210
|
model_req_span.set_attribute('response', model_response)
|
|
194
211
|
model_req_span.set_attribute('cost', request_cost)
|
|
@@ -197,12 +214,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
197
214
|
messages.append(model_response)
|
|
198
215
|
cost += request_cost
|
|
199
216
|
|
|
200
|
-
with _logfire.span('handle model response') as handle_span:
|
|
201
|
-
|
|
217
|
+
with _logfire.span('handle model response', run_step=run_step) as handle_span:
|
|
218
|
+
final_result, response_messages = await self._handle_model_response(model_response, deps)
|
|
202
219
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
220
|
+
# Add all messages to the conversation
|
|
221
|
+
messages.extend(response_messages)
|
|
222
|
+
|
|
223
|
+
# Check if we got a final result
|
|
224
|
+
if final_result is not None:
|
|
225
|
+
result_data = final_result.data
|
|
206
226
|
run_span.set_attribute('all_messages', messages)
|
|
207
227
|
run_span.set_attribute('cost', cost)
|
|
208
228
|
handle_span.set_attribute('result', result_data)
|
|
@@ -210,11 +230,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
210
230
|
return result.RunResult(messages, new_message_index, result_data, cost)
|
|
211
231
|
else:
|
|
212
232
|
# continue the conversation
|
|
213
|
-
tool_responses
|
|
214
|
-
|
|
215
|
-
response_msgs = ' '.join(m.role for m in tool_responses)
|
|
233
|
+
handle_span.set_attribute('tool_responses', response_messages)
|
|
234
|
+
response_msgs = ' '.join(r.role for r in response_messages)
|
|
216
235
|
handle_span.message = f'handle model response -> {response_msgs}'
|
|
217
|
-
messages.extend(tool_responses)
|
|
218
236
|
|
|
219
237
|
def run_sync(
|
|
220
238
|
self,
|
|
@@ -272,28 +290,33 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
272
290
|
# f_back because `asynccontextmanager` adds one frame
|
|
273
291
|
if frame := inspect.currentframe(): # pragma: no branch
|
|
274
292
|
self._infer_name(frame.f_back)
|
|
275
|
-
model_used,
|
|
293
|
+
model_used, mode_selection = await self._get_model(model)
|
|
276
294
|
|
|
277
295
|
deps = self._get_deps(deps)
|
|
278
296
|
|
|
279
297
|
with _logfire.span(
|
|
280
|
-
'{
|
|
298
|
+
'{agent_name} run stream {prompt=}',
|
|
281
299
|
prompt=user_prompt,
|
|
282
300
|
agent=self,
|
|
283
|
-
|
|
301
|
+
mode_selection=mode_selection,
|
|
284
302
|
model_name=model_used.name(),
|
|
303
|
+
agent_name=self.name or 'agent',
|
|
285
304
|
) as run_span:
|
|
286
305
|
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
287
306
|
self.last_run_messages = messages
|
|
288
307
|
|
|
289
308
|
for tool in self._function_tools.values():
|
|
290
|
-
tool.
|
|
309
|
+
tool.current_retry = 0
|
|
291
310
|
|
|
292
311
|
cost = result.Cost()
|
|
293
312
|
|
|
294
313
|
run_step = 0
|
|
295
314
|
while True:
|
|
296
315
|
run_step += 1
|
|
316
|
+
|
|
317
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
318
|
+
agent_model = await self._prepare_model(model_used, deps)
|
|
319
|
+
|
|
297
320
|
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
298
321
|
async with agent_model.request_stream(messages) as model_response:
|
|
299
322
|
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
|
|
@@ -302,10 +325,16 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
302
325
|
model_req_span.__exit__(None, None, None)
|
|
303
326
|
|
|
304
327
|
with _logfire.span('handle model response') as handle_span:
|
|
305
|
-
|
|
328
|
+
final_result, response_messages = await self._handle_streamed_model_response(
|
|
329
|
+
model_response, deps
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Add all messages to the conversation
|
|
333
|
+
messages.extend(response_messages)
|
|
306
334
|
|
|
307
|
-
if
|
|
308
|
-
|
|
335
|
+
# Check if we got a final result
|
|
336
|
+
if final_result is not None:
|
|
337
|
+
result_stream = final_result.data
|
|
309
338
|
run_span.set_attribute('all_messages', messages)
|
|
310
339
|
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
|
|
311
340
|
handle_span.message = 'handle model response -> final result'
|
|
@@ -321,11 +350,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
321
350
|
)
|
|
322
351
|
return
|
|
323
352
|
else:
|
|
324
|
-
|
|
325
|
-
handle_span.set_attribute('tool_responses',
|
|
326
|
-
response_msgs = ' '.join(
|
|
353
|
+
# continue the conversation
|
|
354
|
+
handle_span.set_attribute('tool_responses', response_messages)
|
|
355
|
+
response_msgs = ' '.join(r.role for r in response_messages)
|
|
327
356
|
handle_span.message = f'handle model response -> {response_msgs}'
|
|
328
|
-
messages.extend(tool_responses)
|
|
329
357
|
# the model_response should have been fully streamed by now, we can add it's cost
|
|
330
358
|
cost += model_response.cost()
|
|
331
359
|
|
|
@@ -475,7 +503,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
475
503
|
|
|
476
504
|
@overload
|
|
477
505
|
def tool(
|
|
478
|
-
self,
|
|
506
|
+
self,
|
|
507
|
+
/,
|
|
508
|
+
*,
|
|
509
|
+
retries: int | None = None,
|
|
510
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
479
511
|
) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
|
|
480
512
|
|
|
481
513
|
def tool(
|
|
@@ -484,9 +516,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
484
516
|
/,
|
|
485
517
|
*,
|
|
486
518
|
retries: int | None = None,
|
|
519
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
487
520
|
) -> Any:
|
|
488
|
-
"""Decorator to register a tool function which takes
|
|
489
|
-
[`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
521
|
+
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
490
522
|
|
|
491
523
|
Can decorate a sync or async functions.
|
|
492
524
|
|
|
@@ -519,20 +551,23 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
519
551
|
func: The tool function to register.
|
|
520
552
|
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
521
553
|
which defaults to 1.
|
|
522
|
-
|
|
554
|
+
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
555
|
+
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
556
|
+
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
557
|
+
"""
|
|
523
558
|
if func is None:
|
|
524
559
|
|
|
525
560
|
def tool_decorator(
|
|
526
561
|
func_: ToolFuncContext[AgentDeps, ToolParams],
|
|
527
562
|
) -> ToolFuncContext[AgentDeps, ToolParams]:
|
|
528
563
|
# noinspection PyTypeChecker
|
|
529
|
-
self._register_function(func_, True, retries)
|
|
564
|
+
self._register_function(func_, True, retries, prepare)
|
|
530
565
|
return func_
|
|
531
566
|
|
|
532
567
|
return tool_decorator
|
|
533
568
|
else:
|
|
534
569
|
# noinspection PyTypeChecker
|
|
535
|
-
self._register_function(func, True, retries)
|
|
570
|
+
self._register_function(func, True, retries, prepare)
|
|
536
571
|
return func
|
|
537
572
|
|
|
538
573
|
@overload
|
|
@@ -540,10 +575,21 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
540
575
|
|
|
541
576
|
@overload
|
|
542
577
|
def tool_plain(
|
|
543
|
-
self,
|
|
578
|
+
self,
|
|
579
|
+
/,
|
|
580
|
+
*,
|
|
581
|
+
retries: int | None = None,
|
|
582
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
544
583
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
545
584
|
|
|
546
|
-
def tool_plain(
|
|
585
|
+
def tool_plain(
|
|
586
|
+
self,
|
|
587
|
+
func: ToolFuncPlain[ToolParams] | None = None,
|
|
588
|
+
/,
|
|
589
|
+
*,
|
|
590
|
+
retries: int | None = None,
|
|
591
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
592
|
+
) -> Any:
|
|
547
593
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
548
594
|
|
|
549
595
|
Can decorate a sync or async functions.
|
|
@@ -577,30 +623,38 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
577
623
|
func: The tool function to register.
|
|
578
624
|
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
579
625
|
which defaults to 1.
|
|
626
|
+
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
627
|
+
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
628
|
+
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
580
629
|
"""
|
|
581
630
|
if func is None:
|
|
582
631
|
|
|
583
632
|
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
584
633
|
# noinspection PyTypeChecker
|
|
585
|
-
self._register_function(func_, False, retries)
|
|
634
|
+
self._register_function(func_, False, retries, prepare)
|
|
586
635
|
return func_
|
|
587
636
|
|
|
588
637
|
return tool_decorator
|
|
589
638
|
else:
|
|
590
|
-
self._register_function(func, False, retries)
|
|
639
|
+
self._register_function(func, False, retries, prepare)
|
|
591
640
|
return func
|
|
592
641
|
|
|
593
642
|
def _register_function(
|
|
594
|
-
self,
|
|
643
|
+
self,
|
|
644
|
+
func: ToolFuncEither[AgentDeps, ToolParams],
|
|
645
|
+
takes_ctx: bool,
|
|
646
|
+
retries: int | None,
|
|
647
|
+
prepare: ToolPrepareFunc[AgentDeps] | None,
|
|
595
648
|
) -> None:
|
|
596
649
|
"""Private utility to register a function as a tool."""
|
|
597
650
|
retries_ = retries if retries is not None else self._default_retries
|
|
598
|
-
tool = Tool(func, takes_ctx, max_retries=retries_)
|
|
651
|
+
tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare)
|
|
599
652
|
self._register_tool(tool)
|
|
600
653
|
|
|
601
654
|
def _register_tool(self, tool: Tool[AgentDeps]) -> None:
|
|
602
655
|
"""Private utility to register a tool instance."""
|
|
603
656
|
if tool.max_retries is None:
|
|
657
|
+
# noinspection PyTypeChecker
|
|
604
658
|
tool = dataclasses.replace(tool, max_retries=self._default_retries)
|
|
605
659
|
|
|
606
660
|
if tool.name in self._function_tools:
|
|
@@ -611,16 +665,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
611
665
|
|
|
612
666
|
self._function_tools[tool.name] = tool
|
|
613
667
|
|
|
614
|
-
async def
|
|
615
|
-
self, model: models.Model | models.KnownModelName | None
|
|
616
|
-
) -> tuple[models.Model, models.Model | None, models.AgentModel]:
|
|
668
|
+
async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
|
|
617
669
|
"""Create a model configured for this agent.
|
|
618
670
|
|
|
619
671
|
Args:
|
|
620
672
|
model: model to use for this run, required if `model` was not set when creating the agent.
|
|
621
673
|
|
|
622
674
|
Returns:
|
|
623
|
-
a tuple of `(model used,
|
|
675
|
+
a tuple of `(model used, how the model was selected)`
|
|
624
676
|
"""
|
|
625
677
|
model_: models.Model
|
|
626
678
|
if some_model := self._override_model:
|
|
@@ -631,19 +683,35 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
631
683
|
'(Even when `override(model=...)` is customizing the model that will actually be called)'
|
|
632
684
|
)
|
|
633
685
|
model_ = some_model.value
|
|
634
|
-
|
|
686
|
+
mode_selection = 'override-model'
|
|
635
687
|
elif model is not None:
|
|
636
|
-
|
|
688
|
+
model_ = models.infer_model(model)
|
|
689
|
+
mode_selection = 'custom'
|
|
637
690
|
elif self.model is not None:
|
|
638
691
|
# noinspection PyTypeChecker
|
|
639
692
|
model_ = self.model = models.infer_model(self.model)
|
|
640
|
-
|
|
693
|
+
mode_selection = 'from-agent'
|
|
641
694
|
else:
|
|
642
695
|
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
643
696
|
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
697
|
+
return model_, mode_selection
|
|
698
|
+
|
|
699
|
+
async def _prepare_model(self, model: models.Model, deps: AgentDeps) -> models.AgentModel:
|
|
700
|
+
"""Create building tools and create an agent model."""
|
|
701
|
+
function_tools: list[ToolDefinition] = []
|
|
702
|
+
|
|
703
|
+
async def add_tool(tool: Tool[AgentDeps]) -> None:
|
|
704
|
+
ctx = RunContext(deps, tool.current_retry, tool.name)
|
|
705
|
+
if tool_def := await tool.prepare_tool_def(ctx):
|
|
706
|
+
function_tools.append(tool_def)
|
|
707
|
+
|
|
708
|
+
await asyncio.gather(*map(add_tool, self._function_tools.values()))
|
|
709
|
+
|
|
710
|
+
return await model.agent_model(
|
|
711
|
+
function_tools=function_tools,
|
|
712
|
+
allow_text_result=self._allow_text_result,
|
|
713
|
+
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
|
|
714
|
+
)
|
|
647
715
|
|
|
648
716
|
async def _prepare_messages(
|
|
649
717
|
self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None
|
|
@@ -663,11 +731,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
663
731
|
|
|
664
732
|
async def _handle_model_response(
|
|
665
733
|
self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
|
|
666
|
-
) -> _MarkFinalResult[ResultData] | list[_messages.Message]:
|
|
734
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
|
|
667
735
|
"""Process a non-streamed response from the model.
|
|
668
736
|
|
|
669
737
|
Returns:
|
|
670
|
-
|
|
738
|
+
A tuple of `(final_result, messages)`. If `final_result` is not `None`, the conversation should end.
|
|
671
739
|
"""
|
|
672
740
|
if model_response.role == 'model-text-response':
|
|
673
741
|
# plain string response
|
|
@@ -677,15 +745,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
677
745
|
result_data = await self._validate_result(result_data_input, deps, None)
|
|
678
746
|
except _result.ToolRetryError as e:
|
|
679
747
|
self._incr_result_retry()
|
|
680
|
-
return [e.tool_retry]
|
|
748
|
+
return None, [e.tool_retry]
|
|
681
749
|
else:
|
|
682
|
-
return _MarkFinalResult(result_data)
|
|
750
|
+
return _MarkFinalResult(result_data), []
|
|
683
751
|
else:
|
|
684
752
|
self._incr_result_retry()
|
|
685
753
|
response = _messages.RetryPrompt(
|
|
686
754
|
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
687
755
|
)
|
|
688
|
-
return [response]
|
|
756
|
+
return None, [response]
|
|
689
757
|
elif model_response.role == 'model-structured-response':
|
|
690
758
|
if self._result_schema is not None:
|
|
691
759
|
# if there's a result schema, and any of the calls match one of its tools, return the result
|
|
@@ -697,9 +765,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
697
765
|
result_data = await self._validate_result(result_data, deps, call)
|
|
698
766
|
except _result.ToolRetryError as e:
|
|
699
767
|
self._incr_result_retry()
|
|
700
|
-
return [e.tool_retry]
|
|
768
|
+
return None, [e.tool_retry]
|
|
701
769
|
else:
|
|
702
|
-
|
|
770
|
+
# Add a ToolReturn message for the schema tool call
|
|
771
|
+
tool_return = _messages.ToolReturn(
|
|
772
|
+
tool_name=call.tool_name,
|
|
773
|
+
content='Final result processed.',
|
|
774
|
+
tool_id=call.tool_id,
|
|
775
|
+
)
|
|
776
|
+
return _MarkFinalResult(result_data), [tool_return]
|
|
703
777
|
|
|
704
778
|
if not model_response.calls:
|
|
705
779
|
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
@@ -714,26 +788,24 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
714
788
|
messages.append(self._unknown_tool(call.tool_name))
|
|
715
789
|
|
|
716
790
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
717
|
-
|
|
718
|
-
|
|
791
|
+
task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
|
|
792
|
+
messages.extend(task_results)
|
|
793
|
+
return None, messages
|
|
719
794
|
else:
|
|
720
795
|
assert_never(model_response)
|
|
721
796
|
|
|
722
797
|
async def _handle_streamed_model_response(
|
|
723
798
|
self, model_response: models.EitherStreamedResponse, deps: AgentDeps
|
|
724
|
-
) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]:
|
|
799
|
+
) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.Message]]:
|
|
725
800
|
"""Process a streamed response from the model.
|
|
726
801
|
|
|
727
|
-
TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
|
|
728
|
-
(with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
|
|
729
|
-
|
|
730
802
|
Returns:
|
|
731
|
-
|
|
803
|
+
A tuple of (final_result, messages). If final_result is not None, the conversation should end.
|
|
732
804
|
"""
|
|
733
805
|
if isinstance(model_response, models.StreamTextResponse):
|
|
734
806
|
# plain string response
|
|
735
807
|
if self._allow_text_result:
|
|
736
|
-
return _MarkFinalResult(model_response)
|
|
808
|
+
return _MarkFinalResult(model_response), []
|
|
737
809
|
else:
|
|
738
810
|
self._incr_result_retry()
|
|
739
811
|
response = _messages.RetryPrompt(
|
|
@@ -743,7 +815,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
743
815
|
async for _ in model_response:
|
|
744
816
|
pass
|
|
745
817
|
|
|
746
|
-
return [response]
|
|
818
|
+
return None, [response]
|
|
747
819
|
else:
|
|
748
820
|
assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
|
|
749
821
|
if self._result_schema is not None:
|
|
@@ -757,8 +829,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
757
829
|
break
|
|
758
830
|
structured_msg = model_response.get()
|
|
759
831
|
|
|
760
|
-
if self._result_schema.find_tool(structured_msg):
|
|
761
|
-
|
|
832
|
+
if match := self._result_schema.find_tool(structured_msg):
|
|
833
|
+
call, _ = match
|
|
834
|
+
tool_return = _messages.ToolReturn(
|
|
835
|
+
tool_name=call.tool_name,
|
|
836
|
+
content='Final result processed.',
|
|
837
|
+
tool_id=call.tool_id,
|
|
838
|
+
)
|
|
839
|
+
return _MarkFinalResult(model_response), [tool_return]
|
|
762
840
|
|
|
763
841
|
# the model is calling a tool function, consume the response to get the next message
|
|
764
842
|
async for _ in model_response:
|
|
@@ -777,8 +855,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
777
855
|
messages.append(self._unknown_tool(call.tool_name))
|
|
778
856
|
|
|
779
857
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
780
|
-
|
|
781
|
-
|
|
858
|
+
task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
|
|
859
|
+
messages.extend(task_results)
|
|
860
|
+
return None, messages
|
|
782
861
|
|
|
783
862
|
async def _validate_result(
|
|
784
863
|
self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
|
|
@@ -837,6 +916,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
837
916
|
if item is self:
|
|
838
917
|
self.name = name
|
|
839
918
|
return
|
|
919
|
+
if parent_frame.f_locals != parent_frame.f_globals:
|
|
920
|
+
# if we couldn't find the agent in locals and globals are a different dict, try globals
|
|
921
|
+
for name, item in parent_frame.f_globals.items():
|
|
922
|
+
if item is self:
|
|
923
|
+
self.name = name
|
|
924
|
+
return
|
|
840
925
|
|
|
841
926
|
|
|
842
927
|
@dataclass
|
|
@@ -844,6 +929,8 @@ class _MarkFinalResult(Generic[ResultData]):
|
|
|
844
929
|
"""Marker class to indicate that the result is the final result.
|
|
845
930
|
|
|
846
931
|
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
|
|
932
|
+
|
|
933
|
+
It also avoids problems in the case where the result type is itself `None`, but is set.
|
|
847
934
|
"""
|
|
848
935
|
|
|
849
936
|
data: ResultData
|
pydantic_ai/messages.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
import json
|
|
4
3
|
from dataclasses import dataclass, field
|
|
5
4
|
from datetime import datetime
|
|
6
5
|
from typing import Annotated, Any, Literal, Union
|
|
@@ -74,6 +73,9 @@ class ToolReturn:
|
|
|
74
73
|
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
|
|
75
74
|
|
|
76
75
|
|
|
76
|
+
ErrorDetailsTa = _pydantic.LazyTypeAdapter(list[pydantic_core.ErrorDetails])
|
|
77
|
+
|
|
78
|
+
|
|
77
79
|
@dataclass
|
|
78
80
|
class RetryPrompt:
|
|
79
81
|
"""A message back to a model asking it to try again.
|
|
@@ -109,7 +111,8 @@ class RetryPrompt:
|
|
|
109
111
|
if isinstance(self.content, str):
|
|
110
112
|
description = self.content
|
|
111
113
|
else:
|
|
112
|
-
|
|
114
|
+
json_errors = ErrorDetailsTa.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
|
|
115
|
+
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
|
|
113
116
|
return f'{description}\n\nFix the errors and try again.'
|
|
114
117
|
|
|
115
118
|
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -7,11 +7,11 @@ specific LLM being used.
|
|
|
7
7
|
from __future__ import annotations as _annotations
|
|
8
8
|
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
|
-
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
10
|
+
from collections.abc import AsyncIterator, Iterable, Iterator
|
|
11
11
|
from contextlib import asynccontextmanager, contextmanager
|
|
12
12
|
from datetime import datetime
|
|
13
13
|
from functools import cache
|
|
14
|
-
from typing import TYPE_CHECKING, Literal,
|
|
14
|
+
from typing import TYPE_CHECKING, Literal, Union
|
|
15
15
|
|
|
16
16
|
import httpx
|
|
17
17
|
|
|
@@ -19,8 +19,8 @@ from ..exceptions import UserError
|
|
|
19
19
|
from ..messages import Message, ModelAnyResponse, ModelStructuredResponse
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
|
-
from .._utils import ObjectJsonSchema
|
|
23
22
|
from ..result import Cost
|
|
23
|
+
from ..tools import ToolDefinition
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
KnownModelName = Literal[
|
|
@@ -49,6 +49,23 @@ KnownModelName = Literal[
|
|
|
49
49
|
'gemini-1.5-pro',
|
|
50
50
|
'vertexai:gemini-1.5-flash',
|
|
51
51
|
'vertexai:gemini-1.5-pro',
|
|
52
|
+
'ollama:codellama',
|
|
53
|
+
'ollama:gemma',
|
|
54
|
+
'ollama:gemma2',
|
|
55
|
+
'ollama:llama3',
|
|
56
|
+
'ollama:llama3.1',
|
|
57
|
+
'ollama:llama3.2',
|
|
58
|
+
'ollama:llama3.2-vision',
|
|
59
|
+
'ollama:llama3.3',
|
|
60
|
+
'ollama:mistral',
|
|
61
|
+
'ollama:mistral-nemo',
|
|
62
|
+
'ollama:mixtral',
|
|
63
|
+
'ollama:phi3',
|
|
64
|
+
'ollama:qwq',
|
|
65
|
+
'ollama:qwen',
|
|
66
|
+
'ollama:qwen2',
|
|
67
|
+
'ollama:qwen2.5',
|
|
68
|
+
'ollama:starcoder2',
|
|
52
69
|
'test',
|
|
53
70
|
]
|
|
54
71
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -63,11 +80,12 @@ class Model(ABC):
|
|
|
63
80
|
@abstractmethod
|
|
64
81
|
async def agent_model(
|
|
65
82
|
self,
|
|
66
|
-
|
|
83
|
+
*,
|
|
84
|
+
function_tools: list[ToolDefinition],
|
|
67
85
|
allow_text_result: bool,
|
|
68
|
-
result_tools:
|
|
86
|
+
result_tools: list[ToolDefinition],
|
|
69
87
|
) -> AgentModel:
|
|
70
|
-
"""Create an agent model.
|
|
88
|
+
"""Create an agent model, this is called for each step of an agent run.
|
|
71
89
|
|
|
72
90
|
This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
|
|
73
91
|
|
|
@@ -87,7 +105,7 @@ class Model(ABC):
|
|
|
87
105
|
|
|
88
106
|
|
|
89
107
|
class AgentModel(ABC):
|
|
90
|
-
"""Model configured for
|
|
108
|
+
"""Model configured for each step of an Agent run."""
|
|
91
109
|
|
|
92
110
|
@abstractmethod
|
|
93
111
|
async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
|
|
@@ -238,7 +256,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
238
256
|
elif model.startswith('openai:'):
|
|
239
257
|
from .openai import OpenAIModel
|
|
240
258
|
|
|
241
|
-
return OpenAIModel(model[7:])
|
|
259
|
+
return OpenAIModel(model[7:])
|
|
242
260
|
elif model.startswith('gemini'):
|
|
243
261
|
from .gemini import GeminiModel
|
|
244
262
|
|
|
@@ -252,39 +270,14 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
252
270
|
from .vertexai import VertexAIModel
|
|
253
271
|
|
|
254
272
|
return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
|
|
273
|
+
elif model.startswith('ollama:'):
|
|
274
|
+
from .ollama import OllamaModel
|
|
275
|
+
|
|
276
|
+
return OllamaModel(model[7:])
|
|
255
277
|
else:
|
|
256
278
|
raise UserError(f'Unknown model: {model}')
|
|
257
279
|
|
|
258
280
|
|
|
259
|
-
class AbstractToolDefinition(Protocol):
|
|
260
|
-
"""Abstract definition of a function/tool.
|
|
261
|
-
|
|
262
|
-
This is used for both function tools and result tools.
|
|
263
|
-
"""
|
|
264
|
-
|
|
265
|
-
@property
|
|
266
|
-
def name(self) -> str:
|
|
267
|
-
"""The name of the tool."""
|
|
268
|
-
...
|
|
269
|
-
|
|
270
|
-
@property
|
|
271
|
-
def description(self) -> str:
|
|
272
|
-
"""The description of the tool."""
|
|
273
|
-
...
|
|
274
|
-
|
|
275
|
-
@property
|
|
276
|
-
def json_schema(self) -> ObjectJsonSchema:
|
|
277
|
-
"""The JSON schema for the tool's arguments."""
|
|
278
|
-
...
|
|
279
|
-
|
|
280
|
-
@property
|
|
281
|
-
def outer_typed_dict_key(self) -> str | None:
|
|
282
|
-
"""The key in the outer [TypedDict] that wraps a result tool.
|
|
283
|
-
|
|
284
|
-
This will only be set for result tools which don't have an `object` JSON schema.
|
|
285
|
-
"""
|
|
286
|
-
|
|
287
|
-
|
|
288
281
|
@cache
|
|
289
282
|
def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
290
283
|
"""Cached HTTPX async client so multiple agents and calls can share the same client.
|