pydantic-ai-slim 0.2.18__tar.gz → 0.2.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/PKG-INFO +4 -4
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_agent_graph.py +68 -14
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_function_schema.py +14 -6
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_output.py +1 -1
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_system_prompt.py +1 -1
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_utils.py +28 -3
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/agent.py +13 -3
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/mcp.py +66 -5
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/messages.py +4 -5
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/anthropic.py +1 -1
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/gemini.py +1 -3
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/google.py +18 -6
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/__init__.py +23 -17
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/tools.py +1 -2
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/.gitignore +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/LICENSE +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/README.md +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_a2a.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_cli.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/direct.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/ext/__init__.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/ext/langchain.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/format_prompt.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/__init__.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/bedrock.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/mistral.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/openai.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/__init__.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/_json_schema.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/amazon.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/anthropic.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/cohere.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/deepseek.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/google.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/grok.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/meta.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/mistral.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/openai.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/qwen.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/azure.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/fireworks.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/google.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/grok.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/heroku.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/mistral.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/openrouter.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/together.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/usage.py +0 -0
- {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pyproject.toml +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.20
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -30,11 +30,11 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
30
30
|
Requires-Dist: griffe>=1.3.2
|
|
31
31
|
Requires-Dist: httpx>=0.27
|
|
32
32
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
33
|
-
Requires-Dist: pydantic-graph==0.2.
|
|
33
|
+
Requires-Dist: pydantic-graph==0.2.20
|
|
34
34
|
Requires-Dist: pydantic>=2.10
|
|
35
35
|
Requires-Dist: typing-inspection>=0.4.0
|
|
36
36
|
Provides-Extra: a2a
|
|
37
|
-
Requires-Dist: fasta2a==0.2.
|
|
37
|
+
Requires-Dist: fasta2a==0.2.20; extra == 'a2a'
|
|
38
38
|
Provides-Extra: anthropic
|
|
39
39
|
Requires-Dist: anthropic>=0.52.0; extra == 'anthropic'
|
|
40
40
|
Provides-Extra: bedrock
|
|
@@ -48,7 +48,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
48
48
|
Provides-Extra: duckduckgo
|
|
49
49
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
50
50
|
Provides-Extra: evals
|
|
51
|
-
Requires-Dist: pydantic-evals==0.2.
|
|
51
|
+
Requires-Dist: pydantic-evals==0.2.20; extra == 'evals'
|
|
52
52
|
Provides-Extra: google
|
|
53
53
|
Requires-Dist: google-genai>=1.15.0; extra == 'google'
|
|
54
54
|
Provides-Extra: groq
|
|
@@ -12,18 +12,12 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
|
|
|
12
12
|
from opentelemetry.trace import Tracer
|
|
13
13
|
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
14
14
|
|
|
15
|
+
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
|
|
16
|
+
from pydantic_ai._utils import is_async_callable, run_in_executor
|
|
15
17
|
from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
16
18
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
17
19
|
|
|
18
|
-
from . import
|
|
19
|
-
_output,
|
|
20
|
-
_system_prompt,
|
|
21
|
-
exceptions,
|
|
22
|
-
messages as _messages,
|
|
23
|
-
models,
|
|
24
|
-
result,
|
|
25
|
-
usage as _usage,
|
|
26
|
-
)
|
|
20
|
+
from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
|
|
27
21
|
from .result import OutputDataT
|
|
28
22
|
from .settings import ModelSettings, merge_model_settings
|
|
29
23
|
from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc
|
|
@@ -39,6 +33,7 @@ __all__ = (
|
|
|
39
33
|
'CallToolsNode',
|
|
40
34
|
'build_run_context',
|
|
41
35
|
'capture_run_messages',
|
|
36
|
+
'HistoryProcessor',
|
|
42
37
|
)
|
|
43
38
|
|
|
44
39
|
|
|
@@ -54,6 +49,23 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
54
49
|
DepsT = TypeVar('DepsT')
|
|
55
50
|
OutputT = TypeVar('OutputT')
|
|
56
51
|
|
|
52
|
+
_HistoryProcessorSync = Callable[[list[_messages.ModelMessage]], list[_messages.ModelMessage]]
|
|
53
|
+
_HistoryProcessorAsync = Callable[[list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]]
|
|
54
|
+
_HistoryProcessorSyncWithCtx = Callable[[RunContext[DepsT], list[_messages.ModelMessage]], list[_messages.ModelMessage]]
|
|
55
|
+
_HistoryProcessorAsyncWithCtx = Callable[
|
|
56
|
+
[RunContext[DepsT], list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]
|
|
57
|
+
]
|
|
58
|
+
HistoryProcessor = Union[
|
|
59
|
+
_HistoryProcessorSync,
|
|
60
|
+
_HistoryProcessorAsync,
|
|
61
|
+
_HistoryProcessorSyncWithCtx[DepsT],
|
|
62
|
+
_HistoryProcessorAsyncWithCtx[DepsT],
|
|
63
|
+
]
|
|
64
|
+
"""A function that processes a list of model messages and returns a list of model messages.
|
|
65
|
+
|
|
66
|
+
Can optionally accept a `RunContext` as a parameter.
|
|
67
|
+
"""
|
|
68
|
+
|
|
57
69
|
|
|
58
70
|
@dataclasses.dataclass
|
|
59
71
|
class GraphAgentState:
|
|
@@ -93,6 +105,8 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
|
93
105
|
output_schema: _output.OutputSchema[OutputDataT] | None
|
|
94
106
|
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
|
|
95
107
|
|
|
108
|
+
history_processors: Sequence[HistoryProcessor[DepsT]]
|
|
109
|
+
|
|
96
110
|
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
|
|
97
111
|
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
98
112
|
default_retries: int
|
|
@@ -327,8 +341,11 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
327
341
|
|
|
328
342
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
329
343
|
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
|
|
344
|
+
message_history = await _process_message_history(
|
|
345
|
+
ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx)
|
|
346
|
+
)
|
|
330
347
|
async with ctx.deps.model.request_stream(
|
|
331
|
-
|
|
348
|
+
message_history, model_settings, model_request_parameters
|
|
332
349
|
) as streamed_response:
|
|
333
350
|
self._did_stream = True
|
|
334
351
|
ctx.state.usage.requests += 1
|
|
@@ -350,9 +367,10 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
350
367
|
|
|
351
368
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
352
369
|
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
|
|
353
|
-
|
|
354
|
-
ctx.state.message_history,
|
|
370
|
+
message_history = await _process_message_history(
|
|
371
|
+
ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx)
|
|
355
372
|
)
|
|
373
|
+
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
|
|
356
374
|
ctx.state.usage.incr(_usage.Usage())
|
|
357
375
|
|
|
358
376
|
return self._finish_handling(ctx, model_response)
|
|
@@ -647,6 +665,7 @@ async def process_function_tools( # noqa C901
|
|
|
647
665
|
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
|
|
648
666
|
# validation, we don't add another part here
|
|
649
667
|
if output_tool_name is not None:
|
|
668
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
650
669
|
if found_used_output_tool:
|
|
651
670
|
content = 'Output tool not used - a final result was already processed.'
|
|
652
671
|
else:
|
|
@@ -657,9 +676,14 @@ async def process_function_tools( # noqa C901
|
|
|
657
676
|
content=content,
|
|
658
677
|
tool_call_id=call.tool_call_id,
|
|
659
678
|
)
|
|
679
|
+
yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
|
|
660
680
|
output_parts.append(part)
|
|
661
681
|
else:
|
|
662
|
-
|
|
682
|
+
yield _messages.FunctionToolCallEvent(call)
|
|
683
|
+
|
|
684
|
+
part = _unknown_tool(call.tool_name, call.tool_call_id, ctx)
|
|
685
|
+
yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
|
|
686
|
+
output_parts.append(part)
|
|
663
687
|
|
|
664
688
|
if not calls_to_run:
|
|
665
689
|
return
|
|
@@ -755,7 +779,12 @@ async def _tool_from_mcp_server(
|
|
|
755
779
|
# some weird edge case occurs.
|
|
756
780
|
if not server.is_running: # pragma: no cover
|
|
757
781
|
raise exceptions.UserError(f'MCP server is not running: {server}')
|
|
758
|
-
|
|
782
|
+
|
|
783
|
+
if server.process_tool_call is not None:
|
|
784
|
+
result = await server.process_tool_call(ctx, server.call_tool, tool_name, args)
|
|
785
|
+
else:
|
|
786
|
+
result = await server.call_tool(tool_name, args)
|
|
787
|
+
|
|
759
788
|
return result
|
|
760
789
|
|
|
761
790
|
for server in ctx.deps.mcp_servers:
|
|
@@ -865,3 +894,28 @@ def build_agent_graph(
|
|
|
865
894
|
auto_instrument=False,
|
|
866
895
|
)
|
|
867
896
|
return graph
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
async def _process_message_history(
|
|
900
|
+
messages: list[_messages.ModelMessage],
|
|
901
|
+
processors: Sequence[HistoryProcessor[DepsT]],
|
|
902
|
+
run_context: RunContext[DepsT],
|
|
903
|
+
) -> list[_messages.ModelMessage]:
|
|
904
|
+
"""Process message history through a sequence of processors."""
|
|
905
|
+
for processor in processors:
|
|
906
|
+
takes_ctx = is_takes_ctx(processor)
|
|
907
|
+
|
|
908
|
+
if is_async_callable(processor):
|
|
909
|
+
if takes_ctx:
|
|
910
|
+
messages = await processor(run_context, messages)
|
|
911
|
+
else:
|
|
912
|
+
async_processor = cast(_HistoryProcessorAsync, processor)
|
|
913
|
+
messages = await async_processor(messages)
|
|
914
|
+
else:
|
|
915
|
+
if takes_ctx:
|
|
916
|
+
sync_processor_with_ctx = cast(_HistoryProcessorSyncWithCtx[DepsT], processor)
|
|
917
|
+
messages = await run_in_executor(sync_processor_with_ctx, run_context, messages)
|
|
918
|
+
else:
|
|
919
|
+
sync_processor = cast(_HistoryProcessorSync, processor)
|
|
920
|
+
messages = await run_in_executor(sync_processor, messages)
|
|
921
|
+
return messages
|
|
@@ -5,11 +5,10 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations as _annotations
|
|
7
7
|
|
|
8
|
-
import inspect
|
|
9
8
|
from collections.abc import Awaitable
|
|
10
9
|
from dataclasses import dataclass, field
|
|
11
10
|
from inspect import Parameter, signature
|
|
12
|
-
from typing import TYPE_CHECKING, Any, Callable, cast
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable, Union, cast
|
|
13
12
|
|
|
14
13
|
from pydantic import ConfigDict
|
|
15
14
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
@@ -18,12 +17,12 @@ from pydantic.fields import FieldInfo
|
|
|
18
17
|
from pydantic.json_schema import GenerateJsonSchema
|
|
19
18
|
from pydantic.plugin._schema_validator import create_schema_validator
|
|
20
19
|
from pydantic_core import SchemaValidator, core_schema
|
|
21
|
-
from typing_extensions import get_origin
|
|
20
|
+
from typing_extensions import Concatenate, ParamSpec, TypeIs, TypeVar, get_origin
|
|
22
21
|
|
|
23
22
|
from pydantic_ai.tools import RunContext
|
|
24
23
|
|
|
25
24
|
from ._griffe import doc_descriptions
|
|
26
|
-
from ._utils import check_object_json_schema, is_model_like, run_in_executor
|
|
25
|
+
from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor
|
|
27
26
|
|
|
28
27
|
if TYPE_CHECKING:
|
|
29
28
|
from .tools import DocstringFormat, ObjectJsonSchema
|
|
@@ -214,12 +213,21 @@ def function_schema( # noqa: C901
|
|
|
214
213
|
positional_fields=positional_fields,
|
|
215
214
|
var_positional_field=var_positional_field,
|
|
216
215
|
takes_ctx=takes_ctx,
|
|
217
|
-
is_async=
|
|
216
|
+
is_async=is_async_callable(function),
|
|
218
217
|
function=function,
|
|
219
218
|
)
|
|
220
219
|
|
|
221
220
|
|
|
222
|
-
|
|
221
|
+
P = ParamSpec('P')
|
|
222
|
+
R = TypeVar('R')
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
WithCtx = Callable[Concatenate[RunContext[Any], P], R]
|
|
226
|
+
WithoutCtx = Callable[P, R]
|
|
227
|
+
TargetFunc = Union[WithCtx[P, R], WithoutCtx[P, R]]
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _takes_ctx(function: TargetFunc[P, R]) -> TypeIs[WithCtx[P, R]]:
|
|
223
231
|
"""Check if a function takes a `RunContext` first argument.
|
|
224
232
|
|
|
225
233
|
Args:
|
|
@@ -60,7 +60,7 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
|
60
60
|
|
|
61
61
|
def __post_init__(self):
|
|
62
62
|
self._takes_ctx = len(inspect.signature(self.function).parameters) > 1
|
|
63
|
-
self._is_async =
|
|
63
|
+
self._is_async = _utils.is_async_callable(self.function)
|
|
64
64
|
|
|
65
65
|
async def validate(
|
|
66
66
|
self,
|
|
@@ -18,7 +18,7 @@ class SystemPromptRunner(Generic[AgentDepsT]):
|
|
|
18
18
|
|
|
19
19
|
def __post_init__(self):
|
|
20
20
|
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
|
|
21
|
-
self._is_async =
|
|
21
|
+
self._is_async = _utils.is_async_callable(self.function)
|
|
22
22
|
|
|
23
23
|
async def run(self, run_context: RunContext[AgentDepsT]) -> str:
|
|
24
24
|
if self._takes_ctx:
|
|
@@ -1,20 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import functools
|
|
5
|
+
import inspect
|
|
4
6
|
import time
|
|
5
7
|
import uuid
|
|
6
|
-
from collections.abc import AsyncIterable, AsyncIterator, Iterator
|
|
8
|
+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
|
|
7
9
|
from contextlib import asynccontextmanager, suppress
|
|
8
10
|
from dataclasses import dataclass, fields, is_dataclass
|
|
9
11
|
from datetime import datetime, timezone
|
|
10
12
|
from functools import partial
|
|
11
13
|
from types import GenericAlias
|
|
12
|
-
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
|
|
14
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload
|
|
13
15
|
|
|
14
16
|
from anyio.to_thread import run_sync
|
|
15
17
|
from pydantic import BaseModel, TypeAdapter
|
|
16
18
|
from pydantic.json_schema import JsonSchemaValue
|
|
17
|
-
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
|
|
19
|
+
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict
|
|
18
20
|
|
|
19
21
|
from pydantic_graph._utils import AbstractSpan
|
|
20
22
|
|
|
@@ -302,3 +304,26 @@ def dataclasses_no_defaults_repr(self: Any) -> str:
|
|
|
302
304
|
|
|
303
305
|
def number_to_datetime(x: int | float) -> datetime:
|
|
304
306
|
return TypeAdapter(datetime).validate_python(x)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
AwaitableCallable = Callable[..., Awaitable[T]]
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
@overload
|
|
313
|
+
def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
@overload
|
|
317
|
+
def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def is_async_callable(obj: Any) -> Any:
|
|
321
|
+
"""Correctly check if a callable is async.
|
|
322
|
+
|
|
323
|
+
This function was copied from Starlette:
|
|
324
|
+
https://github.com/encode/starlette/blob/78da9b9e218ab289117df7d62aee200ed4c59617/starlette/_utils.py#L36-L40
|
|
325
|
+
"""
|
|
326
|
+
while isinstance(obj, functools.partial):
|
|
327
|
+
obj = obj.func
|
|
328
|
+
|
|
329
|
+
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
|
|
@@ -28,6 +28,7 @@ from . import (
|
|
|
28
28
|
result,
|
|
29
29
|
usage as _usage,
|
|
30
30
|
)
|
|
31
|
+
from ._agent_graph import HistoryProcessor
|
|
31
32
|
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
|
|
32
33
|
from .result import FinalResult, OutputDataT, StreamedRunResult
|
|
33
34
|
from .settings import ModelSettings, merge_model_settings
|
|
@@ -179,6 +180,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
179
180
|
defer_model_check: bool = False,
|
|
180
181
|
end_strategy: EndStrategy = 'early',
|
|
181
182
|
instrument: InstrumentationSettings | bool | None = None,
|
|
183
|
+
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
|
|
182
184
|
) -> None: ...
|
|
183
185
|
|
|
184
186
|
@overload
|
|
@@ -208,6 +210,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
208
210
|
defer_model_check: bool = False,
|
|
209
211
|
end_strategy: EndStrategy = 'early',
|
|
210
212
|
instrument: InstrumentationSettings | bool | None = None,
|
|
213
|
+
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
|
|
211
214
|
) -> None: ...
|
|
212
215
|
|
|
213
216
|
def __init__(
|
|
@@ -232,6 +235,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
232
235
|
defer_model_check: bool = False,
|
|
233
236
|
end_strategy: EndStrategy = 'early',
|
|
234
237
|
instrument: InstrumentationSettings | bool | None = None,
|
|
238
|
+
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
|
|
235
239
|
**_deprecated_kwargs: Any,
|
|
236
240
|
):
|
|
237
241
|
"""Create an agent.
|
|
@@ -275,6 +279,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
275
279
|
[`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all]
|
|
276
280
|
will be used, which defaults to False.
|
|
277
281
|
See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
|
|
282
|
+
history_processors: Optional list of callables to process the message history before sending it to the model.
|
|
283
|
+
Each processor takes a list of messages and returns a modified list of messages.
|
|
284
|
+
Processors can be sync or async and are applied in sequence.
|
|
278
285
|
"""
|
|
279
286
|
if model is None or defer_model_check:
|
|
280
287
|
self.model = model
|
|
@@ -343,6 +350,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
343
350
|
self._max_result_retries = output_retries if output_retries is not None else retries
|
|
344
351
|
self._mcp_servers = mcp_servers
|
|
345
352
|
self._prepare_tools = prepare_tools
|
|
353
|
+
self.history_processors = history_processors or []
|
|
346
354
|
for tool in tools:
|
|
347
355
|
if isinstance(tool, Tool):
|
|
348
356
|
self._register_tool(tool)
|
|
@@ -669,10 +677,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
669
677
|
if self._instructions is None and not self._instructions_functions:
|
|
670
678
|
return None
|
|
671
679
|
|
|
672
|
-
instructions = self._instructions
|
|
680
|
+
instructions = [self._instructions] if self._instructions else []
|
|
673
681
|
for instructions_runner in self._instructions_functions:
|
|
674
|
-
instructions
|
|
675
|
-
|
|
682
|
+
instructions.append(await instructions_runner.run(run_context))
|
|
683
|
+
concatenated_instructions = '\n'.join(instruction for instruction in instructions if instruction)
|
|
684
|
+
return concatenated_instructions.strip() if concatenated_instructions else None
|
|
676
685
|
|
|
677
686
|
# Copy the function tools so that retry state is agent-run-specific
|
|
678
687
|
# Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
|
|
@@ -689,6 +698,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
689
698
|
end_strategy=self.end_strategy,
|
|
690
699
|
output_schema=output_schema,
|
|
691
700
|
output_validators=output_validators,
|
|
701
|
+
history_processors=self.history_processors,
|
|
692
702
|
function_tools=run_function_tools,
|
|
693
703
|
mcp_servers=self._mcp_servers,
|
|
694
704
|
default_retries=self._default_retries,
|
|
@@ -4,7 +4,7 @@ import base64
|
|
|
4
4
|
import functools
|
|
5
5
|
import json
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
|
-
from collections.abc import AsyncIterator, Sequence
|
|
7
|
+
from collections.abc import AsyncIterator, Awaitable, Sequence
|
|
8
8
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
9
9
|
from dataclasses import dataclass
|
|
10
10
|
from pathlib import Path
|
|
@@ -15,14 +15,20 @@ import anyio
|
|
|
15
15
|
import httpx
|
|
16
16
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
17
17
|
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
|
18
|
+
from mcp.shared.exceptions import McpError
|
|
18
19
|
from mcp.shared.message import SessionMessage
|
|
19
20
|
from mcp.types import (
|
|
20
21
|
AudioContent,
|
|
21
22
|
BlobResourceContents,
|
|
23
|
+
CallToolRequest,
|
|
24
|
+
CallToolRequestParams,
|
|
25
|
+
CallToolResult,
|
|
26
|
+
ClientRequest,
|
|
22
27
|
Content,
|
|
23
28
|
EmbeddedResource,
|
|
24
29
|
ImageContent,
|
|
25
30
|
LoggingLevel,
|
|
31
|
+
RequestParams,
|
|
26
32
|
TextContent,
|
|
27
33
|
TextResourceContents,
|
|
28
34
|
)
|
|
@@ -30,7 +36,7 @@ from typing_extensions import Self, assert_never, deprecated
|
|
|
30
36
|
|
|
31
37
|
from pydantic_ai.exceptions import ModelRetry
|
|
32
38
|
from pydantic_ai.messages import BinaryContent
|
|
33
|
-
from pydantic_ai.tools import ToolDefinition
|
|
39
|
+
from pydantic_ai.tools import RunContext, ToolDefinition
|
|
34
40
|
|
|
35
41
|
try:
|
|
36
42
|
from mcp.client.session import ClientSession
|
|
@@ -60,6 +66,9 @@ class MCPServer(ABC):
|
|
|
60
66
|
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
61
67
|
"""
|
|
62
68
|
|
|
69
|
+
process_tool_call: ProcessToolCallback | None = None
|
|
70
|
+
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
71
|
+
|
|
63
72
|
_client: ClientSession
|
|
64
73
|
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
|
65
74
|
_write_stream: MemoryObjectSendStream[SessionMessage]
|
|
@@ -113,13 +122,17 @@ class MCPServer(ABC):
|
|
|
113
122
|
]
|
|
114
123
|
|
|
115
124
|
async def call_tool(
|
|
116
|
-
self,
|
|
117
|
-
|
|
125
|
+
self,
|
|
126
|
+
tool_name: str,
|
|
127
|
+
arguments: dict[str, Any],
|
|
128
|
+
metadata: dict[str, Any] | None = None,
|
|
129
|
+
) -> ToolResult:
|
|
118
130
|
"""Call a tool on the server.
|
|
119
131
|
|
|
120
132
|
Args:
|
|
121
133
|
tool_name: The name of the tool to call.
|
|
122
134
|
arguments: The arguments to pass to the tool.
|
|
135
|
+
metadata: Request-level metadata (optional)
|
|
123
136
|
|
|
124
137
|
Returns:
|
|
125
138
|
The result of the tool call.
|
|
@@ -127,7 +140,23 @@ class MCPServer(ABC):
|
|
|
127
140
|
Raises:
|
|
128
141
|
ModelRetry: If the tool call fails.
|
|
129
142
|
"""
|
|
130
|
-
|
|
143
|
+
try:
|
|
144
|
+
# meta param is not provided by session yet, so build and can send_request directly.
|
|
145
|
+
result = await self._client.send_request(
|
|
146
|
+
ClientRequest(
|
|
147
|
+
CallToolRequest(
|
|
148
|
+
method='tools/call',
|
|
149
|
+
params=CallToolRequestParams(
|
|
150
|
+
name=self.get_unprefixed_tool_name(tool_name),
|
|
151
|
+
arguments=arguments,
|
|
152
|
+
_meta=RequestParams.Meta(**metadata) if metadata else None,
|
|
153
|
+
),
|
|
154
|
+
)
|
|
155
|
+
),
|
|
156
|
+
CallToolResult,
|
|
157
|
+
)
|
|
158
|
+
except McpError as e:
|
|
159
|
+
raise ModelRetry(e.error.message)
|
|
131
160
|
|
|
132
161
|
content = [self._map_tool_result_part(part) for part in result.content]
|
|
133
162
|
|
|
@@ -265,6 +294,9 @@ class MCPServerStdio(MCPServer):
|
|
|
265
294
|
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
266
295
|
"""
|
|
267
296
|
|
|
297
|
+
process_tool_call: ProcessToolCallback | None = None
|
|
298
|
+
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
299
|
+
|
|
268
300
|
timeout: float = 5
|
|
269
301
|
""" The timeout in seconds to wait for the client to initialize."""
|
|
270
302
|
|
|
@@ -359,6 +391,9 @@ class _MCPServerHTTP(MCPServer):
|
|
|
359
391
|
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
360
392
|
"""
|
|
361
393
|
|
|
394
|
+
process_tool_call: ProcessToolCallback | None = None
|
|
395
|
+
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
396
|
+
|
|
362
397
|
@property
|
|
363
398
|
@abstractmethod
|
|
364
399
|
def _transport_client(
|
|
@@ -517,3 +552,29 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
|
|
|
517
552
|
@property
|
|
518
553
|
def _transport_client(self):
|
|
519
554
|
return streamablehttp_client # pragma: no cover
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
ToolResult = (
|
|
558
|
+
str | BinaryContent | dict[str, Any] | list[Any] | Sequence[str | BinaryContent | dict[str, Any] | list[Any]]
|
|
559
|
+
)
|
|
560
|
+
"""The result type of a tool call."""
|
|
561
|
+
|
|
562
|
+
CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]]
|
|
563
|
+
"""A function type that represents a tool call."""
|
|
564
|
+
|
|
565
|
+
ProcessToolCallback = Callable[
|
|
566
|
+
[
|
|
567
|
+
RunContext[Any],
|
|
568
|
+
CallToolFunc,
|
|
569
|
+
str,
|
|
570
|
+
dict[str, Any],
|
|
571
|
+
],
|
|
572
|
+
Awaitable[ToolResult],
|
|
573
|
+
]
|
|
574
|
+
"""A process tool callback.
|
|
575
|
+
|
|
576
|
+
It accepts a run context, the original tool call function, a tool name, and arguments.
|
|
577
|
+
|
|
578
|
+
Allows wrapping an MCP server tool call to customize it, including adding extra request
|
|
579
|
+
metadata.
|
|
580
|
+
"""
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import uuid
|
|
5
4
|
from abc import ABC, abstractmethod
|
|
6
5
|
from collections.abc import Sequence
|
|
7
6
|
from dataclasses import dataclass, field, replace
|
|
@@ -888,13 +887,13 @@ class FunctionToolCallEvent:
|
|
|
888
887
|
|
|
889
888
|
part: ToolCallPart
|
|
890
889
|
"""The (function) tool call to make."""
|
|
891
|
-
call_id: str = field(init=False)
|
|
892
|
-
"""An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
|
|
893
890
|
event_kind: Literal['function_tool_call'] = 'function_tool_call'
|
|
894
891
|
"""Event type identifier, used as a discriminator."""
|
|
895
892
|
|
|
896
|
-
|
|
897
|
-
|
|
893
|
+
@property
|
|
894
|
+
def call_id(self) -> str:
|
|
895
|
+
"""An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
|
|
896
|
+
return self.part.tool_call_id
|
|
898
897
|
|
|
899
898
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
900
899
|
|
|
@@ -220,7 +220,7 @@ class AnthropicModel(Model):
|
|
|
220
220
|
extra_headers = model_settings.get('extra_headers', {})
|
|
221
221
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
222
222
|
return await self.client.beta.messages.create(
|
|
223
|
-
max_tokens=model_settings.get('max_tokens',
|
|
223
|
+
max_tokens=model_settings.get('max_tokens', 4096),
|
|
224
224
|
system=system_prompt or NOT_GIVEN,
|
|
225
225
|
messages=anthropic_messages,
|
|
226
226
|
model=self._model_name,
|
|
@@ -723,9 +723,7 @@ class _GeminiFunction(TypedDict):
|
|
|
723
723
|
|
|
724
724
|
def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
|
|
725
725
|
json_schema = tool.parameters_json_schema
|
|
726
|
-
f = _GeminiFunction(name=tool.name, description=tool.description)
|
|
727
|
-
if json_schema.get('properties'):
|
|
728
|
-
f['parameters'] = json_schema
|
|
726
|
+
f = _GeminiFunction(name=tool.name, description=tool.description, parameters=json_schema)
|
|
729
727
|
return f
|
|
730
728
|
|
|
731
729
|
|
|
@@ -10,9 +10,8 @@ from uuid import uuid4
|
|
|
10
10
|
|
|
11
11
|
from typing_extensions import assert_never
|
|
12
12
|
|
|
13
|
-
from pydantic_ai.providers import Provider
|
|
14
|
-
|
|
15
13
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
|
+
from ..exceptions import UserError
|
|
16
15
|
from ..messages import (
|
|
17
16
|
BinaryContent,
|
|
18
17
|
FileUrl,
|
|
@@ -30,6 +29,7 @@ from ..messages import (
|
|
|
30
29
|
VideoUrl,
|
|
31
30
|
)
|
|
32
31
|
from ..profiles import ModelProfileSpec
|
|
32
|
+
from ..providers import Provider
|
|
33
33
|
from ..settings import ModelSettings
|
|
34
34
|
from ..tools import ToolDefinition
|
|
35
35
|
from . import (
|
|
@@ -52,6 +52,7 @@ try:
|
|
|
52
52
|
FunctionDeclarationDict,
|
|
53
53
|
GenerateContentConfigDict,
|
|
54
54
|
GenerateContentResponse,
|
|
55
|
+
HttpOptionsDict,
|
|
55
56
|
Part,
|
|
56
57
|
PartDict,
|
|
57
58
|
SafetySettingDict,
|
|
@@ -252,8 +253,17 @@ class GoogleModel(Model):
|
|
|
252
253
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
253
254
|
system_instruction, contents = await self._map_messages(messages)
|
|
254
255
|
|
|
256
|
+
http_options: HttpOptionsDict = {
|
|
257
|
+
'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
258
|
+
}
|
|
259
|
+
if timeout := model_settings.get('timeout'):
|
|
260
|
+
if isinstance(timeout, (int, float)):
|
|
261
|
+
http_options['timeout'] = int(1000 * timeout)
|
|
262
|
+
else:
|
|
263
|
+
raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')
|
|
264
|
+
|
|
255
265
|
config = GenerateContentConfigDict(
|
|
256
|
-
http_options=
|
|
266
|
+
http_options=http_options,
|
|
257
267
|
system_instruction=system_instruction,
|
|
258
268
|
temperature=model_settings.get('temperature'),
|
|
259
269
|
top_p=model_settings.get('top_p'),
|
|
@@ -469,9 +479,11 @@ def _process_response_from_parts(
|
|
|
469
479
|
|
|
470
480
|
def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
|
|
471
481
|
json_schema = tool.parameters_json_schema
|
|
472
|
-
f = FunctionDeclarationDict(
|
|
473
|
-
|
|
474
|
-
|
|
482
|
+
f = FunctionDeclarationDict(
|
|
483
|
+
name=tool.name,
|
|
484
|
+
description=tool.description,
|
|
485
|
+
parameters=json_schema, # type: ignore
|
|
486
|
+
)
|
|
475
487
|
return f
|
|
476
488
|
|
|
477
489
|
|
|
@@ -48,68 +48,74 @@ class Provider(ABC, Generic[InterfaceClient]):
|
|
|
48
48
|
return None # pragma: no cover
|
|
49
49
|
|
|
50
50
|
|
|
51
|
-
def
|
|
52
|
-
"""
|
|
51
|
+
def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
52
|
+
"""Infers the provider class from the provider name."""
|
|
53
53
|
if provider == 'openai':
|
|
54
54
|
from .openai import OpenAIProvider
|
|
55
55
|
|
|
56
|
-
return OpenAIProvider
|
|
56
|
+
return OpenAIProvider
|
|
57
57
|
elif provider == 'deepseek':
|
|
58
58
|
from .deepseek import DeepSeekProvider
|
|
59
59
|
|
|
60
|
-
return DeepSeekProvider
|
|
60
|
+
return DeepSeekProvider
|
|
61
61
|
elif provider == 'openrouter':
|
|
62
62
|
from .openrouter import OpenRouterProvider
|
|
63
63
|
|
|
64
|
-
return OpenRouterProvider
|
|
64
|
+
return OpenRouterProvider
|
|
65
65
|
elif provider == 'azure':
|
|
66
66
|
from .azure import AzureProvider
|
|
67
67
|
|
|
68
|
-
return AzureProvider
|
|
68
|
+
return AzureProvider
|
|
69
69
|
elif provider == 'google-vertex':
|
|
70
70
|
from .google_vertex import GoogleVertexProvider
|
|
71
71
|
|
|
72
|
-
return GoogleVertexProvider
|
|
72
|
+
return GoogleVertexProvider
|
|
73
73
|
elif provider == 'google-gla':
|
|
74
74
|
from .google_gla import GoogleGLAProvider
|
|
75
75
|
|
|
76
|
-
return GoogleGLAProvider
|
|
76
|
+
return GoogleGLAProvider
|
|
77
77
|
# NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
|
|
78
78
|
elif provider == 'bedrock':
|
|
79
79
|
from .bedrock import BedrockProvider
|
|
80
80
|
|
|
81
|
-
return BedrockProvider
|
|
81
|
+
return BedrockProvider
|
|
82
82
|
elif provider == 'groq':
|
|
83
83
|
from .groq import GroqProvider
|
|
84
84
|
|
|
85
|
-
return GroqProvider
|
|
85
|
+
return GroqProvider
|
|
86
86
|
elif provider == 'anthropic':
|
|
87
87
|
from .anthropic import AnthropicProvider
|
|
88
88
|
|
|
89
|
-
return AnthropicProvider
|
|
89
|
+
return AnthropicProvider
|
|
90
90
|
elif provider == 'mistral':
|
|
91
91
|
from .mistral import MistralProvider
|
|
92
92
|
|
|
93
|
-
return MistralProvider
|
|
93
|
+
return MistralProvider
|
|
94
94
|
elif provider == 'cohere':
|
|
95
95
|
from .cohere import CohereProvider
|
|
96
96
|
|
|
97
|
-
return CohereProvider
|
|
97
|
+
return CohereProvider
|
|
98
98
|
elif provider == 'grok':
|
|
99
99
|
from .grok import GrokProvider
|
|
100
100
|
|
|
101
|
-
return GrokProvider
|
|
101
|
+
return GrokProvider
|
|
102
102
|
elif provider == 'fireworks':
|
|
103
103
|
from .fireworks import FireworksProvider
|
|
104
104
|
|
|
105
|
-
return FireworksProvider
|
|
105
|
+
return FireworksProvider
|
|
106
106
|
elif provider == 'together':
|
|
107
107
|
from .together import TogetherProvider
|
|
108
108
|
|
|
109
|
-
return TogetherProvider
|
|
109
|
+
return TogetherProvider
|
|
110
110
|
elif provider == 'heroku':
|
|
111
111
|
from .heroku import HerokuProvider
|
|
112
112
|
|
|
113
|
-
return HerokuProvider
|
|
113
|
+
return HerokuProvider
|
|
114
114
|
else: # pragma: no cover
|
|
115
115
|
raise ValueError(f'Unknown provider: {provider}')
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def infer_provider(provider: str) -> Provider[Any]:
|
|
119
|
+
"""Infer the provider from the provider name."""
|
|
120
|
+
provider_class = infer_provider_class(provider)
|
|
121
|
+
return provider_class()
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
3
|
import dataclasses
|
|
5
4
|
import json
|
|
6
5
|
from collections.abc import Awaitable, Sequence
|
|
@@ -337,7 +336,7 @@ class Tool(Generic[AgentDepsT]):
|
|
|
337
336
|
validator=SchemaValidator(schema=core_schema.any_schema()),
|
|
338
337
|
json_schema=json_schema,
|
|
339
338
|
takes_ctx=False,
|
|
340
|
-
is_async=
|
|
339
|
+
is_async=_utils.is_async_callable(function),
|
|
341
340
|
)
|
|
342
341
|
|
|
343
342
|
return cls(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|