pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -4,7 +4,7 @@ import inspect
4
4
  import re
5
5
  from collections.abc import AsyncIterator, Awaitable, Iterable
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field
7
+ from dataclasses import dataclass, field, replace
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
10
  from typing import Callable, Union, cast
@@ -13,7 +13,20 @@ import pydantic_core
13
13
  from typing_extensions import TypeAlias, assert_never, overload
14
14
 
15
15
  from .. import _utils, result
16
- from ..messages import ArgsJson, Message, ModelAnyResponse, ModelStructuredResponse, ToolCall
16
+ from ..messages import (
17
+ ArgsJson,
18
+ ModelMessage,
19
+ ModelRequest,
20
+ ModelResponse,
21
+ ModelResponsePart,
22
+ RetryPromptPart,
23
+ SystemPromptPart,
24
+ TextPart,
25
+ ToolCallPart,
26
+ ToolReturnPart,
27
+ UserPromptPart,
28
+ )
29
+ from ..settings import ModelSettings
17
30
  from ..tools import ToolDefinition
18
31
  from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
19
32
 
@@ -59,7 +72,7 @@ class FunctionModel(Model):
59
72
  result_tools: list[ToolDefinition],
60
73
  ) -> AgentModel:
61
74
  return FunctionAgentModel(
62
- self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
75
+ self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
63
76
  )
64
77
 
65
78
  def name(self) -> str:
@@ -88,6 +101,8 @@ class AgentInfo:
88
101
  """Whether a plain text result is allowed."""
89
102
  result_tools: list[ToolDefinition]
90
103
  """The tools that can called as the final result of the run."""
104
+ model_settings: ModelSettings | None
105
+ """The model settings passed to the run call."""
91
106
 
92
107
 
93
108
  @dataclass
@@ -106,10 +121,10 @@ class DeltaToolCall:
106
121
  DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
107
122
  """A mapping of tool call IDs to incremental changes."""
108
123
 
109
- FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[ModelAnyResponse, Awaitable[ModelAnyResponse]]]
124
+ FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
110
125
  """A function used to generate a non-streamed response."""
111
126
 
112
- StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
127
+ StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
113
128
  """A function used to generate a streamed response.
114
129
 
115
130
  While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
@@ -127,18 +142,25 @@ class FunctionAgentModel(AgentModel):
127
142
  stream_function: StreamFunctionDef | None
128
143
  agent_info: AgentInfo
129
144
 
130
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
145
+ async def request(
146
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
147
+ ) -> tuple[ModelResponse, result.Cost]:
148
+ agent_info = replace(self.agent_info, model_settings=model_settings)
149
+
131
150
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
132
151
  if inspect.iscoroutinefunction(self.function):
133
- response = await self.function(messages, self.agent_info)
152
+ response = await self.function(messages, agent_info)
134
153
  else:
135
- response_ = await _utils.run_in_executor(self.function, messages, self.agent_info)
136
- response = cast(ModelAnyResponse, response_)
154
+ response_ = await _utils.run_in_executor(self.function, messages, agent_info)
155
+ assert isinstance(response_, ModelResponse), response_
156
+ response = response_
137
157
  # TODO is `messages` right here? Should it just be new messages?
138
158
  return response, _estimate_cost(chain(messages, [response]))
139
159
 
140
160
  @asynccontextmanager
141
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
161
+ async def request_stream(
162
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
163
+ ) -> AsyncIterator[EitherStreamedResponse]:
142
164
  assert (
143
165
  self.stream_function is not None
144
166
  ), 'FunctionModel must receive a `stream_function` to support streamed requests'
@@ -206,13 +228,13 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
206
228
  else:
207
229
  self._delta_tool_calls[key] = new
208
230
 
209
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
210
- calls: list[ToolCall] = []
231
+ def get(self, *, final: bool = False) -> ModelResponse:
232
+ calls: list[ModelResponsePart] = []
211
233
  for c in self._delta_tool_calls.values():
212
234
  if c.name is not None and c.json_args is not None:
213
- calls.append(ToolCall.from_json(c.name, c.json_args))
235
+ calls.append(ToolCallPart.from_json(c.name, c.json_args))
214
236
 
215
- return ModelStructuredResponse(calls, timestamp=self._timestamp)
237
+ return ModelResponse(calls, timestamp=self._timestamp)
216
238
 
217
239
  def cost(self) -> result.Cost:
218
240
  return result.Cost()
@@ -221,32 +243,38 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
221
243
  return self._timestamp
222
244
 
223
245
 
224
- def _estimate_cost(messages: Iterable[Message]) -> result.Cost:
246
+ def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
225
247
  """Very rough guesstimate of the number of tokens associate with a series of messages.
