pydantic-ai-slim 0.0.31__tar.gz → 0.0.33__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/PKG-INFO +4 -4
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/_agent_graph.py +39 -38
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/_pydantic.py +4 -4
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/_result.py +7 -18
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/agent.py +24 -21
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/__init__.py +40 -36
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/anthropic.py +3 -1
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/gemini.py +52 -14
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/instrumented.py +25 -27
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/openai.py +56 -15
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/vertexai.py +9 -1
- pydantic_ai_slim-0.0.33/pydantic_ai/providers/__init__.py +64 -0
- pydantic_ai_slim-0.0.33/pydantic_ai/providers/deepseek.py +68 -0
- pydantic_ai_slim-0.0.33/pydantic_ai/providers/google_gla.py +44 -0
- pydantic_ai_slim-0.0.33/pydantic_ai/providers/google_vertex.py +200 -0
- pydantic_ai_slim-0.0.33/pydantic_ai/providers/openai.py +72 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/result.py +19 -27
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pyproject.toml +4 -4
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/README.md +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/messages.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/mistral.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.0.31 → pydantic_ai_slim-0.0.33}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.33
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -28,12 +28,12 @@ Requires-Dist: eval-type-backport>=0.2.0
|
|
|
28
28
|
Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
|
-
Requires-Dist: logfire-api>=1.2.0
|
|
32
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
33
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.33
|
|
34
33
|
Requires-Dist: pydantic>=2.10
|
|
34
|
+
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
36
|
-
Requires-Dist: anthropic>=0.
|
|
36
|
+
Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
|
|
37
37
|
Provides-Extra: cohere
|
|
38
38
|
Requires-Dist: cohere>=5.13.11; extra == 'cohere'
|
|
39
39
|
Provides-Extra: duckduckgo
|
|
@@ -2,13 +2,14 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
|
+
import json
|
|
5
6
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
6
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
7
8
|
from contextvars import ContextVar
|
|
8
9
|
from dataclasses import field
|
|
9
10
|
from typing import Any, Generic, Literal, Union, cast
|
|
10
11
|
|
|
11
|
-
import
|
|
12
|
+
from opentelemetry.trace import Span, Tracer
|
|
12
13
|
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
13
14
|
|
|
14
15
|
from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
@@ -42,17 +43,6 @@ __all__ = (
|
|
|
42
43
|
'capture_run_messages',
|
|
43
44
|
)
|
|
44
45
|
|
|
45
|
-
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
46
|
-
|
|
47
|
-
# while waiting for https://github.com/pydantic/logfire/issues/745
|
|
48
|
-
try:
|
|
49
|
-
import logfire._internal.stack_info
|
|
50
|
-
except ImportError:
|
|
51
|
-
pass
|
|
52
|
-
else:
|
|
53
|
-
from pathlib import Path
|
|
54
|
-
|
|
55
|
-
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
56
46
|
|
|
57
47
|
T = TypeVar('T')
|
|
58
48
|
S = TypeVar('S')
|
|
@@ -105,7 +95,8 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
105
95
|
|
|
106
96
|
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
|
|
107
97
|
|
|
108
|
-
run_span:
|
|
98
|
+
run_span: Span
|
|
99
|
+
tracer: Tracer
|
|
109
100
|
|
|
110
101
|
|
|
111
102
|
class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
@@ -330,7 +321,9 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
330
321
|
ctx.state.run_step += 1
|
|
331
322
|
|
|
332
323
|
model_settings = merge_model_settings(ctx.deps.model_settings, None)
|
|
333
|
-
with
|
|
324
|
+
with ctx.deps.tracer.start_as_current_span(
|
|
325
|
+
'preparing model request params', attributes=dict(run_step=ctx.state.run_step)
|
|
326
|
+
):
|
|
334
327
|
model_request_parameters = await _prepare_request_parameters(ctx)
|
|
335
328
|
return model_settings, model_request_parameters
|
|
336
329
|
|
|
@@ -380,26 +373,12 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
380
373
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
381
374
|
) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
|
|
382
375
|
"""Process the model response and yield events for the start and end of each function tool call."""
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
yield stream
|
|
376
|
+
stream = self._run_stream(ctx)
|
|
377
|
+
yield stream
|
|
386
378
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
# Set the next node based on the final state of the stream
|
|
392
|
-
next_node = self._next_node
|
|
393
|
-
if isinstance(next_node, End):
|
|
394
|
-
handle_span.set_attribute('result', next_node.data)
|
|
395
|
-
handle_span.message = 'handle model response -> final result'
|
|
396
|
-
elif tool_responses := self._tool_responses:
|
|
397
|
-
# TODO: We could drop `self._tool_responses` if we drop this set_attribute
|
|
398
|
-
# I'm thinking it might be better to just create a span for the handling of each tool
|
|
399
|
-
# than to set an attribute here.
|
|
400
|
-
handle_span.set_attribute('tool_responses', tool_responses)
|
|
401
|
-
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
402
|
-
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
379
|
+
# Run the stream to completion if it was not finished:
|
|
380
|
+
async for _event in stream:
|
|
381
|
+
pass
|
|
403
382
|
|
|
404
383
|
async def _run_stream(
|
|
405
384
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
@@ -494,10 +473,29 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
494
473
|
if tool_responses:
|
|
495
474
|
messages.append(_messages.ModelRequest(parts=tool_responses))
|
|
496
475
|
|
|
497
|
-
run_span.
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
476
|
+
run_span.set_attributes(
|
|
477
|
+
{
|
|
478
|
+
**usage.opentelemetry_attributes(),
|
|
479
|
+
'all_messages_events': json.dumps(
|
|
480
|
+
[InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
|
|
481
|
+
),
|
|
482
|
+
'final_result': final_result.data
|
|
483
|
+
if isinstance(final_result.data, str)
|
|
484
|
+
else json.dumps(InstrumentedModel.serialize_any(final_result.data)),
|
|
485
|
+
}
|
|
486
|
+
)
|
|
487
|
+
run_span.set_attributes(
|
|
488
|
+
{
|
|
489
|
+
'logfire.json_schema': json.dumps(
|
|
490
|
+
{
|
|
491
|
+
'type': 'object',
|
|
492
|
+
'properties': {
|
|
493
|
+
'all_messages_events': {'type': 'array'},
|
|
494
|
+
'final_result': {'type': 'object'},
|
|
495
|
+
},
|
|
496
|
+
}
|
|
497
|
+
),
|
|
498
|
+
}
|
|
501
499
|
)
|
|
502
500
|
|
|
503
501
|
# End the run with self.data
|
|
@@ -619,7 +617,10 @@ async def process_function_tools(
|
|
|
619
617
|
|
|
620
618
|
# Run all tool tasks in parallel
|
|
621
619
|
results_by_index: dict[int, _messages.ModelRequestPart] = {}
|
|
622
|
-
|
|
620
|
+
tool_names = [call.tool_name for _, call in calls_to_run]
|
|
621
|
+
with ctx.deps.tracer.start_as_current_span(
|
|
622
|
+
'running tools', attributes={'tools': tool_names, 'logfire.msg': f'running tools: {", ".join(tool_names)}'}
|
|
623
|
+
):
|
|
623
624
|
# TODO: Should we wrap each individual tool call in a dedicated span?
|
|
624
625
|
tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
|
|
625
626
|
pending = tasks
|
|
@@ -6,7 +6,7 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
|
|
|
6
6
|
from __future__ import annotations as _annotations
|
|
7
7
|
|
|
8
8
|
from inspect import Parameter, signature
|
|
9
|
-
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
|
|
10
10
|
|
|
11
11
|
from pydantic import ConfigDict
|
|
12
12
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
@@ -15,6 +15,7 @@ from pydantic.fields import FieldInfo
|
|
|
15
15
|
from pydantic.json_schema import GenerateJsonSchema
|
|
16
16
|
from pydantic.plugin._schema_validator import create_schema_validator
|
|
17
17
|
from pydantic_core import SchemaValidator, core_schema
|
|
18
|
+
from typing_extensions import get_origin
|
|
18
19
|
|
|
19
20
|
from ._griffe import doc_descriptions
|
|
20
21
|
from ._utils import check_object_json_schema, is_model_like
|
|
@@ -223,8 +224,7 @@ def _build_schema(
|
|
|
223
224
|
|
|
224
225
|
|
|
225
226
|
def _is_call_ctx(annotation: Any) -> bool:
|
|
227
|
+
"""Return whether the annotation is the `RunContext` class, parameterized or not."""
|
|
226
228
|
from .tools import RunContext
|
|
227
229
|
|
|
228
|
-
return annotation is RunContext or (
|
|
229
|
-
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
|
|
230
|
-
)
|
|
230
|
+
return annotation is RunContext or get_origin(annotation) is RunContext
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
import sys
|
|
5
|
-
import types
|
|
6
4
|
from collections.abc import Awaitable, Iterable, Iterator
|
|
7
5
|
from dataclasses import dataclass, field
|
|
8
|
-
from typing import Any, Callable, Generic, Literal, Union, cast
|
|
6
|
+
from typing import Any, Callable, Generic, Literal, Union, cast
|
|
9
7
|
|
|
10
8
|
from pydantic import TypeAdapter, ValidationError
|
|
11
|
-
from typing_extensions import
|
|
9
|
+
from typing_extensions import TypedDict, TypeVar, get_args, get_origin
|
|
10
|
+
from typing_inspection import typing_objects
|
|
11
|
+
from typing_inspection.introspection import is_union_origin
|
|
12
12
|
|
|
13
13
|
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
@@ -248,23 +248,12 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
|
|
|
248
248
|
|
|
249
249
|
|
|
250
250
|
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
251
|
-
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty
|
|
252
|
-
if
|
|
251
|
+
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty tuple."""
|
|
252
|
+
if typing_objects.is_typealiastype(tp):
|
|
253
253
|
tp = tp.__value__
|
|
254
254
|
|
|
255
255
|
origin = get_origin(tp)
|
|
256
|
-
if
|
|
256
|
+
if is_union_origin(origin):
|
|
257
257
|
return get_args(tp)
|
|
258
258
|
else:
|
|
259
259
|
return ()
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
if sys.version_info < (3, 10):
|
|
263
|
-
|
|
264
|
-
def origin_is_union(tp: type[Any] | None) -> bool:
|
|
265
|
-
return tp is Union
|
|
266
|
-
|
|
267
|
-
else:
|
|
268
|
-
|
|
269
|
-
def origin_is_union(tp: type[Any] | None) -> bool:
|
|
270
|
-
return tp is Union or tp is types.UnionType
|
|
@@ -8,7 +8,7 @@ from copy import deepcopy
|
|
|
8
8
|
from types import FrameType
|
|
9
9
|
from typing import Any, Callable, Generic, cast, final, overload
|
|
10
10
|
|
|
11
|
-
import
|
|
11
|
+
from opentelemetry.trace import NoOpTracer, use_span
|
|
12
12
|
from typing_extensions import TypeGuard, TypeVar, deprecated
|
|
13
13
|
|
|
14
14
|
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
@@ -58,17 +58,6 @@ __all__ = (
|
|
|
58
58
|
'UserPromptNode',
|
|
59
59
|
)
|
|
60
60
|
|
|
61
|
-
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
62
|
-
|
|
63
|
-
# while waiting for https://github.com/pydantic/logfire/issues/745
|
|
64
|
-
try:
|
|
65
|
-
import logfire._internal.stack_info
|
|
66
|
-
except ImportError:
|
|
67
|
-
pass
|
|
68
|
-
else:
|
|
69
|
-
from pathlib import Path
|
|
70
|
-
|
|
71
|
-
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
72
61
|
|
|
73
62
|
T = TypeVar('T')
|
|
74
63
|
S = TypeVar('S')
|
|
@@ -123,6 +112,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
123
112
|
The type of the result data, used to validate the result data, defaults to `str`.
|
|
124
113
|
"""
|
|
125
114
|
|
|
115
|
+
instrument: bool
|
|
116
|
+
"""Automatically instrument with OpenTelemetry. Will use Logfire if it's configured."""
|
|
117
|
+
|
|
126
118
|
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
|
|
127
119
|
_result_tool_name: str = dataclasses.field(repr=False)
|
|
128
120
|
_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
@@ -155,6 +147,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
155
147
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
156
148
|
defer_model_check: bool = False,
|
|
157
149
|
end_strategy: EndStrategy = 'early',
|
|
150
|
+
instrument: bool = False,
|
|
158
151
|
):
|
|
159
152
|
"""Create an agent.
|
|
160
153
|
|
|
@@ -184,6 +177,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
184
177
|
[override the model][pydantic_ai.Agent.override] for testing.
|
|
185
178
|
end_strategy: Strategy for handling tool calls that are requested alongside a final result.
|
|
186
179
|
See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
|
|
180
|
+
instrument: Automatically instrument with OpenTelemetry. Will use Logfire if it's configured.
|
|
187
181
|
"""
|
|
188
182
|
if model is None or defer_model_check:
|
|
189
183
|
self.model = model
|
|
@@ -194,6 +188,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
194
188
|
self.name = name
|
|
195
189
|
self.model_settings = model_settings
|
|
196
190
|
self.result_type = result_type
|
|
191
|
+
self.instrument = instrument
|
|
197
192
|
|
|
198
193
|
self._deps_type = deps_type
|
|
199
194
|
|
|
@@ -396,6 +391,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
396
391
|
if infer_name and self.name is None:
|
|
397
392
|
self._infer_name(inspect.currentframe())
|
|
398
393
|
model_used = self._get_model(model)
|
|
394
|
+
del model
|
|
399
395
|
|
|
400
396
|
deps = self._get_deps(deps)
|
|
401
397
|
new_message_index = len(message_history) if message_history else 0
|
|
@@ -425,14 +421,20 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
425
421
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
426
422
|
usage_limits = usage_limits or _usage.UsageLimits()
|
|
427
423
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
424
|
+
if isinstance(model_used, InstrumentedModel):
|
|
425
|
+
tracer = model_used.tracer
|
|
426
|
+
else:
|
|
427
|
+
tracer = NoOpTracer()
|
|
428
|
+
agent_name = self.name or 'agent'
|
|
429
|
+
run_span = tracer.start_span(
|
|
430
|
+
'agent run',
|
|
431
|
+
attributes={
|
|
432
|
+
'model_name': model_used.model_name if model_used else 'no-model',
|
|
433
|
+
'agent_name': agent_name,
|
|
434
|
+
'logfire.msg': f'{agent_name} run',
|
|
435
|
+
},
|
|
435
436
|
)
|
|
437
|
+
|
|
436
438
|
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
|
|
437
439
|
user_deps=deps,
|
|
438
440
|
prompt=user_prompt,
|
|
@@ -447,6 +449,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
447
449
|
result_validators=result_validators,
|
|
448
450
|
function_tools=self._function_tools,
|
|
449
451
|
run_span=run_span,
|
|
452
|
+
tracer=tracer,
|
|
450
453
|
)
|
|
451
454
|
start_node = _agent_graph.UserPromptNode[AgentDepsT](
|
|
452
455
|
user_prompt=user_prompt,
|
|
@@ -460,7 +463,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
460
463
|
state=state,
|
|
461
464
|
deps=graph_deps,
|
|
462
465
|
infer_name=False,
|
|
463
|
-
span=run_span,
|
|
466
|
+
span=use_span(run_span, end_on_exit=True),
|
|
464
467
|
) as graph_run:
|
|
465
468
|
yield AgentRun(graph_run)
|
|
466
469
|
|
|
@@ -1116,7 +1119,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1116
1119
|
else:
|
|
1117
1120
|
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
1118
1121
|
|
|
1119
|
-
if not isinstance(model_, InstrumentedModel):
|
|
1122
|
+
if self.instrument and not isinstance(model_, InstrumentedModel):
|
|
1120
1123
|
model_ = InstrumentedModel(model_)
|
|
1121
1124
|
|
|
1122
1125
|
return model_
|
|
@@ -28,9 +28,11 @@ if TYPE_CHECKING:
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
KnownModelName = Literal[
|
|
31
|
+
'anthropic:claude-3-7-sonnet-latest',
|
|
31
32
|
'anthropic:claude-3-5-haiku-latest',
|
|
32
33
|
'anthropic:claude-3-5-sonnet-latest',
|
|
33
34
|
'anthropic:claude-3-opus-latest',
|
|
35
|
+
'claude-3-7-sonnet-latest',
|
|
34
36
|
'claude-3-5-haiku-latest',
|
|
35
37
|
'claude-3-5-sonnet-latest',
|
|
36
38
|
'claude-3-opus-latest',
|
|
@@ -47,6 +49,8 @@ KnownModelName = Literal[
|
|
|
47
49
|
'cohere:command-r-plus-04-2024',
|
|
48
50
|
'cohere:command-r-plus-08-2024',
|
|
49
51
|
'cohere:command-r7b-12-2024',
|
|
52
|
+
'deepseek:deepseek-chat',
|
|
53
|
+
'deepseek:deepseek-reasoner',
|
|
50
54
|
'google-gla:gemini-1.0-pro',
|
|
51
55
|
'google-gla:gemini-1.5-flash',
|
|
52
56
|
'google-gla:gemini-1.5-flash-8b',
|
|
@@ -56,6 +60,7 @@ KnownModelName = Literal[
|
|
|
56
60
|
'google-gla:gemini-exp-1206',
|
|
57
61
|
'google-gla:gemini-2.0-flash',
|
|
58
62
|
'google-gla:gemini-2.0-flash-lite-preview-02-05',
|
|
63
|
+
'google-gla:gemini-2.0-pro-exp-02-05',
|
|
59
64
|
'google-vertex:gemini-1.0-pro',
|
|
60
65
|
'google-vertex:gemini-1.5-flash',
|
|
61
66
|
'google-vertex:gemini-1.5-flash-8b',
|
|
@@ -65,6 +70,7 @@ KnownModelName = Literal[
|
|
|
65
70
|
'google-vertex:gemini-exp-1206',
|
|
66
71
|
'google-vertex:gemini-2.0-flash',
|
|
67
72
|
'google-vertex:gemini-2.0-flash-lite-preview-02-05',
|
|
73
|
+
'google-vertex:gemini-2.0-pro-exp-02-05',
|
|
68
74
|
'gpt-3.5-turbo',
|
|
69
75
|
'gpt-3.5-turbo-0125',
|
|
70
76
|
'gpt-3.5-turbo-0301',
|
|
@@ -316,54 +322,52 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
316
322
|
from .test import TestModel
|
|
317
323
|
|
|
318
324
|
return TestModel()
|
|
319
|
-
elif model.startswith('cohere:'):
|
|
320
|
-
from .cohere import CohereModel
|
|
321
325
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
326
|
+
try:
|
|
327
|
+
provider, model_name = model.split(':')
|
|
328
|
+
except ValueError:
|
|
329
|
+
model_name = model
|
|
330
|
+
# TODO(Marcelo): We should deprecate this way.
|
|
331
|
+
if model_name.startswith(('gpt', 'o1', 'o3')):
|
|
332
|
+
provider = 'openai'
|
|
333
|
+
elif model_name.startswith('claude'):
|
|
334
|
+
provider = 'anthropic'
|
|
335
|
+
elif model_name.startswith('gemini'):
|
|
336
|
+
provider = 'google-gla'
|
|
337
|
+
else:
|
|
338
|
+
raise UserError(f'Unknown model: {model}')
|
|
339
|
+
|
|
340
|
+
if provider == 'vertexai':
|
|
341
|
+
provider = 'google-vertex'
|
|
342
|
+
|
|
343
|
+
if provider == 'cohere':
|
|
344
|
+
from .cohere import CohereModel
|
|
325
345
|
|
|
326
|
-
|
|
327
|
-
|
|
346
|
+
# TODO(Marcelo): Missing provider API.
|
|
347
|
+
return CohereModel(model_name)
|
|
348
|
+
elif provider in ('deepseek', 'openai'):
|
|
328
349
|
from .openai import OpenAIModel
|
|
329
350
|
|
|
330
|
-
return OpenAIModel(
|
|
331
|
-
elif
|
|
332
|
-
from .gemini import GeminiModel
|
|
333
|
-
|
|
334
|
-
return GeminiModel(model[11:])
|
|
335
|
-
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
|
|
336
|
-
elif model.startswith('gemini'):
|
|
351
|
+
return OpenAIModel(model_name, provider=provider)
|
|
352
|
+
elif provider in ('google-gla', 'google-vertex'):
|
|
337
353
|
from .gemini import GeminiModel
|
|
338
354
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
elif model.startswith('groq:'):
|
|
355
|
+
return GeminiModel(model_name, provider=provider)
|
|
356
|
+
elif provider == 'groq':
|
|
342
357
|
from .groq import GroqModel
|
|
343
358
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
return VertexAIModel(model[14:])
|
|
349
|
-
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
|
|
350
|
-
elif model.startswith('vertexai:'):
|
|
351
|
-
from .vertexai import VertexAIModel
|
|
352
|
-
|
|
353
|
-
return VertexAIModel(model[9:])
|
|
354
|
-
elif model.startswith('mistral:'):
|
|
359
|
+
# TODO(Marcelo): Missing provider API.
|
|
360
|
+
return GroqModel(model_name)
|
|
361
|
+
elif provider == 'mistral':
|
|
355
362
|
from .mistral import MistralModel
|
|
356
363
|
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
return AnthropicModel(model[10:])
|
|
362
|
-
# backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
|
|
363
|
-
elif model.startswith('claude'):
|
|
364
|
+
# TODO(Marcelo): Missing provider API.
|
|
365
|
+
return MistralModel(model_name)
|
|
366
|
+
elif provider == 'anthropic':
|
|
364
367
|
from .anthropic import AnthropicModel
|
|
365
368
|
|
|
366
|
-
|
|
369
|
+
# TODO(Marcelo): Missing provider API.
|
|
370
|
+
return AnthropicModel(model_name)
|
|
367
371
|
else:
|
|
368
372
|
raise UserError(f'Unknown model: {model}')
|
|
369
373
|
|
|
@@ -42,6 +42,7 @@ from . import (
|
|
|
42
42
|
try:
|
|
43
43
|
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
|
|
44
44
|
from anthropic.types import (
|
|
45
|
+
ContentBlock,
|
|
45
46
|
ImageBlockParam,
|
|
46
47
|
Message as AnthropicMessage,
|
|
47
48
|
MessageParam,
|
|
@@ -69,6 +70,7 @@ except ImportError as _import_error:
|
|
|
69
70
|
) from _import_error
|
|
70
71
|
|
|
71
72
|
LatestAnthropicModelNames = Literal[
|
|
73
|
+
'claude-3-7-sonnet-latest',
|
|
72
74
|
'claude-3-5-haiku-latest',
|
|
73
75
|
'claude-3-5-sonnet-latest',
|
|
74
76
|
'claude-3-opus-latest',
|
|
@@ -423,7 +425,7 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
423
425
|
_timestamp: datetime
|
|
424
426
|
|
|
425
427
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
426
|
-
current_block:
|
|
428
|
+
current_block: ContentBlock | None = None
|
|
427
429
|
current_json: str = ''
|
|
428
430
|
|
|
429
431
|
async for event in self._response:
|
|
@@ -8,12 +8,14 @@ from contextlib import asynccontextmanager
|
|
|
8
8
|
from copy import deepcopy
|
|
9
9
|
from dataclasses import dataclass, field
|
|
10
10
|
from datetime import datetime
|
|
11
|
-
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
11
|
+
from typing import Annotated, Any, Literal, Protocol, Union, cast, overload
|
|
12
12
|
from uuid import uuid4
|
|
13
13
|
|
|
14
14
|
import pydantic
|
|
15
15
|
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
16
|
-
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
16
|
+
from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
|
|
17
|
+
|
|
18
|
+
from pydantic_ai.providers import Provider, infer_provider
|
|
17
19
|
|
|
18
20
|
from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
|
|
19
21
|
from ..messages import (
|
|
@@ -53,6 +55,7 @@ LatestGeminiModelNames = Literal[
|
|
|
53
55
|
'gemini-exp-1206',
|
|
54
56
|
'gemini-2.0-flash',
|
|
55
57
|
'gemini-2.0-flash-lite-preview-02-05',
|
|
58
|
+
'gemini-2.0-pro-exp-02-05',
|
|
56
59
|
]
|
|
57
60
|
"""Latest Gemini models."""
|
|
58
61
|
|
|
@@ -81,17 +84,39 @@ class GeminiModel(Model):
|
|
|
81
84
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
82
85
|
"""
|
|
83
86
|
|
|
84
|
-
|
|
87
|
+
client: AsyncHTTPClient = field(repr=False)
|
|
85
88
|
|
|
86
89
|
_model_name: GeminiModelName = field(repr=False)
|
|
90
|
+
_provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = field(repr=False)
|
|
87
91
|
_auth: AuthProtocol | None = field(repr=False)
|
|
88
92
|
_url: str | None = field(repr=False)
|
|
89
93
|
_system: str | None = field(default='google-gla', repr=False)
|
|
90
94
|
|
|
95
|
+
@overload
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
model_name: GeminiModelName,
|
|
99
|
+
*,
|
|
100
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] = 'google-gla',
|
|
101
|
+
) -> None: ...
|
|
102
|
+
|
|
103
|
+
@deprecated('Use the `provider` argument instead of the `api_key`, `http_client`, and `url_template` arguments.')
|
|
104
|
+
@overload
|
|
91
105
|
def __init__(
|
|
92
106
|
self,
|
|
93
107
|
model_name: GeminiModelName,
|
|
94
108
|
*,
|
|
109
|
+
provider: None = None,
|
|
110
|
+
api_key: str | None = None,
|
|
111
|
+
http_client: AsyncHTTPClient | None = None,
|
|
112
|
+
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
113
|
+
) -> None: ...
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
model_name: GeminiModelName,
|
|
118
|
+
*,
|
|
119
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = None,
|
|
95
120
|
api_key: str | None = None,
|
|
96
121
|
http_client: AsyncHTTPClient | None = None,
|
|
97
122
|
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
@@ -100,6 +125,7 @@ class GeminiModel(Model):
|
|
|
100
125
|
|
|
101
126
|
Args:
|
|
102
127
|
model_name: The name of the model to use.
|
|
128
|
+
provider: The provider to use for the model.
|
|
103
129
|
api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
|
|
104
130
|
will be used if available.
|
|
105
131
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
@@ -108,14 +134,24 @@ class GeminiModel(Model):
|
|
|
108
134
|
`model` is substituted with the model name, and `function` is added to the end of the URL.
|
|
109
135
|
"""
|
|
110
136
|
self._model_name = model_name
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
137
|
+
self._provider = provider
|
|
138
|
+
|
|
139
|
+
if provider is not None:
|
|
140
|
+
if isinstance(provider, str):
|
|
141
|
+
self._system = provider
|
|
142
|
+
self.client = infer_provider(provider).client
|
|
114
143
|
else:
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
144
|
+
self._system = provider.name
|
|
145
|
+
self.client = provider.client
|
|
146
|
+
else:
|
|
147
|
+
if api_key is None:
|
|
148
|
+
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
149
|
+
api_key = env_api_key
|
|
150
|
+
else:
|
|
151
|
+
raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
152
|
+
self.client = http_client or cached_async_http_client()
|
|
153
|
+
self._auth = ApiKeyAuth(api_key)
|
|
154
|
+
self._url = url_template.format(model=model_name)
|
|
119
155
|
|
|
120
156
|
@property
|
|
121
157
|
def auth(self) -> AuthProtocol:
|
|
@@ -216,17 +252,19 @@ class GeminiModel(Model):
|
|
|
216
252
|
if generation_config:
|
|
217
253
|
request_data['generation_config'] = generation_config
|
|
218
254
|
|
|
219
|
-
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
220
|
-
|
|
221
255
|
headers = {
|
|
222
256
|
'Content-Type': 'application/json',
|
|
223
257
|
'User-Agent': get_user_agent(),
|
|
224
|
-
**await self.auth.headers(),
|
|
225
258
|
}
|
|
259
|
+
if self._provider is None: # pragma: no cover
|
|
260
|
+
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
261
|
+
headers.update(await self.auth.headers())
|
|
262
|
+
else:
|
|
263
|
+
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
|
|
226
264
|
|
|
227
265
|
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
|
|
228
266
|
|
|
229
|
-
async with self.
|
|
267
|
+
async with self.client.stream(
|
|
230
268
|
'POST',
|
|
231
269
|
url,
|
|
232
270
|
content=request_json,
|