pydantic-ai-slim 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/mcp.py CHANGED
@@ -2,11 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  import base64
4
4
  import functools
5
+ import warnings
5
6
  from abc import ABC, abstractmethod
6
7
  from asyncio import Lock
7
8
  from collections.abc import AsyncIterator, Awaitable, Sequence
8
9
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
9
10
  from dataclasses import dataclass, field, replace
11
+ from datetime import timedelta
10
12
  from pathlib import Path
11
13
  from typing import Any, Callable
12
14
 
@@ -37,7 +39,7 @@ except ImportError as _import_error:
37
39
  ) from _import_error
38
40
 
39
41
  # after mcp imports so any import error maps to this file, not _mcp.py
40
- from . import _mcp, exceptions, messages, models
42
+ from . import _mcp, _utils, exceptions, messages, models
41
43
 
42
44
  __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
43
45
 
@@ -59,6 +61,7 @@ class MCPServer(AbstractToolset[Any], ABC):
59
61
  log_level: mcp_types.LoggingLevel | None = None
60
62
  log_handler: LoggingFnT | None = None
61
63
  timeout: float = 5
64
+ read_timeout: float = 5 * 60
62
65
  process_tool_call: ProcessToolCallback | None = None
63
66
  allow_sampling: bool = True
64
67
  max_retries: int = 1
@@ -148,7 +151,7 @@ class MCPServer(AbstractToolset[Any], ABC):
148
151
  except McpError as e:
149
152
  raise exceptions.ModelRetry(e.error.message)
150
153
 
151
- content = [self._map_tool_result_part(part) for part in result.content]
154
+ content = [await self._map_tool_result_part(part) for part in result.content]
152
155
 
153
156
  if result.isError:
154
157
  text = '\n'.join(str(part) for part in content)
@@ -208,6 +211,7 @@ class MCPServer(AbstractToolset[Any], ABC):
208
211
  write_stream=self._write_stream,
209
212
  sampling_callback=self._sampling_callback if self.allow_sampling else None,
210
213
  logging_callback=self.log_handler,
214
+ read_timeout_seconds=timedelta(seconds=self.read_timeout),
211
215
  )
212
216
  self._client = await self._exit_stack.enter_async_context(client)
213
217
 
@@ -258,8 +262,8 @@ class MCPServer(AbstractToolset[Any], ABC):
258
262
  model=self.sampling_model.model_name,
259
263
  )
260
264
 
261
- def _map_tool_result_part(
262
- self, part: mcp_types.Content
265
+ async def _map_tool_result_part(
266
+ self, part: mcp_types.ContentBlock
263
267
  ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
264
268
  # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
265
269
 
@@ -281,18 +285,29 @@ class MCPServer(AbstractToolset[Any], ABC):
281
285
  ) # pragma: no cover
282
286
  elif isinstance(part, mcp_types.EmbeddedResource):
283
287
  resource = part.resource
284
- if isinstance(resource, mcp_types.TextResourceContents):
285
- return resource.text
286
- elif isinstance(resource, mcp_types.BlobResourceContents):
287
- return messages.BinaryContent(
288
- data=base64.b64decode(resource.blob),
289
- media_type=resource.mimeType or 'application/octet-stream',
290
- )
291
- else:
292
- assert_never(resource)
288
+ return self._get_content(resource)
289
+ elif isinstance(part, mcp_types.ResourceLink):
290
+ resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri)
291
+ return (
292
+ self._get_content(resource_result.contents[0])
293
+ if len(resource_result.contents) == 1
294
+ else [self._get_content(resource) for resource in resource_result.contents]
295
+ )
293
296
  else:
294
297
  assert_never(part)
295
298
 
299
+ def _get_content(
300
+ self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents
301
+ ) -> str | messages.BinaryContent:
302
+ if isinstance(resource, mcp_types.TextResourceContents):
303
+ return resource.text
304
+ elif isinstance(resource, mcp_types.BlobResourceContents):
305
+ return messages.BinaryContent(
306
+ data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream'
307
+ )
308
+ else:
309
+ assert_never(resource)
310
+
296
311
 
