pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -2,29 +2,32 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
4
  import re
5
- from collections.abc import AsyncIterator, Iterable
5
+ from collections.abc import AsyncIterator, Iterable, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from copy import deepcopy
8
8
  from dataclasses import dataclass, field
9
9
  from datetime import datetime
10
10
  from typing import Annotated, Any, Literal, Protocol, Union
11
11
 
12
+ import pydantic
12
13
  import pydantic_core
13
- from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse
14
- from pydantic import Discriminator, Field, Tag
14
+ from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
15
15
  from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
16
16
 
17
- from .. import UnexpectedModelBehavior, _pydantic, _utils, exceptions, result
17
+ from .. import UnexpectedModelBehavior, _utils, exceptions, result
18
18
  from ..messages import (
19
- ArgsDict,
20
- Message,
21
- ModelAnyResponse,
22
- ModelStructuredResponse,
23
- ModelTextResponse,
24
- RetryPrompt,
25
- ToolCall,
26
- ToolReturn,
19
+ ModelMessage,
20
+ ModelRequest,
21
+ ModelResponse,
22
+ ModelResponsePart,
23
+ RetryPromptPart,
24
+ SystemPromptPart,
25
+ TextPart,
26
+ ToolCallPart,
27
+ ToolReturnPart,
28
+ UserPromptPart,
27
29
  )
30
+ from ..settings import ModelSettings
28
31
  from ..tools import ToolDefinition
