pydantic-ai-slim 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydantic_ai/messages.py CHANGED
@@ -85,7 +85,7 @@ class SystemPromptPart:
85
85
  __repr__ = _utils.dataclasses_no_defaults_repr
86
86
 
87
87
 
88
- @dataclass(repr=False)
88
+ @dataclass(init=False, repr=False)
89
89
  class FileUrl(ABC):
90
90
  """Abstract base class for any URL-based file."""
91
91
 
@@ -106,11 +106,29 @@ class FileUrl(ABC):
106
106
  - `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
107
107
  """
108
108
 
109
- @property
109
+ _media_type: str | None = field(init=False, repr=False)
110
+
111
+ def __init__(
112
+ self,
113
+ url: str,
114
+ force_download: bool = False,
115
+ vendor_metadata: dict[str, Any] | None = None,
116
+ media_type: str | None = None,
117
+ ) -> None:
118
+ self.url = url
119
+ self.vendor_metadata = vendor_metadata
120
+ self.force_download = force_download
121
+ self._media_type = media_type
122
+
110
123
  @abstractmethod
111
- def media_type(self) -> str:
124
+ def _infer_media_type(self) -> str:
112
125
  """Return the media type of the file, based on the url."""
113
126
 
127
+ @property
128
+ def media_type(self) -> str:
129
+ """Return the media type of the file, based on the url or the provided `_media_type`."""
130
+ return self._media_type or self._infer_media_type()
131
+
114
132
  @property
115
133
  @abstractmethod
116
134
  def format(self) -> str:
@@ -119,7 +137,7 @@ class FileUrl(ABC):
119
137
  __repr__ = _utils.dataclasses_no_defaults_repr
120
138
 
121
139
 
122
- @dataclass(repr=False)
140
+ @dataclass(init=False, repr=False)
123
141
  class VideoUrl(FileUrl):
124
142
  """A URL to a video."""
125
143
 
@@ -129,8 +147,18 @@ class VideoUrl(FileUrl):
129
147
  kind: Literal['video-url'] = 'video-url'
130
148
  """Type identifier, this is available on all parts as a discriminator."""
131
149
 
132
- @property
133
- def media_type(self) -> VideoMediaType:
150
+ def __init__(
151
+ self,
152
+ url: str,
153
+ force_download: bool = False,
154
+ vendor_metadata: dict[str, Any] | None = None,
155
+ media_type: str | None = None,
156
+ kind: Literal['video-url'] = 'video-url',
157
+ ) -> None:
158
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
159
+ self.kind = kind
160
+
161
+ def _infer_media_type(self) -> VideoMediaType:
134
162
  """Return the media type of the video, based on the url."""
135
163
  if self.url.endswith('.mkv'):
136
164
  return 'video/x-matroska'
@@ -170,7 +198,7 @@ class VideoUrl(FileUrl):
170
198
  return _video_format_lookup[self.media_type]
171
199
 
172
200
 
173
- @dataclass(repr=False)
201
+ @dataclass(init=False, repr=False)
174
202
  class AudioUrl(FileUrl):
175
203
  """A URL to an audio file."""
176
204
 
@@ -180,8 +208,18 @@ class AudioUrl(FileUrl):
180
208
  kind: Literal['audio-url'] = 'audio-url'
181
209
  """Type identifier, this is available on all parts as a discriminator."""
182
210
 
183
- @property
184
- def media_type(self) -> AudioMediaType:
211
+ def __init__(
212
+ self,
213
+ url: str,
214
+ force_download: bool = False,
215
+ vendor_metadata: dict[str, Any] | None = None,
216
+ media_type: str | None = None,
217
+ kind: Literal['audio-url'] = 'audio-url',
218
+ ) -> None:
219
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
220
+ self.kind = kind
221
+
222
+ def _infer_media_type(self) -> AudioMediaType:
185
223
  """Return the media type of the audio file, based on the url.
186
224
 
187
225
  References:
@@ -208,7 +246,7 @@ class AudioUrl(FileUrl):
208
246
  return _audio_format_lookup[self.media_type]
209
247
 
210
248
 
211
- @dataclass(repr=False)
249
+ @dataclass(init=False, repr=False)
212
250
  class ImageUrl(FileUrl):
213
251
  """A URL to an image."""
214
252
 
@@ -218,8 +256,18 @@ class ImageUrl(FileUrl):
218
256
  kind: Literal['image-url'] = 'image-url'
219
257
  """Type identifier, this is available on all parts as a discriminator."""
220
258
 
221
- @property
222
- def media_type(self) -> ImageMediaType:
259
+ def __init__(
260
+ self,
261
+ url: str,
262
+ force_download: bool = False,
263
+ vendor_metadata: dict[str, Any] | None = None,
264
+ media_type: str | None = None,
265
+ kind: Literal['image-url'] = 'image-url',
266
+ ) -> None:
267
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
268
+ self.kind = kind
269
+
270
+ def _infer_media_type(self) -> ImageMediaType:
223
271
  """Return the media type of the image, based on the url."""
224
272
  if self.url.endswith(('.jpg', '.jpeg')):
225
273
  return 'image/jpeg'
@@ -241,7 +289,7 @@ class ImageUrl(FileUrl):
241
289
  return _image_format_lookup[self.media_type]
242
290
 
243
291
 
244
- @dataclass(repr=False)
292
+ @dataclass(init=False, repr=False)
245
293
  class DocumentUrl(FileUrl):
246
294
  """The URL of the document."""
247
295
 
@@ -251,8 +299,18 @@ class DocumentUrl(FileUrl):
251
299
  kind: Literal['document-url'] = 'document-url'
252
300
  """Type identifier, this is available on all parts as a discriminator."""
253
301
 
254
- @property
255
- def media_type(self) -> str:
302
+ def __init__(
303
+ self,
304
+ url: str,
305
+ force_download: bool = False,
306
+ vendor_metadata: dict[str, Any] | None = None,
307
+ media_type: str | None = None,
308
+ kind: Literal['document-url'] = 'document-url',
309
+ ) -> None:
310
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
311
+ self.kind = kind
312
+
313
+ def _infer_media_type(self) -> str:
256
314
  """Return the media type of the document, based on the url."""
257
315
  type_, _ = guess_type(self.url)
258
316
  if type_ is None:
@@ -632,7 +690,7 @@ class ThinkingPart:
632
690
 
633
691
  def has_content(self) -> bool:
634
692
  """Return `True` if the thinking content is non-empty."""
635
- return bool(self.content) # pragma: no cover
693
+ return bool(self.content)
636
694
 
637
695
  __repr__ = _utils.dataclasses_no_defaults_repr
638
696
 
@@ -757,11 +815,16 @@ class ModelResponse:
757
815
  },
758
816
  }
759
817
  )
760
- elif isinstance(part, TextPart):
761
- if body.get('content'):
762
- body = new_event_body()
763
- if settings.include_content:
764
- body['content'] = part.content
818
+ elif isinstance(part, (TextPart, ThinkingPart)):
819
+ kind = part.part_kind
820
+ body.setdefault('content', []).append(
821
+ {'kind': kind, **({'text': part.content} if settings.include_content else {})}
822
+ )
823
+
824
+ if content := body.get('content'):
825
+ text_content = content[0].get('text')
826
+ if content == [{'kind': 'text', 'text': text_content}]:
827
+ body['content'] = text_content
765
828
 
766
829
  return result
767
830
 
@@ -233,6 +233,15 @@ KnownModelName = TypeAliasType(
233
233
  'mistral:mistral-large-latest',
234
234
  'mistral:mistral-moderation-latest',
235
235
  'mistral:mistral-small-latest',
236
+ 'moonshotai:moonshot-v1-8k',
237
+ 'moonshotai:moonshot-v1-32k',
238
+ 'moonshotai:moonshot-v1-128k',
239
+ 'moonshotai:moonshot-v1-8k-vision-preview',
240
+ 'moonshotai:moonshot-v1-32k-vision-preview',
241
+ 'moonshotai:moonshot-v1-128k-vision-preview',
242
+ 'moonshotai:kimi-latest',
243
+ 'moonshotai:kimi-thinking-preview',
244
+ 'moonshotai:kimi-k2-0711-preview',
236
245
  'o1',
237
246
  'o1-2024-12-17',
238
247
  'o1-mini',
@@ -615,7 +624,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
615
624
  'deepseek',
616
625
  'azure',
617
626
  'openrouter',
627
+ 'vercel',
618
628
  'grok',
629
+ 'moonshotai',
619
630
  'fireworks',
620
631
  'together',
621
632
  'heroku',
@@ -470,7 +470,7 @@ class AnthropicStreamedResponse(StreamedResponse):
470
470
  _response: AsyncIterable[BetaRawMessageStreamEvent]
471
471
  _timestamp: datetime
472
472
 
473
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
473
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
474
474
  current_block: BetaContentBlock | None = None
475
475
 
476
476
  async for event in self._response:
@@ -479,7 +479,11 @@ class AnthropicStreamedResponse(StreamedResponse):
479
479
  if isinstance(event, BetaRawContentBlockStartEvent):
480
480
  current_block = event.content_block
481
481
  if isinstance(current_block, BetaTextBlock) and current_block.text:
482
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
482
+ maybe_event = self._parts_manager.handle_text_delta(
483
+ vendor_part_id='content', content=current_block.text
484
+ )
485
+ if maybe_event is not None: # pragma: no branch
486
+ yield maybe_event
483
487
  elif isinstance(current_block, BetaThinkingBlock):
484
488
  yield self._parts_manager.handle_thinking_delta(
485
489
  vendor_part_id='thinking',
@@ -498,7 +502,11 @@ class AnthropicStreamedResponse(StreamedResponse):
498
502
 
499
503
  elif isinstance(event, BetaRawContentBlockDeltaEvent):
500
504
  if isinstance(event.delta, BetaTextDelta):
501
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
505
+ maybe_event = self._parts_manager.handle_text_delta(
506
+ vendor_part_id='content', content=event.delta.text
507
+ )
508
+ if maybe_event is not None: # pragma: no branch
509
+ yield maybe_event
502
510
  elif isinstance(event.delta, BetaThinkingDelta):
503
511
  yield self._parts_manager.handle_thinking_delta(
504
512
  vendor_part_id='thinking', content=event.delta.thinking
@@ -572,7 +572,7 @@ class BedrockStreamedResponse(StreamedResponse):
572
572
  _event_stream: EventStream[ConverseStreamOutputTypeDef]
573
573
  _timestamp: datetime = field(default_factory=_utils.now_utc)
574
574
 
575
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
575
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
576
576
  """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
577
577
 
578
578
  This method should be implemented by subclasses to translate the vendor-specific stream of events into
@@ -618,7 +618,9 @@ class BedrockStreamedResponse(StreamedResponse):
618
618
  UserWarning,
619
619
  )
620
620
  if 'text' in delta:
621
- yield self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
621
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
622
+ if maybe_event is not None:
623
+ yield maybe_event
622
624
  if 'toolUse' in delta:
623
625
  tool_use = delta['toolUse']
624
626
  maybe_event = self._parts_manager.handle_tool_call_delta(
@@ -38,15 +38,15 @@ try:
38
38
  AssistantChatMessageV2,
39
39
  AsyncClientV2,
40
40
  ChatMessageV2,
41
- ChatResponse,
42
41
  SystemChatMessageV2,
43
- TextAssistantMessageContentItem,
42
+ TextAssistantMessageV2ContentItem,
44
43
  ToolCallV2,
45
44
  ToolCallV2Function,
46
45
  ToolChatMessageV2,
47
46
  ToolV2,
48
47
  ToolV2Function,
49
48
  UserChatMessageV2,
49
+ V2ChatResponse,
50
50
  )
51
51
  from cohere.core.api_error import ApiError
52
52
  from cohere.v2.client import OMIT
@@ -164,7 +164,7 @@ class CohereModel(Model):
164
164
  messages: list[ModelMessage],
165
165
  model_settings: CohereModelSettings,
166
166
  model_request_parameters: ModelRequestParameters,
167
- ) -> ChatResponse:
167
+ ) -> V2ChatResponse:
168
168
  tools = self._get_tools(model_request_parameters)
169
169
  cohere_messages = self._map_messages(messages)
170
170
  try:
@@ -185,7 +185,7 @@ class CohereModel(Model):
185
185
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
186
186
  raise # pragma: no cover
187
187
 
188
- def _process_response(self, response: ChatResponse) -> ModelResponse:
188
+ def _process_response(self, response: V2ChatResponse) -> ModelResponse:
189
189
  """Process a non-streamed response, and prepare a message to return."""
190
190
  parts: list[ModelResponsePart] = []
191
191
  if response.message.content is not None and len(response.message.content) > 0:
@@ -227,7 +227,7 @@ class CohereModel(Model):
227
227
  assert_never(item)
228
228
  message_param = AssistantChatMessageV2(role='assistant')
229
229
  if texts:
230
- message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
230
+ message_param.content = [TextAssistantMessageV2ContentItem(text='\n\n'.join(texts))]
231
231
  if tool_calls:
232
232
  message_param.tool_calls = tool_calls
233
233
  cohere_messages.append(message_param)
@@ -294,7 +294,7 @@ class CohereModel(Model):
294
294
  assert_never(part)
295
295
 
296
296
 
297
- def _map_usage(response: ChatResponse) -> usage.Usage:
297
+ def _map_usage(response: V2ChatResponse) -> usage.Usage:
298
298
  u = response.usage
299
299
  if u is None:
300
300
  return usage.Usage()
@@ -264,7 +264,9 @@ class FunctionStreamedResponse(StreamedResponse):
264
264
  if isinstance(item, str):
265
265
  response_tokens = _estimate_string_tokens(item)
266
266
  self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
267
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
267
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
268
+ if maybe_event is not None: # pragma: no branch
269
+ yield maybe_event
268
270
  elif isinstance(item, dict) and item:
269
271
  for dtc_index, delta in item.items():
270
272
  if isinstance(delta, DeltaThinkingPart):
@@ -286,7 +288,7 @@ class FunctionStreamedResponse(StreamedResponse):
286
288
  args=delta.json_args,
287
289
  tool_call_id=delta.tool_call_id,
288
290
  )
289
- if maybe_event is not None:
291
+ if maybe_event is not None: # pragma: no branch
290
292
  yield maybe_event
291
293
  else:
292
294
  assert_never(delta)
@@ -438,7 +438,11 @@ class GeminiStreamedResponse(StreamedResponse):
438
438
  if 'text' in gemini_part:
439
439
  # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
440
440
  # amongst the tool call deltas
441
- yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
441
+ maybe_event = self._parts_manager.handle_text_delta(
442
+ vendor_part_id=None, content=gemini_part['text']
443
+ )
444
+ if maybe_event is not None: # pragma: no branch
445
+ yield maybe_event
442
446
 
443
447
  elif 'function_call' in gemini_part:
444
448
  # Here, we assume all function_call parts are complete and don't have deltas.
@@ -411,7 +411,12 @@ class GoogleModel(Model):
411
411
  file_data_dict['video_metadata'] = item.vendor_metadata
412
412
  content.append(file_data_dict) # type: ignore
413
413
  elif isinstance(item, FileUrl):
414
- if self.system == 'google-gla' or item.force_download:
414
+ if item.force_download or (
415
+ # google-gla does not support passing file urls directly, except for youtube videos
416
+ # (see above) and files uploaded to the file API (which cannot be downloaded anyway)
417
+ self.system == 'google-gla'
418
+ and not item.url.startswith(r'https://generativelanguage.googleapis.com/v1beta/files')
419
+ ):
415
420
  downloaded_item = await download_item(item, data_format='base64')
416
421
  inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
417
422
  content.append({'inline_data': inline_data}) # type: ignore
@@ -453,7 +458,9 @@ class GeminiStreamedResponse(StreamedResponse):
453
458
  if part.thought:
454
459
  yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
455
460
  else:
456
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
461
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
462
+ if maybe_event is not None: # pragma: no branch
463
+ yield maybe_event
457
464
  elif part.function_call:
458
465
  maybe_event = self._parts_manager.handle_tool_call_delta(
459
466
  vendor_part_id=uuid4(),
@@ -415,7 +415,11 @@ class GroqStreamedResponse(StreamedResponse):
415
415
  # Handle the text part of the response
416
416
  content = choice.delta.content
417
417
  if content is not None:
418
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
418
+ maybe_event = self._parts_manager.handle_text_delta(
419
+ vendor_part_id='content', content=content, extract_think_tags=True
420
+ )
421
+ if maybe_event is not None: # pragma: no branch
422
+ yield maybe_event
419
423
 
420
424
  # Handle the tool calls
421
425
  for dtc in choice.delta.tool_calls or []:
@@ -444,7 +448,7 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
444
448
  if isinstance(completion, chat.ChatCompletion):
445
449
  response_usage = completion.usage
446
450
  elif completion.x_groq is not None:
447
- response_usage = completion.x_groq.usage # pragma: no cover
451
+ response_usage = completion.x_groq.usage
448
452
 
449
453
  if response_usage is None:
450
454
  return usage.Usage()
@@ -426,8 +426,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
426
426
 
427
427
  # Handle the text part of the response
428
428
  content = choice.delta.content
429
- if content is not None:
430
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
429
+ if content:
430
+ maybe_event = self._parts_manager.handle_text_delta(
431
+ vendor_part_id='content', content=content, extract_think_tags=True
432
+ )
433
+ if maybe_event is not None: # pragma: no branch
434
+ yield maybe_event
431
435
 
432
436
  for dtc in choice.delta.tool_calls or []:
433
437
  maybe_event = self._parts_manager.handle_tool_call_delta(
@@ -601,7 +601,9 @@ class MistralStreamedResponse(StreamedResponse):
601
601
  tool_call_id=maybe_tool_call_part.tool_call_id,
602
602
  )
603
603
  else:
604
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
604
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
605
+ if maybe_event is not None: # pragma: no branch
606
+ yield maybe_event
605
607
 
606
608
  # Handle the explicit tool calls
607
609
  for index, dtc in enumerate(choice.delta.tool_calls or []):
@@ -17,7 +17,7 @@ from pydantic_ai.providers import Provider, infer_provider
17
17
 
18
18
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
19
19
  from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
20
- from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
20
+ from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
21
21
  from ..messages import (
22
22
  AudioUrl,
23
23
  BinaryContent,
@@ -191,7 +191,17 @@ class OpenAIModel(Model):
191
191
  model_name: OpenAIModelName,
192
192
  *,
193
193
  provider: Literal[
194
- 'openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku', 'github'
194
+ 'openai',
195
+ 'deepseek',
196
+ 'azure',
197
+ 'openrouter',
198
+ 'moonshotai',
199
+ 'vercel',
200
+ 'grok',
201
+ 'fireworks',
202
+ 'together',
203
+ 'heroku',
204
+ 'github',
195
205
  ]
196
206
  | Provider[AsyncOpenAI] = 'openai',
197
207
  profile: ModelProfileSpec | None = None,
@@ -290,7 +300,10 @@ class OpenAIModel(Model):
290
300
  tools = self._get_tools(model_request_parameters)
291
301
  if not tools:
292
302
  tool_choice: Literal['none', 'required', 'auto'] | None = None
293
- elif not model_request_parameters.allow_text_output:
303
+ elif (
304
+ not model_request_parameters.allow_text_output
305
+ and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required
306
+ ):
294
307
  tool_choice = 'required'
295
308
  else:
296
309
  tool_choice = 'auto'
@@ -357,11 +370,17 @@ class OpenAIModel(Model):
357
370
  if not isinstance(response, chat.ChatCompletion):
358
371
  raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data')
359
372
 
373
+ if response.created:
374
+ timestamp = number_to_datetime(response.created)
375
+ else:
376
+ timestamp = _now_utc()
377
+ response.created = int(timestamp.timestamp())
378
+
360
379
  try:
361
380
  response = chat.ChatCompletion.model_validate(response.model_dump())
362
381
  except ValidationError as e:
363
382
  raise UnexpectedModelBehavior(f'Invalid response from OpenAI chat completions endpoint: {e}') from e
364
- timestamp = number_to_datetime(response.created)
383
+
365
384
  choice = response.choices[0]
366
385
  items: list[ModelResponsePart] = []
367
386
  # The `reasoning_content` is only present in DeepSeek models.
@@ -1003,8 +1022,12 @@ class OpenAIStreamedResponse(StreamedResponse):
1003
1022
 
1004
1023
  # Handle the text part of the response
1005
1024
  content = choice.delta.content
1006
- if content is not None:
1007
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
1025
+ if content:
1026
+ maybe_event = self._parts_manager.handle_text_delta(
1027
+ vendor_part_id='content', content=content, extract_think_tags=True
1028
+ )
1029
+ if maybe_event is not None: # pragma: no branch
1030
+ yield maybe_event
1008
1031
 
1009
1032
  # Handle reasoning part of the response, present in DeepSeek models
1010
1033
  if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
@@ -1121,7 +1144,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1121
1144
  )
1122
1145
 
1123
1146
  elif isinstance(chunk, responses.ResponseTextDeltaEvent):
1124
- yield self._parts_manager.handle_text_delta(vendor_part_id=chunk.content_index, content=chunk.delta)
1147
+ maybe_event = self._parts_manager.handle_text_delta(
1148
+ vendor_part_id=chunk.content_index, content=chunk.delta
1149
+ )
1150
+ if maybe_event is not None: # pragma: no branch
1151
+ yield maybe_event
1125
1152
 
1126
1153
  elif isinstance(chunk, responses.ResponseTextDoneEvent):
1127
1154
  pass # there's nothing we need to do here
@@ -269,10 +269,14 @@ class TestStreamedResponse(StreamedResponse):
269
269
  mid = len(text) // 2
270
270
  words = [text[:mid], text[mid:]]
271
271
  self._usage += _get_string_usage('')
272
- yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
272
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
273
+ if maybe_event is not None: # pragma: no branch
274
+ yield maybe_event
273
275
  for word in words:
274
276
  self._usage += _get_string_usage(word)
275
- yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
277
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
278
+ if maybe_event is not None: # pragma: no branch
279
+ yield maybe_event
276
280
  elif isinstance(part, ToolCallPart):
277
281
  yield self._parts_manager.handle_tool_call_part(
278
282
  vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
@@ -21,6 +21,14 @@ class OpenAIModelProfile(ModelProfile):
21
21
  openai_supports_sampling_settings: bool = True
22
22
  """Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
23
23
 
24
+ # Some OpenAI-compatible providers (e.g. MoonshotAI) currently do **not** accept
25
+ # `tool_choice="required"`. This flag lets the calling model know whether it's
26
+ # safe to pass that value along. Default is `True` to preserve existing
27
+ # behaviour for OpenAI itself and most providers.
28
+ openai_supports_tool_choice_required: bool = True
29
+ """Whether the provider accepts the value ``tool_choice='required'`` in the
30
+ request payload."""
31
+
24
32
 
25
33
  def openai_model_profile(model_name: str) -> ModelProfile:
26
34
  """Get the model profile for an OpenAI model."""
@@ -62,6 +62,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
62
62
  from .openrouter import OpenRouterProvider
63
63
 
64
64
  return OpenRouterProvider
65
+ elif provider == 'vercel':
66
+ from .vercel import VercelProvider
67
+
68
+ return VercelProvider
65
69
  elif provider == 'azure':
66
70
  from .azure import AzureProvider
67
71
 
@@ -99,6 +103,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
99
103
  from .grok import GrokProvider
100
104
 
101
105
  return GrokProvider
106
+ elif provider == 'moonshotai':
107
+ from .moonshotai import MoonshotAIProvider
108
+
109
+ return MoonshotAIProvider
102
110
  elif provider == 'fireworks':
103
111
  from .fireworks import FireworksProvider
104
112