297
312
  @dataclass
298
313
  class MCPServerStdio(MCPServer):
@@ -401,7 +416,7 @@ class MCPServerStdio(MCPServer):
401
416
  return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
402
417
 
403
418
 
404
- @dataclass
419
+ @dataclass(init=False)
405
420
  class _MCPServerHTTP(MCPServer):
406
421
  url: str
407
422
  """The URL of the endpoint on the MCP server."""
@@ -438,10 +453,10 @@ class _MCPServerHTTP(MCPServer):
438
453
  ```
439
454
  """
440
455
 
441
- sse_read_timeout: float = 5 * 60
442
- """Maximum time in seconds to wait for new SSE messages before timing out.
456
+ read_timeout: float = 5 * 60
457
+ """Maximum time in seconds to wait for new messages before timing out.
443
458
 
444
- This timeout applies to the long-lived SSE connection after it's established.
459
+ This timeout applies to the long-lived connection after it's established.
445
460
  If no new messages are received within this time, the connection will be considered stale
446
461
  and may be closed. Defaults to 5 minutes (300 seconds).
447
462
  """
@@ -485,6 +500,51 @@ class _MCPServerHTTP(MCPServer):
485
500
  sampling_model: models.Model | None = None
486
501
  """The model to use for sampling."""
487
502
 
503
+ def __init__(
504
+ self,
505
+ *,
506
+ url: str,
507
+ headers: dict[str, str] | None = None,
508
+ http_client: httpx.AsyncClient | None = None,
509
+ read_timeout: float | None = None,
510
+ tool_prefix: str | None = None,
511
+ log_level: mcp_types.LoggingLevel | None = None,
512
+ log_handler: LoggingFnT | None = None,
513
+ timeout: float = 5,
514
+ process_tool_call: ProcessToolCallback | None = None,
515
+ allow_sampling: bool = True,
516
+ max_retries: int = 1,
517
+ sampling_model: models.Model | None = None,
518
+ **kwargs: Any,
519
+ ):
520
+ # Handle deprecated sse_read_timeout parameter
521
+ if 'sse_read_timeout' in kwargs:
522
+ if read_timeout is not None:
523
+ raise TypeError("'read_timeout' and 'sse_read_timeout' cannot be set at the same time.")
524
+
525
+ warnings.warn(
526
+ "'sse_read_timeout' is deprecated, use 'read_timeout' instead.", DeprecationWarning, stacklevel=2
527
+ )
528
+ read_timeout = kwargs.pop('sse_read_timeout')
529
+
530
+ _utils.validate_empty_kwargs(kwargs)
531
+
532
+ if read_timeout is None:
533
+ read_timeout = 5 * 60
534
+
535
+ self.url = url
536
+ self.headers = headers
537
+ self.http_client = http_client
538
+ self.tool_prefix = tool_prefix
539
+ self.log_level = log_level
540
+ self.log_handler = log_handler
541
+ self.timeout = timeout
542
+ self.process_tool_call = process_tool_call
543
+ self.allow_sampling = allow_sampling
544
+ self.max_retries = max_retries
545
+ self.sampling_model = sampling_model
546
+ self.read_timeout = read_timeout
547
+
488
548
  @property
489
549
  @abstractmethod
490
550
  def _transport_client(
@@ -522,7 +582,7 @@ class _MCPServerHTTP(MCPServer):
522
582
  self._transport_client,
523
583
  url=self.url,
524
584
  timeout=self.timeout,
525
- sse_read_timeout=self.sse_read_timeout,
585
+ sse_read_timeout=self.read_timeout,
526
586
  )
527
587
 
528
588
  if self.http_client is not None:
@@ -549,7 +609,7 @@ class _MCPServerHTTP(MCPServer):
549
609
  return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
550
610
 
551
611
 
552
- @dataclass
612
+ @dataclass(init=False)
553
613
  class MCPServerSSE(_MCPServerHTTP):
554
614
  """An MCP server that connects over streamable HTTP connections.
555
615
 
pydantic_ai/messages.py CHANGED
@@ -85,7 +85,7 @@ class SystemPromptPart:
85
85
  __repr__ = _utils.dataclasses_no_defaults_repr
86
86
 
87
87
 
88
- @dataclass(repr=False)
88
+ @dataclass(init=False, repr=False)
89
89
  class FileUrl(ABC):
90
90
  """Abstract base class for any URL-based file."""
91
91
 
@@ -106,11 +106,29 @@ class FileUrl(ABC):
106
106
  - `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
107
107
  """
108
108
 
109
- @property
109
+ _media_type: str | None = field(init=False, repr=False)
110
+
111
+ def __init__(
112
+ self,
113
+ url: str,
114
+ force_download: bool = False,
115
+ vendor_metadata: dict[str, Any] | None = None,
116
+ media_type: str | None = None,
117
+ ) -> None:
118
+ self.url = url
119
+ self.vendor_metadata = vendor_metadata
120
+ self.force_download = force_download
121
+ self._media_type = media_type
122
+
110
123
  @abstractmethod
111
- def media_type(self) -> str:
124
+ def _infer_media_type(self) -> str:
112
125
  """Return the media type of the file, based on the url."""
113
126
 
127
+ @property
128
+ def media_type(self) -> str:
129
+ """Return the media type of the file, based on the url or the provided `_media_type`."""
130
+ return self._media_type or self._infer_media_type()
131
+
114
132
  @property
115
133
  @abstractmethod
116
134
  def format(self) -> str:
@@ -119,7 +137,7 @@ class FileUrl(ABC):
119
137
  __repr__ = _utils.dataclasses_no_defaults_repr
120
138
 
121
139
 
122
- @dataclass(repr=False)
140
+ @dataclass(init=False, repr=False)
123
141
  class VideoUrl(FileUrl):
124
142
  """A URL to a video."""
125
143
 
@@ -129,8 +147,18 @@ class VideoUrl(FileUrl):
129
147
  kind: Literal['video-url'] = 'video-url'
130
148
  """Type identifier, this is available on all parts as a discriminator."""
131
149
 
132
- @property
133
- def media_type(self) -> VideoMediaType:
150
+ def __init__(
151
+ self,
152
+ url: str,
153
+ force_download: bool = False,
154
+ vendor_metadata: dict[str, Any] | None = None,
155
+ media_type: str | None = None,
156
+ kind: Literal['video-url'] = 'video-url',
157
+ ) -> None:
158
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
159
+ self.kind = kind
160
+
161
+ def _infer_media_type(self) -> VideoMediaType:
134
162
  """Return the media type of the video, based on the url."""
135
163
  if self.url.endswith('.mkv'):
136
164
  return 'video/x-matroska'
@@ -170,7 +198,7 @@ class VideoUrl(FileUrl):
170
198
  return _video_format_lookup[self.media_type]
171
199
 
172
200
 
173
- @dataclass(repr=False)
201
+ @dataclass(init=False, repr=False)
174
202
  class AudioUrl(FileUrl):
175
203
  """A URL to an audio file."""
176
204
 
@@ -180,8 +208,18 @@ class AudioUrl(FileUrl):
180
208
  kind: Literal['audio-url'] = 'audio-url'
181
209
  """Type identifier, this is available on all parts as a discriminator."""
182
210
 
183
- @property
184
- def media_type(self) -> AudioMediaType:
211
+ def __init__(
212
+ self,
213
+ url: str,
214
+ force_download: bool = False,
215
+ vendor_metadata: dict[str, Any] | None = None,
216
+ media_type: str | None = None,
217
+ kind: Literal['audio-url'] = 'audio-url',
218
+ ) -> None:
219
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
220
+ self.kind = kind
221
+
222
+ def _infer_media_type(self) -> AudioMediaType:
185
223
  """Return the media type of the audio file, based on the url.
186
224
 
187
225
  References:
@@ -208,7 +246,7 @@ class AudioUrl(FileUrl):
208
246
  return _audio_format_lookup[self.media_type]
209
247
 
210
248
 
211
- @dataclass(repr=False)
249
+ @dataclass(init=False, repr=False)
212
250
  class ImageUrl(FileUrl):
213
251
  """A URL to an image."""
214
252
 
@@ -218,8 +256,18 @@ class ImageUrl(FileUrl):
218
256
  kind: Literal['image-url'] = 'image-url'
219
257
  """Type identifier, this is available on all parts as a discriminator."""
220
258
 
221
- @property
222
- def media_type(self) -> ImageMediaType:
259
+ def __init__(
260
+ self,
261
+ url: str,
262
+ force_download: bool = False,
263
+ vendor_metadata: dict[str, Any] | None = None,
264
+ media_type: str | None = None,
265
+ kind: Literal['image-url'] = 'image-url',
266
+ ) -> None:
267
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
268
+ self.kind = kind
269
+
270
+ def _infer_media_type(self) -> ImageMediaType:
223
271
  """Return the media type of the image, based on the url."""
224
272
  if self.url.endswith(('.jpg', '.jpeg')):
225
273
  return 'image/jpeg'
@@ -241,7 +289,7 @@ class ImageUrl(FileUrl):
241
289
  return _image_format_lookup[self.media_type]
242
290
 
243
291
 
244
- @dataclass(repr=False)
292
+ @dataclass(init=False, repr=False)
245
293
  class DocumentUrl(FileUrl):
246
294
  """The URL of the document."""
247
295
 
@@ -251,8 +299,18 @@ class DocumentUrl(FileUrl):
251
299
  kind: Literal['document-url'] = 'document-url'
252
300
  """Type identifier, this is available on all parts as a discriminator."""
253
301
 
254
- @property
255
- def media_type(self) -> str:
302
+ def __init__(
303
+ self,
304
+ url: str,
305
+ force_download: bool = False,
306
+ vendor_metadata: dict[str, Any] | None = None,
307
+ media_type: str | None = None,
308
+ kind: Literal['document-url'] = 'document-url',
309
+ ) -> None:
310
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
311
+ self.kind = kind
312
+
313
+ def _infer_media_type(self) -> str:
256
314
  """Return the media type of the document, based on the url."""
257
315
  type_, _ = guess_type(self.url)
258
316
  if type_ is None:
@@ -632,7 +690,7 @@ class ThinkingPart:
632
690
 
633
691
  def has_content(self) -> bool:
634
692
  """Return `True` if the thinking content is non-empty."""
635
- return bool(self.content) # pragma: no cover
693
+ return bool(self.content)
636
694
 
637
695
  __repr__ = _utils.dataclasses_no_defaults_repr
638
696
 
