pydantic-ai-slim 0.0.26__tar.gz → 0.0.27__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/PKG-INFO +3 -2
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/__init__.py +11 -1
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/agent.py +3 -3
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/exceptions.py +42 -1
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/anthropic.py +20 -15
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/cohere.py +18 -12
- pydantic_ai_slim-0.0.27/pydantic_ai/models/fallback.py +116 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/function.py +16 -7
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/gemini.py +30 -14
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/groq.py +24 -19
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/mistral.py +21 -14
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/openai.py +26 -21
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pyproject.toml +4 -5
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/README.md +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_agent_graph.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/messages.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/__init__.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/vertexai.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.0.26 → pydantic_ai_slim-0.0.27}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.27
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -25,10 +25,11 @@ Classifier: Topic :: Internet
|
|
|
25
25
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
26
26
|
Requires-Python: >=3.9
|
|
27
27
|
Requires-Dist: eval-type-backport>=0.2.0
|
|
28
|
+
Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
28
29
|
Requires-Dist: griffe>=1.3.2
|
|
29
30
|
Requires-Dist: httpx>=0.27
|
|
30
31
|
Requires-Dist: logfire-api>=1.2.0
|
|
31
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.27
|
|
32
33
|
Requires-Dist: pydantic>=2.10
|
|
33
34
|
Provides-Extra: anthropic
|
|
34
35
|
Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
|
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
from importlib.metadata import version
|
|
2
2
|
|
|
3
3
|
from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
|
|
4
|
-
from .exceptions import
|
|
4
|
+
from .exceptions import (
|
|
5
|
+
AgentRunError,
|
|
6
|
+
FallbackExceptionGroup,
|
|
7
|
+
ModelHTTPError,
|
|
8
|
+
ModelRetry,
|
|
9
|
+
UnexpectedModelBehavior,
|
|
10
|
+
UsageLimitExceeded,
|
|
11
|
+
UserError,
|
|
12
|
+
)
|
|
5
13
|
from .messages import AudioUrl, BinaryContent, ImageUrl
|
|
6
14
|
from .tools import RunContext, Tool
|
|
7
15
|
|
|
@@ -17,6 +25,8 @@ __all__ = (
|
|
|
17
25
|
# exceptions
|
|
18
26
|
'AgentRunError',
|
|
19
27
|
'ModelRetry',
|
|
28
|
+
'ModelHTTPError',
|
|
29
|
+
'FallbackExceptionGroup',
|
|
20
30
|
'UnexpectedModelBehavior',
|
|
21
31
|
'UsageLimitExceeded',
|
|
22
32
|
'UserError',
|
|
@@ -365,7 +365,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
365
365
|
HandleResponseNode(
|
|
366
366
|
model_response=ModelResponse(
|
|
367
367
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
368
|
-
model_name='
|
|
368
|
+
model_name='gpt-4o',
|
|
369
369
|
timestamp=datetime.datetime(...),
|
|
370
370
|
kind='response',
|
|
371
371
|
)
|
|
@@ -1214,7 +1214,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1214
1214
|
HandleResponseNode(
|
|
1215
1215
|
model_response=ModelResponse(
|
|
1216
1216
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
1217
|
-
model_name='
|
|
1217
|
+
model_name='gpt-4o',
|
|
1218
1218
|
timestamp=datetime.datetime(...),
|
|
1219
1219
|
kind='response',
|
|
1220
1220
|
)
|
|
@@ -1357,7 +1357,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1357
1357
|
HandleResponseNode(
|
|
1358
1358
|
model_response=ModelResponse(
|
|
1359
1359
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
1360
|
-
model_name='
|
|
1360
|
+
model_name='gpt-4o',
|
|
1361
1361
|
timestamp=datetime.datetime(...),
|
|
1362
1362
|
kind='response',
|
|
1363
1363
|
)
|
|
@@ -1,8 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import sys
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
if sys.version_info < (3, 11):
|
|
7
|
+
from exceptiongroup import ExceptionGroup
|
|
8
|
+
else:
|
|
9
|
+
ExceptionGroup = ExceptionGroup
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
'ModelRetry',
|
|
13
|
+
'UserError',
|
|
14
|
+
'AgentRunError',
|
|
15
|
+
'UnexpectedModelBehavior',
|
|
16
|
+
'UsageLimitExceeded',
|
|
17
|
+
'ModelHTTPError',
|
|
18
|
+
'FallbackExceptionGroup',
|
|
19
|
+
)
|
|
6
20
|
|
|
7
21
|
|
|
8
22
|
class ModelRetry(Exception):
|
|
@@ -72,3 +86,30 @@ class UnexpectedModelBehavior(AgentRunError):
|
|
|
72
86
|
return f'{self.message}, body:\n{self.body}'
|
|
73
87
|
else:
|
|
74
88
|
return self.message
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class ModelHTTPError(AgentRunError):
|
|
92
|
+
"""Raised when an model provider response has a status code of 4xx or 5xx."""
|
|
93
|
+
|
|
94
|
+
status_code: int
|
|
95
|
+
"""The HTTP status code returned by the API."""
|
|
96
|
+
|
|
97
|
+
model_name: str
|
|
98
|
+
"""The name of the model associated with the error."""
|
|
99
|
+
|
|
100
|
+
body: object | None
|
|
101
|
+
"""The body of the response, if available."""
|
|
102
|
+
|
|
103
|
+
message: str
|
|
104
|
+
"""The error message with the status code and response body, if available."""
|
|
105
|
+
|
|
106
|
+
def __init__(self, status_code: int, model_name: str, body: object | None = None):
|
|
107
|
+
self.status_code = status_code
|
|
108
|
+
self.model_name = model_name
|
|
109
|
+
self.body = body
|
|
110
|
+
message = f'status_code: {status_code}, model_name: {model_name}, body: {body}'
|
|
111
|
+
super().__init__(message)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class FallbackExceptionGroup(ExceptionGroup):
|
|
115
|
+
"""A group of exceptions that can be raised when all fallback models fail."""
|
|
@@ -11,7 +11,7 @@ from typing import Any, Literal, Union, cast, overload
|
|
|
11
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
14
|
-
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
16
16
|
from ..messages import (
|
|
17
17
|
BinaryContent,
|
|
@@ -39,7 +39,7 @@ from . import (
|
|
|
39
39
|
)
|
|
40
40
|
|
|
41
41
|
try:
|
|
42
|
-
from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
|
|
42
|
+
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
|
|
43
43
|
from anthropic.types import (
|
|
44
44
|
ImageBlockParam,
|
|
45
45
|
Message as AnthropicMessage,
|
|
@@ -220,19 +220,24 @@ class AnthropicModel(Model):
|
|
|
220
220
|
|
|
221
221
|
system_prompt, anthropic_messages = await self._map_message(messages)
|
|
222
222
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
223
|
+
try:
|
|
224
|
+
return await self.client.messages.create(
|
|
225
|
+
max_tokens=model_settings.get('max_tokens', 1024),
|
|
226
|
+
system=system_prompt or NOT_GIVEN,
|
|
227
|
+
messages=anthropic_messages,
|
|
228
|
+
model=self._model_name,
|
|
229
|
+
tools=tools or NOT_GIVEN,
|
|
230
|
+
tool_choice=tool_choice or NOT_GIVEN,
|
|
231
|
+
stream=stream,
|
|
232
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
233
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
234
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
235
|
+
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
|
|
236
|
+
)
|
|
237
|
+
except APIStatusError as e:
|
|
238
|
+
if (status_code := e.status_code) >= 400:
|
|
239
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
240
|
+
raise
|
|
236
241
|
|
|
237
242
|
def _process_response(self, response: AnthropicMessage) -> ModelResponse:
|
|
238
243
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -9,7 +9,7 @@ from cohere import TextAssistantMessageContentItem
|
|
|
9
9
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
12
|
-
from .. import result
|
|
12
|
+
from .. import ModelHTTPError, result
|
|
13
13
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
14
14
|
from ..messages import (
|
|
15
15
|
ModelMessage,
|
|
@@ -45,6 +45,7 @@ try:
|
|
|
45
45
|
ToolV2Function,
|
|
46
46
|
UserChatMessageV2,
|
|
47
47
|
)
|
|
48
|
+
from cohere.core.api_error import ApiError
|
|
48
49
|
from cohere.v2.client import OMIT
|
|
49
50
|
except ImportError as _import_error:
|
|
50
51
|
raise ImportError(
|
|
@@ -154,17 +155,22 @@ class CohereModel(Model):
|
|
|
154
155
|
) -> ChatResponse:
|
|
155
156
|
tools = self._get_tools(model_request_parameters)
|
|
156
157
|
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
158
|
+
try:
|
|
159
|
+
return await self.client.chat(
|
|
160
|
+
model=self._model_name,
|
|
161
|
+
messages=cohere_messages,
|
|
162
|
+
tools=tools or OMIT,
|
|
163
|
+
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
164
|
+
temperature=model_settings.get('temperature', OMIT),
|
|
165
|
+
p=model_settings.get('top_p', OMIT),
|
|
166
|
+
seed=model_settings.get('seed', OMIT),
|
|
167
|
+
presence_penalty=model_settings.get('presence_penalty', OMIT),
|
|
168
|
+
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
|
|
169
|
+
)
|
|
170
|
+
except ApiError as e:
|
|
171
|
+
if (status_code := e.status_code) and status_code >= 400:
|
|
172
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
173
|
+
raise
|
|
168
174
|
|
|
169
175
|
def _process_response(self, response: ChatResponse) -> ModelResponse:
|
|
170
176
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import TYPE_CHECKING, Callable
|
|
7
|
+
|
|
8
|
+
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
|
|
9
|
+
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from ..messages import ModelMessage, ModelResponse
|
|
13
|
+
from ..settings import ModelSettings
|
|
14
|
+
from ..usage import Usage
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(init=False)
|
|
18
|
+
class FallbackModel(Model):
|
|
19
|
+
"""A model that uses one or more fallback models upon failure.
|
|
20
|
+
|
|
21
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
models: list[Model]
|
|
25
|
+
|
|
26
|
+
_model_name: str = field(repr=False)
|
|
27
|
+
_fallback_on: Callable[[Exception], bool]
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
default_model: Model | KnownModelName,
|
|
32
|
+
*fallback_models: Model | KnownModelName,
|
|
33
|
+
fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
|
|
34
|
+
):
|
|
35
|
+
"""Initialize a fallback model instance.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
default_model: The name or instance of the default model to use.
|
|
39
|
+
fallback_models: The names or instances of the fallback models to use upon failure.
|
|
40
|
+
fallback_on: A callable or tuple of exceptions that should trigger a fallback.
|
|
41
|
+
"""
|
|
42
|
+
self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]
|
|
43
|
+
self._model_name = f'FallBackModel[{", ".join(model.model_name for model in self.models)}]'
|
|
44
|
+
|
|
45
|
+
if isinstance(fallback_on, tuple):
|
|
46
|
+
self._fallback_on = _default_fallback_condition_factory(fallback_on)
|
|
47
|
+
else:
|
|
48
|
+
self._fallback_on = fallback_on
|
|
49
|
+
|
|
50
|
+
async def request(
|
|
51
|
+
self,
|
|
52
|
+
messages: list[ModelMessage],
|
|
53
|
+
model_settings: ModelSettings | None,
|
|
54
|
+
model_request_parameters: ModelRequestParameters,
|
|
55
|
+
) -> tuple[ModelResponse, Usage]:
|
|
56
|
+
"""Try each model in sequence until one succeeds.
|
|
57
|
+
|
|
58
|
+
In case of failure, raise a FallbackExceptionGroup with all exceptions.
|
|
59
|
+
"""
|
|
60
|
+
exceptions: list[Exception] = []
|
|
61
|
+
|
|
62
|
+
for model in self.models:
|
|
63
|
+
try:
|
|
64
|
+
return await model.request(messages, model_settings, model_request_parameters)
|
|
65
|
+
except Exception as exc:
|
|
66
|
+
if self._fallback_on(exc):
|
|
67
|
+
exceptions.append(exc)
|
|
68
|
+
continue
|
|
69
|
+
raise exc
|
|
70
|
+
|
|
71
|
+
raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
|
|
72
|
+
|
|
73
|
+
@asynccontextmanager
|
|
74
|
+
async def request_stream(
|
|
75
|
+
self,
|
|
76
|
+
messages: list[ModelMessage],
|
|
77
|
+
model_settings: ModelSettings | None,
|
|
78
|
+
model_request_parameters: ModelRequestParameters,
|
|
79
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
80
|
+
"""Try each model in sequence until one succeeds."""
|
|
81
|
+
exceptions: list[Exception] = []
|
|
82
|
+
|
|
83
|
+
for model in self.models:
|
|
84
|
+
async with AsyncExitStack() as stack:
|
|
85
|
+
try:
|
|
86
|
+
response = await stack.enter_async_context(
|
|
87
|
+
model.request_stream(messages, model_settings, model_request_parameters)
|
|
88
|
+
)
|
|
89
|
+
except Exception as exc:
|
|
90
|
+
if self._fallback_on(exc):
|
|
91
|
+
exceptions.append(exc)
|
|
92
|
+
continue
|
|
93
|
+
raise exc
|
|
94
|
+
yield response
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def model_name(self) -> str:
|
|
101
|
+
"""The model name."""
|
|
102
|
+
return self._model_name
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def system(self) -> str | None:
|
|
106
|
+
"""The system / model provider, n/a for fallback models."""
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
|
|
111
|
+
"""Create a default fallback condition for the given exceptions."""
|
|
112
|
+
|
|
113
|
+
def fallback_condition(exception: Exception) -> bool:
|
|
114
|
+
return isinstance(exception, exceptions)
|
|
115
|
+
|
|
116
|
+
return fallback_condition
|
|
@@ -48,15 +48,23 @@ class FunctionModel(Model):
|
|
|
48
48
|
_system: str | None = field(default=None, repr=False)
|
|
49
49
|
|
|
50
50
|
@overload
|
|
51
|
-
def __init__(self, function: FunctionDef) -> None: ...
|
|
51
|
+
def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ...
|
|
52
52
|
|
|
53
53
|
@overload
|
|
54
|
-
def __init__(self, *, stream_function: StreamFunctionDef) -> None: ...
|
|
54
|
+
def __init__(self, *, stream_function: StreamFunctionDef, model_name: str | None = None) -> None: ...
|
|
55
55
|
|
|
56
56
|
@overload
|
|
57
|
-
def __init__(
|
|
57
|
+
def __init__(
|
|
58
|
+
self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None
|
|
59
|
+
) -> None: ...
|
|
58
60
|
|
|
59
|
-
def __init__(
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
function: FunctionDef | None = None,
|
|
64
|
+
*,
|
|
65
|
+
stream_function: StreamFunctionDef | None = None,
|
|
66
|
+
model_name: str | None = None,
|
|
67
|
+
):
|
|
60
68
|
"""Initialize a `FunctionModel`.
|
|
61
69
|
|
|
62
70
|
Either `function` or `stream_function` must be provided, providing both is allowed.
|
|
@@ -64,6 +72,7 @@ class FunctionModel(Model):
|
|
|
64
72
|
Args:
|
|
65
73
|
function: The function to call for non-streamed requests.
|
|
66
74
|
stream_function: The function to call for streamed requests.
|
|
75
|
+
model_name: The name of the model. If not provided, a name is generated from the function names.
|
|
67
76
|
"""
|
|
68
77
|
if function is None and stream_function is None:
|
|
69
78
|
raise TypeError('Either `function` or `stream_function` must be provided')
|
|
@@ -72,7 +81,7 @@ class FunctionModel(Model):
|
|
|
72
81
|
|
|
73
82
|
function_name = self.function.__name__ if self.function is not None else ''
|
|
74
83
|
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
75
|
-
self._model_name = f'function:{function_name}:{stream_function_name}'
|
|
84
|
+
self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
|
|
76
85
|
|
|
77
86
|
async def request(
|
|
78
87
|
self,
|
|
@@ -95,7 +104,7 @@ class FunctionModel(Model):
|
|
|
95
104
|
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
96
105
|
assert isinstance(response_, ModelResponse), response_
|
|
97
106
|
response = response_
|
|
98
|
-
response.model_name =
|
|
107
|
+
response.model_name = self._model_name
|
|
99
108
|
# TODO is `messages` right here? Should it just be new messages?
|
|
100
109
|
return response, _estimate_usage(chain(messages, [response]))
|
|
101
110
|
|
|
@@ -123,7 +132,7 @@ class FunctionModel(Model):
|
|
|
123
132
|
if isinstance(first, _utils.Unset):
|
|
124
133
|
raise ValueError('Stream function must return at least one item')
|
|
125
134
|
|
|
126
|
-
yield FunctionStreamedResponse(_model_name=
|
|
135
|
+
yield FunctionStreamedResponse(_model_name=self._model_name, _iter=response_stream)
|
|
127
136
|
|
|
128
137
|
@property
|
|
129
138
|
def model_name(self) -> str:
|
|
@@ -15,7 +15,7 @@ import pydantic
|
|
|
15
15
|
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
16
16
|
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
17
17
|
|
|
18
|
-
from .. import UnexpectedModelBehavior,
|
|
18
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
|
|
19
19
|
from ..messages import (
|
|
20
20
|
AudioUrl,
|
|
21
21
|
BinaryContent,
|
|
@@ -112,7 +112,7 @@ class GeminiModel(Model):
|
|
|
112
112
|
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
113
113
|
api_key = env_api_key
|
|
114
114
|
else:
|
|
115
|
-
raise
|
|
115
|
+
raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
116
116
|
self.http_client = http_client or cached_async_http_client()
|
|
117
117
|
self._auth = ApiKeyAuth(api_key)
|
|
118
118
|
self._url = url_template.format(model=model_name)
|
|
@@ -233,9 +233,11 @@ class GeminiModel(Model):
|
|
|
233
233
|
headers=headers,
|
|
234
234
|
timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
|
|
235
235
|
) as r:
|
|
236
|
-
if r.status_code != 200:
|
|
236
|
+
if (status_code := r.status_code) != 200:
|
|
237
237
|
await r.aread()
|
|
238
|
-
|
|
238
|
+
if status_code >= 400:
|
|
239
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text)
|
|
240
|
+
raise UnexpectedModelBehavior(f'Unexpected response from gemini {status_code}', r.text)
|
|
239
241
|
yield r
|
|
240
242
|
|
|
241
243
|
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
|
|
@@ -318,10 +320,14 @@ class GeminiModel(Model):
|
|
|
318
320
|
content.append({'text': item})
|
|
319
321
|
elif isinstance(item, BinaryContent):
|
|
320
322
|
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
321
|
-
content.append(
|
|
323
|
+
content.append(
|
|
324
|
+
_GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type})
|
|
325
|
+
)
|
|
322
326
|
elif isinstance(item, (AudioUrl, ImageUrl)):
|
|
323
327
|
try:
|
|
324
|
-
content.append(
|
|
328
|
+
content.append(
|
|
329
|
+
_GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
|
|
330
|
+
)
|
|
325
331
|
except ValueError:
|
|
326
332
|
# Download the file if can't find the mime type.
|
|
327
333
|
client = cached_async_http_client()
|
|
@@ -329,7 +335,9 @@ class GeminiModel(Model):
|
|
|
329
335
|
response.raise_for_status()
|
|
330
336
|
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
331
337
|
content.append(
|
|
332
|
-
_GeminiInlineDataPart(
|
|
338
|
+
_GeminiInlineDataPart(
|
|
339
|
+
inline_data={'data': base64_encoded, 'mime_type': response.headers['Content-Type']}
|
|
340
|
+
)
|
|
333
341
|
)
|
|
334
342
|
else:
|
|
335
343
|
assert_never(item)
|
|
@@ -526,20 +534,28 @@ class _GeminiTextPart(TypedDict):
|
|
|
526
534
|
text: str
|
|
527
535
|
|
|
528
536
|
|
|
537
|
+
class _GeminiInlineData(TypedDict):
|
|
538
|
+
data: str
|
|
539
|
+
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
|
|
540
|
+
|
|
541
|
+
|
|
529
542
|
class _GeminiInlineDataPart(TypedDict):
|
|
530
543
|
"""See <https://ai.google.dev/api/caching#Blob>."""
|
|
531
544
|
|
|
532
|
-
|
|
533
|
-
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
|
|
545
|
+
inline_data: Annotated[_GeminiInlineData, pydantic.Field(alias='inlineData')]
|
|
534
546
|
|
|
535
547
|
|
|
536
|
-
class
|
|
548
|
+
class _GeminiFileData(TypedDict):
|
|
537
549
|
"""See <https://ai.google.dev/api/caching#FileData>."""
|
|
538
550
|
|
|
539
551
|
file_uri: Annotated[str, pydantic.Field(alias='fileUri')]
|
|
540
552
|
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
|
|
541
553
|
|
|
542
554
|
|
|
555
|
+
class _GeminiFileDataPart(TypedDict):
|
|
556
|
+
file_data: Annotated[_GeminiFileData, pydantic.Field(alias='fileData')]
|
|
557
|
+
|
|
558
|
+
|
|
543
559
|
class _GeminiFunctionCallPart(TypedDict):
|
|
544
560
|
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
|
|
545
561
|
|
|
@@ -563,7 +579,7 @@ def _process_response_from_parts(
|
|
|
563
579
|
)
|
|
564
580
|
)
|
|
565
581
|
elif 'function_response' in part:
|
|
566
|
-
raise
|
|
582
|
+
raise UnexpectedModelBehavior(
|
|
567
583
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
568
584
|
)
|
|
569
585
|
return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
|
|
@@ -615,7 +631,7 @@ _GeminiPartUnion = Annotated[
|
|
|
615
631
|
Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
|
|
616
632
|
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
|
|
617
633
|
Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
|
|
618
|
-
Annotated[
|
|
634
|
+
Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
|
|
619
635
|
],
|
|
620
636
|
pydantic.Discriminator(_part_discriminator),
|
|
621
637
|
]
|
|
@@ -778,7 +794,7 @@ class _GeminiJsonSchema:
|
|
|
778
794
|
# noinspection PyTypeChecker
|
|
779
795
|
key = re.sub(r'^#/\$defs/', '', ref)
|
|
780
796
|
if key in refs_stack:
|
|
781
|
-
raise
|
|
797
|
+
raise UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
|
|
782
798
|
refs_stack += (key,)
|
|
783
799
|
schema_def = self.defs[key]
|
|
784
800
|
self._simplify(schema_def, refs_stack)
|
|
@@ -812,7 +828,7 @@ class _GeminiJsonSchema:
|
|
|
812
828
|
def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
|
|
813
829
|
ad_props = schema.pop('additionalProperties', None)
|
|
814
830
|
if ad_props:
|
|
815
|
-
raise
|
|
831
|
+
raise UserError('Additional properties in JSON Schema are not supported by Gemini')
|
|
816
832
|
|
|
817
833
|
if properties := schema.get('properties'): # pragma: no branch
|
|
818
834
|
for value in properties.values():
|
|
@@ -11,7 +11,7 @@ from typing import Literal, Union, cast, overload
|
|
|
11
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
14
|
-
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
16
16
|
from ..messages import (
|
|
17
17
|
BinaryContent,
|
|
@@ -39,7 +39,7 @@ from . import (
|
|
|
39
39
|
)
|
|
40
40
|
|
|
41
41
|
try:
|
|
42
|
-
from groq import NOT_GIVEN, AsyncGroq, AsyncStream
|
|
42
|
+
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
43
43
|
from groq.types import chat
|
|
44
44
|
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
45
45
|
except ImportError as _import_error:
|
|
@@ -197,23 +197,28 @@ class GroqModel(Model):
|
|
|
197
197
|
|
|
198
198
|
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
199
199
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
200
|
+
try:
|
|
201
|
+
return await self.client.chat.completions.create(
|
|
202
|
+
model=str(self._model_name),
|
|
203
|
+
messages=groq_messages,
|
|
204
|
+
n=1,
|
|
205
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
206
|
+
tools=tools or NOT_GIVEN,
|
|
207
|
+
tool_choice=tool_choice or NOT_GIVEN,
|
|
208
|
+
stream=stream,
|
|
209
|
+
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
210
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
211
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
212
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
213
|
+
seed=model_settings.get('seed', NOT_GIVEN),
|
|
214
|
+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
215
|
+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
216
|
+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
217
|
+
)
|
|
218
|
+
except APIStatusError as e:
|
|
219
|
+
if (status_code := e.status_code) >= 400:
|
|
220
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
221
|
+
raise
|
|
217
222
|
|
|
218
223
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
219
224
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -13,7 +13,7 @@ import pydantic_core
|
|
|
13
13
|
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
14
14
|
from typing_extensions import assert_never
|
|
15
15
|
|
|
16
|
-
from .. import UnexpectedModelBehavior, _utils
|
|
16
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
|
|
17
17
|
from .._utils import now_utc as _now_utc
|
|
18
18
|
from ..messages import (
|
|
19
19
|
BinaryContent,
|
|
@@ -59,6 +59,7 @@ try:
|
|
|
59
59
|
ChatCompletionResponse as MistralChatCompletionResponse,
|
|
60
60
|
CompletionEvent as MistralCompletionEvent,
|
|
61
61
|
Messages as MistralMessages,
|
|
62
|
+
SDKError,
|
|
62
63
|
Tool as MistralTool,
|
|
63
64
|
ToolCall as MistralToolCall,
|
|
64
65
|
)
|
|
@@ -184,19 +185,25 @@ class MistralModel(Model):
|
|
|
184
185
|
model_request_parameters: ModelRequestParameters,
|
|
185
186
|
) -> MistralChatCompletionResponse:
|
|
186
187
|
"""Make a non-streaming request to the model."""
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
188
|
+
try:
|
|
189
|
+
response = await self.client.chat.complete_async(
|
|
190
|
+
model=str(self._model_name),
|
|
191
|
+
messages=list(chain(*(self._map_message(m) for m in messages))),
|
|
192
|
+
n=1,
|
|
193
|
+
tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
|
|
194
|
+
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
195
|
+
stream=False,
|
|
196
|
+
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
197
|
+
temperature=model_settings.get('temperature', UNSET),
|
|
198
|
+
top_p=model_settings.get('top_p', 1),
|
|
199
|
+
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
200
|
+
random_seed=model_settings.get('seed', UNSET),
|
|
201
|
+
)
|
|
202
|
+
except SDKError as e:
|
|
203
|
+
if (status_code := e.status_code) >= 400:
|
|
204
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
205
|
+
raise
|
|
206
|
+
|
|
200
207
|
assert response, 'A unexpected empty response from Mistral.'
|
|
201
208
|
return response
|
|
202
209
|
|
|
@@ -11,7 +11,7 @@ from typing import Literal, Union, cast, overload
|
|
|
11
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
14
|
-
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
16
16
|
from ..messages import (
|
|
17
17
|
AudioUrl,
|
|
@@ -40,7 +40,7 @@ from . import (
|
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
try:
|
|
43
|
-
from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
|
|
43
|
+
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream
|
|
44
44
|
from openai.types import ChatModel, chat
|
|
45
45
|
from openai.types.chat import (
|
|
46
46
|
ChatCompletionChunk,
|
|
@@ -224,25 +224,30 @@ class OpenAIModel(Model):
|
|
|
224
224
|
async for msg in self._map_message(m):
|
|
225
225
|
openai_messages.append(msg)
|
|
226
226
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
227
|
+
try:
|
|
228
|
+
return await self.client.chat.completions.create(
|
|
229
|
+
model=self._model_name,
|
|
230
|
+
messages=openai_messages,
|
|
231
|
+
n=1,
|
|
232
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
233
|
+
tools=tools or NOT_GIVEN,
|
|
234
|
+
tool_choice=tool_choice or NOT_GIVEN,
|
|
235
|
+
stream=stream,
|
|
236
|
+
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
237
|
+
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
238
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
239
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
240
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
241
|
+
seed=model_settings.get('seed', NOT_GIVEN),
|
|
242
|
+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
243
|
+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
244
|
+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
245
|
+
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
246
|
+
)
|
|
247
|
+
except APIStatusError as e:
|
|
248
|
+
if (status_code := e.status_code) >= 400:
|
|
249
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
250
|
+
raise
|
|
246
251
|
|
|
247
252
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
248
253
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -4,11 +4,9 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pydantic-ai-slim"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.27"
|
|
8
8
|
description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
|
|
9
|
-
authors = [
|
|
10
|
-
{ name = "Samuel Colvin", email = "samuel@pydantic.dev" },
|
|
11
|
-
]
|
|
9
|
+
authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }]
|
|
12
10
|
license = "MIT"
|
|
13
11
|
readme = "README.md"
|
|
14
12
|
classifiers = [
|
|
@@ -39,7 +37,8 @@ dependencies = [
|
|
|
39
37
|
"httpx>=0.27",
|
|
40
38
|
"logfire-api>=1.2.0",
|
|
41
39
|
"pydantic>=2.10",
|
|
42
|
-
"pydantic-graph==0.0.
|
|
40
|
+
"pydantic-graph==0.0.27",
|
|
41
|
+
"exceptiongroup; python_version < '3.11'",
|
|
43
42
|
]
|
|
44
43
|
|
|
45
44
|
[project.optional-dependencies]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|