pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -7,16 +7,17 @@ from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field, replace
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
- from typing import Callable, Union, cast
10
+ from typing import Callable, Union
11
11
 
12
12
  from typing_extensions import TypeAlias, assert_never, overload
13
13
 
14
- from .. import _utils, result
14
+ from .. import _utils, usage
15
+ from .._utils import PeekableAsyncStream
15
16
  from ..messages import (
16
17
  ModelMessage,
17
18
  ModelRequest,
18
19
  ModelResponse,
19
- ModelResponsePart,
20
+ ModelResponseStreamEvent,
20
21
  RetryPromptPart,
21
22
  SystemPromptPart,
22
23
  TextPart,
@@ -26,7 +27,7 @@ from ..messages import (
26
27
  )
27
28
  from ..settings import ModelSettings
28
29
  from ..tools import ToolDefinition
29
- from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
30
+ from . import AgentModel, Model, StreamedResponse
30
31
 
31
32
 
32
33
  @dataclass(init=False)
@@ -70,16 +71,15 @@ class FunctionModel(Model):
70
71
  result_tools: list[ToolDefinition],
71
72
  ) -> AgentModel:
72
73
  return FunctionAgentModel(
73
- self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
74
+ self.function,
75
+ self.stream_function,
76
+ AgentInfo(function_tools, allow_text_result, result_tools, None),
74
77
  )
75
78
 
76
79
  def name(self) -> str:
77
- labels: list[str] = []
78
- if self.function is not None:
79
- labels.append(self.function.__name__)
80
- if self.stream_function is not None:
81
- labels.append(f'stream-{self.stream_function.__name__}')
82
- return f'function:{",".join(labels)}'
80
+ function_name = self.function.__name__ if self.function is not None else ''
81
+ stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
82
+ return f'function:{function_name}:{stream_function_name}'
83
83
 
84
84
 
85
85
  @dataclass(frozen=True)
@@ -142,106 +142,76 @@ class FunctionAgentModel(AgentModel):
142
142
 
143
143
  async def request(
144
144
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
145
- ) -> tuple[ModelResponse, result.Usage]:
145
+ ) -> tuple[ModelResponse, usage.Usage]:
146
146
  agent_info = replace(self.agent_info, model_settings=model_settings)
147
147
 
148
148
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
149
+ model_name = f'function:{self.function.__name__}'
150
+
149
151
  if inspect.iscoroutinefunction(self.function):
150
152
  response = await self.function(messages, agent_info)
151
153
  else:
152
154
  response_ = await _utils.run_in_executor(self.function, messages, agent_info)
153
155
  assert isinstance(response_, ModelResponse), response_
154
156
  response = response_
157
+ response.model_name = model_name
155
158
  # TODO is `messages` right here? Should it just be new messages?
156
159
  return response, _estimate_usage(chain(messages, [response]))
157
160
 
158
161
  @asynccontextmanager
159
162
  async def request_stream(
160
163
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
- ) -> AsyncIterator[EitherStreamedResponse]:
164
+ ) -> AsyncIterator[StreamedResponse]:
162
165
  assert (
163
166
  self.stream_function is not None
164
167
  ), 'FunctionModel must receive a `stream_function` to support streamed requests'
165
- response_stream = self.stream_function(messages, self.agent_info)
166
- try:
167
- first = await response_stream.__anext__()
168
- except StopAsyncIteration as e:
169
- raise ValueError('Stream function must return at least one item') from e
170
-
171
- if isinstance(first, str):
172
- text_stream = cast(AsyncIterator[str], response_stream)
173
- yield FunctionStreamTextResponse(first, text_stream)
174
- else:
175
- structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream)
176
- yield FunctionStreamStructuredResponse(first, structured_stream)
177
-
178
-
179
- @dataclass
180
- class FunctionStreamTextResponse(StreamTextResponse):
181
- """Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
182
-
183
- _next: str | None
184
- _iter: AsyncIterator[str]
185
- _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
186
- _buffer: list[str] = field(default_factory=list, init=False)
187
-
188
- async def __anext__(self) -> None:
189
- if self._next is not None:
190
- self._buffer.append(self._next)
191
- self._next = None
192
- else:
193
- self._buffer.append(await self._iter.__anext__())
168
+ model_name = f'function:{self.stream_function.__name__}'
194
169
 
195
- def get(self, *, final: bool = False) -> Iterable[str]:
196
- yield from self._buffer
197
- self._buffer.clear()
170
+ response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
198
171
 
199
- def usage(self) -> result.Usage:
200
- return result.Usage()
172
+ first = await response_stream.peek()
173
+ if isinstance(first, _utils.Unset):
174
+ raise ValueError('Stream function must return at least one item')
201
175
 
202
- def timestamp(self) -> datetime:
203
- return self._timestamp
176
+ yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
204
177
 
205
178
 
206
179
  @dataclass
207
- class FunctionStreamStructuredResponse(StreamStructuredResponse):
208
- """Implementation of `StreamStructuredResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
180
+ class FunctionStreamedResponse(StreamedResponse):
181
+ """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
209
182
 
210
- _next: DeltaToolCalls | None
211
- _iter: AsyncIterator[DeltaToolCalls]
212
- _delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict)
183
+ _iter: AsyncIterator[str | DeltaToolCalls]
213
184
  _timestamp: datetime = field(default_factory=_utils.now_utc)
214
185
 
215
- async def __anext__(self) -> None:
216
- if self._next is not None:
217
- tool_call = self._next
218
- self._next = None
219
- else:
220
- tool_call = await self._iter.__anext__()
186
+ def __post_init__(self):
187
+ self._usage += _estimate_usage([])
221
188
 
222
- for key, new in tool_call.items():
223
- if current := self._delta_tool_calls.get(key):
224
- current.name = _utils.add_optional(current.name, new.name)
225
- current.json_args = _utils.add_optional(current.json_args, new.json_args)
189
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
190
+ async for item in self._iter:
191
+ if isinstance(item, str):
192
+ response_tokens = _estimate_string_tokens(item)
193
+ self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
194
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
226
195
  else:
227
- self._delta_tool_calls[key] = new
228
-
229
- def get(self, *, final: bool = False) -> ModelResponse:
230
- calls: list[ModelResponsePart] = []
231
- for c in self._delta_tool_calls.values():
232
- if c.name is not None and c.json_args is not None:
233
- calls.append(ToolCallPart.from_raw_args(c.name, c.json_args))
234
-
235
- return ModelResponse(calls, timestamp=self._timestamp)
236
-
237
- def usage(self) -> result.Usage:
238
- return _estimate_usage([self.get()])
196
+ delta_tool_calls = item
197
+ for dtc_index, delta_tool_call in delta_tool_calls.items():
198
+ if delta_tool_call.json_args:
199
+ response_tokens = _estimate_string_tokens(delta_tool_call.json_args)
200
+ self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
201
+ maybe_event = self._parts_manager.handle_tool_call_delta(
202
+ vendor_part_id=dtc_index,
203
+ tool_name=delta_tool_call.name,
204
+ args=delta_tool_call.json_args,
205
+ tool_call_id=None,
206
+ )
207
+ if maybe_event is not None:
208
+ yield maybe_event
239
209
 
240
210
  def timestamp(self) -> datetime:
241
211
  return self._timestamp
242
212
 
243
213
 
244
- def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
214
+ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
245
215
  """Very rough guesstimate of the token usage associated with a series of messages.
246
216
 
247
217
  This is designed to be used solely to give plausible numbers for testing!
@@ -253,28 +223,30 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
253
223
  if isinstance(message, ModelRequest):
254
224
  for part in message.parts:
255
225
  if isinstance(part, (SystemPromptPart, UserPromptPart)):
256
- request_tokens += _estimate_string_usage(part.content)
226
+ request_tokens += _estimate_string_tokens(part.content)
257
227
  elif isinstance(part, ToolReturnPart):
258
- request_tokens += _estimate_string_usage(part.model_response_str())
228
+ request_tokens += _estimate_string_tokens(part.model_response_str())
259
229
  elif isinstance(part, RetryPromptPart):
260
- request_tokens += _estimate_string_usage(part.model_response())
230
+ request_tokens += _estimate_string_tokens(part.model_response())
261
231
  else:
262
232
  assert_never(part)
263
233
  elif isinstance(message, ModelResponse):
264
234
  for part in message.parts:
265
235
  if isinstance(part, TextPart):
266
- response_tokens += _estimate_string_usage(part.content)
236
+ response_tokens += _estimate_string_tokens(part.content)
267
237
  elif isinstance(part, ToolCallPart):
268
238
  call = part
269
- response_tokens += 1 + _estimate_string_usage(call.args_as_json_str())
239
+ response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
270
240
  else:
271
241
  assert_never(part)
272
242
  else:
273
243
  assert_never(message)
274
- return result.Usage(
244
+ return usage.Usage(
275
245
  request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
276
246
  )
277
247
 
278
248
 
279
- def _estimate_string_usage(content: str) -> int:
280
- return len(re.split(r'[\s",.:]+', content))
249
+ def _estimate_string_tokens(content: str) -> int:
250
+ if not content:
251
+ return 0
252
+ return len(re.split(r'[\s",.:]+', content.strip()))
@@ -2,24 +2,25 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
4
  import re
5
- from collections.abc import AsyncIterator, Iterable, Sequence
5
+ from collections.abc import AsyncIterator, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from copy import deepcopy
8
8
  from dataclasses import dataclass, field
9
9
  from datetime import datetime
10
10
  from typing import Annotated, Any, Literal, Protocol, Union
11
+ from uuid import uuid4
11
12
 
12
13
  import pydantic
13
- import pydantic_core
14
14
  from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
15
- from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
15
+ from typing_extensions import NotRequired, TypedDict, assert_never
16
16
 
17
- from .. import UnexpectedModelBehavior, _utils, exceptions, result
17
+ from .. import UnexpectedModelBehavior, _utils, exceptions, usage
18
18
  from ..messages import (
19
19
  ModelMessage,
20
20
  ModelRequest,
21
21
  ModelResponse,
22
22
  ModelResponsePart,
23
+ ModelResponseStreamEvent,
23
24
  RetryPromptPart,
24
25
  SystemPromptPart,
25
26
  TextPart,
@@ -31,10 +32,8 @@ from ..settings import ModelSettings
31
32
  from ..tools import ToolDefinition
32
33
  from . import (
33
34
  AgentModel,
34
- EitherStreamedResponse,
35
35
  Model,
36
- StreamStructuredResponse,
37
- StreamTextResponse,
36
+ StreamedResponse,
38
37
  cached_async_http_client,
39
38
  check_allow_model_requests,
40
39
  get_user_agent,
@@ -100,6 +99,7 @@ class GeminiModel(Model):
100
99
  allow_text_result: bool,
101
100
  result_tools: list[ToolDefinition],
102
101
  ) -> GeminiAgentModel:
102
+ check_allow_model_requests()
103
103
  return GeminiAgentModel(
104
104
  http_client=self.http_client,
105
105
  model_name=self.model_name,
@@ -152,7 +152,6 @@ class GeminiAgentModel(AgentModel):
152
152
  allow_text_result: bool,
153
153
  result_tools: list[ToolDefinition],
154
154
  ):
155
- check_allow_model_requests()
156
155
  tools = [_function_from_abstract_tool(t) for t in function_tools]
157
156
  if result_tools:
158
157
  tools += [_function_from_abstract_tool(t) for t in result_tools]
@@ -171,7 +170,7 @@ class GeminiAgentModel(AgentModel):
171
170
 
172
171
  async def request(
173
172
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
174
- ) -> tuple[ModelResponse, result.Usage]:
173
+ ) -> tuple[ModelResponse, usage.Usage]:
175
174
  async with self._make_request(messages, False, model_settings) as http_response:
176
175
  response = _gemini_response_ta.validate_json(await http_response.aread())
177
176
  return self._process_response(response), _metadata_as_usage(response)
@@ -179,7 +178,7 @@ class GeminiAgentModel(AgentModel):
179
178
  @asynccontextmanager
180
179
  async def request_stream(
181
180
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
182
- ) -> AsyncIterator[EitherStreamedResponse]:
181
+ ) -> AsyncIterator[StreamedResponse]:
183
182
  async with self._make_request(messages, True, model_settings) as http_response:
184
183
  yield await self._process_streamed_response(http_response)
185
184
 
@@ -230,15 +229,13 @@ class GeminiAgentModel(AgentModel):
230
229
  raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
231
230
  yield r
232
231
 
233
- @staticmethod
234
- def _process_response(response: _GeminiResponse) -> ModelResponse:
232
+ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
235
233
  if len(response['candidates']) != 1:
236
234
  raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
237
235
  parts = response['candidates'][0]['content']['parts']
238
- return _process_response_from_parts(parts)
236
+ return _process_response_from_parts(parts, model_name=self.model_name)
239
237
 
240
- @staticmethod
241
- async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
238
+ async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
242
239
  """Process a streamed response, and prepare a streaming response to return."""
243
240
  aiter_bytes = http_response.aiter_bytes()
244
241
  start_response: _GeminiResponse | None = None
@@ -259,11 +256,7 @@ class GeminiAgentModel(AgentModel):
259
256
  if start_response is None:
260
257
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
261
258
 
262
- # TODO: Update this once we rework stream responses to be more flexible
263
- if _extract_response_parts(start_response).is_left():
264
- return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
265
- else:
266
- return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
259
+ return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
267
260
 
268
261
  @classmethod
269
262
  def _message_to_gemini_content(
@@ -302,86 +295,69 @@ class GeminiAgentModel(AgentModel):
302
295
 
303
296
 
304
297
  @dataclass
305
- class GeminiStreamTextResponse(StreamTextResponse):
306
- """Implementation of `StreamTextResponse` for the Gemini model."""
307
-
308
- _json_content: bytearray
309
- _stream: AsyncIterator[bytes]
310
- _position: int = 0
311
- _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
312
- _usage: result.Usage = field(default_factory=result.Usage, init=False)
313
-
314
- async def __anext__(self) -> None:
315
- chunk = await self._stream.__anext__()
316
- self._json_content.extend(chunk)
317
-
318
- def get(self, *, final: bool = False) -> Iterable[str]:
319
- if final:
320
- all_items = pydantic_core.from_json(self._json_content)
321
- new_items = all_items[self._position :]
322
- self._position = len(all_items)
323
- new_responses = _gemini_streamed_response_ta.validate_python(new_items)
324
- else:
325
- all_items = pydantic_core.from_json(self._json_content, allow_partial=True)
326
- new_items = all_items[self._position : -1]
327
- self._position = len(all_items) - 1
328
- new_responses = _gemini_streamed_response_ta.validate_python(
329
- new_items, experimental_allow_partial='trailing-strings'
330
- )
331
- for r in new_responses:
332
- self._usage += _metadata_as_usage(r)
333
- parts = r['candidates'][0]['content']['parts']
334
- if _all_text_parts(parts):
335
- for part in parts:
336
- yield part['text']
337
- else:
338
- raise UnexpectedModelBehavior(
339
- 'Streamed response with unexpected content, expected all parts to be text'
340
- )
341
-
342
- def usage(self) -> result.Usage:
343
- return self._usage
344
-
345
- def timestamp(self) -> datetime:
346
- return self._timestamp
347
-
348
-
349
- @dataclass
350
- class GeminiStreamStructuredResponse(StreamStructuredResponse):
351
- """Implementation of `StreamStructuredResponse` for the Gemini model."""
298
+ class GeminiStreamedResponse(StreamedResponse):
299
+ """Implementation of `StreamedResponse` for the Gemini model."""
352
300
 
353
301
  _content: bytearray
354
302
  _stream: AsyncIterator[bytes]
355
303
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
356
- _usage: result.Usage = field(default_factory=result.Usage, init=False)
357
-
358
- async def __anext__(self) -> None:
359
- chunk = await self._stream.__anext__()
360
- self._content.extend(chunk)
361
304
 
362
- def get(self, *, final: bool = False) -> ModelResponse:
363
- """Get the `ModelResponse` at this point.
364
-
365
- NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
366
- reply with a single response, when returning a structured data.
305
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
306
+ async for gemini_response in self._get_gemini_responses():
307
+ candidate = gemini_response['candidates'][0]
308
+ gemini_part: _GeminiPartUnion
309
+ for gemini_part in candidate['content']['parts']:
310
+ if 'text' in gemini_part:
311
+ # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
312
+ # amongst the tool call deltas
313
+ yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
314
+
315
+ elif 'function_call' in gemini_part:
316
+ # Here, we assume all function_call parts are complete and don't have deltas.
317
+ # We do this by assigning a unique randomly generated "vendor_part_id".
318
+ # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly
319
+ # it would just be a bit more complicated. And we'd need to confirm the intended semantics.
320
+ maybe_event = self._parts_manager.handle_tool_call_delta(
321
+ vendor_part_id=uuid4(),
322
+ tool_name=gemini_part['function_call']['name'],
323
+ args=gemini_part['function_call']['args'],
324
+ tool_call_id=None,
325
+ )
326
+ if maybe_event is not None:
327
+ yield maybe_event
328
+ else:
329
+ assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}'
330
+
331
+ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
332
+ # This method exists to ensure we only yield completed items, so we don't need to worry about
333
+ # partial gemini responses, which would make everything more complicated
334
+
335
+ gemini_responses: list[_GeminiResponse] = []
336
+ current_gemini_response_index = 0
337
+ # Right now, there are some circumstances where we will have information that could be yielded sooner than it is
338
+ # But changing that would make things a lot more complicated.
339
+ async for chunk in self._stream:
340
+ self._content.extend(chunk)
341
+
342
+ gemini_responses = _gemini_streamed_response_ta.validate_json(
343
+ self._content,
344
+ experimental_allow_partial='trailing-strings',
345
+ )
367
346
 
368
- I'm therefore assuming that each part contains a complete tool call, and not trying to combine data from
369
- separate parts.
370
- """
371
- responses = _gemini_streamed_response_ta.validate_json(
372
- self._content,
373
- experimental_allow_partial='off' if final else 'trailing-strings',
374
- )
375
- combined_parts: list[_GeminiPartUnion] = []
376
- self._usage = result.Usage()
377
- for r in responses:
347
+ # The idea: yield only up to the latest response, which might still be partial.
348
+ # Note that if the latest response is complete, we could yield it immediately, but there's not a good
349
+ # allow_partial API to determine if the last item in the list is complete.
350
+ responses_to_yield = gemini_responses[:-1]
351
+ for r in responses_to_yield[current_gemini_response_index:]:
352
+ current_gemini_response_index += 1
353
+ self._usage += _metadata_as_usage(r)
354
+ yield r
355
+
356
+ # Now yield the final response, which should be complete
357
+ if gemini_responses:
358
+ r = gemini_responses[-1]
378
359
  self._usage += _metadata_as_usage(r)
379
- candidate = r['candidates'][0]
380
- combined_parts.extend(candidate['content']['parts'])
381
- return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
382
-
383
- def usage(self) -> result.Usage:
384
- return self._usage
360
+ yield r
385
361
 
386
362
  def timestamp(self) -> datetime:
387
363
  return self._timestamp
@@ -454,18 +430,25 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
454
430
  return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
455
431
 
456
432
 
457
- def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
433
+ def _process_response_from_parts(
434
+ parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
435
+ ) -> ModelResponse:
458
436
  items: list[ModelResponsePart] = []
459
437
  for part in parts:
460
438
  if 'text' in part:
461
- items.append(TextPart(part['text']))
439
+ items.append(TextPart(content=part['text']))
462
440
  elif 'function_call' in part:
463
- items.append(ToolCallPart.from_raw_args(part['function_call']['name'], part['function_call']['args']))
441
+ items.append(
442
+ ToolCallPart.from_raw_args(
443
+ tool_name=part['function_call']['name'],
444
+ args=part['function_call']['args'],
445
+ )
446
+ )
464
447
  elif 'function_response' in part:
465
448
  raise exceptions.UnexpectedModelBehavior(
466
449
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
467
450
  )
468
- return ModelResponse(items, timestamp=timestamp or _utils.now_utc())
451
+ return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
469
452
 
470
453
 
471
454
  class _GeminiFunctionCall(TypedDict):
@@ -575,35 +558,6 @@ class _GeminiResponse(TypedDict):
575
558
  prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
576
559
 
577
560
 
578
- # TODO: Delete the next three functions once we've reworked streams to be more flexible
579
- def _extract_response_parts(
580
- response: _GeminiResponse,
581
- ) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
582
- """Extract the parts of the response from the Gemini API.
583
-
584
- Returns Either a list of function calls (Either.left) or a list of text parts (Either.right).
585
- """
586
- if len(response['candidates']) != 1:
587
- raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
588
- parts = response['candidates'][0]['content']['parts']
589
- if _all_function_call_parts(parts):
590
- return _utils.Either(left=parts)
591
- elif _all_text_parts(parts):
592
- return _utils.Either(right=parts)
593
- else:
594
- raise exceptions.UnexpectedModelBehavior(
595
- f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {parts!r}'
596
- )
597
-
598
-
599
- def _all_function_call_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiFunctionCallPart]]:
600
- return all('function_call' in part for part in parts)
601
-
602
-
603
- def _all_text_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiTextPart]]:
604
- return all('text' in part for part in parts)
605
-
606
-
607
561
  class _GeminiCandidates(TypedDict):
608
562
  """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
609
563
 
@@ -630,14 +584,14 @@ class _GeminiUsageMetaData(TypedDict, total=False):
630
584
  cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
631
585
 
632
586
 
633
- def _metadata_as_usage(response: _GeminiResponse) -> result.Usage:
587
+ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
634
588
  metadata = response.get('usage_metadata')
635
589
  if metadata is None:
636
- return result.Usage()
590
+ return usage.Usage()
637
591
  details: dict[str, int] = {}
638
592
  if cached_content_token_count := metadata.get('cached_content_token_count'):
639
593
  details['cached_content_token_count'] = cached_content_token_count
640
- return result.Usage(
594
+ return usage.Usage(
641
595
  request_tokens=metadata.get('prompt_token_count', 0),
642
596
  response_tokens=metadata.get('candidates_token_count', 0),
643
597
  total_tokens=metadata.get('total_token_count', 0),