pydantic-ai-slim 0.0.44__tar.gz → 0.0.46__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/PKG-INFO +2 -2
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_parts_manager.py +7 -1
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_utils.py +12 -6
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/agent.py +2 -2
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/exceptions.py +2 -2
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/mcp.py +25 -1
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/messages.py +15 -27
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/__init__.py +15 -6
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/anthropic.py +7 -46
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/bedrock.py +7 -11
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/cohere.py +10 -50
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/gemini.py +18 -73
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/groq.py +9 -53
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/mistral.py +12 -51
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/openai.py +15 -67
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/anthropic.py +6 -6
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/azure.py +9 -10
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/bedrock.py +2 -1
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/cohere.py +6 -8
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/deepseek.py +6 -5
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/google_gla.py +4 -3
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/google_vertex.py +3 -4
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/groq.py +6 -8
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/mistral.py +6 -6
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/openai.py +6 -8
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pyproject.toml +2 -2
- pydantic_ai_slim-0.0.44/pydantic_ai/models/vertexai.py +0 -260
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/README.md +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_agent_graph.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_cli.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.46}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.46
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.46
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
@@ -29,6 +29,8 @@ from pydantic_ai.messages import (
|
|
|
29
29
|
ToolCallPartDelta,
|
|
30
30
|
)
|
|
31
31
|
|
|
32
|
+
from ._utils import generate_tool_call_id as _generate_tool_call_id
|
|
33
|
+
|
|
32
34
|
VendorId = Hashable
|
|
33
35
|
"""
|
|
34
36
|
Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
|
|
@@ -221,7 +223,11 @@ class ModelResponsePartsManager:
|
|
|
221
223
|
ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
|
|
222
224
|
has been added to the manager, or replaced an existing part.
|
|
223
225
|
"""
|
|
224
|
-
new_part = ToolCallPart(
|
|
226
|
+
new_part = ToolCallPart(
|
|
227
|
+
tool_name=tool_name,
|
|
228
|
+
args=args,
|
|
229
|
+
tool_call_id=tool_call_id or _generate_tool_call_id(),
|
|
230
|
+
)
|
|
225
231
|
if vendor_part_id is None:
|
|
226
232
|
# vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
|
|
227
233
|
new_part_index = len(self._parts)
|
|
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import time
|
|
5
|
+
import uuid
|
|
5
6
|
from collections.abc import AsyncIterable, AsyncIterator, Iterator
|
|
6
7
|
from contextlib import asynccontextmanager, suppress
|
|
7
8
|
from dataclasses import dataclass, is_dataclass
|
|
@@ -195,12 +196,17 @@ def now_utc() -> datetime:
|
|
|
195
196
|
return datetime.now(tz=timezone.utc)
|
|
196
197
|
|
|
197
198
|
|
|
198
|
-
def guard_tool_call_id(
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
199
|
+
def guard_tool_call_id(t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart) -> str:
|
|
200
|
+
"""Type guard that either returns the tool call id or generates a new one if it's None."""
|
|
201
|
+
return t.tool_call_id or generate_tool_call_id()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def generate_tool_call_id() -> str:
|
|
205
|
+
"""Generate a tool call id.
|
|
206
|
+
|
|
207
|
+
Ensure that the tool call id is unique.
|
|
208
|
+
"""
|
|
209
|
+
return f'pyd_ai_{uuid.uuid4().hex}'
|
|
204
210
|
|
|
205
211
|
|
|
206
212
|
class PeekableAsyncStream(Generic[T]):
|
|
@@ -13,7 +13,7 @@ from pydantic.json_schema import GenerateJsonSchema
|
|
|
13
13
|
from typing_extensions import TypeGuard, TypeVar, deprecated
|
|
14
14
|
|
|
15
15
|
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
16
|
-
from pydantic_graph._utils import
|
|
16
|
+
from pydantic_graph._utils import get_event_loop
|
|
17
17
|
|
|
18
18
|
from . import (
|
|
19
19
|
_agent_graph,
|
|
@@ -567,7 +567,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
567
567
|
"""
|
|
568
568
|
if infer_name and self.name is None:
|
|
569
569
|
self._infer_name(inspect.currentframe())
|
|
570
|
-
return run_until_complete(
|
|
570
|
+
return get_event_loop().run_until_complete(
|
|
571
571
|
self.run(
|
|
572
572
|
user_prompt,
|
|
573
573
|
result_type=result_type,
|
|
@@ -3,9 +3,9 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import json
|
|
4
4
|
import sys
|
|
5
5
|
|
|
6
|
-
if sys.version_info < (3, 11):
|
|
6
|
+
if sys.version_info < (3, 11): # pragma: no cover
|
|
7
7
|
from exceptiongroup import ExceptionGroup
|
|
8
|
-
else:
|
|
8
|
+
else: # pragma: no cover
|
|
9
9
|
ExceptionGroup = ExceptionGroup
|
|
10
10
|
|
|
11
11
|
__all__ = (
|
|
@@ -188,11 +188,35 @@ class MCPServerHTTP(MCPServer):
|
|
|
188
188
|
For example for a server running locally, this might be `http://localhost:3001/sse`.
|
|
189
189
|
"""
|
|
190
190
|
|
|
191
|
+
headers: dict[str, Any] | None = None
|
|
192
|
+
"""Optional HTTP headers to be sent with each request to the SSE endpoint.
|
|
193
|
+
|
|
194
|
+
These headers will be passed directly to the underlying `httpx.AsyncClient`.
|
|
195
|
+
Useful for authentication, custom headers, or other HTTP-specific configurations.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
timeout: float = 5
|
|
199
|
+
"""Initial connection timeout in seconds for establishing the SSE connection.
|
|
200
|
+
|
|
201
|
+
This timeout applies to the initial connection setup and handshake.
|
|
202
|
+
If the connection cannot be established within this time, the operation will fail.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
sse_read_timeout: float = 60 * 5
|
|
206
|
+
"""Maximum time in seconds to wait for new SSE messages before timing out.
|
|
207
|
+
|
|
208
|
+
This timeout applies to the long-lived SSE connection after it's established.
|
|
209
|
+
If no new messages are received within this time, the connection will be considered stale
|
|
210
|
+
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
211
|
+
"""
|
|
212
|
+
|
|
191
213
|
@asynccontextmanager
|
|
192
214
|
async def client_streams(
|
|
193
215
|
self,
|
|
194
216
|
) -> AsyncIterator[
|
|
195
217
|
tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
|
|
196
218
|
]: # pragma: no cover
|
|
197
|
-
async with sse_client(
|
|
219
|
+
async with sse_client(
|
|
220
|
+
url=self.url, headers=self.headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout
|
|
221
|
+
) as (read_stream, write_stream):
|
|
198
222
|
yield read_stream, write_stream
|
|
@@ -12,7 +12,7 @@ import pydantic_core
|
|
|
12
12
|
from opentelemetry._events import Event
|
|
13
13
|
from typing_extensions import TypeAlias
|
|
14
14
|
|
|
15
|
-
from ._utils import now_utc as _now_utc
|
|
15
|
+
from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
|
|
16
16
|
from .exceptions import UnexpectedModelBehavior
|
|
17
17
|
|
|
18
18
|
|
|
@@ -268,8 +268,8 @@ class ToolReturnPart:
|
|
|
268
268
|
content: Any
|
|
269
269
|
"""The return value."""
|
|
270
270
|
|
|
271
|
-
tool_call_id: str
|
|
272
|
-
"""
|
|
271
|
+
tool_call_id: str
|
|
272
|
+
"""The tool call identifier, this is used by some models including OpenAI."""
|
|
273
273
|
|
|
274
274
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
275
275
|
"""The timestamp, when the tool returned."""
|
|
@@ -328,8 +328,11 @@ class RetryPromptPart:
|
|
|
328
328
|
tool_name: str | None = None
|
|
329
329
|
"""The name of the tool that was called, if any."""
|
|
330
330
|
|
|
331
|
-
tool_call_id: str
|
|
332
|
-
"""
|
|
331
|
+
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
332
|
+
"""The tool call identifier, this is used by some models including OpenAI.
|
|
333
|
+
|
|
334
|
+
In case the tool call id is not provided by the model, PydanticAI will generate a random one.
|
|
335
|
+
"""
|
|
333
336
|
|
|
334
337
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
335
338
|
"""The timestamp, when the retry was triggered."""
|
|
@@ -406,8 +409,11 @@ class ToolCallPart:
|
|
|
406
409
|
This is stored either as a JSON string or a Python dictionary depending on how data was received.
|
|
407
410
|
"""
|
|
408
411
|
|
|
409
|
-
tool_call_id: str
|
|
410
|
-
"""
|
|
412
|
+
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
413
|
+
"""The tool call identifier, this is used by some models including OpenAI.
|
|
414
|
+
|
|
415
|
+
In case the tool call id is not provided by the model, PydanticAI will generate a random one.
|
|
416
|
+
"""
|
|
411
417
|
|
|
412
418
|
part_kind: Literal['tool-call'] = 'tool-call'
|
|
413
419
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
@@ -564,11 +570,7 @@ class ToolCallPartDelta:
|
|
|
564
570
|
if self.tool_name_delta is None or self.args_delta is None:
|
|
565
571
|
return None
|
|
566
572
|
|
|
567
|
-
return ToolCallPart(
|
|
568
|
-
self.tool_name_delta,
|
|
569
|
-
self.args_delta,
|
|
570
|
-
self.tool_call_id,
|
|
571
|
-
)
|
|
573
|
+
return ToolCallPart(self.tool_name_delta, self.args_delta, self.tool_call_id or _generate_tool_call_id())
|
|
572
574
|
|
|
573
575
|
@overload
|
|
574
576
|
def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
|
|
@@ -620,20 +622,11 @@ class ToolCallPartDelta:
|
|
|
620
622
|
delta = replace(delta, args_delta=updated_args_delta)
|
|
621
623
|
|
|
622
624
|
if self.tool_call_id:
|
|
623
|
-
# Set the tool_call_id if it wasn't present, otherwise error if it has changed
|
|
624
|
-
if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
|
|
625
|
-
raise UnexpectedModelBehavior(
|
|
626
|
-
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
|
|
627
|
-
)
|
|
628
625
|
delta = replace(delta, tool_call_id=self.tool_call_id)
|
|
629
626
|
|
|
630
627
|
# If we now have enough data to create a full ToolCallPart, do so
|
|
631
628
|
if delta.tool_name_delta is not None and delta.args_delta is not None:
|
|
632
|
-
return ToolCallPart(
|
|
633
|
-
delta.tool_name_delta,
|
|
634
|
-
delta.args_delta,
|
|
635
|
-
delta.tool_call_id,
|
|
636
|
-
)
|
|
629
|
+
return ToolCallPart(delta.tool_name_delta, delta.args_delta, delta.tool_call_id or _generate_tool_call_id())
|
|
637
630
|
|
|
638
631
|
return delta
|
|
639
632
|
|
|
@@ -656,11 +649,6 @@ class ToolCallPartDelta:
|
|
|
656
649
|
part = replace(part, args=updated_dict)
|
|
657
650
|
|
|
658
651
|
if self.tool_call_id:
|
|
659
|
-
# Replace the tool_call_id entirely if given
|
|
660
|
-
if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
|
|
661
|
-
raise UnexpectedModelBehavior(
|
|
662
|
-
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
|
|
663
|
-
)
|
|
664
652
|
part = replace(part, tool_call_id=self.tool_call_id)
|
|
665
653
|
return part
|
|
666
654
|
|
|
@@ -431,32 +431,41 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
431
431
|
raise UserError(f'Unknown model: {model}')
|
|
432
432
|
|
|
433
433
|
|
|
434
|
-
def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
435
|
-
"""Cached HTTPX async client
|
|
434
|
+
def cached_async_http_client(*, provider: str | None = None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
435
|
+
"""Cached HTTPX async client that creates a separate client for each provider.
|
|
436
|
+
|
|
437
|
+
The client is cached based on the provider parameter. If provider is None, it's used for non-provider specific
|
|
438
|
+
requests (like downloading images). Multiple agents and calls can share the same client when they use the same provider.
|
|
436
439
|
|
|
437
440
|
There are good reasons why in production you should use a `httpx.AsyncClient` as an async context manager as
|
|
438
441
|
described in [encode/httpx#2026](https://github.com/encode/httpx/pull/2026), but when experimenting or showing
|
|
439
|
-
examples, it's very useful not to
|
|
442
|
+
examples, it's very useful not to.
|
|
440
443
|
|
|
441
444
|
The default timeouts match those of OpenAI,
|
|
442
445
|
see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
|
|
443
446
|
"""
|
|
444
|
-
client = _cached_async_http_client(timeout=timeout, connect=connect)
|
|
447
|
+
client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect)
|
|
445
448
|
if client.is_closed:
|
|
446
449
|
# This happens if the context manager is used, so we need to create a new client.
|
|
447
450
|
_cached_async_http_client.cache_clear()
|
|
448
|
-
client = _cached_async_http_client(timeout=timeout, connect=connect)
|
|
451
|
+
client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect)
|
|
449
452
|
return client
|
|
450
453
|
|
|
451
454
|
|
|
452
455
|
@cache
|
|
453
|
-
def _cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
456
|
+
def _cached_async_http_client(provider: str | None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
454
457
|
return httpx.AsyncClient(
|
|
458
|
+
transport=_cached_async_http_transport(),
|
|
455
459
|
timeout=httpx.Timeout(timeout=timeout, connect=connect),
|
|
456
460
|
headers={'User-Agent': get_user_agent()},
|
|
457
461
|
)
|
|
458
462
|
|
|
459
463
|
|
|
464
|
+
@cache
|
|
465
|
+
def _cached_async_http_transport() -> httpx.AsyncHTTPTransport:
|
|
466
|
+
return httpx.AsyncHTTPTransport()
|
|
467
|
+
|
|
468
|
+
|
|
460
469
|
@cache
|
|
461
470
|
def get_user_agent() -> str:
|
|
462
471
|
"""Get the user agent string for the HTTP client."""
|
|
@@ -10,8 +10,7 @@ from json import JSONDecodeError, loads as json_loads
|
|
|
10
10
|
from typing import Any, Literal, Union, cast, overload
|
|
11
11
|
|
|
12
12
|
from anthropic.types import DocumentBlockParam
|
|
13
|
-
from
|
|
14
|
-
from typing_extensions import assert_never, deprecated
|
|
13
|
+
from typing_extensions import assert_never
|
|
15
14
|
|
|
16
15
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
17
16
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
@@ -112,34 +111,11 @@ class AnthropicModel(Model):
|
|
|
112
111
|
_model_name: AnthropicModelName = field(repr=False)
|
|
113
112
|
_system: str = field(default='anthropic', repr=False)
|
|
114
113
|
|
|
115
|
-
@overload
|
|
116
114
|
def __init__(
|
|
117
115
|
self,
|
|
118
116
|
model_name: AnthropicModelName,
|
|
119
117
|
*,
|
|
120
118
|
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
|
|
121
|
-
) -> None: ...
|
|
122
|
-
|
|
123
|
-
@deprecated('Use the `provider` parameter instead of `api_key`, `anthropic_client`, and `http_client`.')
|
|
124
|
-
@overload
|
|
125
|
-
def __init__(
|
|
126
|
-
self,
|
|
127
|
-
model_name: AnthropicModelName,
|
|
128
|
-
*,
|
|
129
|
-
provider: None = None,
|
|
130
|
-
api_key: str | None = None,
|
|
131
|
-
anthropic_client: AsyncAnthropic | None = None,
|
|
132
|
-
http_client: AsyncHTTPClient | None = None,
|
|
133
|
-
) -> None: ...
|
|
134
|
-
|
|
135
|
-
def __init__(
|
|
136
|
-
self,
|
|
137
|
-
model_name: AnthropicModelName,
|
|
138
|
-
*,
|
|
139
|
-
provider: Literal['anthropic'] | Provider[AsyncAnthropic] | None = None,
|
|
140
|
-
api_key: str | None = None,
|
|
141
|
-
anthropic_client: AsyncAnthropic | None = None,
|
|
142
|
-
http_client: AsyncHTTPClient | None = None,
|
|
143
119
|
):
|
|
144
120
|
"""Initialize an Anthropic model.
|
|
145
121
|
|
|
@@ -148,27 +124,12 @@ class AnthropicModel(Model):
|
|
|
148
124
|
[here](https://docs.anthropic.com/en/docs/about-claude/models).
|
|
149
125
|
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
|
|
150
126
|
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
|
|
151
|
-
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
|
|
152
|
-
will be used if available.
|
|
153
|
-
anthropic_client: An existing
|
|
154
|
-
[`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
|
|
155
|
-
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
156
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
157
127
|
"""
|
|
158
128
|
self._model_name = model_name
|
|
159
129
|
|
|
160
|
-
if provider
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
self.client = provider.client
|
|
164
|
-
elif anthropic_client is not None:
|
|
165
|
-
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
|
|
166
|
-
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
|
|
167
|
-
self.client = anthropic_client
|
|
168
|
-
elif http_client is not None:
|
|
169
|
-
self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
|
|
170
|
-
else:
|
|
171
|
-
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
|
|
130
|
+
if isinstance(provider, str):
|
|
131
|
+
provider = infer_provider(provider)
|
|
132
|
+
self.client = provider.client
|
|
172
133
|
|
|
173
134
|
@property
|
|
174
135
|
def base_url(self) -> str:
|
|
@@ -326,7 +287,7 @@ class AnthropicModel(Model):
|
|
|
326
287
|
user_content_params.append(content)
|
|
327
288
|
elif isinstance(request_part, ToolReturnPart):
|
|
328
289
|
tool_result_block_param = ToolResultBlockParam(
|
|
329
|
-
tool_use_id=_guard_tool_call_id(t=request_part
|
|
290
|
+
tool_use_id=_guard_tool_call_id(t=request_part),
|
|
330
291
|
type='tool_result',
|
|
331
292
|
content=request_part.model_response_str(),
|
|
332
293
|
is_error=False,
|
|
@@ -337,7 +298,7 @@ class AnthropicModel(Model):
|
|
|
337
298
|
retry_param = TextBlockParam(type='text', text=request_part.model_response())
|
|
338
299
|
else:
|
|
339
300
|
retry_param = ToolResultBlockParam(
|
|
340
|
-
tool_use_id=_guard_tool_call_id(t=request_part
|
|
301
|
+
tool_use_id=_guard_tool_call_id(t=request_part),
|
|
341
302
|
type='tool_result',
|
|
342
303
|
content=request_part.model_response(),
|
|
343
304
|
is_error=True,
|
|
@@ -351,7 +312,7 @@ class AnthropicModel(Model):
|
|
|
351
312
|
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
|
|
352
313
|
else:
|
|
353
314
|
tool_use_block_param = ToolUseBlockParam(
|
|
354
|
-
id=_guard_tool_call_id(t=response_part
|
|
315
|
+
id=_guard_tool_call_id(t=response_part),
|
|
355
316
|
type='tool_use',
|
|
356
317
|
name=response_part.tool_name,
|
|
357
318
|
input=response_part.args_as_dict(),
|
|
@@ -143,14 +143,15 @@ class BedrockConverseModel(Model):
|
|
|
143
143
|
model_name: The name of the model to use.
|
|
144
144
|
model_name: The name of the Bedrock model to use. List of model names available
|
|
145
145
|
[here](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html).
|
|
146
|
-
provider: The provider to use.
|
|
146
|
+
provider: The provider to use for authentication and API access. Can be either the string
|
|
147
|
+
'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be
|
|
148
|
+
created using the other parameters.
|
|
147
149
|
"""
|
|
148
150
|
self._model_name = model_name
|
|
149
151
|
|
|
150
152
|
if isinstance(provider, str):
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
self.client = cast('BedrockRuntimeClient', provider.client)
|
|
153
|
+
provider = infer_provider(provider)
|
|
154
|
+
self.client = cast('BedrockRuntimeClient', provider.client)
|
|
154
155
|
|
|
155
156
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
|
|
156
157
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
@@ -345,7 +346,7 @@ class BedrockConverseModel(Model):
|
|
|
345
346
|
content.append({'text': item.content})
|
|
346
347
|
else:
|
|
347
348
|
assert isinstance(item, ToolCallPart)
|
|
348
|
-
content.append(self._map_tool_call(item))
|
|
349
|
+
content.append(self._map_tool_call(item))
|
|
349
350
|
bedrock_messages.append({'role': 'assistant', 'content': content})
|
|
350
351
|
else:
|
|
351
352
|
assert_never(m)
|
|
@@ -394,13 +395,8 @@ class BedrockConverseModel(Model):
|
|
|
394
395
|
|
|
395
396
|
@staticmethod
|
|
396
397
|
def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
|
|
397
|
-
assert t.tool_call_id is not None
|
|
398
398
|
return {
|
|
399
|
-
'toolUse': {
|
|
400
|
-
'toolUseId': t.tool_call_id,
|
|
401
|
-
'name': t.tool_name,
|
|
402
|
-
'input': t.args_as_dict(),
|
|
403
|
-
}
|
|
399
|
+
'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()}
|
|
404
400
|
}
|
|
405
401
|
|
|
406
402
|
|
|
@@ -3,14 +3,13 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
5
|
from itertools import chain
|
|
6
|
-
from typing import Literal, Union, cast
|
|
6
|
+
from typing import Literal, Union, cast
|
|
7
7
|
|
|
8
8
|
from cohere import TextAssistantMessageContentItem
|
|
9
|
-
from
|
|
10
|
-
from typing_extensions import assert_never, deprecated
|
|
9
|
+
from typing_extensions import assert_never
|
|
11
10
|
|
|
12
11
|
from .. import ModelHTTPError, result
|
|
13
|
-
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
12
|
+
from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
|
|
14
13
|
from ..messages import (
|
|
15
14
|
ModelMessage,
|
|
16
15
|
ModelRequest,
|
|
@@ -29,7 +28,6 @@ from ..tools import ToolDefinition
|
|
|
29
28
|
from . import (
|
|
30
29
|
Model,
|
|
31
30
|
ModelRequestParameters,
|
|
32
|
-
cached_async_http_client,
|
|
33
31
|
check_allow_model_requests,
|
|
34
32
|
)
|
|
35
33
|
|
|
@@ -102,37 +100,11 @@ class CohereModel(Model):
|
|
|
102
100
|
_model_name: CohereModelName = field(repr=False)
|
|
103
101
|
_system: str = field(default='cohere', repr=False)
|
|
104
102
|
|
|
105
|
-
@overload
|
|
106
103
|
def __init__(
|
|
107
104
|
self,
|
|
108
105
|
model_name: CohereModelName,
|
|
109
106
|
*,
|
|
110
107
|
provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
|
|
111
|
-
api_key: None = None,
|
|
112
|
-
cohere_client: None = None,
|
|
113
|
-
http_client: None = None,
|
|
114
|
-
) -> None: ...
|
|
115
|
-
|
|
116
|
-
@deprecated('Use the `provider` parameter instead of `api_key`, `cohere_client`, and `http_client`.')
|
|
117
|
-
@overload
|
|
118
|
-
def __init__(
|
|
119
|
-
self,
|
|
120
|
-
model_name: CohereModelName,
|
|
121
|
-
*,
|
|
122
|
-
provider: None = None,
|
|
123
|
-
api_key: str | None = None,
|
|
124
|
-
cohere_client: AsyncClientV2 | None = None,
|
|
125
|
-
http_client: AsyncHTTPClient | None = None,
|
|
126
|
-
) -> None: ...
|
|
127
|
-
|
|
128
|
-
def __init__(
|
|
129
|
-
self,
|
|
130
|
-
model_name: CohereModelName,
|
|
131
|
-
*,
|
|
132
|
-
provider: Literal['cohere'] | Provider[AsyncClientV2] | None = None,
|
|
133
|
-
api_key: str | None = None,
|
|
134
|
-
cohere_client: AsyncClientV2 | None = None,
|
|
135
|
-
http_client: AsyncHTTPClient | None = None,
|
|
136
108
|
):
|
|
137
109
|
"""Initialize an Cohere model.
|
|
138
110
|
|
|
@@ -142,24 +114,12 @@ class CohereModel(Model):
|
|
|
142
114
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
143
115
|
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
|
|
144
116
|
created using the other parameters.
|
|
145
|
-
api_key: The API key to use for authentication, if not provided, the
|
|
146
|
-
`CO_API_KEY` environment variable will be used if available.
|
|
147
|
-
cohere_client: An existing Cohere async client to use. If provided,
|
|
148
|
-
`api_key` and `http_client` must be `None`.
|
|
149
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
150
117
|
"""
|
|
151
118
|
self._model_name: CohereModelName = model_name
|
|
152
119
|
|
|
153
|
-
if provider
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
self.client = provider.client
|
|
157
|
-
elif cohere_client is not None:
|
|
158
|
-
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
|
|
159
|
-
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
160
|
-
self.client = cohere_client
|
|
161
|
-
else:
|
|
162
|
-
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client or cached_async_http_client())
|
|
120
|
+
if isinstance(provider, str):
|
|
121
|
+
provider = infer_provider(provider)
|
|
122
|
+
self.client = provider.client
|
|
163
123
|
|
|
164
124
|
@property
|
|
165
125
|
def base_url(self) -> str:
|
|
@@ -225,7 +185,7 @@ class CohereModel(Model):
|
|
|
225
185
|
ToolCallPart(
|
|
226
186
|
tool_name=c.function.name,
|
|
227
187
|
args=c.function.arguments,
|
|
228
|
-
tool_call_id=c.id,
|
|
188
|
+
tool_call_id=c.id or _generate_tool_call_id(),
|
|
229
189
|
)
|
|
230
190
|
)
|
|
231
191
|
return ModelResponse(parts=parts, model_name=self._model_name)
|
|
@@ -262,7 +222,7 @@ class CohereModel(Model):
|
|
|
262
222
|
@staticmethod
|
|
263
223
|
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
|
264
224
|
return ToolCallV2(
|
|
265
|
-
id=_guard_tool_call_id(t=t
|
|
225
|
+
id=_guard_tool_call_id(t=t),
|
|
266
226
|
type='function',
|
|
267
227
|
function=ToolCallV2Function(
|
|
268
228
|
name=t.tool_name,
|
|
@@ -294,7 +254,7 @@ class CohereModel(Model):
|
|
|
294
254
|
elif isinstance(part, ToolReturnPart):
|
|
295
255
|
yield ToolChatMessageV2(
|
|
296
256
|
role='tool',
|
|
297
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
257
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
298
258
|
content=part.model_response_str(),
|
|
299
259
|
)
|
|
300
260
|
elif isinstance(part, RetryPromptPart):
|
|
@@ -303,7 +263,7 @@ class CohereModel(Model):
|
|
|
303
263
|
else:
|
|
304
264
|
yield ToolChatMessageV2(
|
|
305
265
|
role='tool',
|
|
306
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
266
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
307
267
|
content=part.model_response(),
|
|
308
268
|
)
|
|
309
269
|
else:
|