226
248
 
227
249
  This is designed to be used solely to give plausible numbers for testing!
228
250
  """
229
251
  # there seem to be about 50 tokens of overhead for both Gemini and OpenAI calls, so add that here ¯\_(ツ)_/¯
230
-
231
252
  request_tokens = 50
232
253
  response_tokens = 0
233
254
  for message in messages:
234
- if message.role == 'system' or message.role == 'user':
235
- request_tokens += _string_cost(message.content)
236
- elif message.role == 'tool-return':
237
- request_tokens += _string_cost(message.model_response_str())
238
- elif message.role == 'retry-prompt':
239
- request_tokens += _string_cost(message.model_response())
240
- elif message.role == 'model-text-response':
241
- response_tokens += _string_cost(message.content)
242
- elif message.role == 'model-structured-response':
243
- for call in message.calls:
244
- if isinstance(call.args, ArgsJson):
245
- args_str = call.args.args_json
255
+ if isinstance(message, ModelRequest):
256
+ for part in message.parts:
257
+ if isinstance(part, (SystemPromptPart, UserPromptPart)):
258
+ request_tokens += _string_cost(part.content)
259
+ elif isinstance(part, ToolReturnPart):
260
+ request_tokens += _string_cost(part.model_response_str())
261
+ elif isinstance(part, RetryPromptPart):
262
+ request_tokens += _string_cost(part.model_response())
246
263
  else:
247
- args_str = pydantic_core.to_json(call.args.args_dict).decode()
248
-
249
- response_tokens += 1 + _string_cost(args_str)
264
+ assert_never(part)
265
+ elif isinstance(message, ModelResponse):
266
+ for part in message.parts:
267
+ if isinstance(part, TextPart):
268
+ response_tokens += _string_cost(part.content)
269
+ elif isinstance(part, ToolCallPart):
270
+ call = part
271
+ if isinstance(call.args, ArgsJson):
272
+ args_str = call.args.args_json
273
+ else:
274
+ args_str = pydantic_core.to_json(call.args.args_dict).decode()
275
+ response_tokens += 1 + _string_cost(args_str)
276
+ else:
277
+ assert_never(part)
250
278
  else:
251
279
  assert_never(message)
252
280
  return result.Cost(
@@ -2,29 +2,33 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
4
  import re
5
- from collections.abc import AsyncIterator, Iterable
5
+ from collections.abc import AsyncIterator, Iterable, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from copy import deepcopy
8
8
  from dataclasses import dataclass, field
9
9
  from datetime import datetime
10
10
  from typing import Annotated, Any, Literal, Protocol, Union
11
11
 
12
+ import pydantic
12
13
  import pydantic_core
13
- from httpx import AsyncClient as AsyncHTTPClient, Response as HTTPResponse
14
- from pydantic import Discriminator, Field, Tag
14
+ from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
15
15
  from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
16
16
 
17
- from .. import UnexpectedModelBehavior, _pydantic, _utils, exceptions, result
17
+ from .. import UnexpectedModelBehavior, _utils, exceptions, result
18
18
  from ..messages import (
19
19
  ArgsDict,
20
- Message,
21
- ModelAnyResponse,
22
- ModelStructuredResponse,
23
- ModelTextResponse,
24
- RetryPrompt,
25
- ToolCall,
26
- ToolReturn,
20
+ ModelMessage,
21
+ ModelRequest,
22
+ ModelResponse,
23
+ ModelResponsePart,
24
+ RetryPromptPart,
25
+ SystemPromptPart,
26
+ TextPart,
27
+ ToolCallPart,
28
+ ToolReturnPart,
29
+ UserPromptPart,
27
30
  )
31
+ from ..settings import ModelSettings
28
32
  from ..tools import ToolDefinition
29
33
  from . import (
30
34
  AgentModel,
@@ -37,7 +41,9 @@ from . import (
37
41
  get_user_agent,
38
42
  )
39
43
 
40
- GeminiModelName = Literal['gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro']
44
+ GeminiModelName = Literal[
45
+ 'gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro', 'gemini-2.0-flash-exp'
46
+ ]
41
47
  """Named Gemini models.
42
48
 
43
49
  See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
@@ -164,26 +170,25 @@ class GeminiAgentModel(AgentModel):
164
170
  self.tool_config = tool_config
165
171
  self.url = url
166
172
 
167
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
168
- async with self._make_request(messages, False) as http_response:
173
+ async def request(
174
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
175
+ ) -> tuple[ModelResponse, result.Cost]:
176
+ async with self._make_request(messages, False, model_settings) as http_response:
169
177
  response = _gemini_response_ta.validate_json(await http_response.aread())
170
178
  return self._process_response(response), _metadata_as_cost(response)
171
179
 
172
180
  @asynccontextmanager
173
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
174
- async with self._make_request(messages, True) as http_response:
181
+ async def request_stream(
182
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
183
+ ) -> AsyncIterator[EitherStreamedResponse]:
184
+ async with self._make_request(messages, True, model_settings) as http_response:
175
185
  yield await self._process_streamed_response(http_response)
176
186
 
177
187
  @asynccontextmanager
178
- async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncIterator[HTTPResponse]:
179
- contents: list[_GeminiContent] = []
180
- sys_prompt_parts: list[_GeminiTextPart] = []
181
- for m in messages:
182
- either_content = self._message_to_gemini(m)
183
- if left := either_content.left:
184
- sys_prompt_parts.append(left.value)
185
- else:
186
- contents.append(either_content.right)
188
+ async def _make_request(
189
+ self, messages: list[ModelMessage], streamed: bool, model_settings: ModelSettings | None
190
+ ) -> AsyncIterator[HTTPResponse]:
191
+ sys_prompt_parts, contents = self._message_to_gemini_content(messages)
187
192
 
188
193
  request_data = _GeminiRequest(contents=contents)
189
194
  if sys_prompt_parts:
@@ -193,6 +198,17 @@ class GeminiAgentModel(AgentModel):
193
198
  if self.tool_config is not None:
194
199
  request_data['tool_config'] = self.tool_config
195
200
 
201
+ generation_config: _GeminiGenerationConfig = {}
202
+ if model_settings:
203
+ if (max_tokens := model_settings.get('max_tokens')) is not None:
204
+ generation_config['max_output_tokens'] = max_tokens
205
+ if (temperature := model_settings.get('temperature')) is not None:
206
+ generation_config['temperature'] = temperature
207
+ if (top_p := model_settings.get('top_p')) is not None:
208
+ generation_config['top_p'] = top_p
209
+ if generation_config:
210
+ request_data['generation_config'] = generation_config
211
+
196
212
  url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
197
213
 
198
214
  headers = {
@@ -203,19 +219,24 @@ class GeminiAgentModel(AgentModel):
203
219
 
204
220
  request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
205
221
 
206
- async with self.http_client.stream('POST', url, content=request_json, headers=headers) as r:
222
+ async with self.http_client.stream(
223
+ 'POST',
224
+ url,
225
+ content=request_json,
226
+ headers=headers,
227
+ timeout=(model_settings or {}).get('timeout', USE_CLIENT_DEFAULT),
228
+ ) as r:
207
229
  if r.status_code != 200:
208
230
  await r.aread()
209
231
  raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
210
232
  yield r
211
233
 
212
234
  @staticmethod
213
- def _process_response(response: _GeminiResponse) -> ModelAnyResponse:
214
- either = _extract_response_parts(response)
215
- if left := either.left:
216
- return _structured_response_from_parts(left.value)
217
- else:
218
- return ModelTextResponse(content=''.join(part['text'] for part in either.right))
235
+ def _process_response(response: _GeminiResponse) -> ModelResponse:
236
+ if len(response['candidates']) != 1:
237
+ raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
238
+ parts = response['candidates'][0]['content']['parts']
239
+ return _process_response_from_parts(parts)
219
240
 
220
241
  @staticmethod
221
242
  async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
@@ -239,34 +260,37 @@ class GeminiAgentModel(AgentModel):
239
260
  if start_response is None:
240
261
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
241
262
 
263
+ # TODO: Update this once we rework stream responses to be more flexible
242
264
  if _extract_response_parts(start_response).is_left():
243
265
  return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
244
266
  else:
245
267
  return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
246
268
 
247
- @staticmethod
248
- def _message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiContent]:
249
- """Convert a message to a _GeminiTextPart for "system_instructions" or _GeminiContent for "contents"."""
250
- if m.role == 'system':
251
- # SystemPrompt ->
252
- return _utils.Either(left=_GeminiTextPart(text=m.content))
253
- elif m.role == 'user':
254
- # UserPrompt ->
255
- return _utils.Either(right=_content_user_text(m.content))
256
- elif m.role == 'tool-return':
257
- # ToolReturn ->
258
- return _utils.Either(right=_content_function_return(m))
259
- elif m.role == 'retry-prompt':
260
- # RetryPrompt ->
261
- return _utils.Either(right=_content_function_retry(m))
262
- elif m.role == 'model-text-response':
263
- # ModelTextResponse ->
264
- return _utils.Either(right=_content_model_text(m.content))
265
- elif m.role == 'model-structured-response':
266
- # ModelStructuredResponse ->
267
- return _utils.Either(right=_content_function_call(m))
268
- else:
269
- assert_never(m)
269
+ @classmethod
270
+ def _message_to_gemini_content(
271
+ cls, messages: list[ModelMessage]
272
+ ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
273
+ sys_prompt_parts: list[_GeminiTextPart] = []
274
+ contents: list[_GeminiContent] = []
275
+ for m in messages:
276
+ if isinstance(m, ModelRequest):
277
+ for part in m.parts:
278
+ if isinstance(part, SystemPromptPart):
279
+ sys_prompt_parts.append(_GeminiTextPart(text=part.content))
280
+ elif isinstance(part, UserPromptPart):
281
+ contents.append(_content_user_prompt(part))
282
+ elif isinstance(part, ToolReturnPart):
283
+ contents.append(_content_tool_return(part))
284
+ elif isinstance(part, RetryPromptPart):
285
+ contents.append(_content_retry_prompt(part))
286
+ else:
287
+ assert_never(part)
288
+ elif isinstance(m, ModelResponse):
289
+ contents.append(_content_model_response(m))
290
+ else:
291
+ assert_never(m)
292
+
293
+ return sys_prompt_parts, contents
270
294
 
271
295
 
272
296
  @dataclass
@@ -327,8 +351,8 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
327
351
  chunk = await self._stream.__anext__()
328
352
  self._content.extend(chunk)
329
353
 
330
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
331
- """Get the `ModelStructuredResponse` at this point.
354
+ def get(self, *, final: bool = False) -> ModelResponse:
355
+ """Get the `ModelResponse` at this point.
332
356
 
333
357
  NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
334
358
  reply with a single response, when returning a structured data.
@@ -340,20 +364,13 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
340
364
  self._content,
341
365
  experimental_allow_partial='off' if final else 'trailing-strings',
342
366
  )
343
- combined_parts: list[_GeminiFunctionCallPart] = []
367
+ combined_parts: list[_GeminiPartUnion] = []
344
368
  self._cost = result.Cost()
345
369
  for r in responses:
346
370
  self._cost += _metadata_as_cost(r)
347
371
  candidate = r['candidates'][0]
348
- parts = candidate['content']['parts']
349
- if _all_function_call_parts(parts):
350
- combined_parts.extend(parts)
351
- elif not candidate.get('finish_reason'):
352
- # you can get an empty text part along with the finish_reason, so we ignore that case
353
- raise UnexpectedModelBehavior(
354
- 'Streamed response with unexpected content, expected all parts to be function calls'
355
- )
356
- return _structured_response_from_parts(combined_parts, timestamp=self._timestamp)
372
+ combined_parts.extend(candidate['content']['parts'])
373
+ return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
357
374
 
358
375
  def cost(self) -> result.Cost:
359
376
  return self._cost
@@ -367,6 +384,7 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
367
384
  # TypeAdapters take care of validation and serialization
368
385
 
369
386
 
387
+ @pydantic.with_config(pydantic.ConfigDict(defer_build=True))
370
388
  class _GeminiRequest(TypedDict):
371
389
  """Schema for an API request to the Gemini API.
372
390
 
@@ -382,32 +400,37 @@ class _GeminiRequest(TypedDict):
382
400
  Developer generated system instructions, see
383
401
  <https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
384
402
  """
403
+ generation_config: NotRequired[_GeminiGenerationConfig]
385
404
 
386
405
 
387
- class _GeminiContent(TypedDict):
388
- role: Literal['user', 'model']
389
- parts: list[_GeminiPartUnion]
406
+ class _GeminiGenerationConfig(TypedDict, total=False):
407
+ """Schema for an API request to the Gemini API.
390
408
 
409
+ Note there are many additional fields available that have not been added yet.
391
410
 
392
- def _content_user_text(text: str) -> _GeminiContent:
393
- return _GeminiContent(role='user', parts=[_GeminiTextPart(text=text)])
411
+ See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
412
+ """
413
+
414
+ max_output_tokens: int
415
+ temperature: float
416
+ top_p: float
394
417
 
395
418
 
396
- def _content_model_text(text: str) -> _GeminiContent:
397
- return _GeminiContent(role='model', parts=[_GeminiTextPart(text=text)])
419
+ class _GeminiContent(TypedDict):
420
+ role: Literal['user', 'model']
421
+ parts: list[_GeminiPartUnion]
398
422
 
399
423
 
400
- def _content_function_call(m: ModelStructuredResponse) -> _GeminiContent:
401
- parts: list[_GeminiPartUnion] = [_function_call_part_from_call(t) for t in m.calls]
402
- return _GeminiContent(role='model', parts=parts)
424
+ def _content_user_prompt(m: UserPromptPart) -> _GeminiContent:
425
+ return _GeminiContent(role='user', parts=[_GeminiTextPart(text=m.content)])
403
426
 
404
427
 
405
- def _content_function_return(m: ToolReturn) -> _GeminiContent:
428
+ def _content_tool_return(m: ToolReturnPart) -> _GeminiContent:
406
429
  f_response = _response_part_from_response(m.tool_name, m.model_response_object())
407
430
  return _GeminiContent(role='user', parts=[f_response])
408
431
 
409
432
 
410
- def _content_function_retry(m: RetryPrompt) -> _GeminiContent:
433
+ def _content_retry_prompt(m: RetryPromptPart) -> _GeminiContent:
411
434
  if m.tool_name is None:
412
435
  part = _GeminiTextPart(text=m.model_response())
413
436
  else:
@@ -416,26 +439,43 @@ def _content_function_retry(m: RetryPrompt) -> _GeminiContent:
416
439
  return _GeminiContent(role='user', parts=[part])
417
440
 
418
441
 
442
+ def _content_model_response(m: ModelResponse) -> _GeminiContent:
443
+ parts: list[_GeminiPartUnion] = []
444
+ for item in m.parts:
445
+ if isinstance(item, ToolCallPart):
446
+ parts.append(_function_call_part_from_call(item))
447
+ elif isinstance(item, TextPart):
448
+ parts.append(_GeminiTextPart(text=item.content))
449
+ else:
450
+ assert_never(item)
451
+ return _GeminiContent(role='model', parts=parts)
452
+
453
+
419
454
  class _GeminiTextPart(TypedDict):
420
455
  text: str
421
456
 
422
457
 
423
458
  class _GeminiFunctionCallPart(TypedDict):
424
- function_call: Annotated[_GeminiFunctionCall, Field(alias='functionCall')]
459
+ function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
425
460
 
426
461
 
427
- def _function_call_part_from_call(tool: ToolCall) -> _GeminiFunctionCallPart:
462
+ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
428
463
  assert isinstance(tool.args, ArgsDict), f'Expected ArgsObject, got {tool.args}'
429
464
  return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_dict))
430
465
 
431
466
 
432
- def _structured_response_from_parts(
433
- parts: list[_GeminiFunctionCallPart], timestamp: datetime | None = None
434
- ) -> ModelStructuredResponse:
435
- return ModelStructuredResponse(
436
- calls=[ToolCall.from_dict(part['function_call']['name'], part['function_call']['args']) for part in parts],
437
- timestamp=timestamp or _utils.now_utc(),
438
- )
467
+ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
468
+ items: list[ModelResponsePart] = []
469
+ for part in parts:
470
+ if 'text' in part:
471
+ items.append(TextPart(part['text']))
472
+ elif 'function_call' in part:
473
+ items.append(ToolCallPart.from_dict(part['function_call']['name'], part['function_call']['args']))
474
+ elif 'function_response' in part:
475
+ raise exceptions.UnexpectedModelBehavior(
476
+ f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
477
+ )
478
+ return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
439
479
 
440
480
 
441
481
  class _GeminiFunctionCall(TypedDict):
@@ -446,7 +486,7 @@ class _GeminiFunctionCall(TypedDict):
446
486
 
447
487
 
448
488
  class _GeminiFunctionResponsePart(TypedDict):
