pydantic-ai-slim 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,569 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import base64
4
+ import warnings
5
+ from collections.abc import AsyncIterator, Awaitable
6
+ from contextlib import asynccontextmanager
7
+ from dataclasses import dataclass, field, replace
8
+ from datetime import datetime
9
+ from typing import Literal, Union, cast, overload
10
+ from uuid import uuid4
11
+
12
+ from typing_extensions import assert_never
13
+
14
+ from pydantic_ai.providers import Provider
15
+
16
+ from .. import UnexpectedModelBehavior, UserError, _utils, usage
17
+ from ..messages import (
18
+ AudioUrl,
19
+ BinaryContent,
20
+ DocumentUrl,
21
+ ImageUrl,
22
+ ModelMessage,
23
+ ModelRequest,
24
+ ModelResponse,
25
+ ModelResponsePart,
26
+ ModelResponseStreamEvent,
27
+ RetryPromptPart,
28
+ SystemPromptPart,
29
+ TextPart,
30
+ ToolCallPart,
31
+ ToolReturnPart,
32
+ UserPromptPart,
33
+ VideoUrl,
34
+ )
35
+ from ..settings import ModelSettings
36
+ from ..tools import ToolDefinition
37
+ from . import (
38
+ Model,
39
+ ModelRequestParameters,
40
+ StreamedResponse,
41
+ cached_async_http_client,
42
+ check_allow_model_requests,
43
+ get_user_agent,
44
+ )
45
+ from ._json_schema import JsonSchema, WalkJsonSchema
46
+
47
+ try:
48
+ from google import genai
49
+ from google.genai.types import (
50
+ ContentDict,
51
+ ContentUnionDict,
52
+ FunctionCallDict,
53
+ FunctionCallingConfigDict,
54
+ FunctionCallingConfigMode,
55
+ FunctionDeclarationDict,
56
+ GenerateContentConfigDict,
57
+ GenerateContentResponse,
58
+ Part,
59
+ PartDict,
60
+ SafetySettingDict,
61
+ ThinkingConfigDict,
62
+ ToolConfigDict,
63
+ ToolDict,
64
+ ToolListUnionDict,
65
+ )
66
+
67
+ from ..providers.google import GoogleProvider
68
+ except ImportError as _import_error:
69
+ raise ImportError(
70
+ 'Please install `google-genai` to use the Google model, '
71
+ 'you can use the `google` optional group — `pip install "pydantic-ai-slim[google]"`'
72
+ ) from _import_error
73
+
74
+ LatestGoogleModelNames = Literal[
75
+ 'gemini-1.5-flash',
76
+ 'gemini-1.5-flash-8b',
77
+ 'gemini-1.5-pro',
78
+ 'gemini-1.0-pro',
79
+ 'gemini-2.0-flash-exp',
80
+ 'gemini-2.0-flash-thinking-exp-01-21',
81
+ 'gemini-exp-1206',
82
+ 'gemini-2.0-flash',
83
+ 'gemini-2.0-flash-lite-preview-02-05',
84
+ 'gemini-2.0-pro-exp-02-05',
85
+ 'gemini-2.5-flash-preview-04-17',
86
+ 'gemini-2.5-pro-exp-03-25',
87
+ 'gemini-2.5-pro-preview-03-25',
88
+ ]
89
+ """Latest Gemini models."""
90
+
91
+ GoogleModelName = Union[str, LatestGoogleModelNames]
92
+ """Possible Gemini model names.
93
+
94
+ Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
95
+ allow any name in the type hints.
96
+ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
97
+ """
98
+
99
+
100
+ class GoogleModelSettings(ModelSettings, total=False):
101
+ """Settings used for a Gemini model request.
102
+
103
+ ALL FIELDS MUST BE `gemini_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
104
+ """
105
+
106
+ google_safety_settings: list[SafetySettingDict]
107
+ """The safety settings to use for the model.
108
+
109
+ See <https://ai.google.dev/gemini-api/docs/safety-settings> for more information.
110
+ """
111
+
112
+ google_thinking_config: ThinkingConfigDict
113
+ """The thinking configuration to use for the model.
114
+
115
+ See <https://ai.google.dev/gemini-api/docs/thinking> for more information.
116
+ """
117
+
118
+
119
+ @dataclass(init=False)
120
+ class GoogleModel(Model):
121
+ """A model that uses Gemini via `generativelanguage.googleapis.com` API.
122
+
123
+ This is implemented from scratch rather than using a dedicated SDK, good API documentation is
124
+ available [here](https://ai.google.dev/api).
125
+
126
+ Apart from `__init__`, all methods are private or match those of the base class.
127
+ """
128
+
129
+ client: genai.Client = field(repr=False)
130
+
131
+ _model_name: GoogleModelName = field(repr=False)
132
+ _provider: Provider[genai.Client] = field(repr=False)
133
+ _url: str | None = field(repr=False)
134
+ _system: str = field(default='google', repr=False)
135
+
136
+ def __init__(
137
+ self,
138
+ model_name: GoogleModelName,
139
+ *,
140
+ provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla',
141
+ ):
142
+ """Initialize a Gemini model.
143
+
144
+ Args:
145
+ model_name: The name of the model to use.
146
+ provider: The provider to use for authentication and API access. Can be either the string
147
+ 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
148
+ If not provided, a new provider will be created using the other parameters.
149
+ """
150
+ self._model_name = model_name
151
+
152
+ if isinstance(provider, str):
153
+ provider = GoogleProvider(vertexai=provider == 'google-vertex') # pragma: lax no cover
154
+
155
+ self._provider = provider
156
+ self._system = provider.name
157
+ self.client = provider.client
158
+
159
+ @property
160
+ def base_url(self) -> str:
161
+ return self._provider.base_url
162
+
163
+ async def request(
164
+ self,
165
+ messages: list[ModelMessage],
166
+ model_settings: ModelSettings | None,
167
+ model_request_parameters: ModelRequestParameters,
168
+ ) -> ModelResponse:
169
+ check_allow_model_requests()
170
+ model_settings = cast(GoogleModelSettings, model_settings or {})
171
+ response = await self._generate_content(messages, False, model_settings, model_request_parameters)
172
+ return self._process_response(response)
173
+
174
+ @asynccontextmanager
175
+ async def request_stream(
176
+ self,
177
+ messages: list[ModelMessage],
178
+ model_settings: ModelSettings | None,
179
+ model_request_parameters: ModelRequestParameters,
180
+ ) -> AsyncIterator[StreamedResponse]:
181
+ check_allow_model_requests()
182
+ model_settings = cast(GoogleModelSettings, model_settings or {})
183
+ response = await self._generate_content(messages, True, model_settings, model_request_parameters)
184
+ yield await self._process_streamed_response(response) # type: ignore
185
+
186
+ def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
187
+ def _customize_tool_def(t: ToolDefinition):
188
+ return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
189
+
190
+ return ModelRequestParameters(
191
+ function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
192
+ allow_text_output=model_request_parameters.allow_text_output,
193
+ output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
194
+ )
195
+
196
+ @property
197
+ def model_name(self) -> GoogleModelName:
198
+ """The model name."""
199
+ return self._model_name
200
+
201
+ @property
202
+ def system(self) -> str:
203
+ """The system / model provider."""
204
+ return self._system
205
+
206
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
207
+ tools: list[ToolDict] = [
208
+ ToolDict(function_declarations=[_function_declaration_from_tool(t)])
209
+ for t in model_request_parameters.function_tools
210
+ ]
211
+ if model_request_parameters.output_tools:
212
+ tools += [
213
+ ToolDict(function_declarations=[_function_declaration_from_tool(t)])
214
+ for t in model_request_parameters.output_tools
215
+ ]
216
+ return tools or None
217
+
218
+ def _get_tool_config(
219
+ self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None
220
+ ) -> ToolConfigDict | None:
221
+ if model_request_parameters.allow_text_output:
222
+ return None
223
+ elif tools:
224
+ names: list[str] = []
225
+ for tool in tools:
226
+ for function_declaration in tool.get('function_declarations') or []:
227
+ if name := function_declaration.get('name'): # pragma: no branch
228
+ names.append(name)
229
+ return _tool_config(names)
230
+ else:
231
+ return _tool_config([]) # pragma: no cover
232
+
233
+ @overload
234
+ async def _generate_content(
235
+ self,
236
+ messages: list[ModelMessage],
237
+ stream: Literal[False],
238
+ model_settings: GoogleModelSettings,
239
+ model_request_parameters: ModelRequestParameters,
240
+ ) -> GenerateContentResponse: ...
241
+
242
+ @overload
243
+ async def _generate_content(
244
+ self,
245
+ messages: list[ModelMessage],
246
+ stream: Literal[True],
247
+ model_settings: GoogleModelSettings,
248
+ model_request_parameters: ModelRequestParameters,
249
+ ) -> Awaitable[AsyncIterator[GenerateContentResponse]]: ...
250
+
251
+ async def _generate_content(
252
+ self,
253
+ messages: list[ModelMessage],
254
+ stream: bool,
255
+ model_settings: GoogleModelSettings,
256
+ model_request_parameters: ModelRequestParameters,
257
+ ) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]:
258
+ tools = self._get_tools(model_request_parameters)
259
+ tool_config = self._get_tool_config(model_request_parameters, tools)
260
+ system_instruction, contents = await self._map_messages(messages)
261
+
262
+ config = GenerateContentConfigDict(
263
+ http_options={'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}},
264
+ system_instruction=system_instruction,
265
+ temperature=model_settings.get('temperature'),
266
+ top_p=model_settings.get('top_p'),
267
+ max_output_tokens=model_settings.get('max_tokens'),
268
+ presence_penalty=model_settings.get('presence_penalty'),
269
+ frequency_penalty=model_settings.get('frequency_penalty'),
270
+ safety_settings=model_settings.get('google_safety_settings'),
271
+ thinking_config=model_settings.get('google_thinking_config'),
272
+ tools=cast(ToolListUnionDict, tools),
273
+ tool_config=tool_config,
274
+ )
275
+
276
+ func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
277
+ return await func(model=self._model_name, contents=contents, config=config) # type: ignore
278
+
279
+ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
280
+ if not response.candidates or len(response.candidates) != 1:
281
+ raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
282
+ if response.candidates[0].content is None or response.candidates[0].content.parts is None:
283
+ if response.candidates[0].finish_reason == 'SAFETY':
284
+ raise UnexpectedModelBehavior('Safety settings triggered', str(response))
285
+ else:
286
+ raise UnexpectedModelBehavior(
287
+ 'Content field missing from Gemini response', str(response)
288
+ ) # pragma: no cover
289
+ parts = response.candidates[0].content.parts or []
290
+ usage = _metadata_as_usage(response)
291
+ usage.requests = 1
292
+ return _process_response_from_parts(parts, response.model_version or self._model_name, usage)
293
+
294
+ async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
295
+ """Process a streamed response, and prepare a streaming response to return."""
296
+ peekable_response = _utils.PeekableAsyncStream(response)
297
+ first_chunk = await peekable_response.peek()
298
+ if isinstance(first_chunk, _utils.Unset):
299
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
300
+
301
+ return GeminiStreamedResponse(
302
+ _model_name=self._model_name,
303
+ _response=peekable_response,
304
+ _timestamp=first_chunk.create_time or _utils.now_utc(),
305
+ )
306
+
307
+ async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
308
+ contents: list[ContentUnionDict] = []
309
+ system_parts: list[PartDict] = []
310
+
311
+ for m in messages:
312
+ if isinstance(m, ModelRequest):
313
+ message_parts: list[PartDict] = []
314
+
315
+ for part in m.parts:
316
+ if isinstance(part, SystemPromptPart):
317
+ system_parts.append({'text': part.content})
318
+ elif isinstance(part, UserPromptPart):
319
+ message_parts.extend(await self._map_user_prompt(part))
320
+ elif isinstance(part, ToolReturnPart):
321
+ message_parts.append(
322
+ {
323
+ 'function_response': {
324
+ 'name': part.tool_name,
325
+ 'response': part.model_response_object(),
326
+ 'id': part.tool_call_id,
327
+ }
328
+ }
329
+ )
330
+ elif isinstance(part, RetryPromptPart):
331
+ if part.tool_name is None:
332
+ message_parts.append({'text': part.model_response()}) # pragma: no cover
333
+ else:
334
+ message_parts.append(
335
+ {
336
+ 'function_response': {
337
+ 'name': part.tool_name,
338
+ 'response': {'call_error': part.model_response()},
339
+ 'id': part.tool_call_id,
340
+ }
341
+ }
342
+ )
343
+ else:
344
+ assert_never(part)
345
+
346
+ if message_parts: # pragma: no branch
347
+ contents.append({'role': 'user', 'parts': message_parts})
348
+ elif isinstance(m, ModelResponse):
349
+ contents.append(_content_model_response(m))
350
+ else:
351
+ assert_never(m)
352
+ if instructions := self._get_instructions(messages):
353
+ system_parts.insert(0, {'text': instructions})
354
+ system_instruction = ContentDict(role='user', parts=system_parts) if system_parts else None
355
+ return system_instruction, contents
356
+
357
+ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]:
358
+ if isinstance(part.content, str):
359
+ return [{'text': part.content}]
360
+ else:
361
+ content: list[PartDict] = []
362
+ for item in part.content:
363
+ if isinstance(item, str):
364
+ content.append({'text': item})
365
+ elif isinstance(item, BinaryContent):
366
+ # NOTE: The type from Google GenAI is incorrect, it should be `str`, not `bytes`.
367
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
368
+ content.append({'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}}) # type: ignore
369
+ elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl, VideoUrl)):
370
+ client = cached_async_http_client()
371
+ response = await client.get(item.url, follow_redirects=True)
372
+ response.raise_for_status()
373
+ # NOTE: The type from Google GenAI is incorrect, it should be `str`, not `bytes`.
374
+ base64_encoded = base64.b64encode(response.content).decode('utf-8')
375
+ content.append({'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}}) # type: ignore
376
+ else:
377
+ assert_never(item)
378
+ return content
379
+
380
+
381
+ @dataclass
382
+ class GeminiStreamedResponse(StreamedResponse):
383
+ """Implementation of `StreamedResponse` for the Gemini model."""
384
+
385
+ _model_name: GoogleModelName
386
+ _response: AsyncIterator[GenerateContentResponse]
387
+ _timestamp: datetime
388
+
389
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
390
+ async for chunk in self._response:
391
+ self._usage += _metadata_as_usage(chunk)
392
+
393
+ assert chunk.candidates is not None
394
+ candidate = chunk.candidates[0]
395
+ if candidate.content is None:
396
+ raise UnexpectedModelBehavior('Streamed response has no content field') # pragma: no cover
397
+ assert candidate.content.parts is not None
398
+ for part in candidate.content.parts:
399
+ if part.text:
400
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
401
+ elif part.function_call:
402
+ maybe_event = self._parts_manager.handle_tool_call_delta(
403
+ vendor_part_id=uuid4(),
404
+ tool_name=part.function_call.name,
405
+ args=part.function_call.args,
406
+ tool_call_id=part.function_call.id,
407
+ )
408
+ if maybe_event is not None: # pragma: no branch
409
+ yield maybe_event
410
+ else:
411
+ assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
412
+
413
+ @property
414
+ def model_name(self) -> GoogleModelName:
415
+ """Get the model name of the response."""
416
+ return self._model_name
417
+
418
+ @property
419
+ def timestamp(self) -> datetime:
420
+ """Get the timestamp of the response."""
421
+ return self._timestamp
422
+
423
+
424
+ def _content_model_response(m: ModelResponse) -> ContentDict:
425
+ parts: list[PartDict] = []
426
+ for item in m.parts:
427
+ if isinstance(item, ToolCallPart):
428
+ function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
429
+ parts.append({'function_call': function_call})
430
+ elif isinstance(item, TextPart):
431
+ if item.content: # pragma: no branch
432
+ parts.append({'text': item.content})
433
+ else:
434
+ assert_never(item)
435
+ return ContentDict(role='model', parts=parts)
436
+
437
+
438
+ def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName, usage: usage.Usage) -> ModelResponse:
439
+ items: list[ModelResponsePart] = []
440
+ for part in parts:
441
+ if part.text:
442
+ items.append(TextPart(content=part.text))
443
+ elif part.function_call:
444
+ assert part.function_call.name is not None
445
+ tool_call_part = ToolCallPart(tool_name=part.function_call.name, args=part.function_call.args or {})
446
+ if part.function_call.id is not None:
447
+ tool_call_part.tool_call_id = part.function_call.id # pragma: no cover
448
+ items.append(tool_call_part)
449
+ elif part.function_response: # pragma: no cover
450
+ raise UnexpectedModelBehavior(
451
+ f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
452
+ )
453
+ return ModelResponse(parts=items, model_name=model_name, usage=usage)
454
+
455
+
456
+ def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
457
+ json_schema = tool.parameters_json_schema
458
+ f = FunctionDeclarationDict(name=tool.name, description=tool.description)
459
+ if json_schema.get('properties'): # pragma: no branch
460
+ f['parameters'] = json_schema # type: ignore
461
+ return f
462
+
463
+
464
+ def _tool_config(function_names: list[str]) -> ToolConfigDict:
465
+ mode = FunctionCallingConfigMode.ANY
466
+ function_calling_config = FunctionCallingConfigDict(mode=mode, allowed_function_names=function_names)
467
+ return ToolConfigDict(function_calling_config=function_calling_config)
468
+
469
+
470
+ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
471
+ metadata = response.usage_metadata
472
+ if metadata is None:
473
+ return usage.Usage() # pragma: no cover
474
+ # TODO(Marcelo): We exclude the `prompt_tokens_details` and `candidate_token_details` fields because on
475
+ # `usage.Usage.incr``, it will try to sum non-integer values with integers, which will fail. We should probably
476
+ # handle this in the `Usage` class.
477
+ details = metadata.model_dump(
478
+ exclude={'prompt_tokens_details', 'candidates_tokens_details', 'traffic_type'},
479
+ exclude_defaults=True,
480
+ )
481
+ return usage.Usage(
482
+ request_tokens=details.pop('prompt_token_count', 0),
483
+ response_tokens=details.pop('candidates_token_count', 0),
484
+ total_tokens=details.pop('total_token_count', 0),
485
+ details=details,
486
+ )
487
+
488
+
489
+ class _GeminiJsonSchema(WalkJsonSchema):
490
+ """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
491
+
492
+ Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
493
+ a subset of OpenAPI v3.0.3.
494
+
495
+ Specifically:
496
+ * gemini doesn't allow the `title` keyword to be set
497
+ * gemini doesn't allow `$defs` — we need to inline the definitions where possible
498
+ """
499
+
500
+ def __init__(self, schema: JsonSchema):
501
+ super().__init__(schema, prefer_inlined_defs=True, simplify_nullable_unions=True)
502
+
503
+ def transform(self, schema: JsonSchema) -> JsonSchema:
504
+ # Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
505
+ additional_properties = schema.pop(
506
+ 'additionalProperties', None
507
+ ) # don't pop yet so it's included in the warning
508
+ if additional_properties: # pragma: no cover
509
+ original_schema = {**schema, 'additionalProperties': additional_properties}
510
+ warnings.warn(
511
+ '`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
512
+ f' Full schema: {self.schema}\n\n'
513
+ f'Source of additionalProperties within the full schema: {original_schema}\n\n'
514
+ 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
515
+ "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
516
+ ' and we will fix this behavior.',
517
+ UserWarning,
518
+ )
519
+
520
+ schema.pop('title', None)
521
+ schema.pop('default', None)
522
+ schema.pop('$schema', None)
523
+ if (const := schema.pop('const', None)) is not None: # pragma: no cover
524
+ # Gemini doesn't support const, but it does support enum with a single value
525
+ schema['enum'] = [const]
526
+ schema.pop('discriminator', None)
527
+ schema.pop('examples', None)
528
+
529
+ # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
530
+ # where we add notes about these properties to the field description?
531
+ schema.pop('exclusiveMaximum', None)
532
+ schema.pop('exclusiveMinimum', None)
533
+
534
+ type_ = schema.get('type')
535
+ if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
536
+ # This gets hit when we have a discriminated union
537
+ # Gemini returns an API error in this case even though it says in its error message it shouldn't...
538
+ # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
539
+ schema['anyOf'] = schema.pop('oneOf')
540
+
541
+ if type_ == 'string' and (fmt := schema.pop('format', None)):
542
+ description = schema.get('description')
543
+ if description:
544
+ schema['description'] = f'{description} (format: {fmt})'
545
+ else:
546
+ schema['description'] = f'Format: {fmt}'
547
+
548
+ if '$ref' in schema:
549
+ raise UserError( # pragma: no cover
550
+ f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}'
551
+ )
552
+
553
+ if 'prefixItems' in schema: # pragma: lax no cover
554
+ # prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
555
+ prefix_items = schema.pop('prefixItems')
556
+ items = schema.get('items')
557
+ unique_items = [items] if items is not None else []
558
+ for item in prefix_items:
559
+ if item not in unique_items:
560
+ unique_items.append(item)
561
+ if len(unique_items) > 1: # pragma: no cover
562
+ schema['items'] = {'anyOf': unique_items}
563
+ elif len(unique_items) == 1:
564
+ schema['items'] = unique_items[0]
565
+ schema.setdefault('minItems', len(prefix_items))
566
+ if items is None:
567
+ schema.setdefault('maxItems', len(prefix_items))
568
+
569
+ return schema
@@ -227,7 +227,7 @@ class GroqModel(Model):
227
227
  except APIStatusError as e:
228
228
  if (status_code := e.status_code) >= 400:
229
229
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
230
- raise
230
+ raise # pragma: lax no cover
231
231
 
232
232
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
233
233
  """Process a non-streamed response, and prepare a message to return."""
@@ -239,14 +239,18 @@ class GroqModel(Model):
239
239
  if choice.message.tool_calls is not None:
240
240
  for c in choice.message.tool_calls:
241
241
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
242
- return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
242
+ return ModelResponse(
243
+ items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
244
+ )
243
245
 
244
246
  async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
245
247
  """Process a streamed response, and prepare a streaming response to return."""
246
248
  peekable_response = _utils.PeekableAsyncStream(response)
247
249
  first_chunk = await peekable_response.peek()
248
250
  if isinstance(first_chunk, _utils.Unset):
249
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
251
+ raise UnexpectedModelBehavior( # pragma: no cover
252
+ 'Streamed response ended without content or tool calls'
253
+ )
250
254
 
251
255
  return GroqStreamedResponse(
252
256
  _response=peekable_response,
@@ -322,9 +326,11 @@ class GroqModel(Model):
322
326
  tool_call_id=_guard_tool_call_id(t=part),
323
327
  content=part.model_response_str(),
324
328
  )
325
- elif isinstance(part, RetryPromptPart):
329
+ elif isinstance(part, RetryPromptPart): # pragma: no branch
326
330
  if part.tool_name is None:
327
- yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
331
+ yield chat.ChatCompletionUserMessageParam( # pragma: no cover
332
+ role='user', content=part.model_response()
333
+ )
328
334
  else:
329
335
  yield chat.ChatCompletionToolMessageParam(
330
336
  role='tool',
@@ -409,7 +415,7 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
409
415
  if isinstance(completion, chat.ChatCompletion):
410
416
  response_usage = completion.usage
411
417
  elif completion.x_groq is not None:
412
- response_usage = completion.x_groq.usage
418
+ response_usage = completion.x_groq.usage # pragma: no cover
413
419
 
414
420
  if response_usage is None:
415
421
  return usage.Usage()