pydantic-ai-slim 0.0.32__tar.gz → 0.0.34__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/PKG-INFO +7 -2
- pydantic_ai_slim-0.0.34/pydantic_ai/_cli.py +225 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_pydantic.py +4 -4
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_result.py +7 -18
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/agent.py +29 -9
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/messages.py +11 -2
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/__init__.py +36 -36
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/gemini.py +51 -14
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/instrumented.py +43 -9
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/openai.py +56 -15
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/vertexai.py +9 -1
- pydantic_ai_slim-0.0.34/pydantic_ai/providers/__init__.py +64 -0
- pydantic_ai_slim-0.0.34/pydantic_ai/providers/deepseek.py +68 -0
- pydantic_ai_slim-0.0.34/pydantic_ai/providers/google_gla.py +44 -0
- pydantic_ai_slim-0.0.34/pydantic_ai/providers/google_vertex.py +200 -0
- pydantic_ai_slim-0.0.34/pydantic_ai/providers/openai.py +72 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pyproject.toml +8 -2
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/README.md +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_agent_graph.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/anthropic.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/mistral.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.34
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,10 +29,15 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.34
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
|
+
Requires-Dist: typing-inspection>=0.4.0
|
|
34
35
|
Provides-Extra: anthropic
|
|
35
36
|
Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
|
|
37
|
+
Provides-Extra: cli
|
|
38
|
+
Requires-Dist: argcomplete>=3.5.0; extra == 'cli'
|
|
39
|
+
Requires-Dist: prompt-toolkit>=3; extra == 'cli'
|
|
40
|
+
Requires-Dist: rich>=13; extra == 'cli'
|
|
36
41
|
Provides-Extra: cohere
|
|
37
42
|
Requires-Dist: cohere>=5.13.11; extra == 'cohere'
|
|
38
43
|
Provides-Extra: duckduckgo
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import asyncio
|
|
5
|
+
import sys
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from importlib.metadata import version
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import cast
|
|
11
|
+
|
|
12
|
+
from typing_inspection.introspection import get_literal_values
|
|
13
|
+
|
|
14
|
+
from pydantic_ai.exceptions import UserError
|
|
15
|
+
from pydantic_ai.models import KnownModelName
|
|
16
|
+
from pydantic_graph.nodes import End
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import argcomplete
|
|
20
|
+
from prompt_toolkit import PromptSession
|
|
21
|
+
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, Suggestion
|
|
22
|
+
from prompt_toolkit.buffer import Buffer
|
|
23
|
+
from prompt_toolkit.document import Document
|
|
24
|
+
from prompt_toolkit.history import FileHistory
|
|
25
|
+
from rich.console import Console, ConsoleOptions, RenderResult
|
|
26
|
+
from rich.live import Live
|
|
27
|
+
from rich.markdown import CodeBlock, Markdown
|
|
28
|
+
from rich.status import Status
|
|
29
|
+
from rich.syntax import Syntax
|
|
30
|
+
from rich.text import Text
|
|
31
|
+
except ImportError as _import_error:
|
|
32
|
+
raise ImportError(
|
|
33
|
+
'Please install `rich`, `prompt-toolkit` and `argcomplete` to use the PydanticAI CLI, '
|
|
34
|
+
"you can use the `cli` optional group — `pip install 'pydantic-ai-slim[cli]'`"
|
|
35
|
+
) from _import_error
|
|
36
|
+
|
|
37
|
+
from pydantic_ai.agent import Agent
|
|
38
|
+
from pydantic_ai.messages import ModelMessage, PartDeltaEvent, TextPartDelta
|
|
39
|
+
|
|
40
|
+
__version__ = version('pydantic-ai')
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SimpleCodeBlock(CodeBlock):
|
|
44
|
+
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: # pragma: no cover
|
|
45
|
+
code = str(self.text).rstrip()
|
|
46
|
+
yield Text(self.lexer_name, style='dim')
|
|
47
|
+
yield Syntax(code, self.lexer_name, theme=self.theme, background_color='default', word_wrap=True)
|
|
48
|
+
yield Text(f'/{self.lexer_name}', style='dim')
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
Markdown.elements['fence'] = SimpleCodeBlock
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma: no cover
|
|
55
|
+
parser = argparse.ArgumentParser(
|
|
56
|
+
prog='pai',
|
|
57
|
+
description=f"""\
|
|
58
|
+
PydanticAI CLI v{__version__}\n\n
|
|
59
|
+
|
|
60
|
+
Special prompt:
|
|
61
|
+
* `/exit` - exit the interactive mode
|
|
62
|
+
* `/markdown` - show the last markdown output of the last question
|
|
63
|
+
* `/multiline` - toggle multiline mode
|
|
64
|
+
""",
|
|
65
|
+
formatter_class=argparse.RawTextHelpFormatter,
|
|
66
|
+
)
|
|
67
|
+
parser.add_argument('prompt', nargs='?', help='AI Prompt, if omitted fall into interactive mode')
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
'--model',
|
|
70
|
+
nargs='?',
|
|
71
|
+
help='Model to use, it should be "<provider>:<model>" e.g. "openai:gpt-4o". If omitted it will default to "openai:gpt-4o"',
|
|
72
|
+
default='openai:gpt-4o',
|
|
73
|
+
).completer = argcomplete.ChoicesCompleter(list(get_literal_values(KnownModelName))) # type: ignore[reportPrivateUsage]
|
|
74
|
+
parser.add_argument('--no-stream', action='store_true', help='Whether to stream responses from OpenAI')
|
|
75
|
+
parser.add_argument('--version', action='store_true', help='Show version and exit')
|
|
76
|
+
|
|
77
|
+
argcomplete.autocomplete(parser)
|
|
78
|
+
args = parser.parse_args(args_list)
|
|
79
|
+
|
|
80
|
+
console = Console()
|
|
81
|
+
console.print(f'pai - PydanticAI CLI v{__version__}', style='green bold', highlight=False)
|
|
82
|
+
if args.version:
|
|
83
|
+
return 0
|
|
84
|
+
|
|
85
|
+
now_utc = datetime.now(timezone.utc)
|
|
86
|
+
tzname = now_utc.astimezone().tzinfo.tzname(now_utc) # type: ignore
|
|
87
|
+
try:
|
|
88
|
+
agent = Agent(
|
|
89
|
+
model=args.model or 'openai:gpt-4o',
|
|
90
|
+
system_prompt=f"""\
|
|
91
|
+
Help the user by responding to their request, the output should be concise and always written in markdown.
|
|
92
|
+
The current date and time is {datetime.now()} {tzname}.
|
|
93
|
+
The user is running {sys.platform}.""",
|
|
94
|
+
)
|
|
95
|
+
except UserError:
|
|
96
|
+
console.print(f'[red]Invalid model "{args.model}"[/red]')
|
|
97
|
+
return 1
|
|
98
|
+
|
|
99
|
+
stream = not args.no_stream
|
|
100
|
+
|
|
101
|
+
if prompt := cast(str, args.prompt):
|
|
102
|
+
try:
|
|
103
|
+
asyncio.run(ask_agent(agent, prompt, stream, console))
|
|
104
|
+
except KeyboardInterrupt:
|
|
105
|
+
pass
|
|
106
|
+
return 0
|
|
107
|
+
|
|
108
|
+
history = Path.home() / '.pai-prompt-history.txt'
|
|
109
|
+
session = PromptSession(history=FileHistory(str(history))) # type: ignore
|
|
110
|
+
multiline = False
|
|
111
|
+
messages: list[ModelMessage] = []
|
|
112
|
+
|
|
113
|
+
while True:
|
|
114
|
+
try:
|
|
115
|
+
auto_suggest = CustomAutoSuggest(['/markdown', '/multiline', '/exit'])
|
|
116
|
+
text = cast(str, session.prompt('pai ➤ ', auto_suggest=auto_suggest, multiline=multiline))
|
|
117
|
+
except (KeyboardInterrupt, EOFError):
|
|
118
|
+
return 0
|
|
119
|
+
|
|
120
|
+
if not text.strip():
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
ident_prompt = text.lower().strip(' ').replace(' ', '-').lstrip(' ')
|
|
124
|
+
if ident_prompt == '/markdown':
|
|
125
|
+
try:
|
|
126
|
+
parts = messages[-1].parts
|
|
127
|
+
except IndexError:
|
|
128
|
+
console.print('[dim]No markdown output available.[/dim]')
|
|
129
|
+
continue
|
|
130
|
+
for part in parts:
|
|
131
|
+
if part.part_kind == 'text':
|
|
132
|
+
last_content = part.content
|
|
133
|
+
console.print('[dim]Last markdown output of last question:[/dim]\n')
|
|
134
|
+
console.print(Syntax(last_content, lexer='markdown', background_color='default'))
|
|
135
|
+
|
|
136
|
+
continue
|
|
137
|
+
if ident_prompt == '/multiline':
|
|
138
|
+
multiline = not multiline
|
|
139
|
+
if multiline:
|
|
140
|
+
console.print(
|
|
141
|
+
'Enabling multiline mode. '
|
|
142
|
+
'[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
console.print('Disabling multiline mode.')
|
|
146
|
+
continue
|
|
147
|
+
if ident_prompt == '/exit':
|
|
148
|
+
console.print('[dim]Exiting…[/dim]')
|
|
149
|
+
return 0
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
messages = asyncio.run(ask_agent(agent, text, stream, console, messages))
|
|
153
|
+
except KeyboardInterrupt:
|
|
154
|
+
return 0
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
async def ask_agent(
|
|
158
|
+
agent: Agent,
|
|
159
|
+
prompt: str,
|
|
160
|
+
stream: bool,
|
|
161
|
+
console: Console,
|
|
162
|
+
messages: list[ModelMessage] | None = None,
|
|
163
|
+
) -> list[ModelMessage]: # pragma: no cover
|
|
164
|
+
status: None | Status = Status('[dim]Working on it…[/dim]', console=console)
|
|
165
|
+
live = Live('', refresh_per_second=15, console=console)
|
|
166
|
+
status.start()
|
|
167
|
+
|
|
168
|
+
async with agent.iter(prompt, message_history=messages) as agent_run:
|
|
169
|
+
console.print('\nResponse:', style='green')
|
|
170
|
+
|
|
171
|
+
content: str = ''
|
|
172
|
+
interrupted = False
|
|
173
|
+
try:
|
|
174
|
+
node = agent_run.next_node
|
|
175
|
+
while not isinstance(node, End):
|
|
176
|
+
node = await agent_run.next(node)
|
|
177
|
+
if Agent.is_model_request_node(node):
|
|
178
|
+
async with node.stream(agent_run.ctx) as handle_stream:
|
|
179
|
+
# NOTE(Marcelo): It took me a lot of time to figure out how to stop `status` and start `live`
|
|
180
|
+
# in a context manager, so I had to do it manually with `stop` and `start` methods.
|
|
181
|
+
# PR welcome to simplify this code.
|
|
182
|
+
if status is not None:
|
|
183
|
+
status.stop()
|
|
184
|
+
status = None
|
|
185
|
+
if not live.is_started:
|
|
186
|
+
live.start()
|
|
187
|
+
async for event in handle_stream:
|
|
188
|
+
if isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta):
|
|
189
|
+
if stream:
|
|
190
|
+
content += event.delta.content_delta
|
|
191
|
+
live.update(Markdown(content))
|
|
192
|
+
except KeyboardInterrupt:
|
|
193
|
+
interrupted = True
|
|
194
|
+
finally:
|
|
195
|
+
live.stop()
|
|
196
|
+
|
|
197
|
+
if interrupted:
|
|
198
|
+
console.print('[dim]Interrupted[/dim]')
|
|
199
|
+
|
|
200
|
+
assert agent_run.result
|
|
201
|
+
if not stream:
|
|
202
|
+
content = agent_run.result.data
|
|
203
|
+
console.print(Markdown(content))
|
|
204
|
+
return agent_run.result.all_messages()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class CustomAutoSuggest(AutoSuggestFromHistory):
|
|
208
|
+
def __init__(self, special_suggestions: list[str] | None = None): # pragma: no cover
|
|
209
|
+
super().__init__()
|
|
210
|
+
self.special_suggestions = special_suggestions or []
|
|
211
|
+
|
|
212
|
+
def get_suggestion(self, buffer: Buffer, document: Document) -> Suggestion | None: # pragma: no cover
|
|
213
|
+
# Get the suggestion from history
|
|
214
|
+
suggestion = super().get_suggestion(buffer, document)
|
|
215
|
+
|
|
216
|
+
# Check for custom suggestions
|
|
217
|
+
text = document.text_before_cursor.strip()
|
|
218
|
+
for special in self.special_suggestions:
|
|
219
|
+
if special.startswith(text):
|
|
220
|
+
return Suggestion(special[len(text) :])
|
|
221
|
+
return suggestion
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def app(): # pragma: no cover
|
|
225
|
+
sys.exit(cli())
|
|
@@ -6,7 +6,7 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
|
|
|
6
6
|
from __future__ import annotations as _annotations
|
|
7
7
|
|
|
8
8
|
from inspect import Parameter, signature
|
|
9
|
-
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
|
|
10
10
|
|
|
11
11
|
from pydantic import ConfigDict
|
|
12
12
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
@@ -15,6 +15,7 @@ from pydantic.fields import FieldInfo
|
|
|
15
15
|
from pydantic.json_schema import GenerateJsonSchema
|
|
16
16
|
from pydantic.plugin._schema_validator import create_schema_validator
|
|
17
17
|
from pydantic_core import SchemaValidator, core_schema
|
|
18
|
+
from typing_extensions import get_origin
|
|
18
19
|
|
|
19
20
|
from ._griffe import doc_descriptions
|
|
20
21
|
from ._utils import check_object_json_schema, is_model_like
|
|
@@ -223,8 +224,7 @@ def _build_schema(
|
|
|
223
224
|
|
|
224
225
|
|
|
225
226
|
def _is_call_ctx(annotation: Any) -> bool:
|
|
227
|
+
"""Return whether the annotation is the `RunContext` class, parameterized or not."""
|
|
226
228
|
from .tools import RunContext
|
|
227
229
|
|
|
228
|
-
return annotation is RunContext or (
|
|
229
|
-
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
|
|
230
|
-
)
|
|
230
|
+
return annotation is RunContext or get_origin(annotation) is RunContext
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
import sys
|
|
5
|
-
import types
|
|
6
4
|
from collections.abc import Awaitable, Iterable, Iterator
|
|
7
5
|
from dataclasses import dataclass, field
|
|
8
|
-
from typing import Any, Callable, Generic, Literal, Union, cast
|
|
6
|
+
from typing import Any, Callable, Generic, Literal, Union, cast
|
|
9
7
|
|
|
10
8
|
from pydantic import TypeAdapter, ValidationError
|
|
11
|
-
from typing_extensions import
|
|
9
|
+
from typing_extensions import TypedDict, TypeVar, get_args, get_origin
|
|
10
|
+
from typing_inspection import typing_objects
|
|
11
|
+
from typing_inspection.introspection import is_union_origin
|
|
12
12
|
|
|
13
13
|
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
@@ -248,23 +248,12 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
|
|
|
248
248
|
|
|
249
249
|
|
|
250
250
|
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
251
|
-
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty
|
|
252
|
-
if
|
|
251
|
+
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty tuple."""
|
|
252
|
+
if typing_objects.is_typealiastype(tp):
|
|
253
253
|
tp = tp.__value__
|
|
254
254
|
|
|
255
255
|
origin = get_origin(tp)
|
|
256
|
-
if
|
|
256
|
+
if is_union_origin(origin):
|
|
257
257
|
return get_args(tp)
|
|
258
258
|
else:
|
|
259
259
|
return ()
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
if sys.version_info < (3, 10):
|
|
263
|
-
|
|
264
|
-
def origin_is_union(tp: type[Any] | None) -> bool:
|
|
265
|
-
return tp is Union
|
|
266
|
-
|
|
267
|
-
else:
|
|
268
|
-
|
|
269
|
-
def origin_is_union(tp: type[Any] | None) -> bool:
|
|
270
|
-
return tp is Union or tp is types.UnionType
|
|
@@ -6,7 +6,7 @@ from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
|
6
6
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from types import FrameType
|
|
9
|
-
from typing import Any, Callable, Generic, cast, final, overload
|
|
9
|
+
from typing import Any, Callable, ClassVar, Generic, cast, final, overload
|
|
10
10
|
|
|
11
11
|
from opentelemetry.trace import NoOpTracer, use_span
|
|
12
12
|
from typing_extensions import TypeGuard, TypeVar, deprecated
|
|
@@ -25,7 +25,7 @@ from . import (
|
|
|
25
25
|
result,
|
|
26
26
|
usage as _usage,
|
|
27
27
|
)
|
|
28
|
-
from .models.instrumented import InstrumentedModel
|
|
28
|
+
from .models.instrumented import InstrumentationSettings, InstrumentedModel
|
|
29
29
|
from .result import FinalResult, ResultDataT, StreamedRunResult
|
|
30
30
|
from .settings import ModelSettings, merge_model_settings
|
|
31
31
|
from .tools import (
|
|
@@ -56,6 +56,7 @@ __all__ = (
|
|
|
56
56
|
'CallToolsNode',
|
|
57
57
|
'ModelRequestNode',
|
|
58
58
|
'UserPromptNode',
|
|
59
|
+
'InstrumentationSettings',
|
|
59
60
|
)
|
|
60
61
|
|
|
61
62
|
|
|
@@ -112,8 +113,10 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
112
113
|
The type of the result data, used to validate the result data, defaults to `str`.
|
|
113
114
|
"""
|
|
114
115
|
|
|
115
|
-
instrument: bool
|
|
116
|
-
"""
|
|
116
|
+
instrument: InstrumentationSettings | bool | None
|
|
117
|
+
"""Options to automatically instrument with OpenTelemetry."""
|
|
118
|
+
|
|
119
|
+
_instrument_default: ClassVar[InstrumentationSettings | bool] = False
|
|
117
120
|
|
|
118
121
|
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
|
|
119
122
|
_result_tool_name: str = dataclasses.field(repr=False)
|
|
@@ -147,7 +150,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
147
150
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
148
151
|
defer_model_check: bool = False,
|
|
149
152
|
end_strategy: EndStrategy = 'early',
|
|
150
|
-
instrument: bool =
|
|
153
|
+
instrument: InstrumentationSettings | bool | None = None,
|
|
151
154
|
):
|
|
152
155
|
"""Create an agent.
|
|
153
156
|
|
|
@@ -177,7 +180,12 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
177
180
|
[override the model][pydantic_ai.Agent.override] for testing.
|
|
178
181
|
end_strategy: Strategy for handling tool calls that are requested alongside a final result.
|
|
179
182
|
See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
|
|
180
|
-
instrument:
|
|
183
|
+
instrument: Set to True to automatically instrument with OpenTelemetry,
|
|
184
|
+
which will use Logfire if it's configured.
|
|
185
|
+
Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize.
|
|
186
|
+
If this isn't set, then the last value set by
|
|
187
|
+
[`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all]
|
|
188
|
+
will be used, which defaults to False.
|
|
181
189
|
"""
|
|
182
190
|
if model is None or defer_model_check:
|
|
183
191
|
self.model = model
|
|
@@ -213,6 +221,11 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
213
221
|
else:
|
|
214
222
|
self._register_tool(Tool(tool))
|
|
215
223
|
|
|
224
|
+
@staticmethod
|
|
225
|
+
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
|
|
226
|
+
"""Set the instrumentation options for all agents where `instrument` is not set."""
|
|
227
|
+
Agent._instrument_default = instrument
|
|
228
|
+
|
|
216
229
|
@overload
|
|
217
230
|
async def run(
|
|
218
231
|
self,
|
|
@@ -422,7 +435,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
422
435
|
usage_limits = usage_limits or _usage.UsageLimits()
|
|
423
436
|
|
|
424
437
|
if isinstance(model_used, InstrumentedModel):
|
|
425
|
-
tracer = model_used.tracer
|
|
438
|
+
tracer = model_used.options.tracer
|
|
426
439
|
else:
|
|
427
440
|
tracer = NoOpTracer()
|
|
428
441
|
agent_name = self.name or 'agent'
|
|
@@ -1119,8 +1132,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1119
1132
|
else:
|
|
1120
1133
|
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
1121
1134
|
|
|
1122
|
-
|
|
1123
|
-
|
|
1135
|
+
instrument = self.instrument
|
|
1136
|
+
if instrument is None:
|
|
1137
|
+
instrument = self._instrument_default
|
|
1138
|
+
|
|
1139
|
+
if instrument and not isinstance(model_, InstrumentedModel):
|
|
1140
|
+
if instrument is True:
|
|
1141
|
+
instrument = InstrumentationSettings()
|
|
1142
|
+
|
|
1143
|
+
model_ = InstrumentedModel(model_, instrument)
|
|
1124
1144
|
|
|
1125
1145
|
return model_
|
|
1126
1146
|
|
|
@@ -189,7 +189,10 @@ class ToolReturnPart:
|
|
|
189
189
|
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
|
|
190
190
|
|
|
191
191
|
def otel_event(self) -> Event:
|
|
192
|
-
return Event(
|
|
192
|
+
return Event(
|
|
193
|
+
'gen_ai.tool.message',
|
|
194
|
+
body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id, 'name': self.tool_name},
|
|
195
|
+
)
|
|
193
196
|
|
|
194
197
|
|
|
195
198
|
error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
|
|
@@ -244,7 +247,13 @@ class RetryPromptPart:
|
|
|
244
247
|
return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
|
|
245
248
|
else:
|
|
246
249
|
return Event(
|
|
247
|
-
'gen_ai.tool.message',
|
|
250
|
+
'gen_ai.tool.message',
|
|
251
|
+
body={
|
|
252
|
+
'content': self.model_response(),
|
|
253
|
+
'role': 'tool',
|
|
254
|
+
'id': self.tool_call_id,
|
|
255
|
+
'name': self.tool_name,
|
|
256
|
+
},
|
|
248
257
|
)
|
|
249
258
|
|
|
250
259
|
|
|
@@ -49,6 +49,8 @@ KnownModelName = Literal[
|
|
|
49
49
|
'cohere:command-r-plus-04-2024',
|
|
50
50
|
'cohere:command-r-plus-08-2024',
|
|
51
51
|
'cohere:command-r7b-12-2024',
|
|
52
|
+
'deepseek:deepseek-chat',
|
|
53
|
+
'deepseek:deepseek-reasoner',
|
|
52
54
|
'google-gla:gemini-1.0-pro',
|
|
53
55
|
'google-gla:gemini-1.5-flash',
|
|
54
56
|
'google-gla:gemini-1.5-flash-8b',
|
|
@@ -320,54 +322,52 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
320
322
|
from .test import TestModel
|
|
321
323
|
|
|
322
324
|
return TestModel()
|
|
323
|
-
elif model.startswith('cohere:'):
|
|
324
|
-
from .cohere import CohereModel
|
|
325
325
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
326
|
+
try:
|
|
327
|
+
provider, model_name = model.split(':')
|
|
328
|
+
except ValueError:
|
|
329
|
+
model_name = model
|
|
330
|
+
# TODO(Marcelo): We should deprecate this way.
|
|
331
|
+
if model_name.startswith(('gpt', 'o1', 'o3')):
|
|
332
|
+
provider = 'openai'
|
|
333
|
+
elif model_name.startswith('claude'):
|
|
334
|
+
provider = 'anthropic'
|
|
335
|
+
elif model_name.startswith('gemini'):
|
|
336
|
+
provider = 'google-gla'
|
|
337
|
+
else:
|
|
338
|
+
raise UserError(f'Unknown model: {model}')
|
|
339
|
+
|
|
340
|
+
if provider == 'vertexai':
|
|
341
|
+
provider = 'google-vertex'
|
|
342
|
+
|
|
343
|
+
if provider == 'cohere':
|
|
344
|
+
from .cohere import CohereModel
|
|
329
345
|
|
|
330
|
-
|
|
331
|
-
|
|
346
|
+
# TODO(Marcelo): Missing provider API.
|
|
347
|
+
return CohereModel(model_name)
|
|
348
|
+
elif provider in ('deepseek', 'openai'):
|
|
332
349
|
from .openai import OpenAIModel
|
|
333
350
|
|
|
334
|
-
return OpenAIModel(
|
|
335
|
-
elif
|
|
336
|
-
from .gemini import GeminiModel
|
|
337
|
-
|
|
338
|
-
return GeminiModel(model[11:])
|
|
339
|
-
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
|
|
340
|
-
elif model.startswith('gemini'):
|
|
351
|
+
return OpenAIModel(model_name, provider=provider)
|
|
352
|
+
elif provider in ('google-gla', 'google-vertex'):
|
|
341
353
|
from .gemini import GeminiModel
|
|
342
354
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
elif model.startswith('groq:'):
|
|
355
|
+
return GeminiModel(model_name, provider=provider)
|
|
356
|
+
elif provider == 'groq':
|
|
346
357
|
from .groq import GroqModel
|
|
347
358
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
return VertexAIModel(model[14:])
|
|
353
|
-
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
|
|
354
|
-
elif model.startswith('vertexai:'):
|
|
355
|
-
from .vertexai import VertexAIModel
|
|
356
|
-
|
|
357
|
-
return VertexAIModel(model[9:])
|
|
358
|
-
elif model.startswith('mistral:'):
|
|
359
|
+
# TODO(Marcelo): Missing provider API.
|
|
360
|
+
return GroqModel(model_name)
|
|
361
|
+
elif provider == 'mistral':
|
|
359
362
|
from .mistral import MistralModel
|
|
360
363
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
return AnthropicModel(model[10:])
|
|
366
|
-
# backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
|
|
367
|
-
elif model.startswith('claude'):
|
|
364
|
+
# TODO(Marcelo): Missing provider API.
|
|
365
|
+
return MistralModel(model_name)
|
|
366
|
+
elif provider == 'anthropic':
|
|
368
367
|
from .anthropic import AnthropicModel
|
|
369
368
|
|
|
370
|
-
|
|
369
|
+
# TODO(Marcelo): Missing provider API.
|
|
370
|
+
return AnthropicModel(model_name)
|
|
371
371
|
else:
|
|
372
372
|
raise UserError(f'Unknown model: {model}')
|
|
373
373
|
|
|
@@ -8,12 +8,14 @@ from contextlib import asynccontextmanager
|
|
|
8
8
|
from copy import deepcopy
|
|
9
9
|
from dataclasses import dataclass, field
|
|
10
10
|
from datetime import datetime
|
|
11
|
-
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
11
|
+
from typing import Annotated, Any, Literal, Protocol, Union, cast, overload
|
|
12
12
|
from uuid import uuid4
|
|
13
13
|
|
|
14
14
|
import pydantic
|
|
15
15
|
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
16
|
-
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
16
|
+
from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
|
|
17
|
+
|
|
18
|
+
from pydantic_ai.providers import Provider, infer_provider
|
|
17
19
|
|
|
18
20
|
from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
|
|
19
21
|
from ..messages import (
|
|
@@ -82,17 +84,39 @@ class GeminiModel(Model):
|
|
|
82
84
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
83
85
|
"""
|
|
84
86
|
|
|
85
|
-
|
|
87
|
+
client: AsyncHTTPClient = field(repr=False)
|
|
86
88
|
|
|
87
89
|
_model_name: GeminiModelName = field(repr=False)
|
|
90
|
+
_provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = field(repr=False)
|
|
88
91
|
_auth: AuthProtocol | None = field(repr=False)
|
|
89
92
|
_url: str | None = field(repr=False)
|
|
90
93
|
_system: str | None = field(default='google-gla', repr=False)
|
|
91
94
|
|
|
95
|
+
@overload
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
model_name: GeminiModelName,
|
|
99
|
+
*,
|
|
100
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] = 'google-gla',
|
|
101
|
+
) -> None: ...
|
|
102
|
+
|
|
103
|
+
@deprecated('Use the `provider` argument instead of the `api_key`, `http_client`, and `url_template` arguments.')
|
|
104
|
+
@overload
|
|
92
105
|
def __init__(
|
|
93
106
|
self,
|
|
94
107
|
model_name: GeminiModelName,
|
|
95
108
|
*,
|
|
109
|
+
provider: None = None,
|
|
110
|
+
api_key: str | None = None,
|
|
111
|
+
http_client: AsyncHTTPClient | None = None,
|
|
112
|
+
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
113
|
+
) -> None: ...
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
model_name: GeminiModelName,
|
|
118
|
+
*,
|
|
119
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = None,
|
|
96
120
|
api_key: str | None = None,
|
|
97
121
|
http_client: AsyncHTTPClient | None = None,
|
|
98
122
|
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
@@ -101,6 +125,7 @@ class GeminiModel(Model):
|
|
|
101
125
|
|
|
102
126
|
Args:
|
|
103
127
|
model_name: The name of the model to use.
|
|
128
|
+
provider: The provider to use for the model.
|
|
104
129
|
api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
|
|
105
130
|
will be used if available.
|
|
106
131
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
@@ -109,14 +134,24 @@ class GeminiModel(Model):
|
|
|
109
134
|
`model` is substituted with the model name, and `function` is added to the end of the URL.
|
|
110
135
|
"""
|
|
111
136
|
self._model_name = model_name
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
137
|
+
self._provider = provider
|
|
138
|
+
|
|
139
|
+
if provider is not None:
|
|
140
|
+
if isinstance(provider, str):
|
|
141
|
+
self._system = provider
|
|
142
|
+
self.client = infer_provider(provider).client
|
|
115
143
|
else:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
144
|
+
self._system = provider.name
|
|
145
|
+
self.client = provider.client
|
|
146
|
+
else:
|
|
147
|
+
if api_key is None:
|
|
148
|
+
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
149
|
+
api_key = env_api_key
|
|
150
|
+
else:
|
|
151
|
+
raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
152
|
+
self.client = http_client or cached_async_http_client()
|
|
153
|
+
self._auth = ApiKeyAuth(api_key)
|
|
154
|
+
self._url = url_template.format(model=model_name)
|
|
120
155
|
|
|
121
156
|
@property
|
|
122
157
|
def auth(self) -> AuthProtocol:
|
|
@@ -217,17 +252,19 @@ class GeminiModel(Model):
|
|
|
217
252
|
if generation_config:
|
|
218
253
|
request_data['generation_config'] = generation_config
|
|
219
254
|
|
|
220
|
-
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
221
|
-
|
|
222
255
|
headers = {
|
|
223
256
|
'Content-Type': 'application/json',
|
|
224
257
|
'User-Agent': get_user_agent(),
|
|
225
|
-
**await self.auth.headers(),
|
|
226
258
|
}
|
|
259
|
+
if self._provider is None: # pragma: no cover
|
|
260
|
+
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
261
|
+
headers.update(await self.auth.headers())
|
|
262
|
+
else:
|
|
263
|
+
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
|
|
227
264
|
|
|
228
265
|
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
|
|
229
266
|
|
|
230
|
-
async with self.
|
|
267
|
+
async with self.client.stream(
|
|
231
268
|
'POST',
|
|
232
269
|
url,
|
|
233
270
|
content=request_json,
|