449
- function_response: Annotated[_GeminiFunctionResponse, Field(alias='functionResponse')]
489
+ function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')]
450
490
 
451
491
 
452
492
  def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
@@ -476,11 +516,11 @@ def _part_discriminator(v: Any) -> str:
476
516
  # TODO discriminator
477
517
  _GeminiPartUnion = Annotated[
478
518
  Union[
479
- Annotated[_GeminiTextPart, Tag('text')],
480
- Annotated[_GeminiFunctionCallPart, Tag('function_call')],
481
- Annotated[_GeminiFunctionResponsePart, Tag('function_response')],
519
+ Annotated[_GeminiTextPart, pydantic.Tag('text')],
520
+ Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
521
+ Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
482
522
  ],
483
- Discriminator(_part_discriminator),
523
+ pydantic.Discriminator(_part_discriminator),
484
524
  ]
485
525
 
486
526
 
@@ -490,7 +530,7 @@ class _GeminiTextContent(TypedDict):
490
530
 
491
531
 
492
532
  class _GeminiTools(TypedDict):
493
- function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]]
533
+ function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]]
494
534
 
495
535
 
496
536
  class _GeminiFunction(TypedDict):
@@ -531,6 +571,7 @@ class _GeminiFunctionCallingConfig(TypedDict):
531
571
  allowed_function_names: list[str]
532
572
 
533
573
 
574
+ @pydantic.with_config(pydantic.ConfigDict(defer_build=True))
534
575
  class _GeminiResponse(TypedDict):
535
576
  """Schema for the response from the Gemini API.
536
577
 
@@ -540,10 +581,11 @@ class _GeminiResponse(TypedDict):
540
581
 
541
582
  candidates: list[_GeminiCandidates]
542
583
  # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
543
- usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]]
544
- prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, Field(alias='promptFeedback')]]
584
+ usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
585
+ prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
545
586
 
546
587
 
588
+ # TODO: Delete the next three functions once we've reworked streams to be more flexible
547
589
  def _extract_response_parts(
548
590
  response: _GeminiResponse,
549
591
  ) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
@@ -576,14 +618,14 @@ class _GeminiCandidates(TypedDict):
576
618
  """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
577
619
 
578
620
  content: _GeminiContent
579
- finish_reason: NotRequired[Annotated[Literal['STOP'], Field(alias='finishReason')]]
621
+ finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
580
622
  """
581
623
  See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
582
624
  but let's wait until we see them and know what they mean to add them here.
583
625
  """
584
- avg_log_probs: NotRequired[Annotated[float, Field(alias='avgLogProbs')]]
626
+ avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]]
585
627
  index: NotRequired[int]
586
- safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]]
628
+ safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
587
629
 
588
630
 
589
631
  class _GeminiUsageMetaData(TypedDict, total=False):
@@ -592,10 +634,10 @@ class _GeminiUsageMetaData(TypedDict, total=False):
592
634
  The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
593
635
  """
594
636
 
595
- prompt_token_count: Annotated[int, Field(alias='promptTokenCount')]
596
- candidates_token_count: NotRequired[Annotated[int, Field(alias='candidatesTokenCount')]]
597
- total_token_count: Annotated[int, Field(alias='totalTokenCount')]
598
- cached_content_token_count: NotRequired[Annotated[int, Field(alias='cachedContentTokenCount')]]
637
+ prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')]
638
+ candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
639
+ total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
640
+ cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
599
641
 
600
642
 
601
643
  def _metadata_as_cost(response: _GeminiResponse) -> result.Cost:
@@ -629,15 +671,15 @@ class _GeminiSafetyRating(TypedDict):
629
671
  class _GeminiPromptFeedback(TypedDict):
630
672
  """See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
631
673
 
632
- block_reason: Annotated[str, Field(alias='blockReason')]
633
- safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]
674
+ block_reason: Annotated[str, pydantic.Field(alias='blockReason')]
675
+ safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]
634
676
 
635
677
 
636
- _gemini_request_ta = _pydantic.LazyTypeAdapter(_GeminiRequest)
637
- _gemini_response_ta = _pydantic.LazyTypeAdapter(_GeminiResponse)
678
+ _gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest)
679
+ _gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
638
680
 
639
681
  # steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
640
- _gemini_streamed_response_ta = _pydantic.LazyTypeAdapter(list[_GeminiResponse])
682
+ _gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
641
683
 
642
684
 
643
685
  class _GeminiJsonSchema: