pydantic-ai-slim 0.0.25__tar.gz → 0.0.26__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/PKG-INFO +2 -2
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/__init__.py +11 -3
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/_agent_graph.py +15 -12
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/agent.py +10 -10
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/messages.py +90 -1
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/anthropic.py +37 -12
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/cohere.py +4 -1
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/function.py +12 -3
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/gemini.py +55 -3
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/groq.py +35 -8
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/mistral.py +29 -1
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/openai.py +57 -8
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/tools.py +2 -2
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pyproject.toml +2 -2
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/README.md +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/__init__.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/vertexai.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.25 → pydantic_ai_slim-0.0.26}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.26
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -28,7 +28,7 @@ Requires-Dist: eval-type-backport>=0.2.0
|
|
|
28
28
|
Requires-Dist: griffe>=1.3.2
|
|
29
29
|
Requires-Dist: httpx>=0.27
|
|
30
30
|
Requires-Dist: logfire-api>=1.2.0
|
|
31
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
31
|
+
Requires-Dist: pydantic-graph==0.0.26
|
|
32
32
|
Requires-Dist: pydantic>=2.10
|
|
33
33
|
Provides-Extra: anthropic
|
|
34
34
|
Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
|
|
@@ -2,22 +2,30 @@ from importlib.metadata import version
|
|
|
2
2
|
|
|
3
3
|
from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
|
|
4
4
|
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
|
|
5
|
+
from .messages import AudioUrl, BinaryContent, ImageUrl
|
|
5
6
|
from .tools import RunContext, Tool
|
|
6
7
|
|
|
7
8
|
__all__ = (
|
|
9
|
+
'__version__',
|
|
10
|
+
# agent
|
|
8
11
|
'Agent',
|
|
9
12
|
'EndStrategy',
|
|
10
13
|
'HandleResponseNode',
|
|
11
14
|
'ModelRequestNode',
|
|
12
15
|
'UserPromptNode',
|
|
13
16
|
'capture_run_messages',
|
|
14
|
-
|
|
15
|
-
'Tool',
|
|
17
|
+
# exceptions
|
|
16
18
|
'AgentRunError',
|
|
17
19
|
'ModelRetry',
|
|
18
20
|
'UnexpectedModelBehavior',
|
|
19
21
|
'UsageLimitExceeded',
|
|
20
22
|
'UserError',
|
|
21
|
-
|
|
23
|
+
# messages
|
|
24
|
+
'ImageUrl',
|
|
25
|
+
'AudioUrl',
|
|
26
|
+
'BinaryContent',
|
|
27
|
+
# tools
|
|
28
|
+
'Tool',
|
|
29
|
+
'RunContext',
|
|
22
30
|
)
|
|
23
31
|
__version__ = version('pydantic_ai_slim')
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
5
|
from abc import ABC
|
|
6
|
-
from collections.abc import AsyncIterator, Iterator
|
|
6
|
+
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from dataclasses import field
|
|
@@ -89,7 +89,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
89
89
|
|
|
90
90
|
user_deps: DepsT
|
|
91
91
|
|
|
92
|
-
prompt: str
|
|
92
|
+
prompt: str | Sequence[_messages.UserContent]
|
|
93
93
|
new_message_index: int
|
|
94
94
|
|
|
95
95
|
model: models.Model
|
|
@@ -108,20 +108,20 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
108
108
|
|
|
109
109
|
|
|
110
110
|
@dataclasses.dataclass
|
|
111
|
-
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
|
|
112
|
-
user_prompt: str
|
|
111
|
+
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
|
|
112
|
+
user_prompt: str | Sequence[_messages.UserContent]
|
|
113
113
|
|
|
114
114
|
system_prompts: tuple[str, ...]
|
|
115
115
|
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
116
116
|
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
|
|
117
117
|
|
|
118
118
|
async def run(
|
|
119
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT,
|
|
119
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
120
120
|
) -> ModelRequestNode[DepsT, NodeRunEndT]:
|
|
121
121
|
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
122
122
|
|
|
123
123
|
async def _get_first_message(
|
|
124
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT,
|
|
124
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
125
125
|
) -> _messages.ModelRequest:
|
|
126
126
|
run_context = build_run_context(ctx)
|
|
127
127
|
history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
|
|
@@ -135,7 +135,10 @@ class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeR
|
|
|
135
135
|
return next_message
|
|
136
136
|
|
|
137
137
|
async def _prepare_messages(
|
|
138
|
-
self,
|
|
138
|
+
self,
|
|
139
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
140
|
+
message_history: list[_messages.ModelMessage] | None,
|
|
141
|
+
run_context: RunContext[DepsT],
|
|
139
142
|
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
|
|
140
143
|
try:
|
|
141
144
|
ctx_messages = get_captured_run_messages()
|
|
@@ -212,7 +215,7 @@ async def _prepare_request_parameters(
|
|
|
212
215
|
|
|
213
216
|
|
|
214
217
|
@dataclasses.dataclass
|
|
215
|
-
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
|
|
218
|
+
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
216
219
|
"""Make a request to the model using the last message in state.message_history."""
|
|
217
220
|
|
|
218
221
|
request: _messages.ModelRequest
|
|
@@ -316,7 +319,7 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
|
|
|
316
319
|
|
|
317
320
|
|
|
318
321
|
@dataclasses.dataclass
|
|
319
|
-
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
|
|
322
|
+
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
320
323
|
"""Process a model response, and decide whether to end the run or make a new request."""
|
|
321
324
|
|
|
322
325
|
model_response: _messages.ModelResponse
|
|
@@ -338,7 +341,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
338
341
|
|
|
339
342
|
@asynccontextmanager
|
|
340
343
|
async def stream(
|
|
341
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT,
|
|
344
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
342
345
|
) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
|
|
343
346
|
"""Process the model response and yield events for the start and end of each function tool call."""
|
|
344
347
|
with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
|
|
@@ -363,7 +366,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
363
366
|
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
364
367
|
|
|
365
368
|
async def _run_stream(
|
|
366
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT,
|
|
369
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
367
370
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
368
371
|
if self._events_iterator is None:
|
|
369
372
|
# Ensure that the stream is only run once
|
|
@@ -667,7 +670,7 @@ def get_captured_run_messages() -> _RunMessages:
|
|
|
667
670
|
|
|
668
671
|
def build_agent_graph(
|
|
669
672
|
name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
|
|
670
|
-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT,
|
|
673
|
+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]:
|
|
671
674
|
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
|
|
672
675
|
nodes = (
|
|
673
676
|
UserPromptNode[DepsT],
|
|
@@ -220,7 +220,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
220
220
|
@overload
|
|
221
221
|
async def run(
|
|
222
222
|
self,
|
|
223
|
-
user_prompt: str,
|
|
223
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
224
224
|
*,
|
|
225
225
|
result_type: None = None,
|
|
226
226
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -235,7 +235,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
235
235
|
@overload
|
|
236
236
|
async def run(
|
|
237
237
|
self,
|
|
238
|
-
user_prompt: str,
|
|
238
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
239
239
|
*,
|
|
240
240
|
result_type: type[RunResultDataT],
|
|
241
241
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -249,7 +249,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
249
249
|
|
|
250
250
|
async def run(
|
|
251
251
|
self,
|
|
252
|
-
user_prompt: str,
|
|
252
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
253
253
|
*,
|
|
254
254
|
result_type: type[RunResultDataT] | None = None,
|
|
255
255
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -313,7 +313,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
313
313
|
@contextmanager
|
|
314
314
|
def iter(
|
|
315
315
|
self,
|
|
316
|
-
user_prompt: str,
|
|
316
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
317
317
|
*,
|
|
318
318
|
result_type: type[RunResultDataT] | None = None,
|
|
319
319
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -466,7 +466,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
466
466
|
@overload
|
|
467
467
|
def run_sync(
|
|
468
468
|
self,
|
|
469
|
-
user_prompt: str,
|
|
469
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
470
470
|
*,
|
|
471
471
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
472
472
|
model: models.Model | models.KnownModelName | None = None,
|
|
@@ -480,7 +480,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
480
480
|
@overload
|
|
481
481
|
def run_sync(
|
|
482
482
|
self,
|
|
483
|
-
user_prompt: str,
|
|
483
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
484
484
|
*,
|
|
485
485
|
result_type: type[RunResultDataT] | None,
|
|
486
486
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -494,7 +494,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
494
494
|
|
|
495
495
|
def run_sync(
|
|
496
496
|
self,
|
|
497
|
-
user_prompt: str,
|
|
497
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
498
498
|
*,
|
|
499
499
|
result_type: type[RunResultDataT] | None = None,
|
|
500
500
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -555,7 +555,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
555
555
|
@overload
|
|
556
556
|
def run_stream(
|
|
557
557
|
self,
|
|
558
|
-
user_prompt: str,
|
|
558
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
559
559
|
*,
|
|
560
560
|
result_type: None = None,
|
|
561
561
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -570,7 +570,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
570
570
|
@overload
|
|
571
571
|
def run_stream(
|
|
572
572
|
self,
|
|
573
|
-
user_prompt: str,
|
|
573
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
574
574
|
*,
|
|
575
575
|
result_type: type[RunResultDataT],
|
|
576
576
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -585,7 +585,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
585
585
|
@asynccontextmanager
|
|
586
586
|
async def run_stream( # noqa C901
|
|
587
587
|
self,
|
|
588
|
-
user_prompt: str,
|
|
588
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
589
589
|
*,
|
|
590
590
|
result_type: type[RunResultDataT] | None = None,
|
|
591
591
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import uuid
|
|
4
|
+
from collections.abc import Sequence
|
|
4
5
|
from dataclasses import dataclass, field, replace
|
|
5
6
|
from datetime import datetime
|
|
6
7
|
from typing import Annotated, Any, Literal, Union, cast, overload
|
|
7
8
|
|
|
8
9
|
import pydantic
|
|
9
10
|
import pydantic_core
|
|
11
|
+
from typing_extensions import TypeAlias
|
|
10
12
|
|
|
11
13
|
from ._utils import now_utc as _now_utc
|
|
12
14
|
from .exceptions import UnexpectedModelBehavior
|
|
@@ -32,6 +34,93 @@ class SystemPromptPart:
|
|
|
32
34
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
33
35
|
|
|
34
36
|
|
|
37
|
+
@dataclass
|
|
38
|
+
class AudioUrl:
|
|
39
|
+
"""A URL to an audio file."""
|
|
40
|
+
|
|
41
|
+
url: str
|
|
42
|
+
"""The URL of the audio file."""
|
|
43
|
+
|
|
44
|
+
kind: Literal['audio-url'] = 'audio-url'
|
|
45
|
+
"""Type identifier, this is available on all parts as a discriminator."""
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def media_type(self) -> AudioMediaType:
|
|
49
|
+
"""Return the media type of the audio file, based on the url."""
|
|
50
|
+
if self.url.endswith('.mp3'):
|
|
51
|
+
return 'audio/mpeg'
|
|
52
|
+
elif self.url.endswith('.wav'):
|
|
53
|
+
return 'audio/wav'
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f'Unknown audio file extension: {self.url}')
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class ImageUrl:
|
|
60
|
+
"""A URL to an image."""
|
|
61
|
+
|
|
62
|
+
url: str
|
|
63
|
+
"""The URL of the image."""
|
|
64
|
+
|
|
65
|
+
kind: Literal['image-url'] = 'image-url'
|
|
66
|
+
"""Type identifier, this is available on all parts as a discriminator."""
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def media_type(self) -> ImageMediaType:
|
|
70
|
+
"""Return the media type of the image, based on the url."""
|
|
71
|
+
if self.url.endswith(('.jpg', '.jpeg')):
|
|
72
|
+
return 'image/jpeg'
|
|
73
|
+
elif self.url.endswith('.png'):
|
|
74
|
+
return 'image/png'
|
|
75
|
+
elif self.url.endswith('.gif'):
|
|
76
|
+
return 'image/gif'
|
|
77
|
+
elif self.url.endswith('.webp'):
|
|
78
|
+
return 'image/webp'
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(f'Unknown image file extension: {self.url}')
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
|
|
84
|
+
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class BinaryContent:
|
|
89
|
+
"""Binary content, e.g. an audio or image file."""
|
|
90
|
+
|
|
91
|
+
data: bytes
|
|
92
|
+
"""The binary data."""
|
|
93
|
+
|
|
94
|
+
media_type: AudioMediaType | ImageMediaType | str
|
|
95
|
+
"""The media type of the binary data."""
|
|
96
|
+
|
|
97
|
+
kind: Literal['binary'] = 'binary'
|
|
98
|
+
"""Type identifier, this is available on all parts as a discriminator."""
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def is_audio(self) -> bool:
|
|
102
|
+
"""Return `True` if the media type is an audio type."""
|
|
103
|
+
return self.media_type.startswith('audio/')
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def is_image(self) -> bool:
|
|
107
|
+
"""Return `True` if the media type is an image type."""
|
|
108
|
+
return self.media_type.startswith('image/')
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def audio_format(self) -> Literal['mp3', 'wav']:
|
|
112
|
+
"""Return the audio format given the media type."""
|
|
113
|
+
if self.media_type == 'audio/mpeg':
|
|
114
|
+
return 'mp3'
|
|
115
|
+
elif self.media_type == 'audio/wav':
|
|
116
|
+
return 'wav'
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(f'Unknown audio media type: {self.media_type}')
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | BinaryContent'
|
|
122
|
+
|
|
123
|
+
|
|
35
124
|
@dataclass
|
|
36
125
|
class UserPromptPart:
|
|
37
126
|
"""A user prompt, generally written by the end user.
|
|
@@ -40,7 +129,7 @@ class UserPromptPart:
|
|
|
40
129
|
[`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream].
|
|
41
130
|
"""
|
|
42
131
|
|
|
43
|
-
content: str
|
|
132
|
+
content: str | Sequence[UserContent]
|
|
44
133
|
"""The content of the prompt."""
|
|
45
134
|
|
|
46
135
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import io
|
|
4
|
+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
7
|
from datetime import datetime, timezone
|
|
@@ -13,6 +14,8 @@ from typing_extensions import assert_never
|
|
|
13
14
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
16
|
from ..messages import (
|
|
17
|
+
BinaryContent,
|
|
18
|
+
ImageUrl,
|
|
16
19
|
ModelMessage,
|
|
17
20
|
ModelRequest,
|
|
18
21
|
ModelResponse,
|
|
@@ -38,6 +41,7 @@ from . import (
|
|
|
38
41
|
try:
|
|
39
42
|
from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
|
|
40
43
|
from anthropic.types import (
|
|
44
|
+
ImageBlockParam,
|
|
41
45
|
Message as AnthropicMessage,
|
|
42
46
|
MessageParam,
|
|
43
47
|
MetadataParam,
|
|
@@ -214,7 +218,7 @@ class AnthropicModel(Model):
|
|
|
214
218
|
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
|
|
215
219
|
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
|
|
216
220
|
|
|
217
|
-
system_prompt, anthropic_messages = self._map_message(messages)
|
|
221
|
+
system_prompt, anthropic_messages = await self._map_message(messages)
|
|
218
222
|
|
|
219
223
|
return await self.client.messages.create(
|
|
220
224
|
max_tokens=model_settings.get('max_tokens', 1024),
|
|
@@ -266,19 +270,19 @@ class AnthropicModel(Model):
|
|
|
266
270
|
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
267
271
|
return tools
|
|
268
272
|
|
|
269
|
-
def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
273
|
+
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
270
274
|
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
271
275
|
system_prompt: str = ''
|
|
272
276
|
anthropic_messages: list[MessageParam] = []
|
|
273
277
|
for m in messages:
|
|
274
278
|
if isinstance(m, ModelRequest):
|
|
275
|
-
user_content_params: list[ToolResultBlockParam | TextBlockParam] = []
|
|
279
|
+
user_content_params: list[ToolResultBlockParam | TextBlockParam | ImageBlockParam] = []
|
|
276
280
|
for request_part in m.parts:
|
|
277
281
|
if isinstance(request_part, SystemPromptPart):
|
|
278
282
|
system_prompt += request_part.content
|
|
279
283
|
elif isinstance(request_part, UserPromptPart):
|
|
280
|
-
|
|
281
|
-
|
|
284
|
+
async for content in self._map_user_prompt(request_part):
|
|
285
|
+
user_content_params.append(content)
|
|
282
286
|
elif isinstance(request_part, ToolReturnPart):
|
|
283
287
|
tool_result_block_param = ToolResultBlockParam(
|
|
284
288
|
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
|
|
@@ -298,12 +302,7 @@ class AnthropicModel(Model):
|
|
|
298
302
|
is_error=True,
|
|
299
303
|
)
|
|
300
304
|
user_content_params.append(retry_param)
|
|
301
|
-
anthropic_messages.append(
|
|
302
|
-
MessageParam(
|
|
303
|
-
role='user',
|
|
304
|
-
content=user_content_params,
|
|
305
|
-
)
|
|
306
|
-
)
|
|
305
|
+
anthropic_messages.append(MessageParam(role='user', content=user_content_params))
|
|
307
306
|
elif isinstance(m, ModelResponse):
|
|
308
307
|
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
|
|
309
308
|
for response_part in m.parts:
|
|
@@ -322,6 +321,32 @@ class AnthropicModel(Model):
|
|
|
322
321
|
assert_never(m)
|
|
323
322
|
return system_prompt, anthropic_messages
|
|
324
323
|
|
|
324
|
+
@staticmethod
|
|
325
|
+
async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockParam | TextBlockParam]:
|
|
326
|
+
if isinstance(part.content, str):
|
|
327
|
+
yield TextBlockParam(text=part.content, type='text')
|
|
328
|
+
else:
|
|
329
|
+
for item in part.content:
|
|
330
|
+
if isinstance(item, str):
|
|
331
|
+
yield TextBlockParam(text=item, type='text')
|
|
332
|
+
elif isinstance(item, BinaryContent):
|
|
333
|
+
if item.is_image:
|
|
334
|
+
yield ImageBlockParam(
|
|
335
|
+
source={'data': io.BytesIO(item.data), 'media_type': item.media_type, 'type': 'base64'}, # type: ignore
|
|
336
|
+
type='image',
|
|
337
|
+
)
|
|
338
|
+
else:
|
|
339
|
+
raise RuntimeError('Only images are supported for binary content')
|
|
340
|
+
elif isinstance(item, ImageUrl):
|
|
341
|
+
response = await cached_async_http_client().get(item.url)
|
|
342
|
+
response.raise_for_status()
|
|
343
|
+
yield ImageBlockParam(
|
|
344
|
+
source={'data': io.BytesIO(response.content), 'media_type': 'image/jpeg', 'type': 'base64'},
|
|
345
|
+
type='image',
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
349
|
+
|
|
325
350
|
@staticmethod
|
|
326
351
|
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
327
352
|
return {
|
|
@@ -242,7 +242,10 @@ class CohereModel(Model):
|
|
|
242
242
|
if isinstance(part, SystemPromptPart):
|
|
243
243
|
yield SystemChatMessageV2(role='system', content=part.content)
|
|
244
244
|
elif isinstance(part, UserPromptPart):
|
|
245
|
-
|
|
245
|
+
if isinstance(part.content, str):
|
|
246
|
+
yield UserChatMessageV2(role='user', content=part.content)
|
|
247
|
+
else:
|
|
248
|
+
raise RuntimeError('Cohere does not yet support multi-modal inputs.')
|
|
246
249
|
elif isinstance(part, ToolReturnPart):
|
|
247
250
|
yield ToolChatMessageV2(
|
|
248
251
|
role='tool',
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
@@ -14,6 +14,9 @@ from typing_extensions import TypeAlias, assert_never, overload
|
|
|
14
14
|
from .. import _utils, usage
|
|
15
15
|
from .._utils import PeekableAsyncStream
|
|
16
16
|
from ..messages import (
|
|
17
|
+
AudioUrl,
|
|
18
|
+
BinaryContent,
|
|
19
|
+
ImageUrl,
|
|
17
20
|
ModelMessage,
|
|
18
21
|
ModelRequest,
|
|
19
22
|
ModelResponse,
|
|
@@ -23,6 +26,7 @@ from ..messages import (
|
|
|
23
26
|
TextPart,
|
|
24
27
|
ToolCallPart,
|
|
25
28
|
ToolReturnPart,
|
|
29
|
+
UserContent,
|
|
26
30
|
UserPromptPart,
|
|
27
31
|
)
|
|
28
32
|
from ..settings import ModelSettings
|
|
@@ -262,7 +266,12 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
|
262
266
|
)
|
|
263
267
|
|
|
264
268
|
|
|
265
|
-
def _estimate_string_tokens(content: str) -> int:
|
|
269
|
+
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|
266
270
|
if not content:
|
|
267
271
|
return 0
|
|
268
|
-
|
|
272
|
+
if isinstance(content, str):
|
|
273
|
+
return len(re.split(r'[\s",.:]+', content.strip()))
|
|
274
|
+
# TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
|
|
275
|
+
else: # pragma: no cover
|
|
276
|
+
assert isinstance(content, (AudioUrl, ImageUrl, BinaryContent))
|
|
277
|
+
return 0
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import os
|
|
4
5
|
import re
|
|
5
6
|
from collections.abc import AsyncIterator, Sequence
|
|
@@ -16,6 +17,9 @@ from typing_extensions import NotRequired, TypedDict, assert_never
|
|
|
16
17
|
|
|
17
18
|
from .. import UnexpectedModelBehavior, _utils, exceptions, usage
|
|
18
19
|
from ..messages import (
|
|
20
|
+
AudioUrl,
|
|
21
|
+
BinaryContent,
|
|
22
|
+
ImageUrl,
|
|
19
23
|
ModelMessage,
|
|
20
24
|
ModelRequest,
|
|
21
25
|
ModelResponse,
|
|
@@ -185,7 +189,7 @@ class GeminiModel(Model):
|
|
|
185
189
|
) -> AsyncIterator[HTTPResponse]:
|
|
186
190
|
tools = self._get_tools(model_request_parameters)
|
|
187
191
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
188
|
-
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
192
|
+
sys_prompt_parts, contents = await self._message_to_gemini_content(messages)
|
|
189
193
|
|
|
190
194
|
request_data = _GeminiRequest(contents=contents)
|
|
191
195
|
if sys_prompt_parts:
|
|
@@ -269,7 +273,7 @@ class GeminiModel(Model):
|
|
|
269
273
|
return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
|
|
270
274
|
|
|
271
275
|
@classmethod
|
|
272
|
-
def _message_to_gemini_content(
|
|
276
|
+
async def _message_to_gemini_content(
|
|
273
277
|
cls, messages: list[ModelMessage]
|
|
274
278
|
) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
|
|
275
279
|
sys_prompt_parts: list[_GeminiTextPart] = []
|
|
@@ -282,7 +286,7 @@ class GeminiModel(Model):
|
|
|
282
286
|
if isinstance(part, SystemPromptPart):
|
|
283
287
|
sys_prompt_parts.append(_GeminiTextPart(text=part.content))
|
|
284
288
|
elif isinstance(part, UserPromptPart):
|
|
285
|
-
message_parts.
|
|
289
|
+
message_parts.extend(await cls._map_user_prompt(part))
|
|
286
290
|
elif isinstance(part, ToolReturnPart):
|
|
287
291
|
message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
|
|
288
292
|
elif isinstance(part, RetryPromptPart):
|
|
@@ -303,6 +307,34 @@ class GeminiModel(Model):
|
|
|
303
307
|
|
|
304
308
|
return sys_prompt_parts, contents
|
|
305
309
|
|
|
310
|
+
@staticmethod
|
|
311
|
+
async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
|
|
312
|
+
if isinstance(part.content, str):
|
|
313
|
+
return [{'text': part.content}]
|
|
314
|
+
else:
|
|
315
|
+
content: list[_GeminiPartUnion] = []
|
|
316
|
+
for item in part.content:
|
|
317
|
+
if isinstance(item, str):
|
|
318
|
+
content.append({'text': item})
|
|
319
|
+
elif isinstance(item, BinaryContent):
|
|
320
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
321
|
+
content.append(_GeminiInlineDataPart(data=base64_encoded, mime_type=item.media_type))
|
|
322
|
+
elif isinstance(item, (AudioUrl, ImageUrl)):
|
|
323
|
+
try:
|
|
324
|
+
content.append(_GeminiFileDataData(file_uri=item.url, mime_type=item.media_type))
|
|
325
|
+
except ValueError:
|
|
326
|
+
# Download the file if can't find the mime type.
|
|
327
|
+
client = cached_async_http_client()
|
|
328
|
+
response = await client.get(item.url, follow_redirects=True)
|
|
329
|
+
response.raise_for_status()
|
|
330
|
+
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
331
|
+
content.append(
|
|
332
|
+
_GeminiInlineDataPart(data=base64_encoded, mime_type=response.headers['Content-Type'])
|
|
333
|
+
)
|
|
334
|
+
else:
|
|
335
|
+
assert_never(item)
|
|
336
|
+
return content
|
|
337
|
+
|
|
306
338
|
|
|
307
339
|
class AuthProtocol(Protocol):
|
|
308
340
|
"""Abstract definition for Gemini authentication."""
|
|
@@ -494,6 +526,20 @@ class _GeminiTextPart(TypedDict):
|
|
|
494
526
|
text: str
|
|
495
527
|
|
|
496
528
|
|
|
529
|
+
class _GeminiInlineDataPart(TypedDict):
|
|
530
|
+
"""See <https://ai.google.dev/api/caching#Blob>."""
|
|
531
|
+
|
|
532
|
+
data: str
|
|
533
|
+
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
class _GeminiFileDataData(TypedDict):
|
|
537
|
+
"""See <https://ai.google.dev/api/caching#FileData>."""
|
|
538
|
+
|
|
539
|
+
file_uri: Annotated[str, pydantic.Field(alias='fileUri')]
|
|
540
|
+
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
|
|
541
|
+
|
|
542
|
+
|
|
497
543
|
class _GeminiFunctionCallPart(TypedDict):
|
|
498
544
|
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
|
|
499
545
|
|
|
@@ -549,6 +595,10 @@ def _part_discriminator(v: Any) -> str:
|
|
|
549
595
|
if isinstance(v, dict):
|
|
550
596
|
if 'text' in v:
|
|
551
597
|
return 'text'
|
|
598
|
+
elif 'inlineData' in v:
|
|
599
|
+
return 'inline_data'
|
|
600
|
+
elif 'fileData' in v:
|
|
601
|
+
return 'file_data'
|
|
552
602
|
elif 'functionCall' in v or 'function_call' in v:
|
|
553
603
|
return 'function_call'
|
|
554
604
|
elif 'functionResponse' in v or 'function_response' in v:
|
|
@@ -564,6 +614,8 @@ _GeminiPartUnion = Annotated[
|
|
|
564
614
|
Annotated[_GeminiTextPart, pydantic.Tag('text')],
|
|
565
615
|
Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
|
|
566
616
|
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
|
|
617
|
+
Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
|
|
618
|
+
Annotated[_GeminiFileDataData, pydantic.Tag('file_data')],
|
|
567
619
|
],
|
|
568
620
|
pydantic.Discriminator(_part_discriminator),
|
|
569
621
|
]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
@@ -13,6 +14,8 @@ from typing_extensions import assert_never
|
|
|
13
14
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
16
|
from ..messages import (
|
|
17
|
+
BinaryContent,
|
|
18
|
+
ImageUrl,
|
|
16
19
|
ModelMessage,
|
|
17
20
|
ModelRequest,
|
|
18
21
|
ModelResponse,
|
|
@@ -38,7 +41,7 @@ from . import (
|
|
|
38
41
|
try:
|
|
39
42
|
from groq import NOT_GIVEN, AsyncGroq, AsyncStream
|
|
40
43
|
from groq.types import chat
|
|
41
|
-
from groq.types.chat import
|
|
44
|
+
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
42
45
|
except ImportError as _import_error:
|
|
43
46
|
raise ImportError(
|
|
44
47
|
'Please install `groq` to use the Groq model, '
|
|
@@ -163,7 +166,7 @@ class GroqModel(Model):
|
|
|
163
166
|
stream: Literal[True],
|
|
164
167
|
model_settings: GroqModelSettings,
|
|
165
168
|
model_request_parameters: ModelRequestParameters,
|
|
166
|
-
) -> AsyncStream[ChatCompletionChunk]:
|
|
169
|
+
) -> AsyncStream[chat.ChatCompletionChunk]:
|
|
167
170
|
pass
|
|
168
171
|
|
|
169
172
|
@overload
|
|
@@ -182,7 +185,7 @@ class GroqModel(Model):
|
|
|
182
185
|
stream: bool,
|
|
183
186
|
model_settings: GroqModelSettings,
|
|
184
187
|
model_request_parameters: ModelRequestParameters,
|
|
185
|
-
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
188
|
+
) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]:
|
|
186
189
|
tools = self._get_tools(model_request_parameters)
|
|
187
190
|
# standalone function to make it easier to override
|
|
188
191
|
if not tools:
|
|
@@ -224,7 +227,7 @@ class GroqModel(Model):
|
|
|
224
227
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
225
228
|
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
226
229
|
|
|
227
|
-
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
230
|
+
async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
228
231
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
229
232
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
230
233
|
first_chunk = await peekable_response.peek()
|
|
@@ -293,7 +296,7 @@ class GroqModel(Model):
|
|
|
293
296
|
if isinstance(part, SystemPromptPart):
|
|
294
297
|
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
295
298
|
elif isinstance(part, UserPromptPart):
|
|
296
|
-
yield
|
|
299
|
+
yield cls._map_user_prompt(part)
|
|
297
300
|
elif isinstance(part, ToolReturnPart):
|
|
298
301
|
yield chat.ChatCompletionToolMessageParam(
|
|
299
302
|
role='tool',
|
|
@@ -310,13 +313,37 @@ class GroqModel(Model):
|
|
|
310
313
|
content=part.model_response(),
|
|
311
314
|
)
|
|
312
315
|
|
|
316
|
+
@staticmethod
|
|
317
|
+
def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
|
|
318
|
+
content: str | list[chat.ChatCompletionContentPartParam]
|
|
319
|
+
if isinstance(part.content, str):
|
|
320
|
+
content = part.content
|
|
321
|
+
else:
|
|
322
|
+
content = []
|
|
323
|
+
for item in part.content:
|
|
324
|
+
if isinstance(item, str):
|
|
325
|
+
content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text'))
|
|
326
|
+
elif isinstance(item, ImageUrl):
|
|
327
|
+
image_url = ImageURL(url=item.url)
|
|
328
|
+
content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
|
|
329
|
+
elif isinstance(item, BinaryContent):
|
|
330
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
331
|
+
if item.is_image:
|
|
332
|
+
image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
|
|
333
|
+
content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
|
|
334
|
+
else:
|
|
335
|
+
raise RuntimeError('Only images are supported for binary content in Groq.')
|
|
336
|
+
else: # pragma: no cover
|
|
337
|
+
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
338
|
+
return chat.ChatCompletionUserMessageParam(role='user', content=content)
|
|
339
|
+
|
|
313
340
|
|
|
314
341
|
@dataclass
|
|
315
342
|
class GroqStreamedResponse(StreamedResponse):
|
|
316
343
|
"""Implementation of `StreamedResponse` for Groq models."""
|
|
317
344
|
|
|
318
345
|
_model_name: GroqModelName
|
|
319
|
-
_response: AsyncIterable[ChatCompletionChunk]
|
|
346
|
+
_response: AsyncIterable[chat.ChatCompletionChunk]
|
|
320
347
|
_timestamp: datetime
|
|
321
348
|
|
|
322
349
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
@@ -355,9 +382,9 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
355
382
|
return self._timestamp
|
|
356
383
|
|
|
357
384
|
|
|
358
|
-
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
|
|
385
|
+
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage:
|
|
359
386
|
response_usage = None
|
|
360
|
-
if isinstance(completion, ChatCompletion):
|
|
387
|
+
if isinstance(completion, chat.ChatCompletion):
|
|
361
388
|
response_usage = completion.usage
|
|
362
389
|
elif completion.x_groq is not None:
|
|
363
390
|
response_usage = completion.x_groq.usage
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import os
|
|
4
5
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
5
6
|
from contextlib import asynccontextmanager
|
|
@@ -15,6 +16,8 @@ from typing_extensions import assert_never
|
|
|
15
16
|
from .. import UnexpectedModelBehavior, _utils
|
|
16
17
|
from .._utils import now_utc as _now_utc
|
|
17
18
|
from ..messages import (
|
|
19
|
+
BinaryContent,
|
|
20
|
+
ImageUrl,
|
|
18
21
|
ModelMessage,
|
|
19
22
|
ModelRequest,
|
|
20
23
|
ModelResponse,
|
|
@@ -45,6 +48,8 @@ try:
|
|
|
45
48
|
Content as MistralContent,
|
|
46
49
|
ContentChunk as MistralContentChunk,
|
|
47
50
|
FunctionCall as MistralFunctionCall,
|
|
51
|
+
ImageURL as MistralImageURL,
|
|
52
|
+
ImageURLChunk as MistralImageURLChunk,
|
|
48
53
|
Mistral,
|
|
49
54
|
OptionalNullable as MistralOptionalNullable,
|
|
50
55
|
TextChunk as MistralTextChunk,
|
|
@@ -423,7 +428,7 @@ class MistralModel(Model):
|
|
|
423
428
|
if isinstance(part, SystemPromptPart):
|
|
424
429
|
yield MistralSystemMessage(content=part.content)
|
|
425
430
|
elif isinstance(part, UserPromptPart):
|
|
426
|
-
yield
|
|
431
|
+
yield cls._map_user_prompt(part)
|
|
427
432
|
elif isinstance(part, ToolReturnPart):
|
|
428
433
|
yield MistralToolMessage(
|
|
429
434
|
tool_call_id=part.tool_call_id,
|
|
@@ -460,6 +465,29 @@ class MistralModel(Model):
|
|
|
460
465
|
else:
|
|
461
466
|
assert_never(message)
|
|
462
467
|
|
|
468
|
+
@staticmethod
|
|
469
|
+
def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage:
|
|
470
|
+
content: str | list[MistralContentChunk]
|
|
471
|
+
if isinstance(part.content, str):
|
|
472
|
+
content = part.content
|
|
473
|
+
else:
|
|
474
|
+
content = []
|
|
475
|
+
for item in part.content:
|
|
476
|
+
if isinstance(item, str):
|
|
477
|
+
content.append(MistralTextChunk(text=item))
|
|
478
|
+
elif isinstance(item, ImageUrl):
|
|
479
|
+
content.append(MistralImageURLChunk(image_url=MistralImageURL(url=item.url)))
|
|
480
|
+
elif isinstance(item, BinaryContent):
|
|
481
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
482
|
+
if item.is_image:
|
|
483
|
+
image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
|
|
484
|
+
content.append(MistralImageURLChunk(image_url=image_url, type='image_url'))
|
|
485
|
+
else:
|
|
486
|
+
raise RuntimeError('Only image binary content is supported for Mistral.')
|
|
487
|
+
else: # pragma: no cover
|
|
488
|
+
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
489
|
+
return MistralUserMessage(content=content)
|
|
490
|
+
|
|
463
491
|
|
|
464
492
|
MistralToolCallId = Union[str, None]
|
|
465
493
|
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import os
|
|
4
|
-
from collections.abc import AsyncIterable, AsyncIterator
|
|
5
|
+
from collections.abc import AsyncIterable, AsyncIterator
|
|
5
6
|
from contextlib import asynccontextmanager
|
|
6
7
|
from dataclasses import dataclass, field
|
|
7
8
|
from datetime import datetime, timezone
|
|
8
|
-
from itertools import chain
|
|
9
9
|
from typing import Literal, Union, cast, overload
|
|
10
10
|
|
|
11
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
@@ -14,6 +14,9 @@ from typing_extensions import assert_never
|
|
|
14
14
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
15
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
16
16
|
from ..messages import (
|
|
17
|
+
AudioUrl,
|
|
18
|
+
BinaryContent,
|
|
19
|
+
ImageUrl,
|
|
17
20
|
ModelMessage,
|
|
18
21
|
ModelRequest,
|
|
19
22
|
ModelResponse,
|
|
@@ -39,7 +42,15 @@ from . import (
|
|
|
39
42
|
try:
|
|
40
43
|
from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
|
|
41
44
|
from openai.types import ChatModel, chat
|
|
42
|
-
from openai.types.chat import
|
|
45
|
+
from openai.types.chat import (
|
|
46
|
+
ChatCompletionChunk,
|
|
47
|
+
ChatCompletionContentPartImageParam,
|
|
48
|
+
ChatCompletionContentPartInputAudioParam,
|
|
49
|
+
ChatCompletionContentPartParam,
|
|
50
|
+
ChatCompletionContentPartTextParam,
|
|
51
|
+
)
|
|
52
|
+
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
53
|
+
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
|
|
43
54
|
except ImportError as _import_error:
|
|
44
55
|
raise ImportError(
|
|
45
56
|
'Please install `openai` to use the OpenAI model, '
|
|
@@ -208,7 +219,10 @@ class OpenAIModel(Model):
|
|
|
208
219
|
else:
|
|
209
220
|
tool_choice = 'auto'
|
|
210
221
|
|
|
211
|
-
openai_messages
|
|
222
|
+
openai_messages: list[chat.ChatCompletionMessageParam] = []
|
|
223
|
+
for m in messages:
|
|
224
|
+
async for msg in self._map_message(m):
|
|
225
|
+
openai_messages.append(msg)
|
|
212
226
|
|
|
213
227
|
return await self.client.chat.completions.create(
|
|
214
228
|
model=self._model_name,
|
|
@@ -261,10 +275,11 @@ class OpenAIModel(Model):
|
|
|
261
275
|
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
262
276
|
return tools
|
|
263
277
|
|
|
264
|
-
def _map_message(self, message: ModelMessage) ->
|
|
278
|
+
async def _map_message(self, message: ModelMessage) -> AsyncIterable[chat.ChatCompletionMessageParam]:
|
|
265
279
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
266
280
|
if isinstance(message, ModelRequest):
|
|
267
|
-
|
|
281
|
+
async for item in self._map_user_message(message):
|
|
282
|
+
yield item
|
|
268
283
|
elif isinstance(message, ModelResponse):
|
|
269
284
|
texts: list[str] = []
|
|
270
285
|
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
@@ -305,7 +320,7 @@ class OpenAIModel(Model):
|
|
|
305
320
|
},
|
|
306
321
|
}
|
|
307
322
|
|
|
308
|
-
def _map_user_message(self, message: ModelRequest) ->
|
|
323
|
+
async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
|
|
309
324
|
for part in message.parts:
|
|
310
325
|
if isinstance(part, SystemPromptPart):
|
|
311
326
|
if self.system_prompt_role == 'developer':
|
|
@@ -315,7 +330,7 @@ class OpenAIModel(Model):
|
|
|
315
330
|
else:
|
|
316
331
|
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
317
332
|
elif isinstance(part, UserPromptPart):
|
|
318
|
-
yield
|
|
333
|
+
yield await self._map_user_prompt(part)
|
|
319
334
|
elif isinstance(part, ToolReturnPart):
|
|
320
335
|
yield chat.ChatCompletionToolMessageParam(
|
|
321
336
|
role='tool',
|
|
@@ -334,6 +349,40 @@ class OpenAIModel(Model):
|
|
|
334
349
|
else:
|
|
335
350
|
assert_never(part)
|
|
336
351
|
|
|
352
|
+
@staticmethod
|
|
353
|
+
async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
|
|
354
|
+
content: str | list[ChatCompletionContentPartParam]
|
|
355
|
+
if isinstance(part.content, str):
|
|
356
|
+
content = part.content
|
|
357
|
+
else:
|
|
358
|
+
content = []
|
|
359
|
+
for item in part.content:
|
|
360
|
+
if isinstance(item, str):
|
|
361
|
+
content.append(ChatCompletionContentPartTextParam(text=item, type='text'))
|
|
362
|
+
elif isinstance(item, ImageUrl):
|
|
363
|
+
image_url = ImageURL(url=item.url)
|
|
364
|
+
content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
|
|
365
|
+
elif isinstance(item, BinaryContent):
|
|
366
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
367
|
+
if item.is_image:
|
|
368
|
+
image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
|
|
369
|
+
content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
|
|
370
|
+
elif item.is_audio:
|
|
371
|
+
audio = InputAudio(data=base64_encoded, format=item.audio_format)
|
|
372
|
+
content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
|
|
373
|
+
else: # pragma: no cover
|
|
374
|
+
raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
|
|
375
|
+
elif isinstance(item, AudioUrl): # pragma: no cover
|
|
376
|
+
client = cached_async_http_client()
|
|
377
|
+
response = await client.get(item.url)
|
|
378
|
+
response.raise_for_status()
|
|
379
|
+
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
380
|
+
audio = InputAudio(data=base64_encoded, format=response.headers.get('content-type'))
|
|
381
|
+
content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
|
|
382
|
+
else:
|
|
383
|
+
assert_never(item)
|
|
384
|
+
return chat.ChatCompletionUserMessageParam(role='user', content=content)
|
|
385
|
+
|
|
337
386
|
|
|
338
387
|
@dataclass
|
|
339
388
|
class OpenAIStreamedResponse(StreamedResponse):
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import inspect
|
|
5
|
-
from collections.abc import Awaitable
|
|
5
|
+
from collections.abc import Awaitable, Sequence
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
|
|
8
8
|
|
|
@@ -45,7 +45,7 @@ class RunContext(Generic[AgentDepsT]):
|
|
|
45
45
|
"""The model used in this run."""
|
|
46
46
|
usage: Usage
|
|
47
47
|
"""LLM usage associated with the run."""
|
|
48
|
-
prompt: str
|
|
48
|
+
prompt: str | Sequence[_messages.UserContent]
|
|
49
49
|
"""The original user prompt passed to the run."""
|
|
50
50
|
messages: list[_messages.ModelMessage] = field(default_factory=list)
|
|
51
51
|
"""Messages exchanged in the conversation so far."""
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pydantic-ai-slim"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.26"
|
|
8
8
|
description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
|
|
9
9
|
authors = [
|
|
10
10
|
{ name = "Samuel Colvin", email = "samuel@pydantic.dev" },
|
|
@@ -39,7 +39,7 @@ dependencies = [
|
|
|
39
39
|
"httpx>=0.27",
|
|
40
40
|
"logfire-api>=1.2.0",
|
|
41
41
|
"pydantic>=2.10",
|
|
42
|
-
"pydantic-graph==0.0.
|
|
42
|
+
"pydantic-graph==0.0.26",
|
|
43
43
|
]
|
|
44
44
|
|
|
45
45
|
[project.optional-dependencies]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|