pydantic-ai-slim 0.0.32__tar.gz → 0.0.34__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (43) hide show
  1. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/PKG-INFO +7 -2
  2. pydantic_ai_slim-0.0.34/pydantic_ai/_cli.py +225 -0
  3. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_pydantic.py +4 -4
  4. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_result.py +7 -18
  5. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/agent.py +29 -9
  6. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/messages.py +11 -2
  7. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/__init__.py +36 -36
  8. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/gemini.py +51 -14
  9. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/instrumented.py +43 -9
  10. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/openai.py +56 -15
  11. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/vertexai.py +9 -1
  12. pydantic_ai_slim-0.0.34/pydantic_ai/providers/__init__.py +64 -0
  13. pydantic_ai_slim-0.0.34/pydantic_ai/providers/deepseek.py +68 -0
  14. pydantic_ai_slim-0.0.34/pydantic_ai/providers/google_gla.py +44 -0
  15. pydantic_ai_slim-0.0.34/pydantic_ai/providers/google_vertex.py +200 -0
  16. pydantic_ai_slim-0.0.34/pydantic_ai/providers/openai.py +72 -0
  17. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pyproject.toml +8 -2
  18. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/.gitignore +0 -0
  19. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/README.md +0 -0
  20. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/__init__.py +0 -0
  21. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_agent_graph.py +0 -0
  22. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_griffe.py +0 -0
  23. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_parts_manager.py +0 -0
  24. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_system_prompt.py +0 -0
  25. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/_utils.py +0 -0
  26. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/common_tools/__init__.py +0 -0
  27. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  28. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/common_tools/tavily.py +0 -0
  29. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/exceptions.py +0 -0
  30. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/format_as_xml.py +0 -0
  31. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/anthropic.py +0 -0
  32. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/cohere.py +0 -0
  33. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/fallback.py +0 -0
  34. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/function.py +0 -0
  35. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/groq.py +0 -0
  36. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/mistral.py +0 -0
  37. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/test.py +0 -0
  38. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/models/wrapper.py +0 -0
  39. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/py.typed +0 -0
  40. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/result.py +0 -0
  41. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/settings.py +0 -0
  42. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/tools.py +0 -0
  43. {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.34}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.32
3
+ Version: 0.0.34
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,10 +29,15 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.32
32
+ Requires-Dist: pydantic-graph==0.0.34
33
33
  Requires-Dist: pydantic>=2.10
34
+ Requires-Dist: typing-inspection>=0.4.0
34
35
  Provides-Extra: anthropic
35
36
  Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
37
+ Provides-Extra: cli
38
+ Requires-Dist: argcomplete>=3.5.0; extra == 'cli'
39
+ Requires-Dist: prompt-toolkit>=3; extra == 'cli'
40
+ Requires-Dist: rich>=13; extra == 'cli'
36
41
  Provides-Extra: cohere
37
42
  Requires-Dist: cohere>=5.13.11; extra == 'cohere'
38
43
  Provides-Extra: duckduckgo
@@ -0,0 +1,225 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import argparse
4
+ import asyncio
5
+ import sys
6
+ from collections.abc import Sequence
7
+ from datetime import datetime, timezone
8
+ from importlib.metadata import version
9
+ from pathlib import Path
10
+ from typing import cast
11
+
12
+ from typing_inspection.introspection import get_literal_values
13
+
14
+ from pydantic_ai.exceptions import UserError
15
+ from pydantic_ai.models import KnownModelName
16
+ from pydantic_graph.nodes import End
17
+
18
+ try:
19
+ import argcomplete
20
+ from prompt_toolkit import PromptSession
21
+ from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, Suggestion
22
+ from prompt_toolkit.buffer import Buffer
23
+ from prompt_toolkit.document import Document
24
+ from prompt_toolkit.history import FileHistory
25
+ from rich.console import Console, ConsoleOptions, RenderResult
26
+ from rich.live import Live
27
+ from rich.markdown import CodeBlock, Markdown
28
+ from rich.status import Status
29
+ from rich.syntax import Syntax
30
+ from rich.text import Text
31
+ except ImportError as _import_error:
32
+ raise ImportError(
33
+ 'Please install `rich`, `prompt-toolkit` and `argcomplete` to use the PydanticAI CLI, '
34
+ "you can use the `cli` optional group — `pip install 'pydantic-ai-slim[cli]'`"
35
+ ) from _import_error
36
+
37
+ from pydantic_ai.agent import Agent
38
+ from pydantic_ai.messages import ModelMessage, PartDeltaEvent, TextPartDelta
39
+
40
+ __version__ = version('pydantic-ai')
41
+
42
+
43
+ class SimpleCodeBlock(CodeBlock):
44
+ def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: # pragma: no cover
45
+ code = str(self.text).rstrip()
46
+ yield Text(self.lexer_name, style='dim')
47
+ yield Syntax(code, self.lexer_name, theme=self.theme, background_color='default', word_wrap=True)
48
+ yield Text(f'/{self.lexer_name}', style='dim')
49
+
50
+
51
+ Markdown.elements['fence'] = SimpleCodeBlock
52
+
53
+
54
+ def cli(args_list: Sequence[str] | None = None) -> int: # noqa: C901 # pragma: no cover
55
+ parser = argparse.ArgumentParser(
56
+ prog='pai',
57
+ description=f"""\
58
+ PydanticAI CLI v{__version__}\n\n
59
+
60
+ Special prompt:
61
+ * `/exit` - exit the interactive mode
62
+ * `/markdown` - show the last markdown output of the last question
63
+ * `/multiline` - toggle multiline mode
64
+ """,
65
+ formatter_class=argparse.RawTextHelpFormatter,
66
+ )
67
+ parser.add_argument('prompt', nargs='?', help='AI Prompt, if omitted fall into interactive mode')
68
+ parser.add_argument(
69
+ '--model',
70
+ nargs='?',
71
+ help='Model to use, it should be "<provider>:<model>" e.g. "openai:gpt-4o". If omitted it will default to "openai:gpt-4o"',
72
+ default='openai:gpt-4o',
73
+ ).completer = argcomplete.ChoicesCompleter(list(get_literal_values(KnownModelName))) # type: ignore[reportPrivateUsage]
74
+ parser.add_argument('--no-stream', action='store_true', help='Whether to stream responses from OpenAI')
75
+ parser.add_argument('--version', action='store_true', help='Show version and exit')
76
+
77
+ argcomplete.autocomplete(parser)
78
+ args = parser.parse_args(args_list)
79
+
80
+ console = Console()
81
+ console.print(f'pai - PydanticAI CLI v{__version__}', style='green bold', highlight=False)
82
+ if args.version:
83
+ return 0
84
+
85
+ now_utc = datetime.now(timezone.utc)
86
+ tzname = now_utc.astimezone().tzinfo.tzname(now_utc) # type: ignore
87
+ try:
88
+ agent = Agent(
89
+ model=args.model or 'openai:gpt-4o',
90
+ system_prompt=f"""\
91
+ Help the user by responding to their request, the output should be concise and always written in markdown.
92
+ The current date and time is {datetime.now()} {tzname}.
93
+ The user is running {sys.platform}.""",
94
+ )
95
+ except UserError:
96
+ console.print(f'[red]Invalid model "{args.model}"[/red]')
97
+ return 1
98
+
99
+ stream = not args.no_stream
100
+
101
+ if prompt := cast(str, args.prompt):
102
+ try:
103
+ asyncio.run(ask_agent(agent, prompt, stream, console))
104
+ except KeyboardInterrupt:
105
+ pass
106
+ return 0
107
+
108
+ history = Path.home() / '.pai-prompt-history.txt'
109
+ session = PromptSession(history=FileHistory(str(history))) # type: ignore
110
+ multiline = False
111
+ messages: list[ModelMessage] = []
112
+
113
+ while True:
114
+ try:
115
+ auto_suggest = CustomAutoSuggest(['/markdown', '/multiline', '/exit'])
116
+ text = cast(str, session.prompt('pai ➤ ', auto_suggest=auto_suggest, multiline=multiline))
117
+ except (KeyboardInterrupt, EOFError):
118
+ return 0
119
+
120
+ if not text.strip():
121
+ continue
122
+
123
+ ident_prompt = text.lower().strip(' ').replace(' ', '-').lstrip(' ')
124
+ if ident_prompt == '/markdown':
125
+ try:
126
+ parts = messages[-1].parts
127
+ except IndexError:
128
+ console.print('[dim]No markdown output available.[/dim]')
129
+ continue
130
+ for part in parts:
131
+ if part.part_kind == 'text':
132
+ last_content = part.content
133
+ console.print('[dim]Last markdown output of last question:[/dim]\n')
134
+ console.print(Syntax(last_content, lexer='markdown', background_color='default'))
135
+
136
+ continue
137
+ if ident_prompt == '/multiline':
138
+ multiline = not multiline
139
+ if multiline:
140
+ console.print(
141
+ 'Enabling multiline mode. '
142
+ '[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]'
143
+ )
144
+ else:
145
+ console.print('Disabling multiline mode.')
146
+ continue
147
+ if ident_prompt == '/exit':
148
+ console.print('[dim]Exiting…[/dim]')
149
+ return 0
150
+
151
+ try:
152
+ messages = asyncio.run(ask_agent(agent, text, stream, console, messages))
153
+ except KeyboardInterrupt:
154
+ return 0
155
+
156
+
157
+ async def ask_agent(
158
+ agent: Agent,
159
+ prompt: str,
160
+ stream: bool,
161
+ console: Console,
162
+ messages: list[ModelMessage] | None = None,
163
+ ) -> list[ModelMessage]: # pragma: no cover
164
+ status: None | Status = Status('[dim]Working on it…[/dim]', console=console)
165
+ live = Live('', refresh_per_second=15, console=console)
166
+ status.start()
167
+
168
+ async with agent.iter(prompt, message_history=messages) as agent_run:
169
+ console.print('\nResponse:', style='green')
170
+
171
+ content: str = ''
172
+ interrupted = False
173
+ try:
174
+ node = agent_run.next_node
175
+ while not isinstance(node, End):
176
+ node = await agent_run.next(node)
177
+ if Agent.is_model_request_node(node):
178
+ async with node.stream(agent_run.ctx) as handle_stream:
179
+ # NOTE(Marcelo): It took me a lot of time to figure out how to stop `status` and start `live`
180
+ # in a context manager, so I had to do it manually with `stop` and `start` methods.
181
+ # PR welcome to simplify this code.
182
+ if status is not None:
183
+ status.stop()
184
+ status = None
185
+ if not live.is_started:
186
+ live.start()
187
+ async for event in handle_stream:
188
+ if isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta):
189
+ if stream:
190
+ content += event.delta.content_delta
191
+ live.update(Markdown(content))
192
+ except KeyboardInterrupt:
193
+ interrupted = True
194
+ finally:
195
+ live.stop()
196
+
197
+ if interrupted:
198
+ console.print('[dim]Interrupted[/dim]')
199
+
200
+ assert agent_run.result
201
+ if not stream:
202
+ content = agent_run.result.data
203
+ console.print(Markdown(content))
204
+ return agent_run.result.all_messages()
205
+
206
+
207
+ class CustomAutoSuggest(AutoSuggestFromHistory):
208
+ def __init__(self, special_suggestions: list[str] | None = None): # pragma: no cover
209
+ super().__init__()
210
+ self.special_suggestions = special_suggestions or []
211
+
212
+ def get_suggestion(self, buffer: Buffer, document: Document) -> Suggestion | None: # pragma: no cover
213
+ # Get the suggestion from history
214
+ suggestion = super().get_suggestion(buffer, document)
215
+
216
+ # Check for custom suggestions
217
+ text = document.text_before_cursor.strip()
218
+ for special in self.special_suggestions:
219
+ if special.startswith(text):
220
+ return Suggestion(special[len(text) :])
221
+ return suggestion
222
+
223
+
224
+ def app(): # pragma: no cover
225
+ sys.exit(cli())
@@ -6,7 +6,7 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
6
6
  from __future__ import annotations as _annotations
7
7
 
8
8
  from inspect import Parameter, signature
9
- from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
9
+ from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
10
10
 
11
11
  from pydantic import ConfigDict
12
12
  from pydantic._internal import _decorators, _generate_schema, _typing_extra
@@ -15,6 +15,7 @@ from pydantic.fields import FieldInfo
15
15
  from pydantic.json_schema import GenerateJsonSchema
16
16
  from pydantic.plugin._schema_validator import create_schema_validator
17
17
  from pydantic_core import SchemaValidator, core_schema
18
+ from typing_extensions import get_origin
18
19
 
19
20
  from ._griffe import doc_descriptions
20
21
  from ._utils import check_object_json_schema, is_model_like
@@ -223,8 +224,7 @@ def _build_schema(
223
224
 
224
225
 
225
226
  def _is_call_ctx(annotation: Any) -> bool:
227
+ """Return whether the annotation is the `RunContext` class, parameterized or not."""
226
228
  from .tools import RunContext
227
229
 
228
- return annotation is RunContext or (
229
- _typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
230
- )
230
+ return annotation is RunContext or get_origin(annotation) is RunContext
@@ -1,14 +1,14 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
- import sys
5
- import types
6
4
  from collections.abc import Awaitable, Iterable, Iterator
7
5
  from dataclasses import dataclass, field
8
- from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
6
+ from typing import Any, Callable, Generic, Literal, Union, cast
9
7
 
10
8
  from pydantic import TypeAdapter, ValidationError
11
- from typing_extensions import TypeAliasType, TypedDict, TypeVar
9
+ from typing_extensions import TypedDict, TypeVar, get_args, get_origin
10
+ from typing_inspection import typing_objects
11
+ from typing_inspection.introspection import is_union_origin
12
12
 
13
13
  from . import _utils, messages as _messages
14
14
  from .exceptions import ModelRetry
@@ -248,23 +248,12 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
248
248
 
249
249
 
250
250
  def get_union_args(tp: Any) -> tuple[Any, ...]:
251
- """Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty union."""
252
- if isinstance(tp, TypeAliasType):
251
+ """Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty tuple."""
252
+ if typing_objects.is_typealiastype(tp):
253
253
  tp = tp.__value__
254
254
 
255
255
  origin = get_origin(tp)
256
- if origin_is_union(origin):
256
+ if is_union_origin(origin):
257
257
  return get_args(tp)
258
258
  else:
259
259
  return ()
260
-
261
-
262
- if sys.version_info < (3, 10):
263
-
264
- def origin_is_union(tp: type[Any] | None) -> bool:
265
- return tp is Union
266
-
267
- else:
268
-
269
- def origin_is_union(tp: type[Any] | None) -> bool:
270
- return tp is Union or tp is types.UnionType
@@ -6,7 +6,7 @@ from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
6
6
  from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
7
7
  from copy import deepcopy
8
8
  from types import FrameType
9
- from typing import Any, Callable, Generic, cast, final, overload
9
+ from typing import Any, Callable, ClassVar, Generic, cast, final, overload
10
10
 
11
11
  from opentelemetry.trace import NoOpTracer, use_span
12
12
  from typing_extensions import TypeGuard, TypeVar, deprecated
@@ -25,7 +25,7 @@ from . import (
25
25
  result,
26
26
  usage as _usage,
27
27
  )
28
- from .models.instrumented import InstrumentedModel
28
+ from .models.instrumented import InstrumentationSettings, InstrumentedModel
29
29
  from .result import FinalResult, ResultDataT, StreamedRunResult
30
30
  from .settings import ModelSettings, merge_model_settings
31
31
  from .tools import (
@@ -56,6 +56,7 @@ __all__ = (
56
56
  'CallToolsNode',
57
57
  'ModelRequestNode',
58
58
  'UserPromptNode',
59
+ 'InstrumentationSettings',
59
60
  )
60
61
 
61
62
 
@@ -112,8 +113,10 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
112
113
  The type of the result data, used to validate the result data, defaults to `str`.
113
114
  """
114
115
 
115
- instrument: bool
116
- """Automatically instrument with OpenTelemetry. Will use Logfire if it's configured."""
116
+ instrument: InstrumentationSettings | bool | None
117
+ """Options to automatically instrument with OpenTelemetry."""
118
+
119
+ _instrument_default: ClassVar[InstrumentationSettings | bool] = False
117
120
 
118
121
  _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
119
122
  _result_tool_name: str = dataclasses.field(repr=False)
@@ -147,7 +150,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
147
150
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
148
151
  defer_model_check: bool = False,
149
152
  end_strategy: EndStrategy = 'early',
150
- instrument: bool = False,
153
+ instrument: InstrumentationSettings | bool | None = None,
151
154
  ):
152
155
  """Create an agent.
153
156
 
@@ -177,7 +180,12 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
177
180
  [override the model][pydantic_ai.Agent.override] for testing.
178
181
  end_strategy: Strategy for handling tool calls that are requested alongside a final result.
179
182
  See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
180
- instrument: Automatically instrument with OpenTelemetry. Will use Logfire if it's configured.
183
+ instrument: Set to True to automatically instrument with OpenTelemetry,
184
+ which will use Logfire if it's configured.
185
+ Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize.
186
+ If this isn't set, then the last value set by
187
+ [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all]
188
+ will be used, which defaults to False.
181
189
  """
182
190
  if model is None or defer_model_check:
183
191
  self.model = model
@@ -213,6 +221,11 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
213
221
  else:
214
222
  self._register_tool(Tool(tool))
215
223
 
224
+ @staticmethod
225
+ def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
226
+ """Set the instrumentation options for all agents where `instrument` is not set."""
227
+ Agent._instrument_default = instrument
228
+
216
229
  @overload
217
230
  async def run(
218
231
  self,
@@ -422,7 +435,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
422
435
  usage_limits = usage_limits or _usage.UsageLimits()
423
436
 
424
437
  if isinstance(model_used, InstrumentedModel):
425
- tracer = model_used.tracer
438
+ tracer = model_used.options.tracer
426
439
  else:
427
440
  tracer = NoOpTracer()
428
441
  agent_name = self.name or 'agent'
@@ -1119,8 +1132,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1119
1132
  else:
1120
1133
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
1121
1134
 
1122
- if self.instrument and not isinstance(model_, InstrumentedModel):
1123
- model_ = InstrumentedModel(model_)
1135
+ instrument = self.instrument
1136
+ if instrument is None:
1137
+ instrument = self._instrument_default
1138
+
1139
+ if instrument and not isinstance(model_, InstrumentedModel):
1140
+ if instrument is True:
1141
+ instrument = InstrumentationSettings()
1142
+
1143
+ model_ = InstrumentedModel(model_, instrument)
1124
1144
 
1125
1145
  return model_
1126
1146
 
@@ -189,7 +189,10 @@ class ToolReturnPart:
189
189
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
190
190
 
191
191
  def otel_event(self) -> Event:
192
- return Event('gen_ai.tool.message', body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id})
192
+ return Event(
193
+ 'gen_ai.tool.message',
194
+ body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id, 'name': self.tool_name},
195
+ )
193
196
 
194
197
 
195
198
  error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
@@ -244,7 +247,13 @@ class RetryPromptPart:
244
247
  return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
245
248
  else:
246
249
  return Event(
247
- 'gen_ai.tool.message', body={'content': self.model_response(), 'role': 'tool', 'id': self.tool_call_id}
250
+ 'gen_ai.tool.message',
251
+ body={
252
+ 'content': self.model_response(),
253
+ 'role': 'tool',
254
+ 'id': self.tool_call_id,
255
+ 'name': self.tool_name,
256
+ },
248
257
  )
249
258
 
250
259
 
@@ -49,6 +49,8 @@ KnownModelName = Literal[
49
49
  'cohere:command-r-plus-04-2024',
50
50
  'cohere:command-r-plus-08-2024',
51
51
  'cohere:command-r7b-12-2024',
52
+ 'deepseek:deepseek-chat',
53
+ 'deepseek:deepseek-reasoner',
52
54
  'google-gla:gemini-1.0-pro',
53
55
  'google-gla:gemini-1.5-flash',
54
56
  'google-gla:gemini-1.5-flash-8b',
@@ -320,54 +322,52 @@ def infer_model(model: Model | KnownModelName) -> Model:
320
322
  from .test import TestModel
321
323
 
322
324
  return TestModel()
323
- elif model.startswith('cohere:'):
324
- from .cohere import CohereModel
325
325
 
326
- return CohereModel(model[7:])
327
- elif model.startswith('openai:'):
328
- from .openai import OpenAIModel
326
+ try:
327
+ provider, model_name = model.split(':')
328
+ except ValueError:
329
+ model_name = model
330
+ # TODO(Marcelo): We should deprecate this way.
331
+ if model_name.startswith(('gpt', 'o1', 'o3')):
332
+ provider = 'openai'
333
+ elif model_name.startswith('claude'):
334
+ provider = 'anthropic'
335
+ elif model_name.startswith('gemini'):
336
+ provider = 'google-gla'
337
+ else:
338
+ raise UserError(f'Unknown model: {model}')
339
+
340
+ if provider == 'vertexai':
341
+ provider = 'google-vertex'
342
+
343
+ if provider == 'cohere':
344
+ from .cohere import CohereModel
329
345
 
330
- return OpenAIModel(model[7:])
331
- elif model.startswith(('gpt', 'o1', 'o3')):
346
+ # TODO(Marcelo): Missing provider API.
347
+ return CohereModel(model_name)
348
+ elif provider in ('deepseek', 'openai'):
332
349
  from .openai import OpenAIModel
333
350
 
334
- return OpenAIModel(model)
335
- elif model.startswith('google-gla'):
336
- from .gemini import GeminiModel
337
-
338
- return GeminiModel(model[11:])
339
- # backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
340
- elif model.startswith('gemini'):
351
+ return OpenAIModel(model_name, provider=provider)
352
+ elif provider in ('google-gla', 'google-vertex'):
341
353
  from .gemini import GeminiModel
342
354
 
343
- # noinspection PyTypeChecker
344
- return GeminiModel(model)
345
- elif model.startswith('groq:'):
355
+ return GeminiModel(model_name, provider=provider)
356
+ elif provider == 'groq':
346
357
  from .groq import GroqModel
347
358
 
348
- return GroqModel(model[5:])
349
- elif model.startswith('google-vertex'):
350
- from .vertexai import VertexAIModel
351
-
352
- return VertexAIModel(model[14:])
353
- # backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
354
- elif model.startswith('vertexai:'):
355
- from .vertexai import VertexAIModel
356
-
357
- return VertexAIModel(model[9:])
358
- elif model.startswith('mistral:'):
359
+ # TODO(Marcelo): Missing provider API.
360
+ return GroqModel(model_name)
361
+ elif provider == 'mistral':
359
362
  from .mistral import MistralModel
360
363
 
361
- return MistralModel(model[8:])
362
- elif model.startswith('anthropic'):
363
- from .anthropic import AnthropicModel
364
-
365
- return AnthropicModel(model[10:])
366
- # backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
367
- elif model.startswith('claude'):
364
+ # TODO(Marcelo): Missing provider API.
365
+ return MistralModel(model_name)
366
+ elif provider == 'anthropic':
368
367
  from .anthropic import AnthropicModel
369
368
 
370
- return AnthropicModel(model)
369
+ # TODO(Marcelo): Missing provider API.
370
+ return AnthropicModel(model_name)
371
371
  else:
372
372
  raise UserError(f'Unknown model: {model}')
373
373
 
@@ -8,12 +8,14 @@ from contextlib import asynccontextmanager
8
8
  from copy import deepcopy
9
9
  from dataclasses import dataclass, field
10
10
  from datetime import datetime
11
- from typing import Annotated, Any, Literal, Protocol, Union, cast
11
+ from typing import Annotated, Any, Literal, Protocol, Union, cast, overload
12
12
  from uuid import uuid4
13
13
 
14
14
  import pydantic
15
15
  from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
16
- from typing_extensions import NotRequired, TypedDict, assert_never
16
+ from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
17
+
18
+ from pydantic_ai.providers import Provider, infer_provider
17
19
 
18
20
  from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
19
21
  from ..messages import (
@@ -82,17 +84,39 @@ class GeminiModel(Model):
82
84
  Apart from `__init__`, all methods are private or match those of the base class.
83
85
  """
84
86
 
85
- http_client: AsyncHTTPClient = field(repr=False)
87
+ client: AsyncHTTPClient = field(repr=False)
86
88
 
87
89
  _model_name: GeminiModelName = field(repr=False)
90
+ _provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = field(repr=False)
88
91
  _auth: AuthProtocol | None = field(repr=False)
89
92
  _url: str | None = field(repr=False)
90
93
  _system: str | None = field(default='google-gla', repr=False)
91
94
 
95
+ @overload
96
+ def __init__(
97
+ self,
98
+ model_name: GeminiModelName,
99
+ *,
100
+ provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] = 'google-gla',
101
+ ) -> None: ...
102
+
103
+ @deprecated('Use the `provider` argument instead of the `api_key`, `http_client`, and `url_template` arguments.')
104
+ @overload
92
105
  def __init__(
93
106
  self,
94
107
  model_name: GeminiModelName,
95
108
  *,
109
+ provider: None = None,
110
+ api_key: str | None = None,
111
+ http_client: AsyncHTTPClient | None = None,
112
+ url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
113
+ ) -> None: ...
114
+
115
+ def __init__(
116
+ self,
117
+ model_name: GeminiModelName,
118
+ *,
119
+ provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = None,
96
120
  api_key: str | None = None,
97
121
  http_client: AsyncHTTPClient | None = None,
98
122
  url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
@@ -101,6 +125,7 @@ class GeminiModel(Model):
101
125
 
102
126
  Args:
103
127
  model_name: The name of the model to use.
128
+ provider: The provider to use for the model.
104
129
  api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
105
130
  will be used if available.
106
131
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
@@ -109,14 +134,24 @@ class GeminiModel(Model):
109
134
  `model` is substituted with the model name, and `function` is added to the end of the URL.
110
135
  """
111
136
  self._model_name = model_name
112
- if api_key is None:
113
- if env_api_key := os.getenv('GEMINI_API_KEY'):
114
- api_key = env_api_key
137
+ self._provider = provider
138
+
139
+ if provider is not None:
140
+ if isinstance(provider, str):
141
+ self._system = provider
142
+ self.client = infer_provider(provider).client
115
143
  else:
116
- raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
117
- self.http_client = http_client or cached_async_http_client()
118
- self._auth = ApiKeyAuth(api_key)
119
- self._url = url_template.format(model=model_name)
144
+ self._system = provider.name
145
+ self.client = provider.client
146
+ else:
147
+ if api_key is None:
148
+ if env_api_key := os.getenv('GEMINI_API_KEY'):
149
+ api_key = env_api_key
150
+ else:
151
+ raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
152
+ self.client = http_client or cached_async_http_client()
153
+ self._auth = ApiKeyAuth(api_key)
154
+ self._url = url_template.format(model=model_name)
120
155
 
121
156
  @property
122
157
  def auth(self) -> AuthProtocol:
@@ -217,17 +252,19 @@ class GeminiModel(Model):
217
252
  if generation_config:
218
253
  request_data['generation_config'] = generation_config
219
254
 
220
- url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
221
-
222
255
  headers = {
223
256
  'Content-Type': 'application/json',
224
257
  'User-Agent': get_user_agent(),
225
- **await self.auth.headers(),
226
258
  }
259
+ if self._provider is None: # pragma: no cover
260
+ url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
261
+ headers.update(await self.auth.headers())
262
+ else:
263
+ url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
227
264
 
228
265
  request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
229
266
 
230
- async with self.http_client.stream(
267
+ async with self.client.stream(
231
268
  'POST',
232
269
  url,
233
270
  content=request_json,