pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import io
4
4
  from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
5
5
  from contextlib import asynccontextmanager
6
- from dataclasses import dataclass, field
6
+ from dataclasses import dataclass, field, replace
7
7
  from datetime import datetime
8
8
  from typing import Any, Literal, cast, overload
9
9
 
@@ -13,7 +13,7 @@ from typing_extensions import assert_never
13
13
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
14
14
  from .._run_context import RunContext
15
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id
16
- from ..builtin_tools import CodeExecutionTool, MemoryTool, WebSearchTool
16
+ from ..builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, WebSearchTool
17
17
  from ..exceptions import UserError
18
18
  from ..messages import (
19
19
  BinaryContent,
@@ -68,6 +68,9 @@ try:
68
68
  BetaContentBlockParam,
69
69
  BetaImageBlockParam,
70
70
  BetaInputJSONDelta,
71
+ BetaMCPToolResultBlock,
72
+ BetaMCPToolUseBlock,
73
+ BetaMCPToolUseBlockParam,
71
74
  BetaMemoryTool20250818Param,
72
75
  BetaMessage,
73
76
  BetaMessageParam,
@@ -82,6 +85,8 @@ try:
82
85
  BetaRawMessageStreamEvent,
83
86
  BetaRedactedThinkingBlock,
84
87
  BetaRedactedThinkingBlockParam,
88
+ BetaRequestMCPServerToolConfigurationParam,
89
+ BetaRequestMCPServerURLDefinitionParam,
85
90
  BetaServerToolUseBlock,
86
91
  BetaServerToolUseBlockParam,
87
92
  BetaSignatureDelta,
@@ -162,7 +167,7 @@ class AnthropicModel(Model):
162
167
  self,
163
168
  model_name: AnthropicModelName,
164
169
  *,
165
- provider: Literal['anthropic'] | Provider[AsyncAnthropicClient] = 'anthropic',
170
+ provider: Literal['anthropic', 'gateway'] | Provider[AsyncAnthropicClient] = 'anthropic',
166
171
  profile: ModelProfileSpec | None = None,
167
172
  settings: ModelSettings | None = None,
168
173
  ):
@@ -179,7 +184,7 @@ class AnthropicModel(Model):
179
184
  self._model_name = model_name
180
185
 
181
186
  if isinstance(provider, str):
182
- provider = infer_provider(provider)
187
+ provider = infer_provider('gateway/anthropic' if provider == 'gateway' else provider)
183
188
  self._provider = provider
184
189
  self.client = provider.client
185
190
 
@@ -264,7 +269,7 @@ class AnthropicModel(Model):
264
269
  ) -> BetaMessage | AsyncStream[BetaRawMessageStreamEvent]:
265
270
  # standalone function to make it easier to override
266
271
  tools = self._get_tools(model_request_parameters)
267
- tools, beta_features = self._add_builtin_tools(tools, model_request_parameters)
272
+ tools, mcp_servers, beta_features = self._add_builtin_tools(tools, model_request_parameters)
268
273
 
269
274
  tool_choice: BetaToolChoiceParam | None
270
275
 
@@ -300,6 +305,7 @@ class AnthropicModel(Model):
300
305
  model=self._model_name,
301
306
  tools=tools or OMIT,
302
307
  tool_choice=tool_choice or OMIT,
308
+ mcp_servers=mcp_servers or OMIT,
303
309
  stream=stream,
304
310
  thinking=model_settings.get('anthropic_thinking', OMIT),
305
311
  stop_sequences=model_settings.get('stop_sequences', OMIT),
@@ -318,11 +324,14 @@ class AnthropicModel(Model):
318
324
  def _process_response(self, response: BetaMessage) -> ModelResponse:
319
325
  """Process a non-streamed response, and prepare a message to return."""
320
326
  items: list[ModelResponsePart] = []
327
+ builtin_tool_calls: dict[str, BuiltinToolCallPart] = {}
321
328
  for item in response.content:
322
329
  if isinstance(item, BetaTextBlock):
323
330
  items.append(TextPart(content=item.text))
324
331
  elif isinstance(item, BetaServerToolUseBlock):
325
- items.append(_map_server_tool_use_block(item, self.system))
332
+ call_part = _map_server_tool_use_block(item, self.system)
333
+ builtin_tool_calls[call_part.tool_call_id] = call_part
334
+ items.append(call_part)
326
335
  elif isinstance(item, BetaWebSearchToolResultBlock):
327
336
  items.append(_map_web_search_tool_result_block(item, self.system))
328
337
  elif isinstance(item, BetaCodeExecutionToolResultBlock):
@@ -333,6 +342,13 @@ class AnthropicModel(Model):
333
342
  )
334
343
  elif isinstance(item, BetaThinkingBlock):
335
344
  items.append(ThinkingPart(content=item.thinking, signature=item.signature, provider_name=self.system))
345
+ elif isinstance(item, BetaMCPToolUseBlock):
346
+ call_part = _map_mcp_server_use_block(item, self.system)
347
+ builtin_tool_calls[call_part.tool_call_id] = call_part
348
+ items.append(call_part)
349
+ elif isinstance(item, BetaMCPToolResultBlock):
350
+ call_part = builtin_tool_calls.get(item.tool_use_id)
351
+ items.append(_map_mcp_server_result_block(item, call_part, self.system))
336
352
  else:
337
353
  assert isinstance(item, BetaToolUseBlock), f'unexpected item type {type(item)}'
338
354
  items.append(
@@ -383,8 +399,9 @@ class AnthropicModel(Model):
383
399
 
384
400
  def _add_builtin_tools(
385
401
  self, tools: list[BetaToolUnionParam], model_request_parameters: ModelRequestParameters
386
- ) -> tuple[list[BetaToolUnionParam], list[str]]:
402
+ ) -> tuple[list[BetaToolUnionParam], list[BetaRequestMCPServerURLDefinitionParam], list[str]]:
387
403
  beta_features: list[str] = []
404
+ mcp_servers: list[BetaRequestMCPServerURLDefinitionParam] = []
388
405
  for tool in model_request_parameters.builtin_tools:
389
406
  if isinstance(tool, WebSearchTool):
390
407
  user_location = UserLocation(type='approximate', **tool.user_location) if tool.user_location else None
@@ -408,11 +425,26 @@ class AnthropicModel(Model):
408
425
  tools = [tool for tool in tools if tool['name'] != 'memory']
409
426
  tools.append(BetaMemoryTool20250818Param(name='memory', type='memory_20250818'))
410
427
  beta_features.append('context-management-2025-06-27')
428
+ elif isinstance(tool, MCPServerTool) and tool.url:
429
+ mcp_server_url_definition_param = BetaRequestMCPServerURLDefinitionParam(
430
+ type='url',
431
+ name=tool.id,
432
+ url=tool.url,
433
+ )
434
+ if tool.allowed_tools is not None: # pragma: no branch
435
+ mcp_server_url_definition_param['tool_configuration'] = BetaRequestMCPServerToolConfigurationParam(
436
+ enabled=bool(tool.allowed_tools),
437
+ allowed_tools=tool.allowed_tools,
438
+ )
439
+ if tool.authorization_token: # pragma: no cover
440
+ mcp_server_url_definition_param['authorization_token'] = tool.authorization_token
441
+ mcp_servers.append(mcp_server_url_definition_param)
442
+ beta_features.append('mcp-client-2025-04-04')
411
443
  else: # pragma: no cover
412
444
  raise UserError(
413
445
  f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
414
446
  )
415
- return tools, beta_features
447
+ return tools, mcp_servers, beta_features
416
448
 
417
449
  async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
418
450
  """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
@@ -458,6 +490,8 @@ class AnthropicModel(Model):
458
490
  | BetaCodeExecutionToolResultBlockParam
459
491
  | BetaThinkingBlockParam
460
492
  | BetaRedactedThinkingBlockParam
493
+ | BetaMCPToolUseBlockParam
494
+ | BetaMCPToolResultBlock
461
495
  ] = []
462
496
  for response_part in m.parts:
463
497
  if isinstance(response_part, TextPart):
@@ -508,7 +542,7 @@ class AnthropicModel(Model):
508
542
  input=response_part.args_as_dict(),
509
543
  )
510
544
  assistant_content_params.append(server_tool_use_block_param)
511
- elif response_part.tool_name == CodeExecutionTool.kind: # pragma: no branch
545
+ elif response_part.tool_name == CodeExecutionTool.kind:
512
546
  server_tool_use_block_param = BetaServerToolUseBlockParam(
513
547
  id=tool_use_id,
514
548
  type='server_tool_use',
@@ -516,6 +550,21 @@ class AnthropicModel(Model):
516
550
  input=response_part.args_as_dict(),
517
551
  )
518
552
  assistant_content_params.append(server_tool_use_block_param)
553
+ elif (
554
+ response_part.tool_name.startswith(MCPServerTool.kind)
555
+ and (server_id := response_part.tool_name.split(':', 1)[1])
556
+ and (args := response_part.args_as_dict())
557
+ and (tool_name := args.get('tool_name'))
558
+ and (tool_args := args.get('tool_args'))
559
+ ): # pragma: no branch
560
+ mcp_tool_use_block_param = BetaMCPToolUseBlockParam(
561
+ id=tool_use_id,
562
+ type='mcp_tool_use',
563
+ server_name=server_id,
564
+ name=tool_name,
565
+ input=tool_args,
566
+ )
567
+ assistant_content_params.append(mcp_tool_use_block_param)
519
568
  elif isinstance(response_part, BuiltinToolReturnPart):
520
569
  if response_part.provider_name == self.system:
521
570
  tool_use_id = _guard_tool_call_id(t=response_part)
@@ -547,6 +596,16 @@ class AnthropicModel(Model):
547
596
  ),
548
597
  )
549
598
  )
599
+ elif response_part.tool_name.startswith(MCPServerTool.kind) and isinstance(
600
+ response_part.content, dict
601
+ ): # pragma: no branch
602
+ assistant_content_params.append(
603
+ BetaMCPToolResultBlock(
604
+ tool_use_id=tool_use_id,
605
+ type='mcp_tool_result',
606
+ **cast(dict[str, Any], response_part.content), # pyright: ignore[reportUnknownMemberType]
607
+ )
608
+ )
550
609
  elif isinstance(response_part, FilePart): # pragma: no cover
551
610
  # Files generated by models are not sent back to models that don't themselves generate files.
552
611
  pass
@@ -661,6 +720,7 @@ class AnthropicStreamedResponse(StreamedResponse):
661
720
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
662
721
  current_block: BetaContentBlock | None = None
663
722
 
723
+ builtin_tool_calls: dict[str, BuiltinToolCallPart] = {}
664
724
  async for event in self._response:
665
725
  if isinstance(event, BetaRawMessageStartEvent):
666
726
  self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name)
@@ -698,9 +758,11 @@ class AnthropicStreamedResponse(StreamedResponse):
698
758
  if maybe_event is not None: # pragma: no branch
699
759
  yield maybe_event
700
760
  elif isinstance(current_block, BetaServerToolUseBlock):
761
+ call_part = _map_server_tool_use_block(current_block, self.provider_name)
762
+ builtin_tool_calls[call_part.tool_call_id] = call_part
701
763
  yield self._parts_manager.handle_part(
702
764
  vendor_part_id=event.index,
703
- part=_map_server_tool_use_block(current_block, self.provider_name),
765
+ part=call_part,
704
766
  )
705
767
  elif isinstance(current_block, BetaWebSearchToolResultBlock):
706
768
  yield self._parts_manager.handle_part(
@@ -712,6 +774,32 @@ class AnthropicStreamedResponse(StreamedResponse):
712
774
  vendor_part_id=event.index,
713
775
  part=_map_code_execution_tool_result_block(current_block, self.provider_name),
714
776
  )
777
+ elif isinstance(current_block, BetaMCPToolUseBlock):
778
+ call_part = _map_mcp_server_use_block(current_block, self.provider_name)
779
+ builtin_tool_calls[call_part.tool_call_id] = call_part
780
+
781
+ args_json = call_part.args_as_json_str()
782
+ # Drop the final `{}}` so that we can add tool args deltas
783
+ args_json_delta = args_json[:-3]
784
+ assert args_json_delta.endswith('"tool_args":'), (
785
+ f'Expected {args_json_delta!r} to end in `"tool_args":`'
786
+ )
787
+
788
+ yield self._parts_manager.handle_part(
789
+ vendor_part_id=event.index, part=replace(call_part, args=None)
790
+ )
791
+ maybe_event = self._parts_manager.handle_tool_call_delta(
792
+ vendor_part_id=event.index,
793
+ args=args_json_delta,
794
+ )
795
+ if maybe_event is not None: # pragma: no branch
796
+ yield maybe_event
797
+ elif isinstance(current_block, BetaMCPToolResultBlock):
798
+ call_part = builtin_tool_calls.get(current_block.tool_use_id)
799
+ yield self._parts_manager.handle_part(
800
+ vendor_part_id=event.index,
801
+ part=_map_mcp_server_result_block(current_block, call_part, self.provider_name),
802
+ )
715
803
 
716
804
  elif isinstance(event, BetaRawContentBlockDeltaEvent):
717
805
  if isinstance(event.delta, BetaTextDelta):
@@ -749,7 +837,16 @@ class AnthropicStreamedResponse(StreamedResponse):
749
837
  self.provider_details = {'finish_reason': raw_finish_reason}
750
838
  self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
751
839
 
752
- elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
840
+ elif isinstance(event, BetaRawContentBlockStopEvent): # pragma: no branch
841
+ if isinstance(current_block, BetaMCPToolUseBlock):
842
+ maybe_event = self._parts_manager.handle_tool_call_delta(
843
+ vendor_part_id=event.index,
844
+ args='}',
845
+ )
846
+ if maybe_event is not None: # pragma: no branch
847
+ yield maybe_event
848
+ current_block = None
849
+ elif isinstance(event, BetaRawMessageStopEvent): # pragma: no branch
753
850
  current_block = None
754
851
 
755
852
  @property
@@ -817,3 +914,27 @@ def _map_code_execution_tool_result_block(
817
914
  content=code_execution_tool_result_content_ta.dump_python(item.content, mode='json'),
818
915
  tool_call_id=item.tool_use_id,
819
916
  )
917
+
918
+
919
+ def _map_mcp_server_use_block(item: BetaMCPToolUseBlock, provider_name: str) -> BuiltinToolCallPart:
920
+ return BuiltinToolCallPart(
921
+ provider_name=provider_name,
922
+ tool_name=':'.join([MCPServerTool.kind, item.server_name]),
923
+ args={
924
+ 'action': 'call_tool',
925
+ 'tool_name': item.name,
926
+ 'tool_args': cast(dict[str, Any], item.input),
927
+ },
928
+ tool_call_id=item.id,
929
+ )
930
+
931
+
932
+ def _map_mcp_server_result_block(
933
+ item: BetaMCPToolResultBlock, call_part: BuiltinToolCallPart | None, provider_name: str
934
+ ) -> BuiltinToolReturnPart:
935
+ return BuiltinToolReturnPart(
936
+ provider_name=provider_name,
937
+ tool_name=call_part.tool_name if call_part else MCPServerTool.kind,
938
+ content=item.model_dump(mode='json', include={'content', 'is_error'}),
939
+ tool_call_id=item.tool_use_id,
940
+ )
@@ -207,7 +207,7 @@ class BedrockConverseModel(Model):
207
207
  self,
208
208
  model_name: BedrockModelName,
209
209
  *,
210
- provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock',
210
+ provider: Literal['bedrock', 'gateway'] | Provider[BaseClient] = 'bedrock',
211
211
  profile: ModelProfileSpec | None = None,
212
212
  settings: ModelSettings | None = None,
213
213
  ):
@@ -226,7 +226,7 @@ class BedrockConverseModel(Model):
226
226
  self._model_name = model_name
227
227
 
228
228
  if isinstance(provider, str):
229
- provider = infer_provider(provider)
229
+ provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
230
230
  self._provider = provider
231
231
  self.client = cast('BedrockRuntimeClient', provider.client)
232
232
 
@@ -701,8 +701,8 @@ class BedrockStreamedResponse(StreamedResponse):
701
701
  signature=signature,
702
702
  provider_name=self.provider_name if signature else None,
703
703
  )
704
- if 'text' in delta:
705
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
704
+ if text := delta.get('text'):
705
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text)
706
706
  if maybe_event is not None: # pragma: no branch
707
707
  yield maybe_event
708
708
  if 'toolUse' in delta:
@@ -62,15 +62,8 @@ except ImportError as _import_error:
62
62
  LatestCohereModelNames = Literal[
63
63
  'c4ai-aya-expanse-32b',
64
64
  'c4ai-aya-expanse-8b',
65
- 'command',
66
- 'command-light',
67
- 'command-light-nightly',
68
65
  'command-nightly',
69
- 'command-r',
70
- 'command-r-03-2024',
71
66
  'command-r-08-2024',
72
- 'command-r-plus',
73
- 'command-r-plus-04-2024',
74
67
  'command-r-plus-08-2024',
75
68
  'command-r7b-12-2024',
76
69
  ]
@@ -38,7 +38,7 @@ from ..messages import (
38
38
  VideoUrl,
39
39
  )
40
40
  from ..profiles import ModelProfileSpec
41
- from ..providers import Provider, infer_provider
41
+ from ..providers import Provider
42
42
  from ..settings import ModelSettings
43
43
  from ..tools import ToolDefinition
44
44
  from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
@@ -131,7 +131,14 @@ class GeminiModel(Model):
131
131
  self._model_name = model_name
132
132
 
133
133
  if isinstance(provider, str):
134
- provider = infer_provider(provider)
134
+ if provider == 'google-gla':
135
+ from pydantic_ai.providers.google_gla import GoogleGLAProvider # type: ignore[reportDeprecated]
136
+
137
+ provider = GoogleGLAProvider() # type: ignore[reportDeprecated]
138
+ else:
139
+ from pydantic_ai.providers.google_vertex import GoogleVertexProvider # type: ignore[reportDeprecated]
140
+
141
+ provider = GoogleVertexProvider() # type: ignore[reportDeprecated]
135
142
  self._provider = provider
136
143
  self.client = provider.client
137
144
  self._url = str(self.client.base_url)
@@ -37,7 +37,7 @@ from ..messages import (
37
37
  VideoUrl,
38
38
  )
39
39
  from ..profiles import ModelProfileSpec
40
- from ..providers import Provider
40
+ from ..providers import Provider, infer_provider
41
41
  from ..settings import ModelSettings
42
42
  from ..tools import ToolDefinition
43
43
  from . import (
@@ -85,8 +85,6 @@ try:
85
85
  UrlContextDict,
86
86
  VideoMetadataDict,
87
87
  )
88
-
89
- from ..providers.google import GoogleProvider
90
88
  except ImportError as _import_error:
91
89
  raise ImportError(
92
90
  'Please install `google-genai` to use the Google model, '
@@ -128,6 +126,8 @@ _FINISH_REASON_MAP: dict[GoogleFinishReason, FinishReason | None] = {
128
126
  GoogleFinishReason.MALFORMED_FUNCTION_CALL: 'error',
129
127
  GoogleFinishReason.IMAGE_SAFETY: 'content_filter',
130
128
  GoogleFinishReason.UNEXPECTED_TOOL_CALL: 'error',
129
+ GoogleFinishReason.IMAGE_PROHIBITED_CONTENT: 'content_filter',
130
+ GoogleFinishReason.NO_IMAGE: 'error',
131
131
  }
132
132
 
133
133
 
@@ -187,7 +187,7 @@ class GoogleModel(Model):
187
187
  self,
188
188
  model_name: GoogleModelName,
189
189
  *,
190
- provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla',
190
+ provider: Literal['google-gla', 'google-vertex', 'gateway'] | Provider[Client] = 'google-gla',
191
191
  profile: ModelProfileSpec | None = None,
192
192
  settings: ModelSettings | None = None,
193
193
  ):
@@ -196,15 +196,15 @@ class GoogleModel(Model):
196
196
  Args:
197
197
  model_name: The name of the model to use.
198
198
  provider: The provider to use for authentication and API access. Can be either the string
199
- 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
200
- If not provided, a new provider will be created using the other parameters.
199
+ 'google-gla' or 'google-vertex' or an instance of `Provider[google.genai.AsyncClient]`.
200
+ Defaults to 'google-gla'.
201
201
  profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
202
202
  settings: The model settings to use. Defaults to None.
203
203
  """
204
204
  self._model_name = model_name
205
205
 
206
206
  if isinstance(provider, str):
207
- provider = GoogleProvider(vertexai=provider == 'google-vertex')
207
+ provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider)
208
208
  self._provider = provider
209
209
  self.client = provider.client
210
210
 
@@ -455,23 +455,28 @@ class GoogleModel(Model):
455
455
  def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
456
456
  if not response.candidates:
457
457
  raise UnexpectedModelBehavior('Expected at least one candidate in Gemini response') # pragma: no cover
458
+
458
459
  candidate = response.candidates[0]
459
- if candidate.content is None or candidate.content.parts is None:
460
- if candidate.finish_reason == 'SAFETY':
461
- raise UnexpectedModelBehavior('Safety settings triggered', str(response))
462
- else:
463
- raise UnexpectedModelBehavior(
464
- 'Content field missing from Gemini response', str(response)
465
- ) # pragma: no cover
466
- parts = candidate.content.parts or []
467
460
 
468
461
  vendor_id = response.response_id
469
462
  vendor_details: dict[str, Any] | None = None
470
463
  finish_reason: FinishReason | None = None
471
- if raw_finish_reason := candidate.finish_reason: # pragma: no branch
464
+ raw_finish_reason = candidate.finish_reason
465
+ if raw_finish_reason: # pragma: no branch
472
466
  vendor_details = {'finish_reason': raw_finish_reason.value}
473
467
  finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
474
468
 
469
+ if candidate.content is None or candidate.content.parts is None:
470
+ if finish_reason == 'content_filter' and raw_finish_reason:
471
+ raise UnexpectedModelBehavior(
472
+ f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json()
473
+ )
474
+ else:
475
+ raise UnexpectedModelBehavior(
476
+ 'Content field missing from Gemini response', response.model_dump_json()
477
+ ) # pragma: no cover
478
+ parts = candidate.content.parts or []
479
+
475
480
  usage = _metadata_as_usage(response)
476
481
  return _process_response_from_parts(
477
482
  parts,
@@ -625,7 +630,8 @@ class GeminiStreamedResponse(StreamedResponse):
625
630
  if chunk.response_id: # pragma: no branch
626
631
  self.provider_response_id = chunk.response_id
627
632
 
628
- if raw_finish_reason := candidate.finish_reason:
633
+ raw_finish_reason = candidate.finish_reason
634
+ if raw_finish_reason:
629
635
  self.provider_details = {'finish_reason': raw_finish_reason.value}
630
636
  self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
631
637
 
@@ -643,13 +649,17 @@ class GeminiStreamedResponse(StreamedResponse):
643
649
  # )
644
650
 
645
651
  if candidate.content is None or candidate.content.parts is None:
646
- if candidate.finish_reason == 'STOP': # pragma: no cover
652
+ if self.finish_reason == 'stop': # pragma: no cover
647
653
  # Normal completion - skip this chunk
648
654
  continue
649
- elif candidate.finish_reason == 'SAFETY': # pragma: no cover
650
- raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
655
+ elif self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
656
+ raise UnexpectedModelBehavior(
657
+ f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json()
658
+ )
651
659
  else: # pragma: no cover
652
- raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
660
+ raise UnexpectedModelBehavior(
661
+ 'Content field missing from streaming Gemini response', chunk.model_dump_json()
662
+ )
653
663
 
654
664
  parts = candidate.content.parts
655
665
  if not parts:
@@ -141,7 +141,7 @@ class GroqModel(Model):
141
141
  self,
142
142
  model_name: GroqModelName,
143
143
  *,
144
- provider: Literal['groq'] | Provider[AsyncGroq] = 'groq',
144
+ provider: Literal['groq', 'gateway'] | Provider[AsyncGroq] = 'groq',
145
145
  profile: ModelProfileSpec | None = None,
146
146
  settings: ModelSettings | None = None,
147
147
  ):
@@ -159,7 +159,7 @@ class GroqModel(Model):
159
159
  self._model_name = model_name
160
160
 
161
161
  if isinstance(provider, str):
162
- provider = infer_provider(provider)
162
+ provider = infer_provider('gateway/groq' if provider == 'gateway' else provider)
163
163
  self._provider = provider
164
164
  self.client = provider.client
165
165
 
@@ -330,7 +330,7 @@ class GroqModel(Model):
330
330
  if call_part and return_part: # pragma: no branch
331
331
  items.append(call_part)
332
332
  items.append(return_part)
333
- if choice.message.content is not None:
333
+ if choice.message.content:
334
334
  # NOTE: The `<think>` tag is only present if `groq_reasoning_format` is set to `raw`.
335
335
  items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
336
336
  if choice.message.tool_calls is not None:
@@ -563,7 +563,7 @@ class GroqStreamedResponse(StreamedResponse):
563
563
 
564
564
  # Handle the text part of the response
565
565
  content = choice.delta.content
566
- if content is not None:
566
+ if content:
567
567
  maybe_event = self._parts_manager.handle_text_delta(
568
568
  vendor_part_id='content',
569
569
  content=content,
@@ -277,7 +277,7 @@ class HuggingFaceModel(Model):
277
277
 
278
278
  items: list[ModelResponsePart] = []
279
279
 
280
- if content is not None:
280
+ if content:
281
281
  items.extend(split_content_into_text_and_thinking(content, self.profile.thinking_tags))
282
282
  if tool_calls is not None:
283
283
  for c in tool_calls:
@@ -482,7 +482,7 @@ class HuggingFaceStreamedResponse(StreamedResponse):
482
482
 
483
483
  # Handle the text part of the response
484
484
  content = choice.delta.content
485
- if content is not None:
485
+ if content:
486
486
  maybe_event = self._parts_manager.handle_text_delta(
487
487
  vendor_part_id='content',
488
488
  content=content,