pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +16 -4
- pydantic_ai/_agent_graph.py +264 -351
- pydantic_ai/_utils.py +1 -1
- pydantic_ai/agent.py +581 -156
- pydantic_ai/messages.py +121 -1
- pydantic_ai/models/__init__.py +12 -1
- pydantic_ai/models/anthropic.py +67 -50
- pydantic_ai/models/cohere.py +5 -2
- pydantic_ai/models/function.py +15 -6
- pydantic_ai/models/gemini.py +73 -5
- pydantic_ai/models/groq.py +35 -8
- pydantic_ai/models/instrumented.py +225 -0
- pydantic_ai/models/mistral.py +29 -4
- pydantic_ai/models/openai.py +59 -13
- pydantic_ai/models/test.py +6 -6
- pydantic_ai/models/wrapper.py +45 -0
- pydantic_ai/result.py +106 -144
- pydantic_ai/tools.py +2 -2
- {pydantic_ai_slim-0.0.24.dist-info → pydantic_ai_slim-0.0.26.dist-info}/METADATA +2 -2
- pydantic_ai_slim-0.0.26.dist-info/RECORD +32 -0
- pydantic_ai_slim-0.0.24.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.24.dist-info → pydantic_ai_slim-0.0.26.dist-info}/WHEEL +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
@@ -13,6 +14,8 @@ from typing_extensions import assert_never
|
|
|
13
14
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
16
|
from ..messages import (
|
|
17
|
+
BinaryContent,
|
|
18
|
+
ImageUrl,
|
|
16
19
|
ModelMessage,
|
|
17
20
|
ModelRequest,
|
|
18
21
|
ModelResponse,
|
|
@@ -38,7 +41,7 @@ from . import (
|
|
|
38
41
|
try:
|
|
39
42
|
from groq import NOT_GIVEN, AsyncGroq, AsyncStream
|
|
40
43
|
from groq.types import chat
|
|
41
|
-
from groq.types.chat import
|
|
44
|
+
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
42
45
|
except ImportError as _import_error:
|
|
43
46
|
raise ImportError(
|
|
44
47
|
'Please install `groq` to use the Groq model, '
|
|
@@ -163,7 +166,7 @@ class GroqModel(Model):
|
|
|
163
166
|
stream: Literal[True],
|
|
164
167
|
model_settings: GroqModelSettings,
|
|
165
168
|
model_request_parameters: ModelRequestParameters,
|
|
166
|
-
) -> AsyncStream[ChatCompletionChunk]:
|
|
169
|
+
) -> AsyncStream[chat.ChatCompletionChunk]:
|
|
167
170
|
pass
|
|
168
171
|
|
|
169
172
|
@overload
|
|
@@ -182,7 +185,7 @@ class GroqModel(Model):
|
|
|
182
185
|
stream: bool,
|
|
183
186
|
model_settings: GroqModelSettings,
|
|
184
187
|
model_request_parameters: ModelRequestParameters,
|
|
185
|
-
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
188
|
+
) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]:
|
|
186
189
|
tools = self._get_tools(model_request_parameters)
|
|
187
190
|
# standalone function to make it easier to override
|
|
188
191
|
if not tools:
|
|
@@ -224,7 +227,7 @@ class GroqModel(Model):
|
|
|
224
227
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
225
228
|
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
226
229
|
|
|
227
|
-
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
230
|
+
async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
228
231
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
229
232
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
230
233
|
first_chunk = await peekable_response.peek()
|
|
@@ -293,7 +296,7 @@ class GroqModel(Model):
|
|
|
293
296
|
if isinstance(part, SystemPromptPart):
|
|
294
297
|
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
295
298
|
elif isinstance(part, UserPromptPart):
|
|
296
|
-
yield
|
|
299
|
+
yield cls._map_user_prompt(part)
|
|
297
300
|
elif isinstance(part, ToolReturnPart):
|
|
298
301
|
yield chat.ChatCompletionToolMessageParam(
|
|
299
302
|
role='tool',
|
|
@@ -310,13 +313,37 @@ class GroqModel(Model):
|
|
|
310
313
|
content=part.model_response(),
|
|
311
314
|
)
|
|
312
315
|
|
|
316
|
+
@staticmethod
|
|
317
|
+
def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
|
|
318
|
+
content: str | list[chat.ChatCompletionContentPartParam]
|
|
319
|
+
if isinstance(part.content, str):
|
|
320
|
+
content = part.content
|
|
321
|
+
else:
|
|
322
|
+
content = []
|
|
323
|
+
for item in part.content:
|
|
324
|
+
if isinstance(item, str):
|
|
325
|
+
content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text'))
|
|
326
|
+
elif isinstance(item, ImageUrl):
|
|
327
|
+
image_url = ImageURL(url=item.url)
|
|
328
|
+
content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
|
|
329
|
+
elif isinstance(item, BinaryContent):
|
|
330
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
331
|
+
if item.is_image:
|
|
332
|
+
image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
|
|
333
|
+
content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
|
|
334
|
+
else:
|
|
335
|
+
raise RuntimeError('Only images are supported for binary content in Groq.')
|
|
336
|
+
else: # pragma: no cover
|
|
337
|
+
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
338
|
+
return chat.ChatCompletionUserMessageParam(role='user', content=content)
|
|
339
|
+
|
|
313
340
|
|
|
314
341
|
@dataclass
|
|
315
342
|
class GroqStreamedResponse(StreamedResponse):
|
|
316
343
|
"""Implementation of `StreamedResponse` for Groq models."""
|
|
317
344
|
|
|
318
345
|
_model_name: GroqModelName
|
|
319
|
-
_response: AsyncIterable[ChatCompletionChunk]
|
|
346
|
+
_response: AsyncIterable[chat.ChatCompletionChunk]
|
|
320
347
|
_timestamp: datetime
|
|
321
348
|
|
|
322
349
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
@@ -355,9 +382,9 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
355
382
|
return self._timestamp
|
|
356
383
|
|
|
357
384
|
|
|
358
|
-
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
|
|
385
|
+
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage:
|
|
359
386
|
response_usage = None
|
|
360
|
-
if isinstance(completion, ChatCompletion):
|
|
387
|
+
if isinstance(completion, chat.ChatCompletion):
|
|
361
388
|
response_usage = completion.usage
|
|
362
389
|
elif completion.x_groq is not None:
|
|
363
390
|
response_usage = completion.x_groq.usage
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator, Iterator
|
|
4
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import Any, Callable, Literal
|
|
8
|
+
|
|
9
|
+
import logfire_api
|
|
10
|
+
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
|
|
11
|
+
from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
|
|
12
|
+
|
|
13
|
+
from ..messages import (
|
|
14
|
+
ModelMessage,
|
|
15
|
+
ModelRequest,
|
|
16
|
+
ModelRequestPart,
|
|
17
|
+
ModelResponse,
|
|
18
|
+
RetryPromptPart,
|
|
19
|
+
SystemPromptPart,
|
|
20
|
+
TextPart,
|
|
21
|
+
ToolCallPart,
|
|
22
|
+
ToolReturnPart,
|
|
23
|
+
UserPromptPart,
|
|
24
|
+
)
|
|
25
|
+
from ..settings import ModelSettings
|
|
26
|
+
from ..usage import Usage
|
|
27
|
+
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
|
|
28
|
+
from .wrapper import WrapperModel
|
|
29
|
+
|
|
30
|
+
MODEL_SETTING_ATTRIBUTES: tuple[
|
|
31
|
+
Literal[
|
|
32
|
+
'max_tokens',
|
|
33
|
+
'top_p',
|
|
34
|
+
'seed',
|
|
35
|
+
'temperature',
|
|
36
|
+
'presence_penalty',
|
|
37
|
+
'frequency_penalty',
|
|
38
|
+
],
|
|
39
|
+
...,
|
|
40
|
+
] = (
|
|
41
|
+
'max_tokens',
|
|
42
|
+
'top_p',
|
|
43
|
+
'seed',
|
|
44
|
+
'temperature',
|
|
45
|
+
'presence_penalty',
|
|
46
|
+
'frequency_penalty',
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
NOT_GIVEN = object()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class InstrumentedModel(WrapperModel):
|
|
54
|
+
"""Model which is instrumented with logfire."""
|
|
55
|
+
|
|
56
|
+
tracer: Tracer = field(repr=False)
|
|
57
|
+
event_logger: EventLogger = field(repr=False)
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
wrapped: Model | KnownModelName,
|
|
62
|
+
tracer_provider: TracerProvider | None = None,
|
|
63
|
+
event_logger_provider: EventLoggerProvider | None = None,
|
|
64
|
+
):
|
|
65
|
+
super().__init__(wrapped)
|
|
66
|
+
tracer_provider = tracer_provider or get_tracer_provider()
|
|
67
|
+
event_logger_provider = event_logger_provider or get_event_logger_provider()
|
|
68
|
+
self.tracer = tracer_provider.get_tracer('pydantic-ai')
|
|
69
|
+
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_logfire(
|
|
73
|
+
cls,
|
|
74
|
+
wrapped: Model | KnownModelName,
|
|
75
|
+
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
|
|
76
|
+
) -> InstrumentedModel:
|
|
77
|
+
if hasattr(logfire_instance.config, 'get_event_logger_provider'):
|
|
78
|
+
event_provider = logfire_instance.config.get_event_logger_provider()
|
|
79
|
+
else:
|
|
80
|
+
event_provider = None
|
|
81
|
+
tracer_provider = logfire_instance.config.get_tracer_provider()
|
|
82
|
+
return cls(wrapped, tracer_provider, event_provider)
|
|
83
|
+
|
|
84
|
+
async def request(
|
|
85
|
+
self,
|
|
86
|
+
messages: list[ModelMessage],
|
|
87
|
+
model_settings: ModelSettings | None,
|
|
88
|
+
model_request_parameters: ModelRequestParameters,
|
|
89
|
+
) -> tuple[ModelResponse, Usage]:
|
|
90
|
+
with self._instrument(messages, model_settings) as finish:
|
|
91
|
+
response, usage = await super().request(messages, model_settings, model_request_parameters)
|
|
92
|
+
finish(response, usage)
|
|
93
|
+
return response, usage
|
|
94
|
+
|
|
95
|
+
@asynccontextmanager
|
|
96
|
+
async def request_stream(
|
|
97
|
+
self,
|
|
98
|
+
messages: list[ModelMessage],
|
|
99
|
+
model_settings: ModelSettings | None,
|
|
100
|
+
model_request_parameters: ModelRequestParameters,
|
|
101
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
102
|
+
with self._instrument(messages, model_settings) as finish:
|
|
103
|
+
response_stream: StreamedResponse | None = None
|
|
104
|
+
try:
|
|
105
|
+
async with super().request_stream(
|
|
106
|
+
messages, model_settings, model_request_parameters
|
|
107
|
+
) as response_stream:
|
|
108
|
+
yield response_stream
|
|
109
|
+
finally:
|
|
110
|
+
if response_stream:
|
|
111
|
+
finish(response_stream.get(), response_stream.usage())
|
|
112
|
+
|
|
113
|
+
@contextmanager
|
|
114
|
+
def _instrument(
|
|
115
|
+
self,
|
|
116
|
+
messages: list[ModelMessage],
|
|
117
|
+
model_settings: ModelSettings | None,
|
|
118
|
+
) -> Iterator[Callable[[ModelResponse, Usage], None]]:
|
|
119
|
+
operation = 'chat'
|
|
120
|
+
model_name = self.model_name
|
|
121
|
+
span_name = f'{operation} {model_name}'
|
|
122
|
+
system = getattr(self.wrapped, 'system', '') or self.wrapped.__class__.__name__.removesuffix('Model').lower()
|
|
123
|
+
system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
|
|
124
|
+
# TODO Missing attributes:
|
|
125
|
+
# - server.address: requires a Model.base_url abstract method or similar
|
|
126
|
+
# - server.port: to parse from the base_url
|
|
127
|
+
# - error.type: unclear if we should do something here or just always rely on span exceptions
|
|
128
|
+
# - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
|
|
129
|
+
attributes: dict[str, Any] = {
|
|
130
|
+
'gen_ai.operation.name': operation,
|
|
131
|
+
'gen_ai.system': system,
|
|
132
|
+
'gen_ai.request.model': model_name,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
if model_settings:
|
|
136
|
+
for key in MODEL_SETTING_ATTRIBUTES:
|
|
137
|
+
if (value := model_settings.get(key, NOT_GIVEN)) is not NOT_GIVEN:
|
|
138
|
+
attributes[f'gen_ai.request.{key}'] = value
|
|
139
|
+
|
|
140
|
+
emit_event = partial(self._emit_event, system)
|
|
141
|
+
|
|
142
|
+
with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
143
|
+
if span.is_recording():
|
|
144
|
+
for message in messages:
|
|
145
|
+
if isinstance(message, ModelRequest):
|
|
146
|
+
for part in message.parts:
|
|
147
|
+
event_name, body = _request_part_body(part)
|
|
148
|
+
if event_name:
|
|
149
|
+
emit_event(event_name, body)
|
|
150
|
+
elif isinstance(message, ModelResponse):
|
|
151
|
+
for body in _response_bodies(message):
|
|
152
|
+
emit_event('gen_ai.assistant.message', body)
|
|
153
|
+
|
|
154
|
+
def finish(response: ModelResponse, usage: Usage):
|
|
155
|
+
if not span.is_recording():
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
for response_body in _response_bodies(response):
|
|
159
|
+
if response_body:
|
|
160
|
+
emit_event(
|
|
161
|
+
'gen_ai.choice',
|
|
162
|
+
{
|
|
163
|
+
# TODO finish_reason
|
|
164
|
+
'index': 0,
|
|
165
|
+
'message': response_body,
|
|
166
|
+
},
|
|
167
|
+
)
|
|
168
|
+
span.set_attributes(
|
|
169
|
+
{
|
|
170
|
+
k: v
|
|
171
|
+
for k, v in {
|
|
172
|
+
# TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
|
|
173
|
+
# https://github.com/pydantic/pydantic-ai/issues/886
|
|
174
|
+
'gen_ai.response.model': response.model_name or model_name,
|
|
175
|
+
'gen_ai.usage.input_tokens': usage.request_tokens,
|
|
176
|
+
'gen_ai.usage.output_tokens': usage.response_tokens,
|
|
177
|
+
}.items()
|
|
178
|
+
if v is not None
|
|
179
|
+
}
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
yield finish
|
|
183
|
+
|
|
184
|
+
def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
|
|
185
|
+
self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
|
|
189
|
+
if isinstance(part, SystemPromptPart):
|
|
190
|
+
return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
|
|
191
|
+
elif isinstance(part, UserPromptPart):
|
|
192
|
+
return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
|
|
193
|
+
elif isinstance(part, ToolReturnPart):
|
|
194
|
+
return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
|
|
195
|
+
elif isinstance(part, RetryPromptPart):
|
|
196
|
+
if part.tool_name is None:
|
|
197
|
+
return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
|
|
198
|
+
else:
|
|
199
|
+
return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
|
|
200
|
+
else:
|
|
201
|
+
return '', {}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
|
|
205
|
+
body: dict[str, Any] = {'role': 'assistant'}
|
|
206
|
+
result = [body]
|
|
207
|
+
for part in message.parts:
|
|
208
|
+
if isinstance(part, ToolCallPart):
|
|
209
|
+
body.setdefault('tool_calls', []).append(
|
|
210
|
+
{
|
|
211
|
+
'id': part.tool_call_id,
|
|
212
|
+
'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
|
|
213
|
+
'function': {
|
|
214
|
+
'name': part.tool_name,
|
|
215
|
+
'arguments': part.args,
|
|
216
|
+
},
|
|
217
|
+
}
|
|
218
|
+
)
|
|
219
|
+
elif isinstance(part, TextPart):
|
|
220
|
+
if body.get('content'):
|
|
221
|
+
body = {'role': 'assistant'}
|
|
222
|
+
result.append(body)
|
|
223
|
+
body['content'] = part.content
|
|
224
|
+
|
|
225
|
+
return result
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import os
|
|
4
5
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
5
6
|
from contextlib import asynccontextmanager
|
|
@@ -15,6 +16,8 @@ from typing_extensions import assert_never
|
|
|
15
16
|
from .. import UnexpectedModelBehavior, _utils
|
|
16
17
|
from .._utils import now_utc as _now_utc
|
|
17
18
|
from ..messages import (
|
|
19
|
+
BinaryContent,
|
|
20
|
+
ImageUrl,
|
|
18
21
|
ModelMessage,
|
|
19
22
|
ModelRequest,
|
|
20
23
|
ModelResponse,
|
|
@@ -45,6 +48,8 @@ try:
|
|
|
45
48
|
Content as MistralContent,
|
|
46
49
|
ContentChunk as MistralContentChunk,
|
|
47
50
|
FunctionCall as MistralFunctionCall,
|
|
51
|
+
ImageURL as MistralImageURL,
|
|
52
|
+
ImageURLChunk as MistralImageURLChunk,
|
|
48
53
|
Mistral,
|
|
49
54
|
OptionalNullable as MistralOptionalNullable,
|
|
50
55
|
TextChunk as MistralTextChunk,
|
|
@@ -134,9 +139,6 @@ class MistralModel(Model):
|
|
|
134
139
|
api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
|
|
135
140
|
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
|
|
136
141
|
|
|
137
|
-
def name(self) -> str:
|
|
138
|
-
return f'mistral:{self._model_name}'
|
|
139
|
-
|
|
140
142
|
async def request(
|
|
141
143
|
self,
|
|
142
144
|
messages: list[ModelMessage],
|
|
@@ -426,7 +428,7 @@ class MistralModel(Model):
|
|
|
426
428
|
if isinstance(part, SystemPromptPart):
|
|
427
429
|
yield MistralSystemMessage(content=part.content)
|
|
428
430
|
elif isinstance(part, UserPromptPart):
|
|
429
|
-
yield
|
|
431
|
+
yield cls._map_user_prompt(part)
|
|
430
432
|
elif isinstance(part, ToolReturnPart):
|
|
431
433
|
yield MistralToolMessage(
|
|
432
434
|
tool_call_id=part.tool_call_id,
|
|
@@ -463,6 +465,29 @@ class MistralModel(Model):
|
|
|
463
465
|
else:
|
|
464
466
|
assert_never(message)
|
|
465
467
|
|
|
468
|
+
@staticmethod
|
|
469
|
+
def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage:
|
|
470
|
+
content: str | list[MistralContentChunk]
|
|
471
|
+
if isinstance(part.content, str):
|
|
472
|
+
content = part.content
|
|
473
|
+
else:
|
|
474
|
+
content = []
|
|
475
|
+
for item in part.content:
|
|
476
|
+
if isinstance(item, str):
|
|
477
|
+
content.append(MistralTextChunk(text=item))
|
|
478
|
+
elif isinstance(item, ImageUrl):
|
|
479
|
+
content.append(MistralImageURLChunk(image_url=MistralImageURL(url=item.url)))
|
|
480
|
+
elif isinstance(item, BinaryContent):
|
|
481
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
482
|
+
if item.is_image:
|
|
483
|
+
image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
|
|
484
|
+
content.append(MistralImageURLChunk(image_url=image_url, type='image_url'))
|
|
485
|
+
else:
|
|
486
|
+
raise RuntimeError('Only image binary content is supported for Mistral.')
|
|
487
|
+
else: # pragma: no cover
|
|
488
|
+
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
489
|
+
return MistralUserMessage(content=content)
|
|
490
|
+
|
|
466
491
|
|
|
467
492
|
MistralToolCallId = Union[str, None]
|
|
468
493
|
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import os
|
|
4
|
-
from collections.abc import AsyncIterable, AsyncIterator
|
|
5
|
+
from collections.abc import AsyncIterable, AsyncIterator
|
|
5
6
|
from contextlib import asynccontextmanager
|
|
6
7
|
from dataclasses import dataclass, field
|
|
7
8
|
from datetime import datetime, timezone
|
|
8
|
-
from itertools import chain
|
|
9
9
|
from typing import Literal, Union, cast, overload
|
|
10
10
|
|
|
11
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
@@ -14,6 +14,9 @@ from typing_extensions import assert_never
|
|
|
14
14
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
15
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
16
16
|
from ..messages import (
|
|
17
|
+
AudioUrl,
|
|
18
|
+
BinaryContent,
|
|
19
|
+
ImageUrl,
|
|
17
20
|
ModelMessage,
|
|
18
21
|
ModelRequest,
|
|
19
22
|
ModelResponse,
|
|
@@ -39,7 +42,15 @@ from . import (
|
|
|
39
42
|
try:
|
|
40
43
|
from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
|
|
41
44
|
from openai.types import ChatModel, chat
|
|
42
|
-
from openai.types.chat import
|
|
45
|
+
from openai.types.chat import (
|
|
46
|
+
ChatCompletionChunk,
|
|
47
|
+
ChatCompletionContentPartImageParam,
|
|
48
|
+
ChatCompletionContentPartInputAudioParam,
|
|
49
|
+
ChatCompletionContentPartParam,
|
|
50
|
+
ChatCompletionContentPartTextParam,
|
|
51
|
+
)
|
|
52
|
+
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
53
|
+
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
|
|
43
54
|
except ImportError as _import_error:
|
|
44
55
|
raise ImportError(
|
|
45
56
|
'Please install `openai` to use the OpenAI model, '
|
|
@@ -119,9 +130,9 @@ class OpenAIModel(Model):
|
|
|
119
130
|
"""
|
|
120
131
|
self._model_name = model_name
|
|
121
132
|
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
122
|
-
# openai compatible models do not always need an API key.
|
|
133
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
123
134
|
if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
|
|
124
|
-
api_key = ''
|
|
135
|
+
api_key = 'api-key-not-set'
|
|
125
136
|
|
|
126
137
|
if openai_client is not None:
|
|
127
138
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
@@ -135,9 +146,6 @@ class OpenAIModel(Model):
|
|
|
135
146
|
self.system_prompt_role = system_prompt_role
|
|
136
147
|
self._system = system
|
|
137
148
|
|
|
138
|
-
def name(self) -> str:
|
|
139
|
-
return f'openai:{self._model_name}'
|
|
140
|
-
|
|
141
149
|
async def request(
|
|
142
150
|
self,
|
|
143
151
|
messages: list[ModelMessage],
|
|
@@ -211,7 +219,10 @@ class OpenAIModel(Model):
|
|
|
211
219
|
else:
|
|
212
220
|
tool_choice = 'auto'
|
|
213
221
|
|
|
214
|
-
openai_messages
|
|
222
|
+
openai_messages: list[chat.ChatCompletionMessageParam] = []
|
|
223
|
+
for m in messages:
|
|
224
|
+
async for msg in self._map_message(m):
|
|
225
|
+
openai_messages.append(msg)
|
|
215
226
|
|
|
216
227
|
return await self.client.chat.completions.create(
|
|
217
228
|
model=self._model_name,
|
|
@@ -264,10 +275,11 @@ class OpenAIModel(Model):
|
|
|
264
275
|
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
265
276
|
return tools
|
|
266
277
|
|
|
267
|
-
def _map_message(self, message: ModelMessage) ->
|
|
278
|
+
async def _map_message(self, message: ModelMessage) -> AsyncIterable[chat.ChatCompletionMessageParam]:
|
|
268
279
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
269
280
|
if isinstance(message, ModelRequest):
|
|
270
|
-
|
|
281
|
+
async for item in self._map_user_message(message):
|
|
282
|
+
yield item
|
|
271
283
|
elif isinstance(message, ModelResponse):
|
|
272
284
|
texts: list[str] = []
|
|
273
285
|
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
@@ -308,7 +320,7 @@ class OpenAIModel(Model):
|
|
|
308
320
|
},
|
|
309
321
|
}
|
|
310
322
|
|
|
311
|
-
def _map_user_message(self, message: ModelRequest) ->
|
|
323
|
+
async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
|
|
312
324
|
for part in message.parts:
|
|
313
325
|
if isinstance(part, SystemPromptPart):
|
|
314
326
|
if self.system_prompt_role == 'developer':
|
|
@@ -318,7 +330,7 @@ class OpenAIModel(Model):
|
|
|
318
330
|
else:
|
|
319
331
|
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
320
332
|
elif isinstance(part, UserPromptPart):
|
|
321
|
-
yield
|
|
333
|
+
yield await self._map_user_prompt(part)
|
|
322
334
|
elif isinstance(part, ToolReturnPart):
|
|
323
335
|
yield chat.ChatCompletionToolMessageParam(
|
|
324
336
|
role='tool',
|
|
@@ -337,6 +349,40 @@ class OpenAIModel(Model):
|
|
|
337
349
|
else:
|
|
338
350
|
assert_never(part)
|
|
339
351
|
|
|
352
|
+
@staticmethod
|
|
353
|
+
async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
|
|
354
|
+
content: str | list[ChatCompletionContentPartParam]
|
|
355
|
+
if isinstance(part.content, str):
|
|
356
|
+
content = part.content
|
|
357
|
+
else:
|
|
358
|
+
content = []
|
|
359
|
+
for item in part.content:
|
|
360
|
+
if isinstance(item, str):
|
|
361
|
+
content.append(ChatCompletionContentPartTextParam(text=item, type='text'))
|
|
362
|
+
elif isinstance(item, ImageUrl):
|
|
363
|
+
image_url = ImageURL(url=item.url)
|
|
364
|
+
content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
|
|
365
|
+
elif isinstance(item, BinaryContent):
|
|
366
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
367
|
+
if item.is_image:
|
|
368
|
+
image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
|
|
369
|
+
content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
|
|
370
|
+
elif item.is_audio:
|
|
371
|
+
audio = InputAudio(data=base64_encoded, format=item.audio_format)
|
|
372
|
+
content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
|
|
373
|
+
else: # pragma: no cover
|
|
374
|
+
raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
|
|
375
|
+
elif isinstance(item, AudioUrl): # pragma: no cover
|
|
376
|
+
client = cached_async_http_client()
|
|
377
|
+
response = await client.get(item.url)
|
|
378
|
+
response.raise_for_status()
|
|
379
|
+
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
380
|
+
audio = InputAudio(data=base64_encoded, format=response.headers.get('content-type'))
|
|
381
|
+
content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
|
|
382
|
+
else:
|
|
383
|
+
assert_never(item)
|
|
384
|
+
return chat.ChatCompletionUserMessageParam(role='user', content=content)
|
|
385
|
+
|
|
340
386
|
|
|
341
387
|
@dataclass
|
|
342
388
|
class OpenAIStreamedResponse(StreamedResponse):
|
pydantic_ai/models/test.py
CHANGED
|
@@ -130,15 +130,15 @@ class TestModel(Model):
|
|
|
130
130
|
|
|
131
131
|
def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
|
|
132
132
|
if self.custom_result_text is not None:
|
|
133
|
-
assert (
|
|
134
|
-
|
|
135
|
-
)
|
|
133
|
+
assert model_request_parameters.allow_text_result, (
|
|
134
|
+
'Plain response not allowed, but `custom_result_text` is set.'
|
|
135
|
+
)
|
|
136
136
|
assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
|
|
137
137
|
return _TextResult(self.custom_result_text)
|
|
138
138
|
elif self.custom_result_args is not None:
|
|
139
|
-
assert (
|
|
140
|
-
|
|
141
|
-
)
|
|
139
|
+
assert model_request_parameters.result_tools is not None, (
|
|
140
|
+
'No result tools provided, but `custom_result_args` is set.'
|
|
141
|
+
)
|
|
142
142
|
result_tool = model_request_parameters.result_tools[0]
|
|
143
143
|
|
|
144
144
|
if k := result_tool.outer_typed_dict_key:
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ..messages import ModelMessage, ModelResponse
|
|
9
|
+
from ..settings import ModelSettings
|
|
10
|
+
from ..usage import Usage
|
|
11
|
+
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(init=False)
|
|
15
|
+
class WrapperModel(Model):
|
|
16
|
+
"""Model which wraps another model."""
|
|
17
|
+
|
|
18
|
+
wrapped: Model
|
|
19
|
+
|
|
20
|
+
def __init__(self, wrapped: Model | KnownModelName):
|
|
21
|
+
self.wrapped = infer_model(wrapped)
|
|
22
|
+
|
|
23
|
+
async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
|
|
24
|
+
return await self.wrapped.request(*args, **kwargs)
|
|
25
|
+
|
|
26
|
+
@asynccontextmanager
|
|
27
|
+
async def request_stream(
|
|
28
|
+
self,
|
|
29
|
+
messages: list[ModelMessage],
|
|
30
|
+
model_settings: ModelSettings | None,
|
|
31
|
+
model_request_parameters: ModelRequestParameters,
|
|
32
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
33
|
+
async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
|
|
34
|
+
yield response_stream
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def model_name(self) -> str:
|
|
38
|
+
return self.wrapped.model_name
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def system(self) -> str | None:
|
|
42
|
+
return self.wrapped.system
|
|
43
|
+
|
|
44
|
+
def __getattr__(self, item: str):
|
|
45
|
+
return getattr(self.wrapped, item)
|