pydantic-ai-slim 0.0.43__tar.gz → 0.0.45__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/PKG-INFO +3 -3
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/_cli.py +1 -1
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/_griffe.py +29 -2
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/_parts_manager.py +7 -1
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/_utils.py +12 -6
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/agent.py +2 -2
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/exceptions.py +2 -2
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/messages.py +15 -27
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/__init__.py +15 -14
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/anthropic.py +7 -46
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/bedrock.py +7 -11
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/cohere.py +14 -20
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/gemini.py +18 -73
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/groq.py +9 -53
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/instrumented.py +14 -3
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/mistral.py +12 -51
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/openai.py +17 -75
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/__init__.py +4 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/anthropic.py +4 -5
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/azure.py +8 -9
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/bedrock.py +2 -1
- pydantic_ai_slim-0.0.45/pydantic_ai/providers/cohere.py +71 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/deepseek.py +4 -4
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/google_gla.py +3 -2
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/google_vertex.py +2 -3
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/groq.py +4 -5
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/mistral.py +4 -5
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/openai.py +5 -8
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pyproject.toml +3 -3
- pydantic_ai_slim-0.0.43/pydantic_ai/models/vertexai.py +0 -260
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/README.md +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/_agent_graph.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/mcp.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.0.43 → pydantic_ai_slim-0.0.45}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.45
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.45
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
@@ -53,7 +53,7 @@ Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
|
|
|
53
53
|
Provides-Extra: mistral
|
|
54
54
|
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
55
55
|
Provides-Extra: openai
|
|
56
|
-
Requires-Dist: openai>=1.
|
|
56
|
+
Requires-Dist: openai>=1.67.0; extra == 'openai'
|
|
57
57
|
Provides-Extra: tavily
|
|
58
58
|
Requires-Dist: tavily-python>=0.5.0; extra == 'tavily'
|
|
59
59
|
Provides-Extra: vertexai
|
|
@@ -37,7 +37,7 @@ except ImportError as _import_error:
|
|
|
37
37
|
from pydantic_ai.agent import Agent
|
|
38
38
|
from pydantic_ai.messages import ModelMessage, PartDeltaEvent, TextPartDelta
|
|
39
39
|
|
|
40
|
-
__version__ = version('pydantic-ai')
|
|
40
|
+
__version__ = version('pydantic-ai-slim')
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class SimpleCodeBlock(CodeBlock):
|
|
@@ -22,8 +22,16 @@ def doc_descriptions(
|
|
|
22
22
|
) -> tuple[str, dict[str, str]]:
|
|
23
23
|
"""Extract the function description and parameter descriptions from a function's docstring.
|
|
24
24
|
|
|
25
|
+
The function parses the docstring using the specified format (or infers it if 'auto')
|
|
26
|
+
and extracts both the main description and parameter descriptions. If a returns section
|
|
27
|
+
is present in the docstring, the main description will be formatted as XML.
|
|
28
|
+
|
|
25
29
|
Returns:
|
|
26
|
-
A tuple
|
|
30
|
+
A tuple containing:
|
|
31
|
+
- str: Main description string, which may be either:
|
|
32
|
+
* Plain text if no returns section is present
|
|
33
|
+
* XML-formatted if returns section exists, including <summary> and <returns> tags
|
|
34
|
+
- dict[str, str]: Dictionary mapping parameter names to their descriptions
|
|
27
35
|
"""
|
|
28
36
|
doc = func.__doc__
|
|
29
37
|
if doc is None:
|
|
@@ -33,7 +41,14 @@ def doc_descriptions(
|
|
|
33
41
|
parent = cast(GriffeObject, sig)
|
|
34
42
|
|
|
35
43
|
docstring_style = _infer_docstring_style(doc) if docstring_format == 'auto' else docstring_format
|
|
36
|
-
docstring = Docstring(
|
|
44
|
+
docstring = Docstring(
|
|
45
|
+
doc,
|
|
46
|
+
lineno=1,
|
|
47
|
+
parser=docstring_style,
|
|
48
|
+
parent=parent,
|
|
49
|
+
# https://mkdocstrings.github.io/griffe/reference/docstrings/#google-options
|
|
50
|
+
parser_options={'returns_named_value': False, 'returns_multiple_items': False},
|
|
51
|
+
)
|
|
37
52
|
with _disable_griffe_logging():
|
|
38
53
|
sections = docstring.parse()
|
|
39
54
|
|
|
@@ -45,6 +60,18 @@ def doc_descriptions(
|
|
|
45
60
|
if main := next((p for p in sections if p.kind == DocstringSectionKind.text), None):
|
|
46
61
|
main_desc = main.value
|
|
47
62
|
|
|
63
|
+
if return_ := next((p for p in sections if p.kind == DocstringSectionKind.returns), None):
|
|
64
|
+
return_statement = return_.value[0]
|
|
65
|
+
return_desc = return_statement.description
|
|
66
|
+
return_type = return_statement.annotation
|
|
67
|
+
type_tag = f'<type>{return_type}</type>\n' if return_type else ''
|
|
68
|
+
return_xml = f'<returns>\n{type_tag}<description>{return_desc}</description>\n</returns>'
|
|
69
|
+
|
|
70
|
+
if main_desc:
|
|
71
|
+
main_desc = f'<summary>{main_desc}</summary>\n{return_xml}'
|
|
72
|
+
else:
|
|
73
|
+
main_desc = return_xml
|
|
74
|
+
|
|
48
75
|
return main_desc, params
|
|
49
76
|
|
|
50
77
|
|
|
@@ -29,6 +29,8 @@ from pydantic_ai.messages import (
|
|
|
29
29
|
ToolCallPartDelta,
|
|
30
30
|
)
|
|
31
31
|
|
|
32
|
+
from ._utils import generate_tool_call_id as _generate_tool_call_id
|
|
33
|
+
|
|
32
34
|
VendorId = Hashable
|
|
33
35
|
"""
|
|
34
36
|
Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
|
|
@@ -221,7 +223,11 @@ class ModelResponsePartsManager:
|
|
|
221
223
|
ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
|
|
222
224
|
has been added to the manager, or replaced an existing part.
|
|
223
225
|
"""
|
|
224
|
-
new_part = ToolCallPart(
|
|
226
|
+
new_part = ToolCallPart(
|
|
227
|
+
tool_name=tool_name,
|
|
228
|
+
args=args,
|
|
229
|
+
tool_call_id=tool_call_id or _generate_tool_call_id(),
|
|
230
|
+
)
|
|
225
231
|
if vendor_part_id is None:
|
|
226
232
|
# vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
|
|
227
233
|
new_part_index = len(self._parts)
|
|
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import time
|
|
5
|
+
import uuid
|
|
5
6
|
from collections.abc import AsyncIterable, AsyncIterator, Iterator
|
|
6
7
|
from contextlib import asynccontextmanager, suppress
|
|
7
8
|
from dataclasses import dataclass, is_dataclass
|
|
@@ -195,12 +196,17 @@ def now_utc() -> datetime:
|
|
|
195
196
|
return datetime.now(tz=timezone.utc)
|
|
196
197
|
|
|
197
198
|
|
|
198
|
-
def guard_tool_call_id(
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
199
|
+
def guard_tool_call_id(t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart) -> str:
|
|
200
|
+
"""Type guard that either returns the tool call id or generates a new one if it's None."""
|
|
201
|
+
return t.tool_call_id or generate_tool_call_id()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def generate_tool_call_id() -> str:
|
|
205
|
+
"""Generate a tool call id.
|
|
206
|
+
|
|
207
|
+
Ensure that the tool call id is unique.
|
|
208
|
+
"""
|
|
209
|
+
return f'pyd_ai_{uuid.uuid4().hex}'
|
|
204
210
|
|
|
205
211
|
|
|
206
212
|
class PeekableAsyncStream(Generic[T]):
|
|
@@ -13,7 +13,7 @@ from pydantic.json_schema import GenerateJsonSchema
|
|
|
13
13
|
from typing_extensions import TypeGuard, TypeVar, deprecated
|
|
14
14
|
|
|
15
15
|
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
16
|
-
from pydantic_graph._utils import
|
|
16
|
+
from pydantic_graph._utils import run_until_complete
|
|
17
17
|
|
|
18
18
|
from . import (
|
|
19
19
|
_agent_graph,
|
|
@@ -567,7 +567,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
567
567
|
"""
|
|
568
568
|
if infer_name and self.name is None:
|
|
569
569
|
self._infer_name(inspect.currentframe())
|
|
570
|
-
return
|
|
570
|
+
return run_until_complete(
|
|
571
571
|
self.run(
|
|
572
572
|
user_prompt,
|
|
573
573
|
result_type=result_type,
|
|
@@ -3,9 +3,9 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import json
|
|
4
4
|
import sys
|
|
5
5
|
|
|
6
|
-
if sys.version_info < (3, 11):
|
|
6
|
+
if sys.version_info < (3, 11): # pragma: no cover
|
|
7
7
|
from exceptiongroup import ExceptionGroup
|
|
8
|
-
else:
|
|
8
|
+
else: # pragma: no cover
|
|
9
9
|
ExceptionGroup = ExceptionGroup
|
|
10
10
|
|
|
11
11
|
__all__ = (
|
|
@@ -12,7 +12,7 @@ import pydantic_core
|
|
|
12
12
|
from opentelemetry._events import Event
|
|
13
13
|
from typing_extensions import TypeAlias
|
|
14
14
|
|
|
15
|
-
from ._utils import now_utc as _now_utc
|
|
15
|
+
from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
|
|
16
16
|
from .exceptions import UnexpectedModelBehavior
|
|
17
17
|
|
|
18
18
|
|
|
@@ -268,8 +268,8 @@ class ToolReturnPart:
|
|
|
268
268
|
content: Any
|
|
269
269
|
"""The return value."""
|
|
270
270
|
|
|
271
|
-
tool_call_id: str
|
|
272
|
-
"""
|
|
271
|
+
tool_call_id: str
|
|
272
|
+
"""The tool call identifier, this is used by some models including OpenAI."""
|
|
273
273
|
|
|
274
274
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
275
275
|
"""The timestamp, when the tool returned."""
|
|
@@ -328,8 +328,11 @@ class RetryPromptPart:
|
|
|
328
328
|
tool_name: str | None = None
|
|
329
329
|
"""The name of the tool that was called, if any."""
|
|
330
330
|
|
|
331
|
-
tool_call_id: str
|
|
332
|
-
"""
|
|
331
|
+
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
332
|
+
"""The tool call identifier, this is used by some models including OpenAI.
|
|
333
|
+
|
|
334
|
+
In case the tool call id is not provided by the model, PydanticAI will generate a random one.
|
|
335
|
+
"""
|
|
333
336
|
|
|
334
337
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
335
338
|
"""The timestamp, when the retry was triggered."""
|
|
@@ -406,8 +409,11 @@ class ToolCallPart:
|
|
|
406
409
|
This is stored either as a JSON string or a Python dictionary depending on how data was received.
|
|
407
410
|
"""
|
|
408
411
|
|
|
409
|
-
tool_call_id: str
|
|
410
|
-
"""
|
|
412
|
+
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
413
|
+
"""The tool call identifier, this is used by some models including OpenAI.
|
|
414
|
+
|
|
415
|
+
In case the tool call id is not provided by the model, PydanticAI will generate a random one.
|
|
416
|
+
"""
|
|
411
417
|
|
|
412
418
|
part_kind: Literal['tool-call'] = 'tool-call'
|
|
413
419
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
@@ -564,11 +570,7 @@ class ToolCallPartDelta:
|
|
|
564
570
|
if self.tool_name_delta is None or self.args_delta is None:
|
|
565
571
|
return None
|
|
566
572
|
|
|
567
|
-
return ToolCallPart(
|
|
568
|
-
self.tool_name_delta,
|
|
569
|
-
self.args_delta,
|
|
570
|
-
self.tool_call_id,
|
|
571
|
-
)
|
|
573
|
+
return ToolCallPart(self.tool_name_delta, self.args_delta, self.tool_call_id or _generate_tool_call_id())
|
|
572
574
|
|
|
573
575
|
@overload
|
|
574
576
|
def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
|
|
@@ -620,20 +622,11 @@ class ToolCallPartDelta:
|
|
|
620
622
|
delta = replace(delta, args_delta=updated_args_delta)
|
|
621
623
|
|
|
622
624
|
if self.tool_call_id:
|
|
623
|
-
# Set the tool_call_id if it wasn't present, otherwise error if it has changed
|
|
624
|
-
if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
|
|
625
|
-
raise UnexpectedModelBehavior(
|
|
626
|
-
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
|
|
627
|
-
)
|
|
628
625
|
delta = replace(delta, tool_call_id=self.tool_call_id)
|
|
629
626
|
|
|
630
627
|
# If we now have enough data to create a full ToolCallPart, do so
|
|
631
628
|
if delta.tool_name_delta is not None and delta.args_delta is not None:
|
|
632
|
-
return ToolCallPart(
|
|
633
|
-
delta.tool_name_delta,
|
|
634
|
-
delta.args_delta,
|
|
635
|
-
delta.tool_call_id,
|
|
636
|
-
)
|
|
629
|
+
return ToolCallPart(delta.tool_name_delta, delta.args_delta, delta.tool_call_id or _generate_tool_call_id())
|
|
637
630
|
|
|
638
631
|
return delta
|
|
639
632
|
|
|
@@ -656,11 +649,6 @@ class ToolCallPartDelta:
|
|
|
656
649
|
part = replace(part, args=updated_dict)
|
|
657
650
|
|
|
658
651
|
if self.tool_call_id:
|
|
659
|
-
# Replace the tool_call_id entirely if given
|
|
660
|
-
if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
|
|
661
|
-
raise UnexpectedModelBehavior(
|
|
662
|
-
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
|
|
663
|
-
)
|
|
664
652
|
part = replace(part, tool_call_id=self.tool_call_id)
|
|
665
653
|
return part
|
|
666
654
|
|
|
@@ -12,7 +12,7 @@ from contextlib import asynccontextmanager, contextmanager
|
|
|
12
12
|
from dataclasses import dataclass, field
|
|
13
13
|
from datetime import datetime
|
|
14
14
|
from functools import cache
|
|
15
|
-
from typing import TYPE_CHECKING
|
|
15
|
+
from typing import TYPE_CHECKING, cast
|
|
16
16
|
|
|
17
17
|
import httpx
|
|
18
18
|
from typing_extensions import Literal
|
|
@@ -133,8 +133,6 @@ KnownModelName = Literal[
|
|
|
133
133
|
'gpt-4-turbo-2024-04-09',
|
|
134
134
|
'gpt-4-turbo-preview',
|
|
135
135
|
'gpt-4-vision-preview',
|
|
136
|
-
'gpt-4.5-preview',
|
|
137
|
-
'gpt-4.5-preview-2025-02-27',
|
|
138
136
|
'gpt-4o',
|
|
139
137
|
'gpt-4o-2024-05-13',
|
|
140
138
|
'gpt-4o-2024-08-06',
|
|
@@ -146,6 +144,10 @@ KnownModelName = Literal[
|
|
|
146
144
|
'gpt-4o-mini-2024-07-18',
|
|
147
145
|
'gpt-4o-mini-audio-preview',
|
|
148
146
|
'gpt-4o-mini-audio-preview-2024-12-17',
|
|
147
|
+
'gpt-4o-mini-search-preview',
|
|
148
|
+
'gpt-4o-mini-search-preview-2025-03-11',
|
|
149
|
+
'gpt-4o-search-preview',
|
|
150
|
+
'gpt-4o-search-preview-2025-03-11',
|
|
149
151
|
'groq:gemma2-9b-it',
|
|
150
152
|
'groq:llama-3.1-8b-instant',
|
|
151
153
|
'groq:llama-3.2-11b-vision-preview',
|
|
@@ -189,8 +191,6 @@ KnownModelName = Literal[
|
|
|
189
191
|
'openai:gpt-4-turbo-2024-04-09',
|
|
190
192
|
'openai:gpt-4-turbo-preview',
|
|
191
193
|
'openai:gpt-4-vision-preview',
|
|
192
|
-
'openai:gpt-4.5-preview',
|
|
193
|
-
'openai:gpt-4.5-preview-2025-02-27',
|
|
194
194
|
'openai:gpt-4o',
|
|
195
195
|
'openai:gpt-4o-2024-05-13',
|
|
196
196
|
'openai:gpt-4o-2024-08-06',
|
|
@@ -202,6 +202,10 @@ KnownModelName = Literal[
|
|
|
202
202
|
'openai:gpt-4o-mini-2024-07-18',
|
|
203
203
|
'openai:gpt-4o-mini-audio-preview',
|
|
204
204
|
'openai:gpt-4o-mini-audio-preview-2024-12-17',
|
|
205
|
+
'openai:gpt-4o-mini-search-preview',
|
|
206
|
+
'openai:gpt-4o-mini-search-preview-2025-03-11',
|
|
207
|
+
'openai:gpt-4o-search-preview',
|
|
208
|
+
'openai:gpt-4o-search-preview-2025-03-11',
|
|
205
209
|
'openai:o1',
|
|
206
210
|
'openai:o1-2024-12-17',
|
|
207
211
|
'openai:o1-mini',
|
|
@@ -379,6 +383,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
379
383
|
|
|
380
384
|
try:
|
|
381
385
|
provider, model_name = model.split(':', maxsplit=1)
|
|
386
|
+
provider = cast(str, provider)
|
|
382
387
|
except ValueError:
|
|
383
388
|
model_name = model
|
|
384
389
|
# TODO(Marcelo): We should deprecate this way.
|
|
@@ -397,8 +402,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
397
402
|
if provider == 'cohere':
|
|
398
403
|
from .cohere import CohereModel
|
|
399
404
|
|
|
400
|
-
|
|
401
|
-
return CohereModel(model_name)
|
|
405
|
+
return CohereModel(model_name, provider=provider)
|
|
402
406
|
elif provider in ('deepseek', 'openai'):
|
|
403
407
|
from .openai import OpenAIModel
|
|
404
408
|
|
|
@@ -410,22 +414,19 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
410
414
|
elif provider == 'groq':
|
|
411
415
|
from .groq import GroqModel
|
|
412
416
|
|
|
413
|
-
|
|
414
|
-
return GroqModel(model_name)
|
|
417
|
+
return GroqModel(model_name, provider=provider)
|
|
415
418
|
elif provider == 'mistral':
|
|
416
419
|
from .mistral import MistralModel
|
|
417
420
|
|
|
418
|
-
|
|
419
|
-
return MistralModel(model_name)
|
|
421
|
+
return MistralModel(model_name, provider=provider)
|
|
420
422
|
elif provider == 'anthropic':
|
|
421
423
|
from .anthropic import AnthropicModel
|
|
422
424
|
|
|
423
|
-
|
|
424
|
-
return AnthropicModel(model_name)
|
|
425
|
+
return AnthropicModel(model_name, provider=provider)
|
|
425
426
|
elif provider == 'bedrock':
|
|
426
427
|
from .bedrock import BedrockConverseModel
|
|
427
428
|
|
|
428
|
-
return BedrockConverseModel(model_name)
|
|
429
|
+
return BedrockConverseModel(model_name, provider=provider)
|
|
429
430
|
else:
|
|
430
431
|
raise UserError(f'Unknown model: {model}')
|
|
431
432
|
|
|
@@ -10,8 +10,7 @@ from json import JSONDecodeError, loads as json_loads
|
|
|
10
10
|
from typing import Any, Literal, Union, cast, overload
|
|
11
11
|
|
|
12
12
|
from anthropic.types import DocumentBlockParam
|
|
13
|
-
from
|
|
14
|
-
from typing_extensions import assert_never, deprecated
|
|
13
|
+
from typing_extensions import assert_never
|
|
15
14
|
|
|
16
15
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
17
16
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
@@ -112,34 +111,11 @@ class AnthropicModel(Model):
|
|
|
112
111
|
_model_name: AnthropicModelName = field(repr=False)
|
|
113
112
|
_system: str = field(default='anthropic', repr=False)
|
|
114
113
|
|
|
115
|
-
@overload
|
|
116
114
|
def __init__(
|
|
117
115
|
self,
|
|
118
116
|
model_name: AnthropicModelName,
|
|
119
117
|
*,
|
|
120
118
|
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
|
|
121
|
-
) -> None: ...
|
|
122
|
-
|
|
123
|
-
@deprecated('Use the `provider` parameter instead of `api_key`, `anthropic_client`, and `http_client`.')
|
|
124
|
-
@overload
|
|
125
|
-
def __init__(
|
|
126
|
-
self,
|
|
127
|
-
model_name: AnthropicModelName,
|
|
128
|
-
*,
|
|
129
|
-
provider: None = None,
|
|
130
|
-
api_key: str | None = None,
|
|
131
|
-
anthropic_client: AsyncAnthropic | None = None,
|
|
132
|
-
http_client: AsyncHTTPClient | None = None,
|
|
133
|
-
) -> None: ...
|
|
134
|
-
|
|
135
|
-
def __init__(
|
|
136
|
-
self,
|
|
137
|
-
model_name: AnthropicModelName,
|
|
138
|
-
*,
|
|
139
|
-
provider: Literal['anthropic'] | Provider[AsyncAnthropic] | None = None,
|
|
140
|
-
api_key: str | None = None,
|
|
141
|
-
anthropic_client: AsyncAnthropic | None = None,
|
|
142
|
-
http_client: AsyncHTTPClient | None = None,
|
|
143
119
|
):
|
|
144
120
|
"""Initialize an Anthropic model.
|
|
145
121
|
|
|
@@ -148,27 +124,12 @@ class AnthropicModel(Model):
|
|
|
148
124
|
[here](https://docs.anthropic.com/en/docs/about-claude/models).
|
|
149
125
|
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
|
|
150
126
|
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
|
|
151
|
-
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
|
|
152
|
-
will be used if available.
|
|
153
|
-
anthropic_client: An existing
|
|
154
|
-
[`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
|
|
155
|
-
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
156
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
157
127
|
"""
|
|
158
128
|
self._model_name = model_name
|
|
159
129
|
|
|
160
|
-
if provider
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
self.client = provider.client
|
|
164
|
-
elif anthropic_client is not None:
|
|
165
|
-
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
|
|
166
|
-
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
|
|
167
|
-
self.client = anthropic_client
|
|
168
|
-
elif http_client is not None:
|
|
169
|
-
self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
|
|
170
|
-
else:
|
|
171
|
-
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
|
|
130
|
+
if isinstance(provider, str):
|
|
131
|
+
provider = infer_provider(provider)
|
|
132
|
+
self.client = provider.client
|
|
172
133
|
|
|
173
134
|
@property
|
|
174
135
|
def base_url(self) -> str:
|
|
@@ -326,7 +287,7 @@ class AnthropicModel(Model):
|
|
|
326
287
|
user_content_params.append(content)
|
|
327
288
|
elif isinstance(request_part, ToolReturnPart):
|
|
328
289
|
tool_result_block_param = ToolResultBlockParam(
|
|
329
|
-
tool_use_id=_guard_tool_call_id(t=request_part
|
|
290
|
+
tool_use_id=_guard_tool_call_id(t=request_part),
|
|
330
291
|
type='tool_result',
|
|
331
292
|
content=request_part.model_response_str(),
|
|
332
293
|
is_error=False,
|
|
@@ -337,7 +298,7 @@ class AnthropicModel(Model):
|
|
|
337
298
|
retry_param = TextBlockParam(type='text', text=request_part.model_response())
|
|
338
299
|
else:
|
|
339
300
|
retry_param = ToolResultBlockParam(
|
|
340
|
-
tool_use_id=_guard_tool_call_id(t=request_part
|
|
301
|
+
tool_use_id=_guard_tool_call_id(t=request_part),
|
|
341
302
|
type='tool_result',
|
|
342
303
|
content=request_part.model_response(),
|
|
343
304
|
is_error=True,
|
|
@@ -351,7 +312,7 @@ class AnthropicModel(Model):
|
|
|
351
312
|
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
|
|
352
313
|
else:
|
|
353
314
|
tool_use_block_param = ToolUseBlockParam(
|
|
354
|
-
id=_guard_tool_call_id(t=response_part
|
|
315
|
+
id=_guard_tool_call_id(t=response_part),
|
|
355
316
|
type='tool_use',
|
|
356
317
|
name=response_part.tool_name,
|
|
357
318
|
input=response_part.args_as_dict(),
|
|
@@ -143,14 +143,15 @@ class BedrockConverseModel(Model):
|
|
|
143
143
|
model_name: The name of the model to use.
|
|
144
144
|
model_name: The name of the Bedrock model to use. List of model names available
|
|
145
145
|
[here](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html).
|
|
146
|
-
provider: The provider to use.
|
|
146
|
+
provider: The provider to use for authentication and API access. Can be either the string
|
|
147
|
+
'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be
|
|
148
|
+
created using the other parameters.
|
|
147
149
|
"""
|
|
148
150
|
self._model_name = model_name
|
|
149
151
|
|
|
150
152
|
if isinstance(provider, str):
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
self.client = cast('BedrockRuntimeClient', provider.client)
|
|
153
|
+
provider = infer_provider(provider)
|
|
154
|
+
self.client = cast('BedrockRuntimeClient', provider.client)
|
|
154
155
|
|
|
155
156
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
|
|
156
157
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
@@ -345,7 +346,7 @@ class BedrockConverseModel(Model):
|
|
|
345
346
|
content.append({'text': item.content})
|
|
346
347
|
else:
|
|
347
348
|
assert isinstance(item, ToolCallPart)
|
|
348
|
-
content.append(self._map_tool_call(item))
|
|
349
|
+
content.append(self._map_tool_call(item))
|
|
349
350
|
bedrock_messages.append({'role': 'assistant', 'content': content})
|
|
350
351
|
else:
|
|
351
352
|
assert_never(m)
|
|
@@ -394,13 +395,8 @@ class BedrockConverseModel(Model):
|
|
|
394
395
|
|
|
395
396
|
@staticmethod
|
|
396
397
|
def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
|
|
397
|
-
assert t.tool_call_id is not None
|
|
398
398
|
return {
|
|
399
|
-
'toolUse': {
|
|
400
|
-
'toolUseId': t.tool_call_id,
|
|
401
|
-
'name': t.tool_name,
|
|
402
|
-
'input': t.args_as_dict(),
|
|
403
|
-
}
|
|
399
|
+
'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()}
|
|
404
400
|
}
|
|
405
401
|
|
|
406
402
|
|
|
@@ -6,11 +6,10 @@ from itertools import chain
|
|
|
6
6
|
from typing import Literal, Union, cast
|
|
7
7
|
|
|
8
8
|
from cohere import TextAssistantMessageContentItem
|
|
9
|
-
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
9
|
from typing_extensions import assert_never
|
|
11
10
|
|
|
12
11
|
from .. import ModelHTTPError, result
|
|
13
|
-
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
12
|
+
from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
|
|
14
13
|
from ..messages import (
|
|
15
14
|
ModelMessage,
|
|
16
15
|
ModelRequest,
|
|
@@ -23,6 +22,7 @@ from ..messages import (
|
|
|
23
22
|
ToolReturnPart,
|
|
24
23
|
UserPromptPart,
|
|
25
24
|
)
|
|
25
|
+
from ..providers import Provider, infer_provider
|
|
26
26
|
from ..settings import ModelSettings
|
|
27
27
|
from ..tools import ToolDefinition
|
|
28
28
|
from . import (
|
|
@@ -104,28 +104,22 @@ class CohereModel(Model):
|
|
|
104
104
|
self,
|
|
105
105
|
model_name: CohereModelName,
|
|
106
106
|
*,
|
|
107
|
-
|
|
108
|
-
cohere_client: AsyncClientV2 | None = None,
|
|
109
|
-
http_client: AsyncHTTPClient | None = None,
|
|
107
|
+
provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
|
|
110
108
|
):
|
|
111
109
|
"""Initialize an Cohere model.
|
|
112
110
|
|
|
113
111
|
Args:
|
|
114
112
|
model_name: The name of the Cohere model to use. List of model names
|
|
115
113
|
available [here](https://docs.cohere.com/docs/models#command).
|
|
116
|
-
|
|
117
|
-
`
|
|
118
|
-
|
|
119
|
-
`api_key` and `http_client` must be `None`.
|
|
120
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
114
|
+
provider: The provider to use for authentication and API access. Can be either the string
|
|
115
|
+
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
|
|
116
|
+
created using the other parameters.
|
|
121
117
|
"""
|
|
122
118
|
self._model_name: CohereModelName = model_name
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
else:
|
|
128
|
-
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
|
|
119
|
+
|
|
120
|
+
if isinstance(provider, str):
|
|
121
|
+
provider = infer_provider(provider)
|
|
122
|
+
self.client = provider.client
|
|
129
123
|
|
|
130
124
|
@property
|
|
131
125
|
def base_url(self) -> str:
|
|
@@ -191,7 +185,7 @@ class CohereModel(Model):
|
|
|
191
185
|
ToolCallPart(
|
|
192
186
|
tool_name=c.function.name,
|
|
193
187
|
args=c.function.arguments,
|
|
194
|
-
tool_call_id=c.id,
|
|
188
|
+
tool_call_id=c.id or _generate_tool_call_id(),
|
|
195
189
|
)
|
|
196
190
|
)
|
|
197
191
|
return ModelResponse(parts=parts, model_name=self._model_name)
|
|
@@ -228,7 +222,7 @@ class CohereModel(Model):
|
|
|
228
222
|
@staticmethod
|
|
229
223
|
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
|
230
224
|
return ToolCallV2(
|
|
231
|
-
id=_guard_tool_call_id(t=t
|
|
225
|
+
id=_guard_tool_call_id(t=t),
|
|
232
226
|
type='function',
|
|
233
227
|
function=ToolCallV2Function(
|
|
234
228
|
name=t.tool_name,
|
|
@@ -260,7 +254,7 @@ class CohereModel(Model):
|
|
|
260
254
|
elif isinstance(part, ToolReturnPart):
|
|
261
255
|
yield ToolChatMessageV2(
|
|
262
256
|
role='tool',
|
|
263
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
257
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
264
258
|
content=part.model_response_str(),
|
|
265
259
|
)
|
|
266
260
|
elif isinstance(part, RetryPromptPart):
|
|
@@ -269,7 +263,7 @@ class CohereModel(Model):
|
|
|
269
263
|
else:
|
|
270
264
|
yield ToolChatMessageV2(
|
|
271
265
|
role='tool',
|
|
272
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
266
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
273
267
|
content=part.model_response(),
|
|
274
268
|
)
|
|
275
269
|
else:
|