pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/messages.py CHANGED
@@ -1,11 +1,14 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import uuid
4
+ from collections.abc import Sequence
3
5
  from dataclasses import dataclass, field, replace
4
6
  from datetime import datetime
5
7
  from typing import Annotated, Any, Literal, Union, cast, overload
6
8
 
7
9
  import pydantic
8
10
  import pydantic_core
11
+ from typing_extensions import TypeAlias
9
12
 
10
13
  from ._utils import now_utc as _now_utc
11
14
  from .exceptions import UnexpectedModelBehavior
@@ -31,6 +34,93 @@ class SystemPromptPart:
31
34
  """Part type identifier, this is available on all parts as a discriminator."""
32
35
 
33
36
 
37
+ @dataclass
38
+ class AudioUrl:
39
+ """A URL to an audio file."""
40
+
41
+ url: str
42
+ """The URL of the audio file."""
43
+
44
+ kind: Literal['audio-url'] = 'audio-url'
45
+ """Type identifier, this is available on all parts as a discriminator."""
46
+
47
+ @property
48
+ def media_type(self) -> AudioMediaType:
49
+ """Return the media type of the audio file, based on the url."""
50
+ if self.url.endswith('.mp3'):
51
+ return 'audio/mpeg'
52
+ elif self.url.endswith('.wav'):
53
+ return 'audio/wav'
54
+ else:
55
+ raise ValueError(f'Unknown audio file extension: {self.url}')
56
+
57
+
58
+ @dataclass
59
+ class ImageUrl:
60
+ """A URL to an image."""
61
+
62
+ url: str
63
+ """The URL of the image."""
64
+
65
+ kind: Literal['image-url'] = 'image-url'
66
+ """Type identifier, this is available on all parts as a discriminator."""
67
+
68
+ @property
69
+ def media_type(self) -> ImageMediaType:
70
+ """Return the media type of the image, based on the url."""
71
+ if self.url.endswith(('.jpg', '.jpeg')):
72
+ return 'image/jpeg'
73
+ elif self.url.endswith('.png'):
74
+ return 'image/png'
75
+ elif self.url.endswith('.gif'):
76
+ return 'image/gif'
77
+ elif self.url.endswith('.webp'):
78
+ return 'image/webp'
79
+ else:
80
+ raise ValueError(f'Unknown image file extension: {self.url}')
81
+
82
+
83
+ AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
84
+ ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
85
+
86
+
87
+ @dataclass
88
+ class BinaryContent:
89
+ """Binary content, e.g. an audio or image file."""
90
+
91
+ data: bytes
92
+ """The binary data."""
93
+
94
+ media_type: AudioMediaType | ImageMediaType | str
95
+ """The media type of the binary data."""
96
+
97
+ kind: Literal['binary'] = 'binary'
98
+ """Type identifier, this is available on all parts as a discriminator."""
99
+
100
+ @property
101
+ def is_audio(self) -> bool:
102
+ """Return `True` if the media type is an audio type."""
103
+ return self.media_type.startswith('audio/')
104
+
105
+ @property
106
+ def is_image(self) -> bool:
107
+ """Return `True` if the media type is an image type."""
108
+ return self.media_type.startswith('image/')
109
+
110
+ @property
111
+ def audio_format(self) -> Literal['mp3', 'wav']:
112
+ """Return the audio format given the media type."""
113
+ if self.media_type == 'audio/mpeg':
114
+ return 'mp3'
115
+ elif self.media_type == 'audio/wav':
116
+ return 'wav'
117
+ else:
118
+ raise ValueError(f'Unknown audio media type: {self.media_type}')
119
+
120
+
121
+ UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | BinaryContent'
122
+
123
+
34
124
  @dataclass
35
125
  class UserPromptPart:
36
126
  """A user prompt, generally written by the end user.
@@ -39,7 +129,7 @@ class UserPromptPart:
39
129
  [`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream].
40
130
  """
41
131
 
42
- content: str
132
+ content: str | Sequence[UserContent]
43
133
  """The content of the prompt."""
44
134
 
45
135
  timestamp: datetime = field(default_factory=_now_utc)
@@ -445,3 +535,33 @@ class PartDeltaEvent:
445
535
 
446
536
  ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
447
537
  """An event in the model response stream, either starting a new part or applying a delta to an existing one."""
538
+
539
+
540
+ @dataclass
541
+ class FunctionToolCallEvent:
542
+ """An event indicating the start to a call to a function tool."""
543
+
544
+ part: ToolCallPart
545
+ """The (function) tool call to make."""
546
+ call_id: str = field(init=False)
547
+ """An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
548
+ event_kind: Literal['function_tool_call'] = 'function_tool_call'
549
+ """Event type identifier, used as a discriminator."""
550
+
551
+ def __post_init__(self):
552
+ self.call_id = self.part.tool_call_id or str(uuid.uuid4())
553
+
554
+
555
+ @dataclass
556
+ class FunctionToolResultEvent:
557
+ """An event indicating the result of a function tool call."""
558
+
559
+ result: ToolReturnPart | RetryPromptPart
560
+ """The result of the call to the function tool."""
561
+ call_id: str
562
+ """An ID used to match the result to its original call."""
563
+ event_kind: Literal['function_tool_result'] = 'function_tool_result'
564
+ """Event type identifier, used as a discriminator."""
565
+
566
+
567
+ HandleResponseEvent = Annotated[Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('kind')]
@@ -234,6 +234,8 @@ class StreamedResponse(ABC):
234
234
 
235
235
  This method should be implemented by subclasses to translate the vendor-specific stream of events into
236
236
  pydantic_ai-format events.
237
+
238
+ It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes.
237
239
  """
238
240
  raise NotImplementedError()
239
241
  # noinspection PyUnreachableCode
@@ -362,7 +364,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
362
364
  raise UserError(f'Unknown model: {model}')
363
365
 
364
366
 
365
- @cache
366
367
  def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
367
368
  """Cached HTTPX async client so multiple agents and calls can share the same client.
368
369
 
@@ -373,6 +374,16 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.Asyn
373
374
  The default timeouts match those of OpenAI,
374
375
  see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
375
376
  """
377
+ client = _cached_async_http_client(timeout=timeout, connect=connect)
378
+ if client.is_closed:
379
+ # This happens if the context manager is used, so we need to create a new client.
380
+ _cached_async_http_client.cache_clear()
381
+ client = _cached_async_http_client(timeout=timeout, connect=connect)
382
+ return client
383
+
384
+
385
+ @cache
386
+ def _cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
376
387
  return httpx.AsyncClient(
377
388
  timeout=httpx.Timeout(timeout=timeout, connect=connect),
378
389
  headers={'User-Agent': get_user_agent()},
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterable, AsyncIterator
3
+ import io
4
+ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
4
5
  from contextlib import asynccontextmanager
5
6
  from dataclasses import dataclass, field
6
7
  from datetime import datetime, timezone
@@ -13,6 +14,8 @@ from typing_extensions import assert_never
13
14
  from .. import UnexpectedModelBehavior, _utils, usage
14
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
16
  from ..messages import (
17
+ BinaryContent,
18
+ ImageUrl,
16
19
  ModelMessage,
17
20
  ModelRequest,
18
21
  ModelResponse,
@@ -38,6 +41,7 @@ from . import (
38
41
  try:
39
42
  from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
40
43
  from anthropic.types import (
44
+ ImageBlockParam,
41
45
  Message as AnthropicMessage,
42
46
  MessageParam,
43
47
  MetadataParam,
@@ -214,7 +218,7 @@ class AnthropicModel(Model):
214
218
  if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
215
219
  tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
216
220
 
217
- system_prompt, anthropic_messages = self._map_message(messages)
221
+ system_prompt, anthropic_messages = await self._map_message(messages)
218
222
 
219
223
  return await self.client.messages.create(
220
224
  max_tokens=model_settings.get('max_tokens', 1024),
@@ -266,69 +270,82 @@ class AnthropicModel(Model):
266
270
  tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
267
271
  return tools
268
272
 
269
- def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
273
+ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
270
274
  """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
271
275
  system_prompt: str = ''
272
276
  anthropic_messages: list[MessageParam] = []
273
277
  for m in messages:
274
278
  if isinstance(m, ModelRequest):
275
- for part in m.parts:
276
- if isinstance(part, SystemPromptPart):
277
- system_prompt += part.content
278
- elif isinstance(part, UserPromptPart):
279
- anthropic_messages.append(MessageParam(role='user', content=part.content))
280
- elif isinstance(part, ToolReturnPart):
281
- anthropic_messages.append(
282
- MessageParam(
283
- role='user',
284
- content=[
285
- ToolResultBlockParam(
286
- tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
287
- type='tool_result',
288
- content=part.model_response_str(),
289
- is_error=False,
290
- )
291
- ],
292
- )
279
+ user_content_params: list[ToolResultBlockParam | TextBlockParam | ImageBlockParam] = []
280
+ for request_part in m.parts:
281
+ if isinstance(request_part, SystemPromptPart):
282
+ system_prompt += request_part.content
283
+ elif isinstance(request_part, UserPromptPart):
284
+ async for content in self._map_user_prompt(request_part):
285
+ user_content_params.append(content)
286
+ elif isinstance(request_part, ToolReturnPart):
287
+ tool_result_block_param = ToolResultBlockParam(
288
+ tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
289
+ type='tool_result',
290
+ content=request_part.model_response_str(),
291
+ is_error=False,
293
292
  )
294
- elif isinstance(part, RetryPromptPart):
295
- if part.tool_name is None:
296
- anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
293
+ user_content_params.append(tool_result_block_param)
294
+ elif isinstance(request_part, RetryPromptPart):
295
+ if request_part.tool_name is None:
296
+ retry_param = TextBlockParam(type='text', text=request_part.model_response())
297
297
  else:
298
- anthropic_messages.append(
299
- MessageParam(
300
- role='user',
301
- content=[
302
- ToolResultBlockParam(
303
- tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
304
- type='tool_result',
305
- content=part.model_response(),
306
- is_error=True,
307
- ),
308
- ],
309
- )
298
+ retry_param = ToolResultBlockParam(
299
+ tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
300
+ type='tool_result',
301
+ content=request_part.model_response(),
302
+ is_error=True,
310
303
  )
304
+ user_content_params.append(retry_param)
305
+ anthropic_messages.append(MessageParam(role='user', content=user_content_params))
311
306
  elif isinstance(m, ModelResponse):
312
- content: list[TextBlockParam | ToolUseBlockParam] = []
313
- for item in m.parts:
314
- if isinstance(item, TextPart):
315
- content.append(TextBlockParam(text=item.content, type='text'))
307
+ assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
308
+ for response_part in m.parts:
309
+ if isinstance(response_part, TextPart):
310
+ assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
316
311
  else:
317
- assert isinstance(item, ToolCallPart)
318
- content.append(self._map_tool_call(item))
319
- anthropic_messages.append(MessageParam(role='assistant', content=content))
312
+ tool_use_block_param = ToolUseBlockParam(
313
+ id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
314
+ type='tool_use',
315
+ name=response_part.tool_name,
316
+ input=response_part.args_as_dict(),
317
+ )
318
+ assistant_content_params.append(tool_use_block_param)
319
+ anthropic_messages.append(MessageParam(role='assistant', content=assistant_content_params))
320
320
  else:
321
321
  assert_never(m)
322
322
  return system_prompt, anthropic_messages
323
323
 
324
324
  @staticmethod
325
- def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
326
- return ToolUseBlockParam(
327
- id=_guard_tool_call_id(t=t, model_source='Anthropic'),
328
- type='tool_use',
329
- name=t.tool_name,
330
- input=t.args_as_dict(),
331
- )
325
+ async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockParam | TextBlockParam]:
326
+ if isinstance(part.content, str):
327
+ yield TextBlockParam(text=part.content, type='text')
328
+ else:
329
+ for item in part.content:
330
+ if isinstance(item, str):
331
+ yield TextBlockParam(text=item, type='text')
332
+ elif isinstance(item, BinaryContent):
333
+ if item.is_image:
334
+ yield ImageBlockParam(
335
+ source={'data': io.BytesIO(item.data), 'media_type': item.media_type, 'type': 'base64'}, # type: ignore
336
+ type='image',
337
+ )
338
+ else:
339
+ raise RuntimeError('Only images are supported for binary content')
340
+ elif isinstance(item, ImageUrl):
341
+ response = await cached_async_http_client().get(item.url)
342
+ response.raise_for_status()
343
+ yield ImageBlockParam(
344
+ source={'data': io.BytesIO(response.content), 'media_type': 'image/jpeg', 'type': 'base64'},
345
+ type='image',
346
+ )
347
+ else:
348
+ raise RuntimeError(f'Unsupported content type: {type(item)}')
332
349
 
333
350
  @staticmethod
334
351
  def _map_tool_definition(f: ToolDefinition) -> ToolParam:
@@ -124,7 +124,7 @@ class CohereModel(Model):
124
124
  assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
125
125
  self.client = cohere_client
126
126
  else:
127
- self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
127
+ self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)
128
128
 
129
129
  async def request(
130
130
  self,
@@ -242,7 +242,10 @@ class CohereModel(Model):
242
242
  if isinstance(part, SystemPromptPart):
243
243
  yield SystemChatMessageV2(role='system', content=part.content)
244
244
  elif isinstance(part, UserPromptPart):
245
- yield UserChatMessageV2(role='user', content=part.content)
245
+ if isinstance(part.content, str):
246
+ yield UserChatMessageV2(role='user', content=part.content)
247
+ else:
248
+ raise RuntimeError('Cohere does not yet support multi-modal inputs.')
246
249
  elif isinstance(part, ToolReturnPart):
247
250
  yield ToolChatMessageV2(
248
251
  role='tool',
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
4
  import re
5
- from collections.abc import AsyncIterator, Awaitable, Iterable
5
+ from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime
@@ -14,6 +14,9 @@ from typing_extensions import TypeAlias, assert_never, overload
14
14
  from .. import _utils, usage
15
15
  from .._utils import PeekableAsyncStream
16
16
  from ..messages import (
17
+ AudioUrl,
18
+ BinaryContent,
19
+ ImageUrl,
17
20
  ModelMessage,
18
21
  ModelRequest,
19
22
  ModelResponse,
@@ -23,6 +26,7 @@ from ..messages import (
23
26
  TextPart,
24
27
  ToolCallPart,
25
28
  ToolReturnPart,
29
+ UserContent,
26
30
  UserPromptPart,
27
31
  )
28
32
  from ..settings import ModelSettings
@@ -109,9 +113,9 @@ class FunctionModel(Model):
109
113
  model_settings,
110
114
  )
111
115
 
112
- assert (
113
- self.stream_function is not None
114
- ), 'FunctionModel must receive a `stream_function` to support streamed requests'
116
+ assert self.stream_function is not None, (
117
+ 'FunctionModel must receive a `stream_function` to support streamed requests'
118
+ )
115
119
 
116
120
  response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
117
121
 
@@ -262,7 +266,12 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
262
266
  )
263
267
 
264
268
 
265
- def _estimate_string_tokens(content: str) -> int:
269
+ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
266
270
  if not content:
267
271
  return 0
268
- return len(re.split(r'[\s",.:]+', content.strip()))
272
+ if isinstance(content, str):
273
+ return len(re.split(r'[\s",.:]+', content.strip()))
274
+ # TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
275
+ else: # pragma: no cover
276
+ assert isinstance(content, (AudioUrl, ImageUrl, BinaryContent))
277
+ return 0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import base64
3
4
  import os
4
5
  import re
5
6
  from collections.abc import AsyncIterator, Sequence
@@ -16,6 +17,9 @@ from typing_extensions import NotRequired, TypedDict, assert_never
16
17
 
17
18
  from .. import UnexpectedModelBehavior, _utils, exceptions, usage
18
19
  from ..messages import (
20
+ AudioUrl,
21
+ BinaryContent,
22
+ ImageUrl,
19
23
  ModelMessage,
20
24
  ModelRequest,
21
25
  ModelResponse,
@@ -185,7 +189,7 @@ class GeminiModel(Model):
185
189
  ) -> AsyncIterator[HTTPResponse]:
186
190
  tools = self._get_tools(model_request_parameters)
187
191
  tool_config = self._get_tool_config(model_request_parameters, tools)
188
- sys_prompt_parts, contents = self._message_to_gemini_content(messages)
192
+ sys_prompt_parts, contents = await self._message_to_gemini_content(messages)
189
193
 
190
194
  request_data = _GeminiRequest(contents=contents)
191
195
  if sys_prompt_parts:
@@ -254,7 +258,7 @@ class GeminiModel(Model):
254
258
  async for chunk in aiter_bytes:
255
259
  content.extend(chunk)
256
260
  responses = _gemini_streamed_response_ta.validate_json(
257
- content,
261
+ _ensure_decodeable(content),
258
262
  experimental_allow_partial='trailing-strings',
259
263
  )
260
264
  if responses:
@@ -269,7 +273,7 @@ class GeminiModel(Model):
269
273
  return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
270
274
 
271
275
  @classmethod
272
- def _message_to_gemini_content(
276
+ async def _message_to_gemini_content(
273
277
  cls, messages: list[ModelMessage]
274
278
  ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
275
279
  sys_prompt_parts: list[_GeminiTextPart] = []
@@ -282,7 +286,7 @@ class GeminiModel(Model):
282
286
  if isinstance(part, SystemPromptPart):
283
287
  sys_prompt_parts.append(_GeminiTextPart(text=part.content))
284
288
  elif isinstance(part, UserPromptPart):
285
- message_parts.append(_GeminiTextPart(text=part.content))
289
+ message_parts.extend(await cls._map_user_prompt(part))
286
290
  elif isinstance(part, ToolReturnPart):
287
291
  message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
288
292
  elif isinstance(part, RetryPromptPart):
@@ -303,6 +307,34 @@ class GeminiModel(Model):
303
307
 
304
308
  return sys_prompt_parts, contents
305
309
 
310
+ @staticmethod
311
+ async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
312
+ if isinstance(part.content, str):
313
+ return [{'text': part.content}]
314
+ else:
315
+ content: list[_GeminiPartUnion] = []
316
+ for item in part.content:
317
+ if isinstance(item, str):
318
+ content.append({'text': item})
319
+ elif isinstance(item, BinaryContent):
320
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
321
+ content.append(_GeminiInlineDataPart(data=base64_encoded, mime_type=item.media_type))
322
+ elif isinstance(item, (AudioUrl, ImageUrl)):
323
+ try:
324
+ content.append(_GeminiFileDataData(file_uri=item.url, mime_type=item.media_type))
325
+ except ValueError:
326
+ # Download the file if can't find the mime type.
327
+ client = cached_async_http_client()
328
+ response = await client.get(item.url, follow_redirects=True)
329
+ response.raise_for_status()
330
+ base64_encoded = base64.b64encode(response.content).decode('utf-8')
331
+ content.append(
332
+ _GeminiInlineDataPart(data=base64_encoded, mime_type=response.headers['Content-Type'])
333
+ )
334
+ else:
335
+ assert_never(item)
336
+ return content
337
+
306
338
 
307
339
  class AuthProtocol(Protocol):
308
340
  """Abstract definition for Gemini authentication."""
@@ -370,7 +402,7 @@ class GeminiStreamedResponse(StreamedResponse):
370
402
  self._content.extend(chunk)
371
403
 
372
404
  gemini_responses = _gemini_streamed_response_ta.validate_json(
373
- self._content,
405
+ _ensure_decodeable(self._content),
374
406
  experimental_allow_partial='trailing-strings',
375
407
  )
376
408
 
@@ -494,6 +526,20 @@ class _GeminiTextPart(TypedDict):
494
526
  text: str
495
527
 
496
528
 
529
+ class _GeminiInlineDataPart(TypedDict):
530
+ """See <https://ai.google.dev/api/caching#Blob>."""
531
+
532
+ data: str
533
+ mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
534
+
535
+
536
+ class _GeminiFileDataData(TypedDict):
537
+ """See <https://ai.google.dev/api/caching#FileData>."""
538
+
539
+ file_uri: Annotated[str, pydantic.Field(alias='fileUri')]
540
+ mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
541
+
542
+
497
543
  class _GeminiFunctionCallPart(TypedDict):
498
544
  function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
499
545
 
@@ -549,6 +595,10 @@ def _part_discriminator(v: Any) -> str:
549
595
  if isinstance(v, dict):
550
596
  if 'text' in v:
551
597
  return 'text'
598
+ elif 'inlineData' in v:
599
+ return 'inline_data'
600
+ elif 'fileData' in v:
601
+ return 'file_data'
552
602
  elif 'functionCall' in v or 'function_call' in v:
553
603
  return 'function_call'
554
604
  elif 'functionResponse' in v or 'function_response' in v:
@@ -564,6 +614,8 @@ _GeminiPartUnion = Annotated[
564
614
  Annotated[_GeminiTextPart, pydantic.Tag('text')],
565
615
  Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
566
616
  Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
617
+ Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
618
+ Annotated[_GeminiFileDataData, pydantic.Tag('file_data')],
567
619
  ],
568
620
  pydantic.Discriminator(_part_discriminator),
569
621
  ]
@@ -774,3 +826,19 @@ class _GeminiJsonSchema:
774
826
 
775
827
  if items_schema := schema.get('items'): # pragma: no branch
776
828
  self._simplify(items_schema, refs_stack)
829
+
830
+
831
+ def _ensure_decodeable(content: bytearray) -> bytearray:
832
+ """Trim any invalid unicode point bytes off the end of a bytearray.
833
+
834
+ This is necessary before attempting to parse streaming JSON bytes.
835
+
836
+ This is a temporary workaround until https://github.com/pydantic/pydantic-core/issues/1633 is resolved
837
+ """
838
+ while True:
839
+ try:
840
+ content.decode()
841
+ except UnicodeDecodeError:
842
+ content = content[:-1] # this will definitely succeed before we run out of bytes
843
+ else:
844
+ return content