pydantic-ai-slim 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,116 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import AsyncExitStack, asynccontextmanager
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Callable
7
+
8
+ from ..exceptions import FallbackExceptionGroup, ModelHTTPError
9
+ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
10
+
11
+ if TYPE_CHECKING:
12
+ from ..messages import ModelMessage, ModelResponse
13
+ from ..settings import ModelSettings
14
+ from ..usage import Usage
15
+
16
+
17
+ @dataclass(init=False)
18
+ class FallbackModel(Model):
19
+ """A model that uses one or more fallback models upon failure.
20
+
21
+ Apart from `__init__`, all methods are private or match those of the base class.
22
+ """
23
+
24
+ models: list[Model]
25
+
26
+ _model_name: str = field(repr=False)
27
+ _fallback_on: Callable[[Exception], bool]
28
+
29
+ def __init__(
30
+ self,
31
+ default_model: Model | KnownModelName,
32
+ *fallback_models: Model | KnownModelName,
33
+ fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
34
+ ):
35
+ """Initialize a fallback model instance.
36
+
37
+ Args:
38
+ default_model: The name or instance of the default model to use.
39
+ fallback_models: The names or instances of the fallback models to use upon failure.
40
+ fallback_on: A callable or tuple of exceptions that should trigger a fallback.
41
+ """
42
+ self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]
43
+ self._model_name = f'FallBackModel[{", ".join(model.model_name for model in self.models)}]'
44
+
45
+ if isinstance(fallback_on, tuple):
46
+ self._fallback_on = _default_fallback_condition_factory(fallback_on)
47
+ else:
48
+ self._fallback_on = fallback_on
49
+
50
+ async def request(
51
+ self,
52
+ messages: list[ModelMessage],
53
+ model_settings: ModelSettings | None,
54
+ model_request_parameters: ModelRequestParameters,
55
+ ) -> tuple[ModelResponse, Usage]:
56
+ """Try each model in sequence until one succeeds.
57
+
58
+ In case of failure, raise a FallbackExceptionGroup with all exceptions.
59
+ """
60
+ exceptions: list[Exception] = []
61
+
62
+ for model in self.models:
63
+ try:
64
+ return await model.request(messages, model_settings, model_request_parameters)
65
+ except Exception as exc:
66
+ if self._fallback_on(exc):
67
+ exceptions.append(exc)
68
+ continue
69
+ raise exc
70
+
71
+ raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
72
+
73
+ @asynccontextmanager
74
+ async def request_stream(
75
+ self,
76
+ messages: list[ModelMessage],
77
+ model_settings: ModelSettings | None,
78
+ model_request_parameters: ModelRequestParameters,
79
+ ) -> AsyncIterator[StreamedResponse]:
80
+ """Try each model in sequence until one succeeds."""
81
+ exceptions: list[Exception] = []
82
+
83
+ for model in self.models:
84
+ async with AsyncExitStack() as stack:
85
+ try:
86
+ response = await stack.enter_async_context(
87
+ model.request_stream(messages, model_settings, model_request_parameters)
88
+ )
89
+ except Exception as exc:
90
+ if self._fallback_on(exc):
91
+ exceptions.append(exc)
92
+ continue
93
+ raise exc
94
+ yield response
95
+ return
96
+
97
+ raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
98
+
99
+ @property
100
+ def model_name(self) -> str:
101
+ """The model name."""
102
+ return self._model_name
103
+
104
+ @property
105
+ def system(self) -> str | None:
106
+ """The system / model provider, n/a for fallback models."""
107
+ return None
108
+
109
+
110
+ def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
111
+ """Create a default fallback condition for the given exceptions."""
112
+
113
+ def fallback_condition(exception: Exception) -> bool:
114
+ return isinstance(exception, exceptions)
115
+
116
+ return fallback_condition
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
4
  import re
5
- from collections.abc import AsyncIterator, Awaitable, Iterable
5
+ from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime
@@ -14,6 +14,9 @@ from typing_extensions import TypeAlias, assert_never, overload
14
14
  from .. import _utils, usage
15
15
  from .._utils import PeekableAsyncStream
