pydantic-ai-slim 0.0.6a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,720 @@
1
+ """Custom interface to the `generativelanguage.googleapis.com` API using [HTTPX] and [Pydantic].
2
+
3
+ The Google SDK for interacting with the `generativelanguage.googleapis.com` API
4
+ [`google-generativeai`](https://ai.google.dev/gemini-api/docs/quickstart?lang=python) reads like it was written by a
5
+ Java developer who thought they knew everything about OOP, spent 30 minutes trying to learn Python,
6
+ gave up and decided to build the library to prove how horrible Python is. It also doesn't use httpx for HTTP requests,
7
+ and tries to implement tool calling itself, but doesn't use Pydantic or equivalent for validation.
8
+
9
+ We could also use the Google Vertex SDK,
10
+ [`google-cloud-aiplatform`](https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk)
11
+ which uses the `*-aiplatform.googleapis.com` API, but that requires a service account for authentication
12
+ which is a faff to set up and manage.
13
+
14
+ Both APIs claim compatibility with OpenAI's API, but that breaks down with even the simplest of requests,
15
+ hence this custom interface.
16
+
17
+ Despite these limitations, the Gemini model is actually quite powerful and very fast.
18
+
19
+ [HTTPX]: https://www.python-httpx.org/
20
+ [Pydantic]: https://docs.pydantic.dev/latest/
21
+ """
22
+
23
+ from __future__ import annotations as _annotations
24
+
25
+ import os
26
+ import re
27
+ from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
28
+ from contextlib import asynccontextmanager
29
+ from copy import deepcopy
30
+ from dataclasses import dataclass, field
31
+ from datetime import datetime
32
+ from typing import Annotated, Any, Literal, Protocol, Union
33
+
34
+ import pydantic_core
35
+ from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse
36
+ from pydantic import Discriminator, Field, Tag
37
+ from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
38
+
39
+ from .. import UnexpectedModelBehavior, _pydantic, _utils, exceptions, result
40
+ from ..messages import (
41
+ ArgsObject,
42
+ Message,
43
+ ModelAnyResponse,
44
+ ModelStructuredResponse,
45
+ ModelTextResponse,
46
+ RetryPrompt,
47
+ ToolCall,
48
+ ToolReturn,
49
+ )
50
+ from . import (
51
+ AbstractToolDefinition,
52
+ AgentModel,
53
+ EitherStreamedResponse,
54
+ Model,
55
+ StreamStructuredResponse,
56
+ StreamTextResponse,
57
+ cached_async_http_client,
58
+ check_allow_model_requests,
59
+ get_user_agent,
60
+ )
61
+
62
+ GeminiModelName = Literal['gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro']
63
+ """Named Gemini models.
64
+
65
+ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
66
+ """
67
+
68
+
69
+ @dataclass(init=False)
70
+ class GeminiModel(Model):
71
+ """A model that uses Gemini via `generativelanguage.googleapis.com` API.
72
+
73
+ This is implemented from scratch rather than using a dedicated SDK, good API documentation is
74
+ available [here](https://ai.google.dev/api).
75
+
76
+ Apart from `__init__`, all methods are private or match those of the base class.
77
+ """
78
+
79
+ model_name: GeminiModelName
80
+ auth: AuthProtocol
81
+ http_client: AsyncHTTPClient
82
+ url: str
83
+
84
+ def __init__(
85
+ self,
86
+ model_name: GeminiModelName,
87
+ *,
88
+ api_key: str | None = None,
89
+ http_client: AsyncHTTPClient | None = None,
90
+ url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
91
+ ):
92
+ """Initialize a Gemini model.
93
+
94
+ Args:
95
+ model_name: The name of the model to use.
96
+ api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
97
+ will be used if available.
98
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
99
+ url_template: The URL template to use for making requests, you shouldn't need to change this,
100
+ docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
101
+ `model` is substituted with the model name, and `function` is added to the end of the URL.
102
+ """
103
+ self.model_name = model_name
104
+ if api_key is None:
105
+ if env_api_key := os.getenv('GEMINI_API_KEY'):
106
+ api_key = env_api_key
107
+ else:
108
+ raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
109
+ self.auth = ApiKeyAuth(api_key)
110
+ self.http_client = http_client or cached_async_http_client()
111
+ self.url = url_template.format(model=model_name)
112
+
113
+ async def agent_model(
114
+ self,
115
+ retrievers: Mapping[str, AbstractToolDefinition],
116
+ allow_text_result: bool,
117
+ result_tools: Sequence[AbstractToolDefinition] | None,
118
+ ) -> GeminiAgentModel:
119
+ return GeminiAgentModel(
120
+ http_client=self.http_client,
121
+ model_name=self.model_name,
122
+ auth=self.auth,
123
+ url=self.url,
124
+ retrievers=retrievers,
125
+ allow_text_result=allow_text_result,
126
+ result_tools=result_tools,
127
+ )
128
+
129
+ def name(self) -> str:
130
+ return self.model_name
131
+
132
+
133
+ class AuthProtocol(Protocol):
134
+ async def headers(self) -> dict[str, str]: ...
135
+
136
+
137
+ @dataclass
138
+ class ApiKeyAuth:
139
+ api_key: str
140
+
141
+ async def headers(self) -> dict[str, str]:
142
+ # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
143
+ return {'X-Goog-Api-Key': self.api_key}
144
+
145
+
146
+ @dataclass(init=False)
147
+ class GeminiAgentModel(AgentModel):
148
+ """Implementation of `AgentModel` for Gemini models."""
149
+
150
+ http_client: AsyncHTTPClient
151
+ model_name: GeminiModelName
152
+ auth: AuthProtocol
153
+ tools: _GeminiTools | None
154
+ tool_config: _GeminiToolConfig | None
155
+ url: str
156
+
157
+ def __init__(
158
+ self,
159
+ http_client: AsyncHTTPClient,
160
+ model_name: GeminiModelName,
161
+ auth: AuthProtocol,
162
+ url: str,
163
+ retrievers: Mapping[str, AbstractToolDefinition],
164
+ allow_text_result: bool,
165
+ result_tools: Sequence[AbstractToolDefinition] | None,
166
+ ):
167
+ check_allow_model_requests()
168
+ tools = [_function_from_abstract_tool(t) for t in retrievers.values()]
169
+ if result_tools is not None:
170
+ tools += [_function_from_abstract_tool(t) for t in result_tools]
171
+
172
+ if allow_text_result:
173
+ tool_config = None
174
+ else:
175
+ tool_config = _tool_config([t['name'] for t in tools])
176
+
177
+ self.http_client = http_client
178
+ self.model_name = model_name
179
+ self.auth = auth
180
+ self.tools = _GeminiTools(function_declarations=tools) if tools else None
181
+ self.tool_config = tool_config
182
+ self.url = url
183
+
184
+ async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
185
+ async with self._make_request(messages, False) as http_response:
186
+ response = _gemini_response_ta.validate_json(await http_response.aread())
187
+ return self._process_response(response), _metadata_as_cost(response)
188
+
189
+ @asynccontextmanager
190
+ async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
191
+ async with self._make_request(messages, True) as http_response:
192
+ yield await self._process_streamed_response(http_response)
193
+
194
+ @asynccontextmanager
195
+ async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncIterator[HTTPResponse]:
196
+ contents: list[_GeminiContent] = []
197
+ sys_prompt_parts: list[_GeminiTextPart] = []
198
+ for m in messages:
199
+ either_content = self._message_to_gemini(m)
200
+ if left := either_content.left:
201
+ sys_prompt_parts.append(left.value)
202
+ else:
203
+ contents.append(either_content.right)
204
+
205
+ request_data = _GeminiRequest(contents=contents)
206
+ if sys_prompt_parts:
207
+ request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
208
+ if self.tools is not None:
209
+ request_data['tools'] = self.tools
210
+ if self.tool_config is not None:
211
+ request_data['tool_config'] = self.tool_config
212
+
213
+ url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
214
+
215
+ headers = {
216
+ 'Content-Type': 'application/json',
217
+ 'User-Agent': get_user_agent(),
218
+ **await self.auth.headers(),
219
+ }
220
+
221
+ request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
222
+
223
+ async with self.http_client.stream('POST', url, content=request_json, headers=headers) as r:
224
+ if r.status_code != 200:
225
+ await r.aread()
226
+ raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
227
+ yield r
228
+
229
+ @staticmethod
230
+ def _process_response(response: _GeminiResponse) -> ModelAnyResponse:
231
+ either = _extract_response_parts(response)
232
+ if left := either.left:
233
+ return _structured_response_from_parts(left.value)
234
+ else:
235
+ return ModelTextResponse(content=''.join(part['text'] for part in either.right))
236
+
237
+ @staticmethod
238
+ async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
239
+ """Process a streamed response, and prepare a streaming response to return."""
240
+ aiter_bytes = http_response.aiter_bytes()
241
+ start_response: _GeminiResponse | None = None
242
+ content = bytearray()
243
+
244
+ async for chunk in aiter_bytes:
245
+ content.extend(chunk)
246
+ responses = _gemini_streamed_response_ta.validate_json(
247
+ content,
248
+ experimental_allow_partial='trailing-strings',
249
+ )
250
+ if responses:
251
+ last = responses[-1]
252
+ if last['candidates'] and last['candidates'][0]['content']['parts']:
253
+ start_response = last
254
+ break
255
+
256
+ if start_response is None:
257
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
258
+
259
+ if _extract_response_parts(start_response).is_left():
260
+ return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
261
+ else:
262
+ return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
263
+
264
+ @staticmethod
265
+ def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]:
266
+ """Convert a message to a _GeminiTextPart for "system_instructions" or _GeminiContent for "contents"."""
267
+ if m.role == 'system':
268
+ # SystemPrompt ->
269
+ return _utils.Either(left=_GeminiTextPart(text=m.content))
270
+ elif m.role == 'user':
271
+ # UserPrompt ->
272
+ return _utils.Either(right=_content_user_text(m.content))
273
+ elif m.role == 'tool-return':
274
+ # ToolReturn ->
275
+ return _utils.Either(right=_content_function_return(m))
276
+ elif m.role == 'retry-prompt':
277
+ # RetryPrompt ->
278
+ return _utils.Either(right=_content_function_retry(m))
279
+ elif m.role == 'model-text-response':
280
+ # ModelTextResponse ->
281
+ return _utils.Either(right=_content_model_text(m.content))
282
+ elif m.role == 'model-structured-response':
283
+ # ModelStructuredResponse ->
284
+ return _utils.Either(right=_content_function_call(m))
285
+ else:
286
+ assert_never(m)
287
+
288
+
289
+ @dataclass
290
+ class GeminiStreamTextResponse(StreamTextResponse):
291
+ """Implementation of `StreamTextResponse` for the Gemini model."""
292
+
293
+ _json_content: bytearray
294
+ _stream: AsyncIterator[bytes]
295
+ _position: int = 0
296
+ _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
297
+ _cost: result.Cost = field(default_factory=result.Cost, init=False)
298
+
299
+ async def __anext__(self) -> None:
300
+ chunk = await self._stream.__anext__()
301
+ self._json_content.extend(chunk)
302
+
303
+ def get(self, *, final: bool = False) -> Iterable[str]:
304
+ if final:
305
+ all_items = pydantic_core.from_json(self._json_content)
306
+ new_items = all_items[self._position :]
307
+ self._position = len(all_items)
308
+ new_responses = _gemini_streamed_response_ta.validate_python(new_items)
309
+ else:
310
+ all_items = pydantic_core.from_json(self._json_content, allow_partial=True)
311
+ new_items = all_items[self._position : -1]
312
+ self._position = len(all_items) - 1
313
+ new_responses = _gemini_streamed_response_ta.validate_python(
314
+ new_items, experimental_allow_partial='trailing-strings'
315
+ )
316
+ for r in new_responses:
317
+ self._cost += _metadata_as_cost(r)
318
+ parts = r['candidates'][0]['content']['parts']
319
+ if _all_text_parts(parts):
320
+ for part in parts:
321
+ yield part['text']
322
+ else:
323
+ raise UnexpectedModelBehavior(
324
+ 'Streamed response with unexpected content, expected all parts to be text'
325
+ )
326
+
327
+ def cost(self) -> result.Cost:
328
+ return self._cost
329
+
330
+ def timestamp(self) -> datetime:
331
+ return self._timestamp
332
+
333
+
334
+ @dataclass
335
+ class GeminiStreamStructuredResponse(StreamStructuredResponse):
336
+ """Implementation of `StreamStructuredResponse` for the Gemini model."""
337
+
338
+ _content: bytearray
339
+ _stream: AsyncIterator[bytes]
340
+ _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
341
+ _cost: result.Cost = field(default_factory=result.Cost, init=False)
342
+
343
+ async def __anext__(self) -> None:
344
+ chunk = await self._stream.__anext__()
345
+ self._content.extend(chunk)
346
+
347
+ def get(self, *, final: bool = False) -> ModelStructuredResponse:
348
+ """Get the `ModelStructuredResponse` at this point.
349
+
350
+ NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
351
+ reply with a single response, when returning a structured data.
352
+
353
+ I'm therefore assuming that each part contains a complete tool call, and not trying to combine data from
354
+ separate parts.
355
+ """
356
+ responses = _gemini_streamed_response_ta.validate_json(
357
+ self._content,
358
+ experimental_allow_partial='off' if final else 'trailing-strings',
359
+ )
360
+ combined_parts: list[_GeminiFunctionCallPart] = []
361
+ self._cost = result.Cost()
362
+ for r in responses:
363
+ self._cost += _metadata_as_cost(r)
364
+ candidate = r['candidates'][0]
365
+ parts = candidate['content']['parts']
366
+ if _all_function_call_parts(parts):
367
+ combined_parts.extend(parts)
368
+ elif not candidate.get('finish_reason'):
369
+ # you can get an empty text part along with the finish_reason, so we ignore that case
370
+ raise UnexpectedModelBehavior(
371
+ 'Streamed response with unexpected content, expected all parts to be function calls'
372
+ )
373
+ return _structured_response_from_parts(combined_parts, timestamp=self._timestamp)
374
+
375
+ def cost(self) -> result.Cost:
376
+ return self._cost
377
+
378
+ def timestamp(self) -> datetime:
379
+ return self._timestamp
380
+
381
+
382
+ # We use typed dicts to define the Gemini API response schema
383
+ # once Pydantic partial validation supports, dataclasses, we could revert to using them
384
+ # TypeAdapters take care of validation and serialization
385
+
386
+
387
+ class _GeminiRequest(TypedDict):
388
+ """Schema for an API request to the Gemini API.
389
+
390
+ See <https://ai.google.dev/api/generate-content#request-body> for API docs.
391
+ """
392
+
393
+ contents: list[_GeminiContent]
394
+ tools: NotRequired[_GeminiTools]
395
+ tool_config: NotRequired[_GeminiToolConfig]
396
+ # we don't implement `generationConfig`, instead we use a named tool for the response
397
+ system_instruction: NotRequired[_GeminiTextContent]
398
+ """
399
+ Developer generated system instructions, see
400
+ <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
401
+ """
402
+
403
+
404
+ class _GeminiContent(TypedDict):
405
+ role: Literal['user', 'model']
406
+ parts: list[_GeminiPartUnion]
407
+
408
+
409
+ def _content_user_text(text: str) -> _GeminiContent:
410
+ return _GeminiContent(role='user', parts=[_GeminiTextPart(text=text)])
411
+
412
+
413
+ def _content_model_text(text: str) -> _GeminiContent:
414
+ return _GeminiContent(role='model', parts=[_GeminiTextPart(text=text)])
415
+
416
+
417
+ def _content_function_call(m: ModelStructuredResponse) -> _GeminiContent:
418
+ parts: list[_GeminiPartUnion] = [_function_call_part_from_call(t) for t in m.calls]
419
+ return _GeminiContent(role='model', parts=parts)
420
+
421
+
422
+ def _content_function_return(m: ToolReturn) -> _GeminiContent:
423
+ f_response = _response_part_from_response(m.tool_name, m.model_response_object())
424
+ return _GeminiContent(role='user', parts=[f_response])
425
+
426
+
427
+ def _content_function_retry(m: RetryPrompt) -> _GeminiContent:
428
+ if m.tool_name is None:
429
+ part = _GeminiTextPart(text=m.model_response())
430
+ else:
431
+ response = {'call_error': m.model_response()}
432
+ part = _response_part_from_response(m.tool_name, response)
433
+ return _GeminiContent(role='user', parts=[part])
434
+
435
+
436
+ class _GeminiTextPart(TypedDict):
437
+ text: str
438
+
439
+
440
+ class _GeminiFunctionCallPart(TypedDict):
441
+ function_call: Annotated[_GeminiFunctionCall, Field(alias='functionCall')]
442
+
443
+
444
+ def _function_call_part_from_call(tool: ToolCall) -> _GeminiFunctionCallPart:
445
+ assert isinstance(tool.args, ArgsObject), f'Expected ArgsObject, got {tool.args}'
446
+ return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_object))
447
+
448
+
449
+ def _structured_response_from_parts(
450
+ parts: list[_GeminiFunctionCallPart], timestamp: datetime | None = None
451
+ ) -> ModelStructuredResponse:
452
+ return ModelStructuredResponse(
453
+ calls=[ToolCall.from_object(part['function_call']['name'], part['function_call']['args']) for part in parts],
454
+ timestamp=timestamp or _utils.now_utc(),
455
+ )
456
+
457
+
458
+ class _GeminiFunctionCall(TypedDict):
459
+ """See <https://ai.google.dev/api/caching#FunctionCall>."""
460
+
461
+ name: str
462
+ args: dict[str, Any]
463
+
464
+
465
+ class _GeminiFunctionResponsePart(TypedDict):
466
+ function_response: Annotated[_GeminiFunctionResponse, Field(alias='functionResponse')]
467
+
468
+
469
+ def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
470
+ return _GeminiFunctionResponsePart(function_response=_GeminiFunctionResponse(name=name, response=response))
471
+
472
+
473
+ class _GeminiFunctionResponse(TypedDict):
474
+ """See <https://ai.google.dev/api/caching#FunctionResponse>."""
475
+
476
+ name: str
477
+ response: dict[str, Any]
478
+
479
+
480
+ def _part_discriminator(v: Any) -> str:
481
+ if isinstance(v, dict):
482
+ if 'text' in v:
483
+ return 'text'
484
+ elif 'functionCall' in v or 'function_call' in v:
485
+ return 'function_call'
486
+ elif 'functionResponse' in v or 'function_response' in v:
487
+ return 'function_response'
488
+ return 'text'
489
+
490
+
491
+ # See <https://ai.google.dev/api/caching#Part>
492
+ # we don't currently support other part types
493
+ # TODO discriminator
494
+ _GeminiPartUnion = Annotated[
495
+ Union[
496
+ Annotated[_GeminiTextPart, Tag('text')],
497
+ Annotated[_GeminiFunctionCallPart, Tag('function_call')],
498
+ Annotated[_GeminiFunctionResponsePart, Tag('function_response')],
499
+ ],
500
+ Discriminator(_part_discriminator),
501
+ ]
502
+
503
+
504
+ class _GeminiTextContent(TypedDict):
505
+ role: Literal['user', 'model']
506
+ parts: list[_GeminiTextPart]
507
+
508
+
509
+ class _GeminiTools(TypedDict):
510
+ function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]]
511
+
512
+
513
+ class _GeminiFunction(TypedDict):
514
+ name: str
515
+ description: str
516
+ parameters: NotRequired[dict[str, Any]]
517
+ """
518
+ ObjectJsonSchema isn't really true since Gemini only accepts a subset of JSON Schema
519
+ <https://ai.google.dev/gemini-api/docs/function-calling#function_declarations>
520
+ and
521
+ <https://ai.google.dev/api/caching#FunctionDeclaration>
522
+ """
523
+
524
+
525
+ def _function_from_abstract_tool(tool: AbstractToolDefinition) -> _GeminiFunction:
526
+ json_schema = _GeminiJsonSchema(tool.json_schema).simplify()
527
+ f = _GeminiFunction(
528
+ name=tool.name,
529
+ description=tool.description,
530
+ )
531
+ if json_schema.get('properties'):
532
+ f['parameters'] = json_schema
533
+ return f
534
+
535
+
536
+ class _GeminiToolConfig(TypedDict):
537
+ function_calling_config: _GeminiFunctionCallingConfig
538
+
539
+
540
+ def _tool_config(function_names: list[str]) -> _GeminiToolConfig:
541
+ return _GeminiToolConfig(
542
+ function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=function_names)
543
+ )
544
+
545
+
546
+ class _GeminiFunctionCallingConfig(TypedDict):
547
+ mode: Literal['ANY', 'AUTO']
548
+ allowed_function_names: list[str]
549
+
550
+
551
+ class _GeminiResponse(TypedDict):
552
+ """Schema for the response from the Gemini API.
553
+
554
+ See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>
555
+ and <https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse>
556
+ """
557
+
558
+ candidates: list[_GeminiCandidates]
559
+ # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
560
+ usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]]
561
+ prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, Field(alias='promptFeedback')]]
562
+
563
+
564
+ def _extract_response_parts(
565
+ response: _GeminiResponse,
566
+ ) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
567
+ """Extract the parts of the response from the Gemini API.
568
+
569
+ Returns Either a list of function calls (Either.left) or a list of text parts (Either.right).
570
+ """
571
+ if len(response['candidates']) != 1:
572
+ raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
573
+ parts = response['candidates'][0]['content']['parts']
574
+ if _all_function_call_parts(parts):
575
+ return _utils.Either(left=parts)
576
+ elif _all_text_parts(parts):
577
+ return _utils.Either(right=parts)
578
+ else:
579
+ raise exceptions.UnexpectedModelBehavior(
580
+ f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {parts!r}'
581
+ )
582
+
583
+
584
+ def _all_function_call_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiFunctionCallPart]]:
585
+ return all('function_call' in part for part in parts)
586
+
587
+
588
+ def _all_text_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiTextPart]]:
589
+ return all('text' in part for part in parts)
590
+
591
+
592
+ class _GeminiCandidates(TypedDict):
593
+ """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
594
+
595
+ content: _GeminiContent
596
+ finish_reason: NotRequired[Annotated[Literal['STOP'], Field(alias='finishReason')]]
597
+ """
598
+ See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
599
+ but let's wait until we see them and know what they mean to add them here.
600
+ """
601
+ avg_log_probs: NotRequired[Annotated[float, Field(alias='avgLogProbs')]]
602
+ index: NotRequired[int]
603
+ safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]]
604
+
605
+
606
+ class _GeminiUsageMetaData(TypedDict, total=False):
607
+ """See <https://ai.google.dev/api/generate-content#FinishReason>.
608
+
609
+ The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
610
+ """
611
+
612
+ prompt_token_count: Annotated[int, Field(alias='promptTokenCount')]
613
+ candidates_token_count: NotRequired[Annotated[int, Field(alias='candidatesTokenCount')]]
614
+ total_token_count: Annotated[int, Field(alias='totalTokenCount')]
615
+ cached_content_token_count: NotRequired[Annotated[int, Field(alias='cachedContentTokenCount')]]
616
+
617
+
618
+ def _metadata_as_cost(response: _GeminiResponse) -> result.Cost:
619
+ metadata = response.get('usage_metadata')
620
+ if metadata is None:
621
+ return result.Cost()
622
+ details: dict[str, int] = {}
623
+ if cached_content_token_count := metadata.get('cached_content_token_count'):
624
+ details['cached_content_token_count'] = cached_content_token_count
625
+ return result.Cost(
626
+ request_tokens=metadata.get('prompt_token_count', 0),
627
+ response_tokens=metadata.get('candidates_token_count', 0),
628
+ total_tokens=metadata.get('total_token_count', 0),
629
+ details=details,
630
+ )
631
+
632
+
633
+ class _GeminiSafetyRating(TypedDict):
634
+ """See <https://ai.google.dev/gemini-api/docs/safety-settings#safety-filters>."""
635
+
636
+ category: Literal[
637
+ 'HARM_CATEGORY_HARASSMENT',
638
+ 'HARM_CATEGORY_HATE_SPEECH',
639
+ 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
640
+ 'HARM_CATEGORY_DANGEROUS_CONTENT',
641
+ 'HARM_CATEGORY_CIVIC_INTEGRITY',
642
+ ]
643
+ probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH']
644
+
645
+
646
+ class _GeminiPromptFeedback(TypedDict):
647
+ """See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
648
+
649
+ block_reason: Annotated[str, Field(alias='blockReason')]
650
+ safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]
651
+
652
+
653
+ _gemini_request_ta = _pydantic.LazyTypeAdapter(_GeminiRequest)
654
+ _gemini_response_ta = _pydantic.LazyTypeAdapter(_GeminiResponse)
655
+
656
+ # steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
657
+ _gemini_streamed_response_ta = _pydantic.LazyTypeAdapter(list[_GeminiResponse])
658
+
659
+
660
+ class _GeminiJsonSchema:
661
+ """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
662
+
663
+ Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
664
+ a subset of OpenAPI v3.0.3.
665
+
666
+ Specifically:
667
+ * gemini doesn't allow the `title` keyword to be set
668
+ * gemini doesn't allow `$defs` — we need to inline the definitions where possible
669
+ """
670
+
671
+ def __init__(self, schema: _utils.ObjectJsonSchema):
672
+ self.schema = deepcopy(schema)
673
+ self.defs = self.schema.pop('$defs', {})
674
+
675
+ def simplify(self) -> dict[str, Any]:
676
+ self._simplify(self.schema, refs_stack=())
677
+ return self.schema
678
+
679
+ def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
680
+ schema.pop('title', None)
681
+ schema.pop('default', None)
682
+ if ref := schema.pop('$ref', None):
683
+ # noinspection PyTypeChecker
684
+ key = re.sub(r'^#/\$defs/', '', ref)
685
+ if key in refs_stack:
686
+ raise exceptions.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
687
+ refs_stack += (key,)
688
+ schema_def = self.defs[key]
689
+ self._simplify(schema_def, refs_stack)
690
+ schema.update(schema_def)
691
+ return
692
+
693
+ if any_of := schema.get('anyOf'):
694
+ for schema in any_of:
695
+ self._simplify(schema, refs_stack)
696
+
697
+ type_ = schema.get('type')
698
+
699
+ if type_ == 'object':
700
+ self._object(schema, refs_stack)
701
+ elif type_ == 'array':
702
+ return self._array(schema, refs_stack)
703
+
704
+ def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
705
+ ad_props = schema.pop('additionalProperties', None)
706
+ if ad_props:
707
+ raise exceptions.UserError('Additional properties in JSON Schema are not supported by Gemini')
708
+
709
+ if properties := schema.get('properties'): # pragma: no branch
710
+ for value in properties.values():
711
+ self._simplify(value, refs_stack)
712
+
713
+ def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
714
+ if prefix_items := schema.get('prefixItems'):
715
+ # TODO I think this not is supported by Gemini, maybe we should raise an error?
716
+ for prefix_item in prefix_items:
717
+ self._simplify(prefix_item, refs_stack)
718
+
719
+ if items_schema := schema.get('items'): # pragma: no branch
720
+ self._simplify(items_schema, refs_stack)