29
32
  from . import (
30
33
  AgentModel,
@@ -37,7 +40,9 @@ from . import (
37
40
  get_user_agent,
38
41
  )
39
42
 
40
- GeminiModelName = Literal['gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro']
43
+ GeminiModelName = Literal[
44
+ 'gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro', 'gemini-2.0-flash-exp'
45
+ ]
41
46
  """Named Gemini models.
42
47
 
43
48
  See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
@@ -164,26 +169,25 @@ class GeminiAgentModel(AgentModel):
164
169
  self.tool_config = tool_config
165
170
  self.url = url
166
171
 
167
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
168
- async with self._make_request(messages, False) as http_response:
172
+ async def request(
173
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
174
+ ) -> tuple[ModelResponse, result.Usage]:
175
+ async with self._make_request(messages, False, model_settings) as http_response:
169
176
  response = _gemini_response_ta.validate_json(await http_response.aread())
170
- return self._process_response(response), _metadata_as_cost(response)
177
+ return self._process_response(response), _metadata_as_usage(response)
171
178
 
172
179
  @asynccontextmanager
173
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
174
- async with self._make_request(messages, True) as http_response:
180
+ async def request_stream(
181
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
182
+ ) -> AsyncIterator[EitherStreamedResponse]:
183
+ async with self._make_request(messages, True, model_settings) as http_response:
175
184
  yield await self._process_streamed_response(http_response)
176
185
 
177
186
  @asynccontextmanager
178
- async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncIterator[HTTPResponse]:
179
- contents: list[_GeminiContent] = []
180
- sys_prompt_parts: list[_GeminiTextPart] = []
181
- for m in messages:
182
- either_content = self._message_to_gemini(m)
183
- if left := either_content.left:
184
- sys_prompt_parts.append(left.value)
185
- else:
186
- contents.append(either_content.right)
187
+ async def _make_request(
188
+ self, messages: list[ModelMessage], streamed: bool, model_settings: ModelSettings | None
189
+ ) -> AsyncIterator[HTTPResponse]:
190
+ sys_prompt_parts, contents = self._message_to_gemini_content(messages)
187
191
 
188
192
  request_data = _GeminiRequest(contents=contents)
189
193
  if sys_prompt_parts:
@@ -193,6 +197,17 @@ class GeminiAgentModel(AgentModel):
193
197
  if self.tool_config is not None:
194
198
  request_data['tool_config'] = self.tool_config
195
199
 
200
+ generation_config: _GeminiGenerationConfig = {}
201
+ if model_settings:
202
+ if (max_tokens := model_settings.get('max_tokens')) is not None:
203
+ generation_config['max_output_tokens'] = max_tokens
204
+ if (temperature := model_settings.get('temperature')) is not None:
205
+ generation_config['temperature'] = temperature
206
+ if (top_p := model_settings.get('top_p')) is not None:
207
+ generation_config['top_p'] = top_p
208
+ if generation_config:
209
+ request_data['generation_config'] = generation_config
210
+
196
211
  url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
197
212
 
198
213
  headers = {
@@ -203,19 +218,24 @@ class GeminiAgentModel(AgentModel):
203
218
 
204
219
  request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
205
220
 
206
- async with self.http_client.stream('POST', url, content=request_json, headers=headers) as r:
221
+ async with self.http_client.stream(
222
+ 'POST',
223
+ url,
224
+ content=request_json,
225
+ headers=headers,
226
+ timeout=(model_settings or {}).get('timeout', USE_CLIENT_DEFAULT),
227
+ ) as r:
207
228
  if r.status_code != 200:
208
229
  await r.aread()
209
230
  raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
210
231
  yield r
211
232
 
212
233
  @staticmethod
213
- def _process_response(response: _GeminiResponse) -> ModelAnyResponse:
214
- either = _extract_response_parts(response)
215
- if left := either.left:
216
- return _structured_response_from_parts(left.value)
217
- else:
218
- return ModelTextResponse(content=''.join(part['text'] for part in either.right))
234
+ def _process_response(response: _GeminiResponse) -> ModelResponse:
235
+ if len(response['candidates']) != 1:
236
+ raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
237
+ parts = response['candidates'][0]['content']['parts']
238
+ return _process_response_from_parts(parts)
219
239
 
220
240
  @staticmethod
221
241
  async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
@@ -239,34 +259,37 @@ class GeminiAgentModel(AgentModel):
239
259
  if start_response is None:
240
260
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
241
261
 
262
+ # TODO: Update this once we rework stream responses to be more flexible
242
263
  if _extract_response_parts(start_response).is_left():
243
264
  return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
244
265
  else:
245
266
  return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
246
267
 
247
- @staticmethod
248
- def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]:
249
- """Convert a message to a _GeminiTextPart for "system_instructions" or _GeminiContent for "contents"."""
250
- if m.role == 'system':
251
- # SystemPrompt ->
252
- return _utils.Either(left=_GeminiTextPart(text=m.content))
253
- elif m.role == 'user':
254
- # UserPrompt ->
255
- return _utils.Either(right=_content_user_text(m.content))
256
- elif m.role == 'tool-return':
257
- # ToolReturn ->
258
- return _utils.Either(right=_content_function_return(m))
259
- elif m.role == 'retry-prompt':
260
- # RetryPrompt ->
261
- return _utils.Either(right=_content_function_retry(m))
262
- elif m.role == 'model-text-response':
263
- # ModelTextResponse ->
264
- return _utils.Either(right=_content_model_text(m.content))
265
- elif m.role == 'model-structured-response':
266
- # ModelStructuredResponse ->
267
- return _utils.Either(right=_content_function_call(m))
268
- else:
269
- assert_never(m)
268
+ @classmethod
269
+ def _message_to_gemini_content(
270
+ cls, messages: list[ModelMessage]
271
+ ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
272
+ sys_prompt_parts: list[_GeminiTextPart] = []
273
+ contents: list[_GeminiContent] = []
274
+ for m in messages:
275
+ if isinstance(m, ModelRequest):
276
+ for part in m.parts:
277
+ if isinstance(part, SystemPromptPart):
278
+ sys_prompt_parts.append(_GeminiTextPart(text=part.content))
279
+ elif isinstance(part, UserPromptPart):
280
+ contents.append(_content_user_prompt(part))
281
+ elif isinstance(part, ToolReturnPart):
282
+ contents.append(_content_tool_return(part))
283
+ elif isinstance(part, RetryPromptPart):
284
+ contents.append(_content_retry_prompt(part))
285
+ else:
286
+ assert_never(part)
287
+ elif isinstance(m, ModelResponse):
288
+ contents.append(_content_model_response(m))
289
+ else:
290
+ assert_never(m)
291
+
292
+ return sys_prompt_parts, contents
270
293
 
271
294
 
272
295
  @dataclass
@@ -277,7 +300,7 @@ class GeminiStreamTextResponse(StreamTextResponse):
277
300
  _stream: AsyncIterator[bytes]
278
301
  _position: int = 0
279
302
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
280
- _cost: result.Cost = field(default_factory=result.Cost, init=False)
303
+ _usage: result.Usage = field(default_factory=result.Usage, init=False)
281
304
 
282
305
  async def __anext__(self) -> None:
283
306
  chunk = await self._stream.__anext__()
@@ -297,7 +320,7 @@ class GeminiStreamTextResponse(StreamTextResponse):
297
320
  new_items, experimental_allow_partial='trailing-strings'
298
321
  )
299
322
  for r in new_responses:
300
- self._cost += _metadata_as_cost(r)
323
+ self._usage += _metadata_as_usage(r)
301
324
  parts = r['candidates'][0]['content']['parts']
302
325
  if _all_text_parts(parts):
303
326
  for part in parts:
@@ -307,8 +330,8 @@ class GeminiStreamTextResponse(StreamTextResponse):
307
330
  'Streamed response with unexpected content, expected all parts to be text'
308
331
  )
309
332
 
310
- def cost(self) -> result.Cost:
311
- return self._cost
333
+ def usage(self) -> result.Usage:
334
+ return self._usage
312
335
 
313
336
  def timestamp(self) -> datetime:
314
337
  return self._timestamp
@@ -321,14 +344,14 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
321
344
  _content: bytearray
322
345
  _stream: AsyncIterator[bytes]
323
346
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
324
- _cost: result.Cost = field(default_factory=result.Cost, init=False)
347
+ _usage: result.Usage = field(default_factory=result.Usage, init=False)
325
348
 
326
349
  async def __anext__(self) -> None:
327
350
  chunk = await self._stream.__anext__()
328
351
  self._content.extend(chunk)
329
352
 
330
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
331
- """Get the `ModelStructuredResponse` at this point.
353
+ def get(self, *, final: bool = False) -> ModelResponse:
354
+ """Get the `ModelResponse` at this point.
332
355
 
333
356
  NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
334
357
  reply with a single response, when returning a structured data.
@@ -340,23 +363,16 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
340
363
  self._content,
341
364
  experimental_allow_partial='off' if final else 'trailing-strings',
342
365
  )
343
- combined_parts: list[_GeminiFunctionCallPart] = []
344
- self._cost = result.Cost()
366
+ combined_parts: list[_GeminiPartUnion] = []
367
+ self._usage = result.Usage()
345
368
  for r in responses:
346
- self._cost += _metadata_as_cost(r)
369
+ self._usage += _metadata_as_usage(r)
347
370
  candidate = r['candidates'][0]
348
- parts = candidate['content']['parts']
349
- if _all_function_call_parts(parts):
350
- combined_parts.extend(parts)
351
- elif not candidate.get('finish_reason'):
352
- # you can get an empty text part along with the finish_reason, so we ignore that case
353
- raise UnexpectedModelBehavior(
354
- 'Streamed response with unexpected content, expected all parts to be function calls'
355
- )
356
- return _structured_response_from_parts(combined_parts, timestamp=self._timestamp)
371
+ combined_parts.extend(candidate['content']['parts'])
372
+ return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
357
373
 
358
- def cost(self) -> result.Cost:
359
- return self._cost
374
+ def usage(self) -> result.Usage:
375
+ return self._usage
360
376
 
361
377
  def timestamp(self) -> datetime:
362
378
  return self._timestamp
@@ -367,6 +383,7 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
367
383
  # TypeAdapters take care of validation and serialization
368
384
 
369
385
 
386
+ @pydantic.with_config(pydantic.ConfigDict(defer_build=True))
370
387
  class _GeminiRequest(TypedDict):
371
388
  """Schema for an API request to the Gemini API.
372
389
 
@@ -382,32 +399,37 @@ class _GeminiRequest(TypedDict):
382
399
  Developer generated system instructions, see
383
400
  <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
384
401
  """
402
+ generation_config: NotRequired[_GeminiGenerationConfig]
385
403
 
386
404
 
387
- class _GeminiContent(TypedDict):
388
- role: Literal['user', 'model']
389
- parts: list[_GeminiPartUnion]
405
+ class _GeminiGenerationConfig(TypedDict, total=False):
406
+ """Schema for an API request to the Gemini API.
390
407
 
408
+ Note there are many additional fields available that have not been added yet.
391
409
 
392
- def _content_user_text(text: str) -> _GeminiContent:
393
- return _GeminiContent(role='user', parts=[_GeminiTextPart(text=text)])
410
+ See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
411
+ """
412
+
413
+ max_output_tokens: int
414
+ temperature: float
415
+ top_p: float
394
416
 
395
417
 
396
- def _content_model_text(text: str) -> _GeminiContent:
397
- return _GeminiContent(role='model', parts=[_GeminiTextPart(text=text)])
418
+ class _GeminiContent(TypedDict):
419
+ role: Literal['user', 'model']
420
+ parts: list[_GeminiPartUnion]
398
421
 
399
422
 
400
- def _content_function_call(m: ModelStructuredResponse) -> _GeminiContent:
401
- parts: list[_GeminiPartUnion] = [_function_call_part_from_call(t) for t in m.calls]
402
- return _GeminiContent(role='model', parts=parts)
423
+ def _content_user_prompt(m: UserPromptPart) -> _GeminiContent:
424
+ return _GeminiContent(role='user', parts=[_GeminiTextPart(text=m.content)])
403
425
 
404
426
 
405
- def _content_function_return(m: ToolReturn) -> _GeminiContent:
427
+ def _content_tool_return(m: ToolReturnPart) -> _GeminiContent:
406
428
  f_response = _response_part_from_response(m.tool_name, m.model_response_object())
407
429
  return _GeminiContent(role='user', parts=[f_response])
408
430
 
409
431
 
410
- def _content_function_retry(m: RetryPrompt) -> _GeminiContent:
432
+ def _content_retry_prompt(m: RetryPromptPart) -> _GeminiContent:
411
433
  if m.tool_name is None:
412
434
  part = _GeminiTextPart(text=m.model_response())
413
435
  else:
@@ -416,26 +438,42 @@ def _content_function_retry(m: RetryPrompt) -> _GeminiContent:
416
438
  return _GeminiContent(role='user', parts=[part])
417
439
 
418
440
 
441
+ def _content_model_response(m: ModelResponse) -> _GeminiContent:
442
+ parts: list[_GeminiPartUnion] = []
443
+ for item in m.parts:
444
+ if isinstance(item, ToolCallPart):
445
+ parts.append(_function_call_part_from_call(item))
446
+ elif isinstance(item, TextPart):
447
+ parts.append(_GeminiTextPart(text=item.content))
448
+ else:
449
+ assert_never(item)
450
+ return _GeminiContent(role='model', parts=parts)
451
+
452
+
419
453
  class _GeminiTextPart(TypedDict):
420
454
  text: str
421
455
 
422
456
 
423
457
  class _GeminiFunctionCallPart(TypedDict):
424
- function_call: Annotated[_GeminiFunctionCall, Field(alias='functionCall')]
458
+ function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
425
459
 
426
460
 
427
- def _function_call_part_from_call(tool: ToolCall) -> _GeminiFunctionCallPart:
428
- assert isinstance(tool.args, ArgsDict), f'Expected ArgsObject, got {tool.args}'
429
- return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_dict))
461
+ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
462
+ return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
430
463
 
431
464
 
432
- def _structured_response_from_parts(
433
- parts: list[_GeminiFunctionCallPart], timestamp: datetime | None = None
434
- ) -> ModelStructuredResponse:
435
- return ModelStructuredResponse(
436
- calls=[ToolCall.from_dict(part['function_call']['name'], part['function_call']['args']) for part in parts],
437
- timestamp=timestamp or _utils.now_utc(),
438
- )
465
+ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
466
+ items: list[ModelResponsePart] = []
467
+ for part in parts:
468
+ if 'text' in part:
469
+ items.append(TextPart(part['text']))
470
+ elif 'function_call' in part:
471
+ items.append(ToolCallPart.from_raw_args(part['function_call']['name'], part['function_call']['args']))
472
+ elif 'function_response' in part:
473
+ raise exceptions.UnexpectedModelBehavior(
474
+ f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
475
+ )
476
+ return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
439
477
 
440
478
 
441
479
  class _GeminiFunctionCall(TypedDict):
@@ -446,7 +484,7 @@ class _GeminiFunctionCall(TypedDict):
446
484
 
447
485
 
448
486
  class _GeminiFunctionResponsePart(TypedDict):
449
- function_response: Annotated[_GeminiFunctionResponse, Field(alias='functionResponse')]
487
+ function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')]
450
488
 
451
489
 
452
490
  def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
@@ -476,11 +514,11 @@ def _part_discriminator(v: Any) -> str:
476
514
  # TODO discriminator
477
515
  _GeminiPartUnion = Annotated[
478
516
  Union[
479
- Annotated[_GeminiTextPart, Tag('text')],
480
- Annotated[_GeminiFunctionCallPart, Tag('function_call')],
481
- Annotated[_GeminiFunctionResponsePart, Tag('function_response')],
517
+ Annotated[_GeminiTextPart, pydantic.Tag('text')],
518
+ Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
519
+ Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
482
520
  ],
483
- Discriminator(_part_discriminator),
521
+ pydantic.Discriminator(_part_discriminator),
484
522
  ]
485
523
 
486
524
 
@@ -490,7 +528,7 @@ class _GeminiTextContent(TypedDict):
490
528
 
491
529
 
492
530
  class _GeminiTools(TypedDict):
493
- function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]]
531
+ function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]]
494
532
 
495
533
 
496
534
  class _GeminiFunction(TypedDict):
@@ -531,6 +569,7 @@ class _GeminiFunctionCallingConfig(TypedDict):
531
569
  allowed_function_names: list[str]
532
570
 
533
571
 
572
+ @pydantic.with_config(pydantic.ConfigDict(defer_build=True))
534
573
  class _GeminiResponse(TypedDict):
535
574
  """Schema for the response from the Gemini API.
536
575
 
@@ -540,10 +579,11 @@ class _GeminiResponse(TypedDict):
540
579
 
541
580
  candidates: list[_GeminiCandidates]
542
581
  # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
543
- usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]]
544
- prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, Field(alias='promptFeedback')]]
582
+ usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
583
+ prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
545
584
 
546
585
 
586
+ # TODO: Delete the next three functions once we've reworked streams to be more flexible
547
587
  def _extract_response_parts(
548
588
  response: _GeminiResponse,
549
589
  ) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
@@ -576,14 +616,14 @@ class _GeminiCandidates(TypedDict):
576
616
  """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
577
617
 
578
618
  content: _GeminiContent
579
- finish_reason: NotRequired[Annotated[Literal['STOP'], Field(alias='finishReason')]]
619
+ finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
580
620
  """
581
621
  See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
582
622
  but let's wait until we see them and know what they mean to add them here.
583
623
  """
584
- avg_log_probs: NotRequired[Annotated[float, Field(alias='avgLogProbs')]]
624
+ avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]]
585
625
  index: NotRequired[int]
586
- safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]]
626
+ safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
587
627
 
588
628
 
589
629
  class _GeminiUsageMetaData(TypedDict, total=False):
@@ -592,20 +632,20 @@ class _GeminiUsageMetaData(TypedDict, total=False):
592
632
  The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
593
633
  """
594
634
 
595
- prompt_token_count: Annotated[int, Field(alias='promptTokenCount')]
596
- candidates_token_count: NotRequired[Annotated[int, Field(alias='candidatesTokenCount')]]
597
- total_token_count: Annotated[int, Field(alias='totalTokenCount')]
598
- cached_content_token_count: NotRequired[Annotated[int, Field(alias='cachedContentTokenCount')]]
635
+ prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')]
636
+ candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
637
+ total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
638
+ cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
599
639
 