@@ -233,6 +233,15 @@ KnownModelName = TypeAliasType(
233
233
  'mistral:mistral-large-latest',
234
234
  'mistral:mistral-moderation-latest',
235
235
  'mistral:mistral-small-latest',
236
+ 'moonshotai:moonshot-v1-8k',
237
+ 'moonshotai:moonshot-v1-32k',
238
+ 'moonshotai:moonshot-v1-128k',
239
+ 'moonshotai:moonshot-v1-8k-vision-preview',
240
+ 'moonshotai:moonshot-v1-32k-vision-preview',
241
+ 'moonshotai:moonshot-v1-128k-vision-preview',
242
+ 'moonshotai:kimi-latest',
243
+ 'moonshotai:kimi-thinking-preview',
244
+ 'moonshotai:kimi-k2-0711-preview',
236
245
  'o1',
237
246
  'o1-2024-12-17',
238
247
  'o1-mini',
@@ -615,7 +624,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
615
624
  'deepseek',
616
625
  'azure',
617
626
  'openrouter',
627
+ 'vercel',
618
628
  'grok',
629
+ 'moonshotai',
619
630
  'fireworks',
620
631
  'together',
621
632
  'heroku',
@@ -758,7 +769,7 @@ async def download_item(
758
769
 
759
770
  data_type = media_type
760
771
  if type_format == 'extension':
761
- data_type = data_type.split('/')[1]
772
+ data_type = item.format
762
773
 
763
774
  data = response.content
764
775
  if data_format in ('base64', 'base64_uri'):
@@ -470,7 +470,7 @@ class AnthropicStreamedResponse(StreamedResponse):
470
470
  _response: AsyncIterable[BetaRawMessageStreamEvent]
471
471
  _timestamp: datetime
472
472
 
473
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
473
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
474
474
  current_block: BetaContentBlock | None = None
475
475
 
476
476
  async for event in self._response:
@@ -479,7 +479,11 @@ class AnthropicStreamedResponse(StreamedResponse):
479
479
  if isinstance(event, BetaRawContentBlockStartEvent):
480
480
  current_block = event.content_block
481
481
  if isinstance(current_block, BetaTextBlock) and current_block.text:
482
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
482
+ maybe_event = self._parts_manager.handle_text_delta(
483
+ vendor_part_id='content', content=current_block.text
484
+ )
485
+ if maybe_event is not None: # pragma: no branch
486
+ yield maybe_event
483
487
  elif isinstance(current_block, BetaThinkingBlock):
484
488
  yield self._parts_manager.handle_thinking_delta(
485
489
  vendor_part_id='thinking',
@@ -498,7 +502,11 @@ class AnthropicStreamedResponse(StreamedResponse):
498
502
 
499
503
  elif isinstance(event, BetaRawContentBlockDeltaEvent):
500
504
  if isinstance(event.delta, BetaTextDelta):
501
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
505
+ maybe_event = self._parts_manager.handle_text_delta(
506
+ vendor_part_id='content', content=event.delta.text
507
+ )
508
+ if maybe_event is not None: # pragma: no branch
509
+ yield maybe_event
502
510
  elif isinstance(event.delta, BetaThinkingDelta):
503
511
  yield self._parts_manager.handle_thinking_delta(
504
512
  vendor_part_id='thinking', content=event.delta.thinking
@@ -572,7 +572,7 @@ class BedrockStreamedResponse(StreamedResponse):
572
572
  _event_stream: EventStream[ConverseStreamOutputTypeDef]
573
573
  _timestamp: datetime = field(default_factory=_utils.now_utc)
574
574
 
575
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
575
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
576
576
  """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
577
577
 
578
578
  This method should be implemented by subclasses to translate the vendor-specific stream of events into
@@ -618,7 +618,9 @@ class BedrockStreamedResponse(StreamedResponse):
618
618
  UserWarning,
619
619
  )
620
620
  if 'text' in delta:
621
- yield self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
621
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
622
+ if maybe_event is not None:
623
+ yield maybe_event
622
624
  if 'toolUse' in delta:
623
625
  tool_use = delta['toolUse']
624
626
  maybe_event = self._parts_manager.handle_tool_call_delta(
@@ -38,15 +38,15 @@ try:
38
38
  AssistantChatMessageV2,
39
39
  AsyncClientV2,
40
40
  ChatMessageV2,
41
- ChatResponse,
42
41
  SystemChatMessageV2,
43
- TextAssistantMessageContentItem,
42
+ TextAssistantMessageV2ContentItem,
44
43
  ToolCallV2,
45
44
  ToolCallV2Function,
46
45
  ToolChatMessageV2,
47
46
  ToolV2,
48
47
  ToolV2Function,
49
48
  UserChatMessageV2,
49
+ V2ChatResponse,
50
50
  )
51
51
  from cohere.core.api_error import ApiError
52
52
  from cohere.v2.client import OMIT
@@ -164,7 +164,7 @@ class CohereModel(Model):
164
164
  messages: list[ModelMessage],
165
165
  model_settings: CohereModelSettings,
166
166
  model_request_parameters: ModelRequestParameters,
167
- ) -> ChatResponse:
167
+ ) -> V2ChatResponse:
168
168
  tools = self._get_tools(model_request_parameters)
169
169
  cohere_messages = self._map_messages(messages)
170
170
  try:
@@ -185,7 +185,7 @@ class CohereModel(Model):
185
185
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
186
186
  raise # pragma: no cover
187
187
 
188
- def _process_response(self, response: ChatResponse) -> ModelResponse:
188
+ def _process_response(self, response: V2ChatResponse) -> ModelResponse:
189
189
  """Process a non-streamed response, and prepare a message to return."""
190
190
  parts: list[ModelResponsePart] = []
191
191
  if response.message.content is not None and len(response.message.content) > 0:
@@ -227,7 +227,7 @@ class CohereModel(Model):
227
227
  assert_never(item)
228
228
  message_param = AssistantChatMessageV2(role='assistant')
229
229
  if texts:
230
- message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
230
+ message_param.content = [TextAssistantMessageV2ContentItem(text='\n\n'.join(texts))]
231
231
  if tool_calls:
232
232
  message_param.tool_calls = tool_calls
233
233
  cohere_messages.append(message_param)
@@ -294,7 +294,7 @@ class CohereModel(Model):
294
294
  assert_never(part)
295
295
 
296
296
 
297
- def _map_usage(response: ChatResponse) -> usage.Usage:
297
+ def _map_usage(response: V2ChatResponse) -> usage.Usage:
298
298
  u = response.usage
299
299
  if u is None:
300
300
  return usage.Usage()
@@ -16,9 +16,7 @@ from pydantic_ai.profiles import ModelProfileSpec
16
16
  from .. import _utils, usage
17
17
  from .._utils import PeekableAsyncStream
18
18
  from ..messages import (
19
- AudioUrl,
20
19
  BinaryContent,
21
- ImageUrl,
22
20
  ModelMessage,
23
21
  ModelRequest,
24
22
  ModelResponse,
@@ -266,7 +264,9 @@ class FunctionStreamedResponse(StreamedResponse):
266
264
  if isinstance(item, str):
267
265
  response_tokens = _estimate_string_tokens(item)
268
266
  self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
269
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
267
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
268
+ if maybe_event is not None: # pragma: no branch
269
+ yield maybe_event
270
270
  elif isinstance(item, dict) and item:
271
271
  for dtc_index, delta in item.items():
272
272
  if isinstance(delta, DeltaThinkingPart):
@@ -288,7 +288,7 @@ class FunctionStreamedResponse(StreamedResponse):
288
288
  args=delta.json_args,
289
289
  tool_call_id=delta.tool_call_id,
290
290
  )
291
- if maybe_event is not None:
291
+ if maybe_event is not None: # pragma: no branch
292
292
  yield maybe_event
293
293
  else:
294
294
  assert_never(delta)
@@ -345,18 +345,19 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
345
345
  def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
346
346
  if not content:
347
347
  return 0
348
+
348
349
  if isinstance(content, str):
349
- return len(re.split(r'[\s",.:]+', content.strip()))
350
- else:
351
- tokens = 0
352
- for part in content:
353
- if isinstance(part, str):
354
- tokens += len(re.split(r'[\s",.:]+', part.strip()))
355
- # TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
356
- if isinstance(part, (AudioUrl, ImageUrl)):
357
- tokens += 0
358
- elif isinstance(part, BinaryContent):
359
- tokens += len(part.data)
360
- else:
361
- tokens += 0
362
- return tokens
350
+ return len(_TOKEN_SPLIT_RE.split(content.strip()))
351
+
352
+ tokens = 0
353
+ for part in content:
354
+ if isinstance(part, str):
355
+ tokens += len(_TOKEN_SPLIT_RE.split(part.strip()))
356
+ elif isinstance(part, BinaryContent):
357
+ tokens += len(part.data)
358
+ # TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.
359
+
360
+ return tokens
361
+
362
+
363
+ _TOKEN_SPLIT_RE = re.compile(r'[\s",.:]+')
@@ -438,7 +438,11 @@ class GeminiStreamedResponse(StreamedResponse):
438
438
  if 'text' in gemini_part:
439
439
  # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
440
440
  # amongst the tool call deltas
441
- yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
441
+ maybe_event = self._parts_manager.handle_text_delta(
442
+ vendor_part_id=None, content=gemini_part['text']
443
+ )
444
+ if maybe_event is not None: # pragma: no branch
445
+ yield maybe_event
442
446
 
443
447
  elif 'function_call' in gemini_part:
444
448
  # Here, we assume all function_call parts are complete and don't have deltas.
@@ -411,7 +411,12 @@ class GoogleModel(Model):
411
411
  file_data_dict['video_metadata'] = item.vendor_metadata
412
412
  content.append(file_data_dict) # type: ignore
413
413
  elif isinstance(item, FileUrl):
414
- if self.system == 'google-gla' or item.force_download:
414
+ if item.force_download or (
415
+ # google-gla does not support passing file urls directly, except for youtube videos
416
+ # (see above) and files uploaded to the file API (which cannot be downloaded anyway)
417
+ self.system == 'google-gla'
418
+ and not item.url.startswith(r'https://generativelanguage.googleapis.com/v1beta/files')
419
+ ):
415
420
  downloaded_item = await download_item(item, data_format='base64')
416
421
  inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
417
422
  content.append({'inline_data': inline_data}) # type: ignore
@@ -453,7 +458,9 @@ class GeminiStreamedResponse(StreamedResponse):
453
458
  if part.thought:
454
459
  yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
455
460
  else:
456
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
461
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
462
+ if maybe_event is not None: # pragma: no branch
463
+ yield maybe_event
457
464
  elif part.function_call:
458
465
  maybe_event = self._parts_manager.handle_tool_call_delta(
459
466
  vendor_part_id=uuid4(),
@@ -415,7 +415,11 @@ class GroqStreamedResponse(StreamedResponse):
415
415
  # Handle the text part of the response
416
416
  content = choice.delta.content
417
417
  if content is not None:
418
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
418
+ maybe_event = self._parts_manager.handle_text_delta(
419
+ vendor_part_id='content', content=content, extract_think_tags=True
420
+ )
421
+ if maybe_event is not None: # pragma: no branch
422
+ yield maybe_event
419
423
 
420
424
  # Handle the tool calls
421
425
  for dtc in choice.delta.tool_calls or []:
@@ -444,7 +448,7 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
444
448
  if isinstance(completion, chat.ChatCompletion):
445
449
  response_usage = completion.usage
446
450
  elif completion.x_groq is not None:
447
- response_usage = completion.x_groq.usage # pragma: no cover
451
+ response_usage = completion.x_groq.usage
448
452
 
449
453
  if response_usage is None:
450
454
  return usage.Usage()
@@ -426,8 +426,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
426
426
 
427
427
  # Handle the text part of the response
428
428
  content = choice.delta.content
429
- if content is not None:
430
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
429
+ if content:
430
+ maybe_event = self._parts_manager.handle_text_delta(
431
+ vendor_part_id='content', content=content, extract_think_tags=True
432
+ )
433
+ if maybe_event is not None: # pragma: no branch
434
+ yield maybe_event
431
435
 
432
436
  for dtc in choice.delta.tool_calls or []:
433
437
  maybe_event = self._parts_manager.handle_tool_call_delta(