16
16
  from ..messages import (
17
+ AudioUrl,
18
+ BinaryContent,
19
+ ImageUrl,
17
20
  ModelMessage,
18
21
  ModelRequest,
19
22
  ModelResponse,
@@ -23,6 +26,7 @@ from ..messages import (
23
26
  TextPart,
24
27
  ToolCallPart,
25
28
  ToolReturnPart,
29
+ UserContent,
26
30
  UserPromptPart,
27
31
  )
28
32
  from ..settings import ModelSettings
@@ -44,15 +48,23 @@ class FunctionModel(Model):
44
48
  _system: str | None = field(default=None, repr=False)
45
49
 
46
50
  @overload
47
- def __init__(self, function: FunctionDef) -> None: ...
51
+ def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ...
48
52
 
49
53
  @overload
50
- def __init__(self, *, stream_function: StreamFunctionDef) -> None: ...
54
+ def __init__(self, *, stream_function: StreamFunctionDef, model_name: str | None = None) -> None: ...
51
55
 
52
56
  @overload
53
- def __init__(self, function: FunctionDef, *, stream_function: StreamFunctionDef) -> None: ...
57
+ def __init__(
58
+ self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None
59
+ ) -> None: ...
54
60
 
55
- def __init__(self, function: FunctionDef | None = None, *, stream_function: StreamFunctionDef | None = None):
61
+ def __init__(
62
+ self,
63
+ function: FunctionDef | None = None,
64
+ *,
65
+ stream_function: StreamFunctionDef | None = None,
66
+ model_name: str | None = None,
67
+ ):
56
68
  """Initialize a `FunctionModel`.
57
69
 
58
70
  Either `function` or `stream_function` must be provided, providing both is allowed.
@@ -60,6 +72,7 @@ class FunctionModel(Model):
60
72
  Args:
61
73
  function: The function to call for non-streamed requests.
62
74
  stream_function: The function to call for streamed requests.
75
+ model_name: The name of the model. If not provided, a name is generated from the function names.
63
76
  """
64
77
  if function is None and stream_function is None:
65
78
  raise TypeError('Either `function` or `stream_function` must be provided')
@@ -68,7 +81,7 @@ class FunctionModel(Model):
68
81
 
69
82
  function_name = self.function.__name__ if self.function is not None else ''
70
83
  stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
71
- self._model_name = f'function:{function_name}:{stream_function_name}'
84
+ self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
72
85
 
73
86
  async def request(
74
87
  self,
@@ -91,7 +104,7 @@ class FunctionModel(Model):
91
104
  response_ = await _utils.run_in_executor(self.function, messages, agent_info)
92
105
  assert isinstance(response_, ModelResponse), response_
93
106
  response = response_
94
- response.model_name = f'function:{self.function.__name__}'
107
+ response.model_name = self._model_name
95
108
  # TODO is `messages` right here? Should it just be new messages?
96
109
  return response, _estimate_usage(chain(messages, [response]))
97
110
 
@@ -119,7 +132,7 @@ class FunctionModel(Model):
119
132
  if isinstance(first, _utils.Unset):
120
133
  raise ValueError('Stream function must return at least one item')
121
134
 
122
- yield FunctionStreamedResponse(_model_name=f'function:{self.stream_function.__name__}', _iter=response_stream)
135
+ yield FunctionStreamedResponse(_model_name=self._model_name, _iter=response_stream)
123
136
 
124
137
  @property
125
138
  def model_name(self) -> str:
@@ -262,7 +275,12 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
262
275
  )
263
276
 
264
277
 
265
- def _estimate_string_tokens(content: str) -> int:
278
+ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
266
279
  if not content:
267
280
  return 0
268
- return len(re.split(r'[\s",.:]+', content.strip()))
281
+ if isinstance(content, str):
282
+ return len(re.split(r'[\s",.:]+', content.strip()))
283
+ # TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
284
+ else: # pragma: no cover
285
+ assert isinstance(content, (AudioUrl, ImageUrl, BinaryContent))
286
+ return 0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import base64
3
4
  import os
4
5
  import re
5
6
  from collections.abc import AsyncIterator, Sequence
@@ -14,8 +15,11 @@ import pydantic
14
15
  from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
15
16
  from typing_extensions import NotRequired, TypedDict, assert_never
16
17
 
17
- from .. import UnexpectedModelBehavior, _utils, exceptions, usage
18
+ from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
18
19
  from ..messages import (
20
+ AudioUrl,
21
+ BinaryContent,
22
+ ImageUrl,
19
23
  ModelMessage,
20
24
  ModelRequest,
21
25
  ModelResponse,
@@ -108,7 +112,7 @@ class GeminiModel(Model):
108
112
  if env_api_key := os.getenv('GEMINI_API_KEY'):
109
113
  api_key = env_api_key
110
114
  else:
111
- raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
115
+ raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
112
116
  self.http_client = http_client or cached_async_http_client()
113
117
  self._auth = ApiKeyAuth(api_key)
114
118
  self._url = url_template.format(model=model_name)
@@ -185,7 +189,7 @@ class GeminiModel(Model):
185
189
  ) -> AsyncIterator[HTTPResponse]:
186
190
  tools = self._get_tools(model_request_parameters)
187
191
  tool_config = self._get_tool_config(model_request_parameters, tools)
188
- sys_prompt_parts, contents = self._message_to_gemini_content(messages)
192
+ sys_prompt_parts, contents = await self._message_to_gemini_content(messages)
189
193
 
190
194
  request_data = _GeminiRequest(contents=contents)
191
195
  if sys_prompt_parts:
@@ -229,9 +233,11 @@ class GeminiModel(Model):
229
233
  headers=headers,
230
234
  timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
231
235
  ) as r:
232
- if r.status_code != 200:
236
+ if (status_code := r.status_code) != 200:
233
237
  await r.aread()
234
- raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
238
+ if status_code >= 400:
239
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text)
240
+ raise UnexpectedModelBehavior(f'Unexpected response from gemini {status_code}', r.text)
235
241
  yield r
236
242
 
237
243
  def _process_response(self, response: _GeminiResponse) -> ModelResponse:
@@ -269,7 +275,7 @@ class GeminiModel(Model):
269
275
  return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
270
276
 
271
277
  @classmethod
272
- def _message_to_gemini_content(
278
+ async def _message_to_gemini_content(
273
279
  cls, messages: list[ModelMessage]
274
280
  ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
275
281
  sys_prompt_parts: list[_GeminiTextPart] = []
@@ -282,7 +288,7 @@ class GeminiModel(Model):
282
288
  if isinstance(part, SystemPromptPart):
283
289
  sys_prompt_parts.append(_GeminiTextPart(text=part.content))
284
290
  elif isinstance(part, UserPromptPart):
285
- message_parts.append(_GeminiTextPart(text=part.content))
291
+ message_parts.extend(await cls._map_user_prompt(part))
286
292
  elif isinstance(part, ToolReturnPart):
287
293
  message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
288
294
  elif isinstance(part, RetryPromptPart):
@@ -303,6 +309,40 @@ class GeminiModel(Model):
303
309
 
304
310
  return sys_prompt_parts, contents
305
311
 
312
+ @staticmethod
313
+ async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
314
+ if isinstance(part.content, str):
315
+ return [{'text': part.content}]
316
+ else:
317
+ content: list[_GeminiPartUnion] = []
318
+ for item in part.content:
319
+ if isinstance(item, str):
320
+ content.append({'text': item})
321
+ elif isinstance(item, BinaryContent):
322
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
323
+ content.append(
324
+ _GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type})
325
+ )
326
+ elif isinstance(item, (AudioUrl, ImageUrl)):
327
+ try:
328
+ content.append(
329
+ _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
330
+ )
331
+ except ValueError:
332
+ # Download the file if can't find the mime type.
333
+ client = cached_async_http_client()
334
+ response = await client.get(item.url, follow_redirects=True)
335
+ response.raise_for_status()
336
+ base64_encoded = base64.b64encode(response.content).decode('utf-8')
337
+ content.append(
338
+ _GeminiInlineDataPart(
339
+ inline_data={'data': base64_encoded, 'mime_type': response.headers['Content-Type']}
340
+ )
341
+ )
342
+ else:
343
+ assert_never(item)
344
+ return content
345
+
306
346
 
307
347
  class AuthProtocol(Protocol):
308
348
  """Abstract definition for Gemini authentication."""
@@ -494,6 +534,28 @@ class _GeminiTextPart(TypedDict):
494
534
  text: str
495
535
 
496
536
 
537
+ class _GeminiInlineData(TypedDict):
538
+ data: str
539
+ mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
540
+
541
+
542
+ class _GeminiInlineDataPart(TypedDict):
543
+ """See <https://ai.google.dev/api/caching#Blob>."""
544
+
545
+ inline_data: Annotated[_GeminiInlineData, pydantic.Field(alias='inlineData')]
546
+
547
+
548
+ class _GeminiFileData(TypedDict):
549
+ """See <https://ai.google.dev/api/caching#FileData>."""
550
+
551
+ file_uri: Annotated[str, pydantic.Field(alias='fileUri')]
552
+ mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
553
+
554
+
555
+ class _GeminiFileDataPart(TypedDict):
556
+ file_data: Annotated[_GeminiFileData, pydantic.Field(alias='fileData')]
557
+
558
+
497
559
  class _GeminiFunctionCallPart(TypedDict):
498
560
  function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
499
561
 
@@ -517,7 +579,7 @@ def _process_response_from_parts(
517
579
  )
518
580
  )
519
581
  elif 'function_response' in part:
520
- raise exceptions.UnexpectedModelBehavior(
582
+ raise UnexpectedModelBehavior(
521
583
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
522
584
  )
523
585
  return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
@@ -549,6 +611,10 @@ def _part_discriminator(v: Any) -> str:
549
611
  if isinstance(v, dict):
550
612
  if 'text' in v:
551
613
  return 'text'
614
+ elif 'inlineData' in v:
615
+ return 'inline_data'
616
+ elif 'fileData' in v:
617
+ return 'file_data'
552
618
  elif 'functionCall' in v or 'function_call' in v:
553
619
  return 'function_call'
554
620
  elif 'functionResponse' in v or 'function_response' in v:
@@ -564,6 +630,8 @@ _GeminiPartUnion = Annotated[
564
630
  Annotated[_GeminiTextPart, pydantic.Tag('text')],
565
631
  Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
566
632
  Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
633
+ Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
634
+ Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
567
635
  ],
568
636
  pydantic.Discriminator(_part_discriminator),
569
637
  ]
@@ -726,7 +794,7 @@ class _GeminiJsonSchema:
726
794
  # noinspection PyTypeChecker
727
795
  key = re.sub(r'^#/\$defs/', '', ref)
728
796
  if key in refs_stack:
729
- raise exceptions.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
797
+ raise UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
730
798
  refs_stack += (key,)
731
799
  schema_def = self.defs[key]
732
800
  self._simplify(schema_def, refs_stack)
@@ -760,7 +828,7 @@ class _GeminiJsonSchema:
760
828
  def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
761
829
  ad_props = schema.pop('additionalProperties', None)
762
830
  if ad_props:
763
- raise exceptions.UserError('Additional properties in JSON Schema are not supported by Gemini')
831
+ raise UserError('Additional properties in JSON Schema are not supported by Gemini')
764
832
 
765
833
  if properties := schema.get('properties'): # pragma: no branch
766
834
  for value in properties.values():
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import base64
3
4
  from collections.abc import AsyncIterable, AsyncIterator, Iterable
4
5
  from contextlib import asynccontextmanager
5
6
  from dataclasses import dataclass, field
@@ -10,9 +11,11 @@ from typing import Literal, Union, cast, overload
10
11
  from httpx import AsyncClient as AsyncHTTPClient
11
12
  from typing_extensions import assert_never
12
13
 
13
- from .. import UnexpectedModelBehavior, _utils, usage
14
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
14
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
16
  from ..messages import (
17
+ BinaryContent,
18
+ ImageUrl,
16
19
  ModelMessage,
17
20
  ModelRequest,
18
21
  ModelResponse,
@@ -36,9 +39,9 @@ from . import (
36
39
  )
37
40
 
38
41
  try:
39
- from groq import NOT_GIVEN, AsyncGroq, AsyncStream
42
+ from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
40
43
  from groq.types import chat
41
- from groq.types.chat import ChatCompletion, ChatCompletionChunk
44
+ from groq.types.chat.chat_completion_content_part_image_param import ImageURL
42
45
  except ImportError as _import_error:
43
46
  raise ImportError(
44
47
  'Please install `groq` to use the Groq model, '
@@ -163,7 +166,7 @@ class GroqModel(Model):
163
166
  stream: Literal[True],
164
167
  model_settings: GroqModelSettings,
165
168
  model_request_parameters: ModelRequestParameters,
166
- ) -> AsyncStream[ChatCompletionChunk]:
169
+ ) -> AsyncStream[chat.ChatCompletionChunk]:
167
170
  pass
168
171
 
169
172
  @overload
@@ -182,7 +185,7 @@ class GroqModel(Model):
182
185
  stream: bool,
183
186
  model_settings: GroqModelSettings,
184
187
  model_request_parameters: ModelRequestParameters,
185
- ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
188
+ ) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]:
186
189
  tools = self._get_tools(model_request_parameters)
187
190
  # standalone function to make it easier to override
188
191
  if not tools:
@@ -194,23 +197,28 @@ class GroqModel(Model):
194
197
 
195
198
  groq_messages = list(chain(*(self._map_message(m) for m in messages)))
196
199
 
197
- return await self.client.chat.completions.create(
198
- model=str(self._model_name),
199
- messages=groq_messages,
200
- n=1,
201
- parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
202
- tools=tools or NOT_GIVEN,
203
- tool_choice=tool_choice or NOT_GIVEN,
204
- stream=stream,
205
- max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
206
- temperature=model_settings.get('temperature', NOT_GIVEN),
207
- top_p=model_settings.get('top_p', NOT_GIVEN),
208
- timeout=model_settings.get('timeout', NOT_GIVEN),
209
- seed=model_settings.get('seed', NOT_GIVEN),
210
- presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
211
- frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
212
- logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
213
- )
200
+ try:
201
+ return await self.client.chat.completions.create(
202
+ model=str(self._model_name),
203
+ messages=groq_messages,
204
+ n=1,
205
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
206
+ tools=tools or NOT_GIVEN,
207
+ tool_choice=tool_choice or NOT_GIVEN,
208
+ stream=stream,
209
+ max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
210
+ temperature=model_settings.get('temperature', NOT_GIVEN),
211
+ top_p=model_settings.get('top_p', NOT_GIVEN),
212
+ timeout=model_settings.get('timeout', NOT_GIVEN),
213
+ seed=model_settings.get('seed', NOT_GIVEN),
214
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
215
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
216
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
217
+ )
218
+ except APIStatusError as e:
219
+ if (status_code := e.status_code) >= 400:
220
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
221
+ raise
214
222
 
215
223
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
216
224
  """Process a non-streamed response, and prepare a message to return."""
@@ -224,7 +232,7 @@ class GroqModel(Model):
224
232
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
225
233
  return ModelResponse(items, model_name=response.model, timestamp=timestamp)
226
234
 
227
- async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
235
+ async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
228
236
  """Process a streamed response, and prepare a streaming response to return."""
229
237
  peekable_response = _utils.PeekableAsyncStream(response)
230
238
  first_chunk = await peekable_response.peek()
@@ -293,7 +301,7 @@ class GroqModel(Model):
293
301
  if isinstance(part, SystemPromptPart):
294
302
  yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
295
303
  elif isinstance(part, UserPromptPart):
296
- yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
304
+ yield cls._map_user_prompt(part)
297
305
  elif isinstance(part, ToolReturnPart):
298
306
  yield chat.ChatCompletionToolMessageParam(
299
307
  role='tool',
@@ -310,13 +318,37 @@ class GroqModel(Model):
310
318
  content=part.model_response(),
311
319
  )
312
320
 
321
+ @staticmethod
322
+ def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
323
+ content: str | list[chat.ChatCompletionContentPartParam]
324
+ if isinstance(part.content, str):
325
+ content = part.content
326
+ else:
327
+ content = []
328
+ for item in part.content:
329
+ if isinstance(item, str):
330
+ content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text'))
331
+ elif isinstance(item, ImageUrl):
332
+ image_url = ImageURL(url=item.url)
333
+ content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
334
+ elif isinstance(item, BinaryContent):
335
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
336
+ if item.is_image:
337
+ image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
338
+ content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
339
+ else:
340
+ raise RuntimeError('Only images are supported for binary content in Groq.')
341
+ else: # pragma: no cover
342
+ raise RuntimeError(f'Unsupported content type: {type(item)}')
343
+ return chat.ChatCompletionUserMessageParam(role='user', content=content)
344
+
313
345
 
314
346
  @dataclass
315
347
  class GroqStreamedResponse(StreamedResponse):
316
348
  """Implementation of `StreamedResponse` for Groq models."""
317
349
 
318
350
  _model_name: GroqModelName
319
- _response: AsyncIterable[ChatCompletionChunk]
351
+ _response: AsyncIterable[chat.ChatCompletionChunk]
320
352
  _timestamp: datetime
321
353
 
322
354
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
@@ -355,9 +387,9 @@ class GroqStreamedResponse(StreamedResponse):
355
387
  return self._timestamp
356
388
 
357
389
 
358
- def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
390
+ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage:
359
391
  response_usage = None
360
- if isinstance(completion, ChatCompletion):
392
+ if isinstance(completion, chat.ChatCompletion):
361
393
  response_usage = completion.usage
362
394
  elif completion.x_groq is not None:
363
395
  response_usage = completion.x_groq.usage
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import base64
3
4
  import os
4
5
  from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
6
  from contextlib import asynccontextmanager
@@ -12,9 +13,11 @@ import pydantic_core
12
13
  from httpx import AsyncClient as AsyncHTTPClient, Timeout
13
14
  from typing_extensions import assert_never
14
15
 
15
- from .. import UnexpectedModelBehavior, _utils
16
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
16
17
  from .._utils import now_utc as _now_utc
17
18
  from ..messages import (
19
+ BinaryContent,
20
+ ImageUrl,
18
21
  ModelMessage,
19
22
  ModelRequest,
20
23
  ModelResponse,
@@ -45,6 +48,8 @@ try:
45
48
  Content as MistralContent,
46
49
  ContentChunk as MistralContentChunk,
47
50
  FunctionCall as MistralFunctionCall,
51
+ ImageURL as MistralImageURL,
52
+ ImageURLChunk as MistralImageURLChunk,
48
53
  Mistral,
49
54
  OptionalNullable as MistralOptionalNullable,
50
55
  TextChunk as MistralTextChunk,
@@ -54,6 +59,7 @@ try:
54
59
  ChatCompletionResponse as MistralChatCompletionResponse,
55
60
  CompletionEvent as MistralCompletionEvent,
56
61
  Messages as MistralMessages,
62
+ SDKError,
57
63
  Tool as MistralTool,
58
64
  ToolCall as MistralToolCall,
59
65
  )
@@ -179,19 +185,25 @@ class MistralModel(Model):
179
185
  model_request_parameters: ModelRequestParameters,
180
186
  ) -> MistralChatCompletionResponse:
181
187
  """Make a non-streaming request to the model."""
182
- response = await self.client.chat.complete_async(
183
- model=str(self._model_name),
184
- messages=list(chain(*(self._map_message(m) for m in messages))),
185
- n=1,
186
- tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
187
- tool_choice=self._get_tool_choice(model_request_parameters),
188
- stream=False,
189
- max_tokens=model_settings.get('max_tokens', UNSET),
190
- temperature=model_settings.get('temperature', UNSET),
191
- top_p=model_settings.get('top_p', 1),
192
- timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
193
- random_seed=model_settings.get('seed', UNSET),
194
- )
188
+ try:
189
+ response = await self.client.chat.complete_async(
190
+ model=str(self._model_name),
191
+ messages=list(chain(*(self._map_message(m) for m in messages))),
192
+ n=1,
193
+ tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
194
+ tool_choice=self._get_tool_choice(model_request_parameters),
195
+ stream=False,
196
+ max_tokens=model_settings.get('max_tokens', UNSET),
197
+ temperature=model_settings.get('temperature', UNSET),
198
+ top_p=model_settings.get('top_p', 1),
199
+ timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
200
+ random_seed=model_settings.get('seed', UNSET),
201
+ )
202
+ except SDKError as e:
203
+ if (status_code := e.status_code) >= 400:
204
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
205
+ raise
206
+
195
207
  assert response, 'A unexpected empty response from Mistral.'
196
208
  return response
197
209
 
@@ -423,7 +435,7 @@ class MistralModel(Model):
423
435
  if isinstance(part, SystemPromptPart):
424
436
  yield MistralSystemMessage(content=part.content)
425
437
  elif isinstance(part, UserPromptPart):
426
- yield MistralUserMessage(content=part.content)
438
+ yield cls._map_user_prompt(part)
427
439
  elif isinstance(part, ToolReturnPart):
428
440
  yield MistralToolMessage(
429
441
  tool_call_id=part.tool_call_id,
@@ -460,6 +472,29 @@ class MistralModel(Model):
460
472
  else:
461
473
  assert_never(message)
462
474
 
475
+ @staticmethod
476
+ def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage:
477
+ content: str | list[MistralContentChunk]
478
+ if isinstance(part.content, str):
479
+ content = part.content
480
+ else:
481
+ content = []
482
+ for item in part.content:
483
+ if isinstance(item, str):
484
+ content.append(MistralTextChunk(text=item))
485
+ elif isinstance(item, ImageUrl):
486
+ content.append(MistralImageURLChunk(image_url=MistralImageURL(url=item.url)))
487
+ elif isinstance(item, BinaryContent):
488
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
489
+ if item.is_image:
490
+ image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
491
+ content.append(MistralImageURLChunk(image_url=image_url, type='image_url'))
492
+ else:
493
+ raise RuntimeError('Only image binary content is supported for Mistral.')
494
+ else: # pragma: no cover
495
+ raise RuntimeError(f'Unsupported content type: {type(item)}')
496
+ return MistralUserMessage(content=content)
497
+
463
498
 
464
499
  MistralToolCallId = Union[str, None]
465
500