600
640
 
601
- def _metadata_as_cost(response: _GeminiResponse) -> result.Cost:
641
+ def _metadata_as_usage(response: _GeminiResponse) -> result.Usage:
602
642
  metadata = response.get('usage_metadata')
603
643
  if metadata is None:
604
- return result.Cost()
644
+ return result.Usage()
605
645
  details: dict[str, int] = {}
606
646
  if cached_content_token_count := metadata.get('cached_content_token_count'):
607
647
  details['cached_content_token_count'] = cached_content_token_count
608
- return result.Cost(
648
+ return result.Usage(
609
649
  request_tokens=metadata.get('prompt_token_count', 0),
610
650
  response_tokens=metadata.get('candidates_token_count', 0),
611
651
  total_tokens=metadata.get('total_token_count', 0),
@@ -629,15 +669,15 @@ class _GeminiSafetyRating(TypedDict):
629
669
  class _GeminiPromptFeedback(TypedDict):
630
670
  """See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
631
671
 
632
- block_reason: Annotated[str, Field(alias='blockReason')]
633
- safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]
672
+ block_reason: Annotated[str, pydantic.Field(alias='blockReason')]
673
+ safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]
634
674
 
635
675
 
636
- _gemini_request_ta = _pydantic.LazyTypeAdapter(_GeminiRequest)
637
- _gemini_response_ta = _pydantic.LazyTypeAdapter(_GeminiResponse)
676
+ _gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest)
677
+ _gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
638
678
 
639
679
  # steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
640
- _gemini_streamed_response_ta = _pydantic.LazyTypeAdapter(list[_GeminiResponse])
680
+ _gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
641
681
 
642
682
 
643
683
  class _GeminiJsonSchema: