pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import base64
3
4
  from collections.abc import AsyncIterable, AsyncIterator, Iterable
4
5
  from contextlib import asynccontextmanager
5
6
  from dataclasses import dataclass, field
@@ -13,6 +14,8 @@ from typing_extensions import assert_never
13
14
  from .. import UnexpectedModelBehavior, _utils, usage
14
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
16
  from ..messages import (
17
+ BinaryContent,
18
+ ImageUrl,
16
19
  ModelMessage,
17
20
  ModelRequest,
18
21
  ModelResponse,
@@ -38,7 +41,7 @@ from . import (
38
41
  try:
39
42
  from groq import NOT_GIVEN, AsyncGroq, AsyncStream
40
43
  from groq.types import chat
41
- from groq.types.chat import ChatCompletion, ChatCompletionChunk
44
+ from groq.types.chat.chat_completion_content_part_image_param import ImageURL
42
45
  except ImportError as _import_error:
43
46
  raise ImportError(
44
47
  'Please install `groq` to use the Groq model, '
@@ -163,7 +166,7 @@ class GroqModel(Model):
163
166
  stream: Literal[True],
164
167
  model_settings: GroqModelSettings,
165
168
  model_request_parameters: ModelRequestParameters,
166
- ) -> AsyncStream[ChatCompletionChunk]:
169
+ ) -> AsyncStream[chat.ChatCompletionChunk]:
167
170
  pass
168
171
 
169
172
  @overload
@@ -182,7 +185,7 @@ class GroqModel(Model):
182
185
  stream: bool,
183
186
  model_settings: GroqModelSettings,
184
187
  model_request_parameters: ModelRequestParameters,
185
- ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
188
+ ) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]:
186
189
  tools = self._get_tools(model_request_parameters)
187
190
  # standalone function to make it easier to override
188
191
  if not tools:
@@ -224,7 +227,7 @@ class GroqModel(Model):
224
227
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
225
228
  return ModelResponse(items, model_name=response.model, timestamp=timestamp)
226
229
 
227
- async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
230
+ async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
228
231
  """Process a streamed response, and prepare a streaming response to return."""
229
232
  peekable_response = _utils.PeekableAsyncStream(response)
230
233
  first_chunk = await peekable_response.peek()
@@ -293,7 +296,7 @@ class GroqModel(Model):
293
296
  if isinstance(part, SystemPromptPart):
294
297
  yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
295
298
  elif isinstance(part, UserPromptPart):
296
- yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
299
+ yield cls._map_user_prompt(part)
297
300
  elif isinstance(part, ToolReturnPart):
298
301
  yield chat.ChatCompletionToolMessageParam(
299
302
  role='tool',
@@ -310,13 +313,37 @@ class GroqModel(Model):
310
313
  content=part.model_response(),
311
314
  )
312
315
 
316
+ @staticmethod
317
+ def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
318
+ content: str | list[chat.ChatCompletionContentPartParam]
319
+ if isinstance(part.content, str):
320
+ content = part.content
321
+ else:
322
+ content = []
323
+ for item in part.content:
324
+ if isinstance(item, str):
325
+ content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text'))
326
+ elif isinstance(item, ImageUrl):
327
+ image_url = ImageURL(url=item.url)
328
+ content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
329
+ elif isinstance(item, BinaryContent):
330
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
331
+ if item.is_image:
332
+ image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
333
+ content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
334
+ else:
335
+ raise RuntimeError('Only images are supported for binary content in Groq.')
336
+ else: # pragma: no cover
337
+ raise RuntimeError(f'Unsupported content type: {type(item)}')
338
+ return chat.ChatCompletionUserMessageParam(role='user', content=content)
339
+
313
340
 
314
341
  @dataclass
315
342
  class GroqStreamedResponse(StreamedResponse):
316
343
  """Implementation of `StreamedResponse` for Groq models."""
317
344
 
318
345
  _model_name: GroqModelName
319
- _response: AsyncIterable[ChatCompletionChunk]
346
+ _response: AsyncIterable[chat.ChatCompletionChunk]
320
347
  _timestamp: datetime
321
348
 
322
349
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
@@ -355,9 +382,9 @@ class GroqStreamedResponse(StreamedResponse):
355
382
  return self._timestamp
356
383
 
357
384
 
358
- def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
385
+ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage:
359
386
  response_usage = None
360
- if isinstance(completion, ChatCompletion):
387
+ if isinstance(completion, chat.ChatCompletion):
361
388
  response_usage = completion.usage
362
389
  elif completion.x_groq is not None:
363
390
  response_usage = completion.x_groq.usage
@@ -0,0 +1,225 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator, Iterator
4
+ from contextlib import asynccontextmanager, contextmanager
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Any, Callable, Literal
8
+
9
+ import logfire_api
10
+ from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
11
+ from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
12
+
13
+ from ..messages import (
14
+ ModelMessage,
15
+ ModelRequest,
16
+ ModelRequestPart,
17
+ ModelResponse,
18
+ RetryPromptPart,
19
+ SystemPromptPart,
20
+ TextPart,
21
+ ToolCallPart,
22
+ ToolReturnPart,
23
+ UserPromptPart,
24
+ )
25
+ from ..settings import ModelSettings
26
+ from ..usage import Usage
27
+ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
28
+ from .wrapper import WrapperModel
29
+
30
+ MODEL_SETTING_ATTRIBUTES: tuple[
31
+ Literal[
32
+ 'max_tokens',
33
+ 'top_p',
34
+ 'seed',
35
+ 'temperature',
36
+ 'presence_penalty',
37
+ 'frequency_penalty',
38
+ ],
39
+ ...,
40
+ ] = (
41
+ 'max_tokens',
42
+ 'top_p',
43
+ 'seed',
44
+ 'temperature',
45
+ 'presence_penalty',
46
+ 'frequency_penalty',
47
+ )
48
+
49
+ NOT_GIVEN = object()
50
+
51
+
52
+ @dataclass
53
+ class InstrumentedModel(WrapperModel):
54
+ """Model which is instrumented with logfire."""
55
+
56
+ tracer: Tracer = field(repr=False)
57
+ event_logger: EventLogger = field(repr=False)
58
+
59
+ def __init__(
60
+ self,
61
+ wrapped: Model | KnownModelName,
62
+ tracer_provider: TracerProvider | None = None,
63
+ event_logger_provider: EventLoggerProvider | None = None,
64
+ ):
65
+ super().__init__(wrapped)
66
+ tracer_provider = tracer_provider or get_tracer_provider()
67
+ event_logger_provider = event_logger_provider or get_event_logger_provider()
68
+ self.tracer = tracer_provider.get_tracer('pydantic-ai')
69
+ self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
70
+
71
+ @classmethod
72
+ def from_logfire(
73
+ cls,
74
+ wrapped: Model | KnownModelName,
75
+ logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
76
+ ) -> InstrumentedModel:
77
+ if hasattr(logfire_instance.config, 'get_event_logger_provider'):
78
+ event_provider = logfire_instance.config.get_event_logger_provider()
79
+ else:
80
+ event_provider = None
81
+ tracer_provider = logfire_instance.config.get_tracer_provider()
82
+ return cls(wrapped, tracer_provider, event_provider)
83
+
84
+ async def request(
85
+ self,
86
+ messages: list[ModelMessage],
87
+ model_settings: ModelSettings | None,
88
+ model_request_parameters: ModelRequestParameters,
89
+ ) -> tuple[ModelResponse, Usage]:
90
+ with self._instrument(messages, model_settings) as finish:
91
+ response, usage = await super().request(messages, model_settings, model_request_parameters)
92
+ finish(response, usage)
93
+ return response, usage
94
+
95
+ @asynccontextmanager
96
+ async def request_stream(
97
+ self,
98
+ messages: list[ModelMessage],
99
+ model_settings: ModelSettings | None,
100
+ model_request_parameters: ModelRequestParameters,
101
+ ) -> AsyncIterator[StreamedResponse]:
102
+ with self._instrument(messages, model_settings) as finish:
103
+ response_stream: StreamedResponse | None = None
104
+ try:
105
+ async with super().request_stream(
106
+ messages, model_settings, model_request_parameters
107
+ ) as response_stream:
108
+ yield response_stream
109
+ finally:
110
+ if response_stream:
111
+ finish(response_stream.get(), response_stream.usage())
112
+
113
+ @contextmanager
114
+ def _instrument(
115
+ self,
116
+ messages: list[ModelMessage],
117
+ model_settings: ModelSettings | None,
118
+ ) -> Iterator[Callable[[ModelResponse, Usage], None]]:
119
+ operation = 'chat'
120
+ model_name = self.model_name
121
+ span_name = f'{operation} {model_name}'
122
+ system = getattr(self.wrapped, 'system', '') or self.wrapped.__class__.__name__.removesuffix('Model').lower()
123
+ system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
124
+ # TODO Missing attributes:
125
+ # - server.address: requires a Model.base_url abstract method or similar
126
+ # - server.port: to parse from the base_url
127
+ # - error.type: unclear if we should do something here or just always rely on span exceptions
128
+ # - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
129
+ attributes: dict[str, Any] = {
130
+ 'gen_ai.operation.name': operation,
131
+ 'gen_ai.system': system,
132
+ 'gen_ai.request.model': model_name,
133
+ }
134
+
135
+ if model_settings:
136
+ for key in MODEL_SETTING_ATTRIBUTES:
137
+ if (value := model_settings.get(key, NOT_GIVEN)) is not NOT_GIVEN:
138
+ attributes[f'gen_ai.request.{key}'] = value
139
+
140
+ emit_event = partial(self._emit_event, system)
141
+
142
+ with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
143
+ if span.is_recording():
144
+ for message in messages:
145
+ if isinstance(message, ModelRequest):
146
+ for part in message.parts:
147
+ event_name, body = _request_part_body(part)
148
+ if event_name:
149
+ emit_event(event_name, body)
150
+ elif isinstance(message, ModelResponse):
151
+ for body in _response_bodies(message):
152
+ emit_event('gen_ai.assistant.message', body)
153
+
154
+ def finish(response: ModelResponse, usage: Usage):
155
+ if not span.is_recording():
156
+ return
157
+
158
+ for response_body in _response_bodies(response):
159
+ if response_body:
160
+ emit_event(
161
+ 'gen_ai.choice',
162
+ {
163
+ # TODO finish_reason
164
+ 'index': 0,
165
+ 'message': response_body,
166
+ },
167
+ )
168
+ span.set_attributes(
169
+ {
170
+ k: v
171
+ for k, v in {
172
+ # TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
173
+ # https://github.com/pydantic/pydantic-ai/issues/886
174
+ 'gen_ai.response.model': response.model_name or model_name,
175
+ 'gen_ai.usage.input_tokens': usage.request_tokens,
176
+ 'gen_ai.usage.output_tokens': usage.response_tokens,
177
+ }.items()
178
+ if v is not None
179
+ }
180
+ )
181
+
182
+ yield finish
183
+
184
+ def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
185
+ self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))
186
+
187
+
188
+ def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
189
+ if isinstance(part, SystemPromptPart):
190
+ return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
191
+ elif isinstance(part, UserPromptPart):
192
+ return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
193
+ elif isinstance(part, ToolReturnPart):
194
+ return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
195
+ elif isinstance(part, RetryPromptPart):
196
+ if part.tool_name is None:
197
+ return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
198
+ else:
199
+ return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
200
+ else:
201
+ return '', {}
202
+
203
+
204
+ def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
205
+ body: dict[str, Any] = {'role': 'assistant'}
206
+ result = [body]
207
+ for part in message.parts:
208
+ if isinstance(part, ToolCallPart):
209
+ body.setdefault('tool_calls', []).append(
210
+ {
211
+ 'id': part.tool_call_id,
212
+ 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
213
+ 'function': {
214
+ 'name': part.tool_name,
215
+ 'arguments': part.args,
216
+ },
217
+ }
218
+ )
219
+ elif isinstance(part, TextPart):
220
+ if body.get('content'):
221
+ body = {'role': 'assistant'}
222
+ result.append(body)
223
+ body['content'] = part.content
224
+
225
+ return result
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import base64
3
4
  import os
4
5
  from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
6
  from contextlib import asynccontextmanager
@@ -15,6 +16,8 @@ from typing_extensions import assert_never
15
16
  from .. import UnexpectedModelBehavior, _utils
16
17
  from .._utils import now_utc as _now_utc
17
18
  from ..messages import (
19
+ BinaryContent,
20
+ ImageUrl,
18
21
  ModelMessage,
19
22
  ModelRequest,
20
23
  ModelResponse,
@@ -45,6 +48,8 @@ try:
45
48
  Content as MistralContent,
46
49
  ContentChunk as MistralContentChunk,
47
50
  FunctionCall as MistralFunctionCall,
51
+ ImageURL as MistralImageURL,
52
+ ImageURLChunk as MistralImageURLChunk,
48
53
  Mistral,
49
54
  OptionalNullable as MistralOptionalNullable,
50
55
  TextChunk as MistralTextChunk,
@@ -134,9 +139,6 @@ class MistralModel(Model):
134
139
  api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
135
140
  self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
136
141
 
137
- def name(self) -> str:
138
- return f'mistral:{self._model_name}'
139
-
140
142
  async def request(
141
143
  self,
142
144
  messages: list[ModelMessage],
@@ -426,7 +428,7 @@ class MistralModel(Model):
426
428
  if isinstance(part, SystemPromptPart):
427
429
  yield MistralSystemMessage(content=part.content)
428
430
  elif isinstance(part, UserPromptPart):
429
- yield MistralUserMessage(content=part.content)
431
+ yield cls._map_user_prompt(part)
430
432
  elif isinstance(part, ToolReturnPart):
431
433
  yield MistralToolMessage(
432
434
  tool_call_id=part.tool_call_id,
@@ -463,6 +465,29 @@ class MistralModel(Model):
463
465
  else:
464
466
  assert_never(message)
465
467
 
468
+ @staticmethod
469
+ def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage:
470
+ content: str | list[MistralContentChunk]
471
+ if isinstance(part.content, str):
472
+ content = part.content
473
+ else:
474
+ content = []
475
+ for item in part.content:
476
+ if isinstance(item, str):
477
+ content.append(MistralTextChunk(text=item))
478
+ elif isinstance(item, ImageUrl):
479
+ content.append(MistralImageURLChunk(image_url=MistralImageURL(url=item.url)))
480
+ elif isinstance(item, BinaryContent):
481
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
482
+ if item.is_image:
483
+ image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
484
+ content.append(MistralImageURLChunk(image_url=image_url, type='image_url'))
485
+ else:
486
+ raise RuntimeError('Only image binary content is supported for Mistral.')
487
+ else: # pragma: no cover
488
+ raise RuntimeError(f'Unsupported content type: {type(item)}')
489
+ return MistralUserMessage(content=content)
490
+
466
491
 
467
492
  MistralToolCallId = Union[str, None]
468
493
 
@@ -1,11 +1,11 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import base64
3
4
  import os
4
- from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
+ from collections.abc import AsyncIterable, AsyncIterator
5
6
  from contextlib import asynccontextmanager
6
7
  from dataclasses import dataclass, field
7
8
  from datetime import datetime, timezone
8
- from itertools import chain
9
9
  from typing import Literal, Union, cast, overload
10
10
 
11
11
  from httpx import AsyncClient as AsyncHTTPClient
@@ -14,6 +14,9 @@ from typing_extensions import assert_never
14
14
  from .. import UnexpectedModelBehavior, _utils, usage
15
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id
16
16
  from ..messages import (
17
+ AudioUrl,
18
+ BinaryContent,
19
+ ImageUrl,
17
20
  ModelMessage,
18
21
  ModelRequest,
19
22
  ModelResponse,
@@ -39,7 +42,15 @@ from . import (
39
42
  try:
40
43
  from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
41
44
  from openai.types import ChatModel, chat
42
- from openai.types.chat import ChatCompletionChunk
45
+ from openai.types.chat import (
46
+ ChatCompletionChunk,
47
+ ChatCompletionContentPartImageParam,
48
+ ChatCompletionContentPartInputAudioParam,
49
+ ChatCompletionContentPartParam,
50
+ ChatCompletionContentPartTextParam,
51
+ )
52
+ from openai.types.chat.chat_completion_content_part_image_param import ImageURL
53
+ from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
43
54
  except ImportError as _import_error:
44
55
  raise ImportError(
45
56
  'Please install `openai` to use the OpenAI model, '
@@ -119,9 +130,9 @@ class OpenAIModel(Model):
119
130
  """
120
131
  self._model_name = model_name
121
132
  # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
122
- # openai compatible models do not always need an API key.
133
+ # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
123
134
  if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
124
- api_key = ''
135
+ api_key = 'api-key-not-set'
125
136
 
126
137
  if openai_client is not None:
127
138
  assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
@@ -135,9 +146,6 @@ class OpenAIModel(Model):
135
146
  self.system_prompt_role = system_prompt_role
136
147
  self._system = system
137
148
 
138
- def name(self) -> str:
139
- return f'openai:{self._model_name}'
140
-
141
149
  async def request(
142
150
  self,
143
151
  messages: list[ModelMessage],
@@ -211,7 +219,10 @@ class OpenAIModel(Model):
211
219
  else:
212
220
  tool_choice = 'auto'
213
221
 
214
- openai_messages = list(chain(*(self._map_message(m) for m in messages)))
222
+ openai_messages: list[chat.ChatCompletionMessageParam] = []
223
+ for m in messages:
224
+ async for msg in self._map_message(m):
225
+ openai_messages.append(msg)
215
226
 
216
227
  return await self.client.chat.completions.create(
217
228
  model=self._model_name,
@@ -264,10 +275,11 @@ class OpenAIModel(Model):
264
275
  tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
265
276
  return tools
266
277
 
267
- def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
278
+ async def _map_message(self, message: ModelMessage) -> AsyncIterable[chat.ChatCompletionMessageParam]:
268
279
  """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
269
280
  if isinstance(message, ModelRequest):
270
- yield from self._map_user_message(message)
281
+ async for item in self._map_user_message(message):
282
+ yield item
271
283
  elif isinstance(message, ModelResponse):
272
284
  texts: list[str] = []
273
285
  tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
@@ -308,7 +320,7 @@ class OpenAIModel(Model):
308
320
  },
309
321
  }
310
322
 
311
- def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
323
+ async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
312
324
  for part in message.parts:
313
325
  if isinstance(part, SystemPromptPart):
314
326
  if self.system_prompt_role == 'developer':
@@ -318,7 +330,7 @@ class OpenAIModel(Model):
318
330
  else:
319
331
  yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
320
332
  elif isinstance(part, UserPromptPart):
321
- yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
333
+ yield await self._map_user_prompt(part)
322
334
  elif isinstance(part, ToolReturnPart):
323
335
  yield chat.ChatCompletionToolMessageParam(
324
336
  role='tool',
@@ -337,6 +349,40 @@ class OpenAIModel(Model):
337
349
  else:
338
350
  assert_never(part)
339
351
 
352
+ @staticmethod
353
+ async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
354
+ content: str | list[ChatCompletionContentPartParam]
355
+ if isinstance(part.content, str):
356
+ content = part.content
357
+ else:
358
+ content = []
359
+ for item in part.content:
360
+ if isinstance(item, str):
361
+ content.append(ChatCompletionContentPartTextParam(text=item, type='text'))
362
+ elif isinstance(item, ImageUrl):
363
+ image_url = ImageURL(url=item.url)
364
+ content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
365
+ elif isinstance(item, BinaryContent):
366
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
367
+ if item.is_image:
368
+ image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
369
+ content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
370
+ elif item.is_audio:
371
+ audio = InputAudio(data=base64_encoded, format=item.audio_format)
372
+ content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
373
+ else: # pragma: no cover
374
+ raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
375
+ elif isinstance(item, AudioUrl): # pragma: no cover
376
+ client = cached_async_http_client()
377
+ response = await client.get(item.url)
378
+ response.raise_for_status()
379
+ base64_encoded = base64.b64encode(response.content).decode('utf-8')
380
+ audio = InputAudio(data=base64_encoded, format=response.headers.get('content-type'))
381
+ content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
382
+ else:
383
+ assert_never(item)
384
+ return chat.ChatCompletionUserMessageParam(role='user', content=content)
385
+
340
386
 
341
387
  @dataclass
342
388
  class OpenAIStreamedResponse(StreamedResponse):
@@ -130,15 +130,15 @@ class TestModel(Model):
130
130
 
131
131
  def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult:
132
132
  if self.custom_result_text is not None:
133
- assert (
134
- model_request_parameters.allow_text_result
135
- ), 'Plain response not allowed, but `custom_result_text` is set.'
133
+ assert model_request_parameters.allow_text_result, (
134
+ 'Plain response not allowed, but `custom_result_text` is set.'
135
+ )
136
136
  assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
137
137
  return _TextResult(self.custom_result_text)
138
138
  elif self.custom_result_args is not None:
139
- assert (
140
- model_request_parameters.result_tools is not None
141
- ), 'No result tools provided, but `custom_result_args` is set.'
139
+ assert model_request_parameters.result_tools is not None, (
140
+ 'No result tools provided, but `custom_result_args` is set.'
141
+ )
142
142
  result_tool = model_request_parameters.result_tools[0]
143
143
 
144
144
  if k := result_tool.outer_typed_dict_key:
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ from ..messages import ModelMessage, ModelResponse
9
+ from ..settings import ModelSettings
10
+ from ..usage import Usage
11
+ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
12
+
13
+
14
+ @dataclass(init=False)
15
+ class WrapperModel(Model):
16
+ """Model which wraps another model."""
17
+
18
+ wrapped: Model
19
+
20
+ def __init__(self, wrapped: Model | KnownModelName):
21
+ self.wrapped = infer_model(wrapped)
22
+
23
+ async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
24
+ return await self.wrapped.request(*args, **kwargs)
25
+
26
+ @asynccontextmanager
27
+ async def request_stream(
28
+ self,
29
+ messages: list[ModelMessage],
30
+ model_settings: ModelSettings | None,
31
+ model_request_parameters: ModelRequestParameters,
32
+ ) -> AsyncIterator[StreamedResponse]:
33
+ async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
34
+ yield response_stream
35
+
36
+ @property
37
+ def model_name(self) -> str:
38
+ return self.wrapped.model_name
39
+
40
+ @property
41
+ def system(self) -> str | None:
42
+ return self.wrapped.system
43
+
44
+ def __getattr__(self, item: str):
45
+ return getattr(self.wrapped, item)