pydantic-ai-slim 0.0.44__tar.gz → 0.0.45__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/PKG-INFO +2 -2
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_parts_manager.py +7 -1
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_utils.py +12 -6
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/exceptions.py +2 -2
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/messages.py +15 -27
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/anthropic.py +7 -46
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/bedrock.py +7 -11
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/cohere.py +10 -50
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/gemini.py +18 -73
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/groq.py +9 -53
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/mistral.py +12 -51
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/openai.py +15 -67
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/anthropic.py +4 -5
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/azure.py +8 -9
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/bedrock.py +2 -1
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/cohere.py +4 -5
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/deepseek.py +4 -4
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/google_gla.py +3 -2
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/google_vertex.py +2 -3
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/groq.py +4 -5
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/mistral.py +4 -5
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/openai.py +5 -8
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pyproject.toml +2 -2
- pydantic_ai_slim-0.0.44/pydantic_ai/models/vertexai.py +0 -260
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/README.md +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_agent_graph.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_cli.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/agent.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/mcp.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/__init__.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.0.44 → pydantic_ai_slim-0.0.45}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.45
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.45
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
@@ -29,6 +29,8 @@ from pydantic_ai.messages import (
|
|
|
29
29
|
ToolCallPartDelta,
|
|
30
30
|
)
|
|
31
31
|
|
|
32
|
+
from ._utils import generate_tool_call_id as _generate_tool_call_id
|
|
33
|
+
|
|
32
34
|
VendorId = Hashable
|
|
33
35
|
"""
|
|
34
36
|
Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
|
|
@@ -221,7 +223,11 @@ class ModelResponsePartsManager:
|
|
|
221
223
|
ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
|
|
222
224
|
has been added to the manager, or replaced an existing part.
|
|
223
225
|
"""
|
|
224
|
-
new_part = ToolCallPart(
|
|
226
|
+
new_part = ToolCallPart(
|
|
227
|
+
tool_name=tool_name,
|
|
228
|
+
args=args,
|
|
229
|
+
tool_call_id=tool_call_id or _generate_tool_call_id(),
|
|
230
|
+
)
|
|
225
231
|
if vendor_part_id is None:
|
|
226
232
|
# vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
|
|
227
233
|
new_part_index = len(self._parts)
|
|
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import time
|
|
5
|
+
import uuid
|
|
5
6
|
from collections.abc import AsyncIterable, AsyncIterator, Iterator
|
|
6
7
|
from contextlib import asynccontextmanager, suppress
|
|
7
8
|
from dataclasses import dataclass, is_dataclass
|
|
@@ -195,12 +196,17 @@ def now_utc() -> datetime:
|
|
|
195
196
|
return datetime.now(tz=timezone.utc)
|
|
196
197
|
|
|
197
198
|
|
|
198
|
-
def guard_tool_call_id(
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
199
|
+
def guard_tool_call_id(t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart) -> str:
|
|
200
|
+
"""Type guard that either returns the tool call id or generates a new one if it's None."""
|
|
201
|
+
return t.tool_call_id or generate_tool_call_id()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def generate_tool_call_id() -> str:
|
|
205
|
+
"""Generate a tool call id.
|
|
206
|
+
|
|
207
|
+
Ensure that the tool call id is unique.
|
|
208
|
+
"""
|
|
209
|
+
return f'pyd_ai_{uuid.uuid4().hex}'
|
|
204
210
|
|
|
205
211
|
|
|
206
212
|
class PeekableAsyncStream(Generic[T]):
|
|
@@ -3,9 +3,9 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import json
|
|
4
4
|
import sys
|
|
5
5
|
|
|
6
|
-
if sys.version_info < (3, 11):
|
|
6
|
+
if sys.version_info < (3, 11): # pragma: no cover
|
|
7
7
|
from exceptiongroup import ExceptionGroup
|
|
8
|
-
else:
|
|
8
|
+
else: # pragma: no cover
|
|
9
9
|
ExceptionGroup = ExceptionGroup
|
|
10
10
|
|
|
11
11
|
__all__ = (
|
|
@@ -12,7 +12,7 @@ import pydantic_core
|
|
|
12
12
|
from opentelemetry._events import Event
|
|
13
13
|
from typing_extensions import TypeAlias
|
|
14
14
|
|
|
15
|
-
from ._utils import now_utc as _now_utc
|
|
15
|
+
from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
|
|
16
16
|
from .exceptions import UnexpectedModelBehavior
|
|
17
17
|
|
|
18
18
|
|
|
@@ -268,8 +268,8 @@ class ToolReturnPart:
|
|
|
268
268
|
content: Any
|
|
269
269
|
"""The return value."""
|
|
270
270
|
|
|
271
|
-
tool_call_id: str
|
|
272
|
-
"""
|
|
271
|
+
tool_call_id: str
|
|
272
|
+
"""The tool call identifier, this is used by some models including OpenAI."""
|
|
273
273
|
|
|
274
274
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
275
275
|
"""The timestamp, when the tool returned."""
|
|
@@ -328,8 +328,11 @@ class RetryPromptPart:
|
|
|
328
328
|
tool_name: str | None = None
|
|
329
329
|
"""The name of the tool that was called, if any."""
|
|
330
330
|
|
|
331
|
-
tool_call_id: str
|
|
332
|
-
"""
|
|
331
|
+
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
332
|
+
"""The tool call identifier, this is used by some models including OpenAI.
|
|
333
|
+
|
|
334
|
+
In case the tool call id is not provided by the model, PydanticAI will generate a random one.
|
|
335
|
+
"""
|
|
333
336
|
|
|
334
337
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
335
338
|
"""The timestamp, when the retry was triggered."""
|
|
@@ -406,8 +409,11 @@ class ToolCallPart:
|
|
|
406
409
|
This is stored either as a JSON string or a Python dictionary depending on how data was received.
|
|
407
410
|
"""
|
|
408
411
|
|
|
409
|
-
tool_call_id: str
|
|
410
|
-
"""
|
|
412
|
+
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
413
|
+
"""The tool call identifier, this is used by some models including OpenAI.
|
|
414
|
+
|
|
415
|
+
In case the tool call id is not provided by the model, PydanticAI will generate a random one.
|
|
416
|
+
"""
|
|
411
417
|
|
|
412
418
|
part_kind: Literal['tool-call'] = 'tool-call'
|
|
413
419
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
@@ -564,11 +570,7 @@ class ToolCallPartDelta:
|
|
|
564
570
|
if self.tool_name_delta is None or self.args_delta is None:
|
|
565
571
|
return None
|
|
566
572
|
|
|
567
|
-
return ToolCallPart(
|
|
568
|
-
self.tool_name_delta,
|
|
569
|
-
self.args_delta,
|
|
570
|
-
self.tool_call_id,
|
|
571
|
-
)
|
|
573
|
+
return ToolCallPart(self.tool_name_delta, self.args_delta, self.tool_call_id or _generate_tool_call_id())
|
|
572
574
|
|
|
573
575
|
@overload
|
|
574
576
|
def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
|
|
@@ -620,20 +622,11 @@ class ToolCallPartDelta:
|
|
|
620
622
|
delta = replace(delta, args_delta=updated_args_delta)
|
|
621
623
|
|
|
622
624
|
if self.tool_call_id:
|
|
623
|
-
# Set the tool_call_id if it wasn't present, otherwise error if it has changed
|
|
624
|
-
if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
|
|
625
|
-
raise UnexpectedModelBehavior(
|
|
626
|
-
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
|
|
627
|
-
)
|
|
628
625
|
delta = replace(delta, tool_call_id=self.tool_call_id)
|
|
629
626
|
|
|
630
627
|
# If we now have enough data to create a full ToolCallPart, do so
|
|
631
628
|
if delta.tool_name_delta is not None and delta.args_delta is not None:
|
|
632
|
-
return ToolCallPart(
|
|
633
|
-
delta.tool_name_delta,
|
|
634
|
-
delta.args_delta,
|
|
635
|
-
delta.tool_call_id,
|
|
636
|
-
)
|
|
629
|
+
return ToolCallPart(delta.tool_name_delta, delta.args_delta, delta.tool_call_id or _generate_tool_call_id())
|
|
637
630
|
|
|
638
631
|
return delta
|
|
639
632
|
|
|
@@ -656,11 +649,6 @@ class ToolCallPartDelta:
|
|
|
656
649
|
part = replace(part, args=updated_dict)
|
|
657
650
|
|
|
658
651
|
if self.tool_call_id:
|
|
659
|
-
# Replace the tool_call_id entirely if given
|
|
660
|
-
if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
|
|
661
|
-
raise UnexpectedModelBehavior(
|
|
662
|
-
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
|
|
663
|
-
)
|
|
664
652
|
part = replace(part, tool_call_id=self.tool_call_id)
|
|
665
653
|
return part
|
|
666
654
|
|
|
@@ -10,8 +10,7 @@ from json import JSONDecodeError, loads as json_loads
|
|
|
10
10
|
from typing import Any, Literal, Union, cast, overload
|
|
11
11
|
|
|
12
12
|
from anthropic.types import DocumentBlockParam
|
|
13
|
-
from
|
|
14
|
-
from typing_extensions import assert_never, deprecated
|
|
13
|
+
from typing_extensions import assert_never
|
|
15
14
|
|
|
16
15
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
17
16
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
@@ -112,34 +111,11 @@ class AnthropicModel(Model):
|
|
|
112
111
|
_model_name: AnthropicModelName = field(repr=False)
|
|
113
112
|
_system: str = field(default='anthropic', repr=False)
|
|
114
113
|
|
|
115
|
-
@overload
|
|
116
114
|
def __init__(
|
|
117
115
|
self,
|
|
118
116
|
model_name: AnthropicModelName,
|
|
119
117
|
*,
|
|
120
118
|
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
|
|
121
|
-
) -> None: ...
|
|
122
|
-
|
|
123
|
-
@deprecated('Use the `provider` parameter instead of `api_key`, `anthropic_client`, and `http_client`.')
|
|
124
|
-
@overload
|
|
125
|
-
def __init__(
|
|
126
|
-
self,
|
|
127
|
-
model_name: AnthropicModelName,
|
|
128
|
-
*,
|
|
129
|
-
provider: None = None,
|
|
130
|
-
api_key: str | None = None,
|
|
131
|
-
anthropic_client: AsyncAnthropic | None = None,
|
|
132
|
-
http_client: AsyncHTTPClient | None = None,
|
|
133
|
-
) -> None: ...
|
|
134
|
-
|
|
135
|
-
def __init__(
|
|
136
|
-
self,
|
|
137
|
-
model_name: AnthropicModelName,
|
|
138
|
-
*,
|
|
139
|
-
provider: Literal['anthropic'] | Provider[AsyncAnthropic] | None = None,
|
|
140
|
-
api_key: str | None = None,
|
|
141
|
-
anthropic_client: AsyncAnthropic | None = None,
|
|
142
|
-
http_client: AsyncHTTPClient | None = None,
|
|
143
119
|
):
|
|
144
120
|
"""Initialize an Anthropic model.
|
|
145
121
|
|
|
@@ -148,27 +124,12 @@ class AnthropicModel(Model):
|
|
|
148
124
|
[here](https://docs.anthropic.com/en/docs/about-claude/models).
|
|
149
125
|
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
|
|
150
126
|
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
|
|
151
|
-
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
|
|
152
|
-
will be used if available.
|
|
153
|
-
anthropic_client: An existing
|
|
154
|
-
[`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
|
|
155
|
-
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
156
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
157
127
|
"""
|
|
158
128
|
self._model_name = model_name
|
|
159
129
|
|
|
160
|
-
if provider
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
self.client = provider.client
|
|
164
|
-
elif anthropic_client is not None:
|
|
165
|
-
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
|
|
166
|
-
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
|
|
167
|
-
self.client = anthropic_client
|
|
168
|
-
elif http_client is not None:
|
|
169
|
-
self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
|
|
170
|
-
else:
|
|
171
|
-
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
|
|
130
|
+
if isinstance(provider, str):
|
|
131
|
+
provider = infer_provider(provider)
|
|
132
|
+
self.client = provider.client
|
|
172
133
|
|
|
173
134
|
@property
|
|
174
135
|
def base_url(self) -> str:
|
|
@@ -326,7 +287,7 @@ class AnthropicModel(Model):
|
|
|
326
287
|
user_content_params.append(content)
|
|
327
288
|
elif isinstance(request_part, ToolReturnPart):
|
|
328
289
|
tool_result_block_param = ToolResultBlockParam(
|
|
329
|
-
tool_use_id=_guard_tool_call_id(t=request_part
|
|
290
|
+
tool_use_id=_guard_tool_call_id(t=request_part),
|
|
330
291
|
type='tool_result',
|
|
331
292
|
content=request_part.model_response_str(),
|
|
332
293
|
is_error=False,
|
|
@@ -337,7 +298,7 @@ class AnthropicModel(Model):
|
|
|
337
298
|
retry_param = TextBlockParam(type='text', text=request_part.model_response())
|
|
338
299
|
else:
|
|
339
300
|
retry_param = ToolResultBlockParam(
|
|
340
|
-
tool_use_id=_guard_tool_call_id(t=request_part
|
|
301
|
+
tool_use_id=_guard_tool_call_id(t=request_part),
|
|
341
302
|
type='tool_result',
|
|
342
303
|
content=request_part.model_response(),
|
|
343
304
|
is_error=True,
|
|
@@ -351,7 +312,7 @@ class AnthropicModel(Model):
|
|
|
351
312
|
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
|
|
352
313
|
else:
|
|
353
314
|
tool_use_block_param = ToolUseBlockParam(
|
|
354
|
-
id=_guard_tool_call_id(t=response_part
|
|
315
|
+
id=_guard_tool_call_id(t=response_part),
|
|
355
316
|
type='tool_use',
|
|
356
317
|
name=response_part.tool_name,
|
|
357
318
|
input=response_part.args_as_dict(),
|
|
@@ -143,14 +143,15 @@ class BedrockConverseModel(Model):
|
|
|
143
143
|
model_name: The name of the model to use.
|
|
144
144
|
model_name: The name of the Bedrock model to use. List of model names available
|
|
145
145
|
[here](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html).
|
|
146
|
-
provider: The provider to use.
|
|
146
|
+
provider: The provider to use for authentication and API access. Can be either the string
|
|
147
|
+
'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be
|
|
148
|
+
created using the other parameters.
|
|
147
149
|
"""
|
|
148
150
|
self._model_name = model_name
|
|
149
151
|
|
|
150
152
|
if isinstance(provider, str):
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
self.client = cast('BedrockRuntimeClient', provider.client)
|
|
153
|
+
provider = infer_provider(provider)
|
|
154
|
+
self.client = cast('BedrockRuntimeClient', provider.client)
|
|
154
155
|
|
|
155
156
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
|
|
156
157
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
@@ -345,7 +346,7 @@ class BedrockConverseModel(Model):
|
|
|
345
346
|
content.append({'text': item.content})
|
|
346
347
|
else:
|
|
347
348
|
assert isinstance(item, ToolCallPart)
|
|
348
|
-
content.append(self._map_tool_call(item))
|
|
349
|
+
content.append(self._map_tool_call(item))
|
|
349
350
|
bedrock_messages.append({'role': 'assistant', 'content': content})
|
|
350
351
|
else:
|
|
351
352
|
assert_never(m)
|
|
@@ -394,13 +395,8 @@ class BedrockConverseModel(Model):
|
|
|
394
395
|
|
|
395
396
|
@staticmethod
|
|
396
397
|
def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
|
|
397
|
-
assert t.tool_call_id is not None
|
|
398
398
|
return {
|
|
399
|
-
'toolUse': {
|
|
400
|
-
'toolUseId': t.tool_call_id,
|
|
401
|
-
'name': t.tool_name,
|
|
402
|
-
'input': t.args_as_dict(),
|
|
403
|
-
}
|
|
399
|
+
'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()}
|
|
404
400
|
}
|
|
405
401
|
|
|
406
402
|
|
|
@@ -3,14 +3,13 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
5
|
from itertools import chain
|
|
6
|
-
from typing import Literal, Union, cast
|
|
6
|
+
from typing import Literal, Union, cast
|
|
7
7
|
|
|
8
8
|
from cohere import TextAssistantMessageContentItem
|
|
9
|
-
from
|
|
10
|
-
from typing_extensions import assert_never, deprecated
|
|
9
|
+
from typing_extensions import assert_never
|
|
11
10
|
|
|
12
11
|
from .. import ModelHTTPError, result
|
|
13
|
-
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
12
|
+
from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
|
|
14
13
|
from ..messages import (
|
|
15
14
|
ModelMessage,
|
|
16
15
|
ModelRequest,
|
|
@@ -29,7 +28,6 @@ from ..tools import ToolDefinition
|
|
|
29
28
|
from . import (
|
|
30
29
|
Model,
|
|
31
30
|
ModelRequestParameters,
|
|
32
|
-
cached_async_http_client,
|
|
33
31
|
check_allow_model_requests,
|
|
34
32
|
)
|
|
35
33
|
|
|
@@ -102,37 +100,11 @@ class CohereModel(Model):
|
|
|
102
100
|
_model_name: CohereModelName = field(repr=False)
|
|
103
101
|
_system: str = field(default='cohere', repr=False)
|
|
104
102
|
|
|
105
|
-
@overload
|
|
106
103
|
def __init__(
|
|
107
104
|
self,
|
|
108
105
|
model_name: CohereModelName,
|
|
109
106
|
*,
|
|
110
107
|
provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
|
|
111
|
-
api_key: None = None,
|
|
112
|
-
cohere_client: None = None,
|
|
113
|
-
http_client: None = None,
|
|
114
|
-
) -> None: ...
|
|
115
|
-
|
|
116
|
-
@deprecated('Use the `provider` parameter instead of `api_key`, `cohere_client`, and `http_client`.')
|
|
117
|
-
@overload
|
|
118
|
-
def __init__(
|
|
119
|
-
self,
|
|
120
|
-
model_name: CohereModelName,
|
|
121
|
-
*,
|
|
122
|
-
provider: None = None,
|
|
123
|
-
api_key: str | None = None,
|
|
124
|
-
cohere_client: AsyncClientV2 | None = None,
|
|
125
|
-
http_client: AsyncHTTPClient | None = None,
|
|
126
|
-
) -> None: ...
|
|
127
|
-
|
|
128
|
-
def __init__(
|
|
129
|
-
self,
|
|
130
|
-
model_name: CohereModelName,
|
|
131
|
-
*,
|
|
132
|
-
provider: Literal['cohere'] | Provider[AsyncClientV2] | None = None,
|
|
133
|
-
api_key: str | None = None,
|
|
134
|
-
cohere_client: AsyncClientV2 | None = None,
|
|
135
|
-
http_client: AsyncHTTPClient | None = None,
|
|
136
108
|
):
|
|
137
109
|
"""Initialize an Cohere model.
|
|
138
110
|
|
|
@@ -142,24 +114,12 @@ class CohereModel(Model):
|
|
|
142
114
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
143
115
|
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
|
|
144
116
|
created using the other parameters.
|
|
145
|
-
api_key: The API key to use for authentication, if not provided, the
|
|
146
|
-
`CO_API_KEY` environment variable will be used if available.
|
|
147
|
-
cohere_client: An existing Cohere async client to use. If provided,
|
|
148
|
-
`api_key` and `http_client` must be `None`.
|
|
149
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
150
117
|
"""
|
|
151
118
|
self._model_name: CohereModelName = model_name
|
|
152
119
|
|
|
153
|
-
if provider
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
self.client = provider.client
|
|
157
|
-
elif cohere_client is not None:
|
|
158
|
-
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
|
|
159
|
-
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
160
|
-
self.client = cohere_client
|
|
161
|
-
else:
|
|
162
|
-
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client or cached_async_http_client())
|
|
120
|
+
if isinstance(provider, str):
|
|
121
|
+
provider = infer_provider(provider)
|
|
122
|
+
self.client = provider.client
|
|
163
123
|
|
|
164
124
|
@property
|
|
165
125
|
def base_url(self) -> str:
|
|
@@ -225,7 +185,7 @@ class CohereModel(Model):
|
|
|
225
185
|
ToolCallPart(
|
|
226
186
|
tool_name=c.function.name,
|
|
227
187
|
args=c.function.arguments,
|
|
228
|
-
tool_call_id=c.id,
|
|
188
|
+
tool_call_id=c.id or _generate_tool_call_id(),
|
|
229
189
|
)
|
|
230
190
|
)
|
|
231
191
|
return ModelResponse(parts=parts, model_name=self._model_name)
|
|
@@ -262,7 +222,7 @@ class CohereModel(Model):
|
|
|
262
222
|
@staticmethod
|
|
263
223
|
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
|
264
224
|
return ToolCallV2(
|
|
265
|
-
id=_guard_tool_call_id(t=t
|
|
225
|
+
id=_guard_tool_call_id(t=t),
|
|
266
226
|
type='function',
|
|
267
227
|
function=ToolCallV2Function(
|
|
268
228
|
name=t.tool_name,
|
|
@@ -294,7 +254,7 @@ class CohereModel(Model):
|
|
|
294
254
|
elif isinstance(part, ToolReturnPart):
|
|
295
255
|
yield ToolChatMessageV2(
|
|
296
256
|
role='tool',
|
|
297
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
257
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
298
258
|
content=part.model_response_str(),
|
|
299
259
|
)
|
|
300
260
|
elif isinstance(part, RetryPromptPart):
|
|
@@ -303,7 +263,7 @@ class CohereModel(Model):
|
|
|
303
263
|
else:
|
|
304
264
|
yield ToolChatMessageV2(
|
|
305
265
|
role='tool',
|
|
306
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
266
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
307
267
|
content=part.model_response(),
|
|
308
268
|
)
|
|
309
269
|
else:
|
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import os
|
|
5
4
|
import re
|
|
6
5
|
from collections.abc import AsyncIterator, Sequence
|
|
7
6
|
from contextlib import asynccontextmanager
|
|
8
7
|
from copy import deepcopy
|
|
9
8
|
from dataclasses import dataclass, field
|
|
10
9
|
from datetime import datetime
|
|
11
|
-
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
10
|
+
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
12
11
|
from uuid import uuid4
|
|
13
12
|
|
|
13
|
+
import httpx
|
|
14
14
|
import pydantic
|
|
15
|
-
from httpx import USE_CLIENT_DEFAULT,
|
|
16
|
-
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
15
|
+
from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
|
|
16
|
+
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
17
17
|
|
|
18
18
|
from pydantic_ai.providers import Provider, infer_provider
|
|
19
19
|
|
|
@@ -85,78 +85,36 @@ class GeminiModel(Model):
|
|
|
85
85
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
86
86
|
"""
|
|
87
87
|
|
|
88
|
-
client:
|
|
88
|
+
client: httpx.AsyncClient = field(repr=False)
|
|
89
89
|
|
|
90
90
|
_model_name: GeminiModelName = field(repr=False)
|
|
91
|
-
_provider: Literal['google-gla', 'google-vertex'] | Provider[
|
|
91
|
+
_provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] | None = field(repr=False)
|
|
92
92
|
_auth: AuthProtocol | None = field(repr=False)
|
|
93
93
|
_url: str | None = field(repr=False)
|
|
94
94
|
_system: str = field(default='gemini', repr=False)
|
|
95
95
|
|
|
96
|
-
@overload
|
|
97
96
|
def __init__(
|
|
98
97
|
self,
|
|
99
98
|
model_name: GeminiModelName,
|
|
100
99
|
*,
|
|
101
|
-
provider: Literal['google-gla', 'google-vertex'] | Provider[
|
|
102
|
-
) -> None: ...
|
|
103
|
-
|
|
104
|
-
@deprecated('Use the `provider` argument instead of the `api_key`, `http_client`, and `url_template` arguments.')
|
|
105
|
-
@overload
|
|
106
|
-
def __init__(
|
|
107
|
-
self,
|
|
108
|
-
model_name: GeminiModelName,
|
|
109
|
-
*,
|
|
110
|
-
provider: None = None,
|
|
111
|
-
api_key: str | None = None,
|
|
112
|
-
http_client: AsyncHTTPClient | None = None,
|
|
113
|
-
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
114
|
-
) -> None: ...
|
|
115
|
-
|
|
116
|
-
def __init__(
|
|
117
|
-
self,
|
|
118
|
-
model_name: GeminiModelName,
|
|
119
|
-
*,
|
|
120
|
-
provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = None,
|
|
121
|
-
api_key: str | None = None,
|
|
122
|
-
http_client: AsyncHTTPClient | None = None,
|
|
123
|
-
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
100
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla',
|
|
124
101
|
):
|
|
125
102
|
"""Initialize a Gemini model.
|
|
126
103
|
|
|
127
104
|
Args:
|
|
128
105
|
model_name: The name of the model to use.
|
|
129
|
-
provider: The provider to use for the
|
|
130
|
-
|
|
131
|
-
will be
|
|
132
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
133
|
-
url_template: The URL template to use for making requests, you shouldn't need to change this,
|
|
134
|
-
docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
|
|
135
|
-
`model` is substituted with the model name, and `function` is added to the end of the URL.
|
|
106
|
+
provider: The provider to use for authentication and API access. Can be either the string
|
|
107
|
+
'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
|
|
108
|
+
If not provided, a new provider will be created using the other parameters.
|
|
136
109
|
"""
|
|
137
110
|
self._model_name = model_name
|
|
138
111
|
self._provider = provider
|
|
139
112
|
|
|
140
|
-
if provider
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
self._url = str(self.client.base_url)
|
|
146
|
-
else:
|
|
147
|
-
if api_key is None:
|
|
148
|
-
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
149
|
-
api_key = env_api_key
|
|
150
|
-
else:
|
|
151
|
-
raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
152
|
-
self.client = http_client or cached_async_http_client()
|
|
153
|
-
self._auth = ApiKeyAuth(api_key)
|
|
154
|
-
self._url = url_template.format(model=model_name)
|
|
155
|
-
|
|
156
|
-
@property
|
|
157
|
-
def auth(self) -> AuthProtocol:
|
|
158
|
-
assert self._auth is not None, 'Auth not initialized'
|
|
159
|
-
return self._auth
|
|
113
|
+
if isinstance(provider, str):
|
|
114
|
+
provider = infer_provider(provider)
|
|
115
|
+
self._system = provider.name
|
|
116
|
+
self.client = provider.client
|
|
117
|
+
self._url = str(self.client.base_url)
|
|
160
118
|
|
|
161
119
|
@property
|
|
162
120
|
def base_url(self) -> str:
|
|
@@ -252,18 +210,10 @@ class GeminiModel(Model):
|
|
|
252
210
|
if generation_config:
|
|
253
211
|
request_data['generation_config'] = generation_config
|
|
254
212
|
|
|
255
|
-
headers = {
|
|
256
|
-
|
|
257
|
-
'User-Agent': get_user_agent(),
|
|
258
|
-
}
|
|
259
|
-
if self._provider is None: # pragma: no cover
|
|
260
|
-
url = self.base_url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
261
|
-
headers.update(await self.auth.headers())
|
|
262
|
-
else:
|
|
263
|
-
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
|
|
213
|
+
headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
214
|
+
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
|
|
264
215
|
|
|
265
216
|
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
|
|
266
|
-
|
|
267
217
|
async with self.client.stream(
|
|
268
218
|
'POST',
|
|
269
219
|
url,
|
|
@@ -603,12 +553,7 @@ def _process_response_from_parts(
|
|
|
603
553
|
if 'text' in part:
|
|
604
554
|
items.append(TextPart(content=part['text']))
|
|
605
555
|
elif 'function_call' in part:
|
|
606
|
-
items.append(
|
|
607
|
-
ToolCallPart(
|
|
608
|
-
tool_name=part['function_call']['name'],
|
|
609
|
-
args=part['function_call']['args'],
|
|
610
|
-
)
|
|
611
|
-
)
|
|
556
|
+
items.append(ToolCallPart(tool_name=part['function_call']['name'], args=part['function_call']['args']))
|
|
612
557
|
elif 'function_response' in part:
|
|
613
558
|
raise UnexpectedModelBehavior(
|
|
614
559
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|